Package ai.djl.pytorch.engine
Class PtSymbolBlock
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractSymbolBlock
ai.djl.pytorch.engine.PtSymbolBlock
- All Implemented Interfaces:
ai.djl.nn.Block
,ai.djl.nn.SymbolBlock
,AutoCloseable
PtSymbolBlock
is the PyTorch implementation of SymbolBlock
.
You can create a PtSymbolBlock
using Model.load(java.nio.file.Path, String)
.
-
Field Summary
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Constructor Summary
ConstructorsConstructorDescriptionPtSymbolBlock
(PtNDManager manager) Constructs an EmptyPtSymbolBlock
.PtSymbolBlock
(PtNDManager manager, long handle) Constructs aPtSymbolBlock
. -
Method Summary
Modifier and TypeMethodDescriptionvoid
close()
ai.djl.util.PairList<String,
ai.djl.ndarray.types.Shape> ai.djl.util.PairList<String,
ai.djl.ndarray.types.Shape> Runs the forward of this PyTorch module.protected ai.djl.ndarray.NDList
forwardInternal
(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) ai.djl.nn.ParameterList
Get the native PyTorch model pointer.ai.djl.ndarray.types.Shape[]
getOutputShapes
(ai.djl.ndarray.types.Shape[] inputShapes) ai.djl.ndarray.types.Shape[]
getOutputShapes
(ai.djl.ndarray.types.Shape[] inputShapes, ai.djl.ndarray.types.DataType[] dataTypes) void
loadParameters
(ai.djl.ndarray.NDManager manager, DataInputStream is) void
Methods inherited from class ai.djl.nn.AbstractSymbolBlock
getChildren
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, prepare, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
cast, clear, forward, forward, forward, freezeParameters, freezeParameters, getCustomMetadata, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, setInitializer, setInitializer, setInitializer
Methods inherited from interface ai.djl.nn.SymbolBlock
removeLastBlock
-
Constructor Details
-
PtSymbolBlock
Constructs aPtSymbolBlock
.You can create a
PtSymbolBlock
usingModel.load(java.nio.file.Path, String)
.- Parameters:
manager
- the manager to use for the blockhandle
- the module handle
-
PtSymbolBlock
Constructs an EmptyPtSymbolBlock
.- Parameters:
manager
- the manager to use for the block
-
-
Method Details
-
close
public void close()- Specified by:
close
in interfaceAutoCloseable
-
forward
Runs the forward of this PyTorch module. -
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) - Specified by:
forwardInternal
in classai.djl.nn.AbstractBaseBlock
-
describeInput
- Specified by:
describeInput
in interfaceai.djl.nn.Block
- Overrides:
describeInput
in classai.djl.nn.AbstractBaseBlock
-
getDirectParameters
public ai.djl.nn.ParameterList getDirectParameters()- Specified by:
getDirectParameters
in interfaceai.djl.nn.Block
-
describeOutput
- Specified by:
describeOutput
in interfaceai.djl.nn.SymbolBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes) - Specified by:
getOutputShapes
in interfaceai.djl.nn.Block
- Overrides:
getOutputShapes
in classai.djl.nn.AbstractSymbolBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes, ai.djl.ndarray.types.DataType[] dataTypes) - Specified by:
getOutputShapes
in interfaceai.djl.nn.Block
-
saveParameters
- Specified by:
saveParameters
in interfaceai.djl.nn.Block
- Overrides:
saveParameters
in classai.djl.nn.AbstractBaseBlock
- Throws:
IOException
-
loadParameters
public void loadParameters(ai.djl.ndarray.NDManager manager, DataInputStream is) throws IOException, ai.djl.MalformedModelException - Specified by:
loadParameters
in interfaceai.djl.nn.Block
- Overrides:
loadParameters
in classai.djl.nn.AbstractBaseBlock
- Throws:
IOException
ai.djl.MalformedModelException
-
getHandle
Get the native PyTorch model pointer.- Returns:
- the pointer
-