Package ai.djl.pytorch.engine
Class PtSymbolBlock
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractSymbolBlock
ai.djl.pytorch.engine.PtSymbolBlock
- All Implemented Interfaces:
ai.djl.nn.Block,ai.djl.nn.SymbolBlock,AutoCloseable
PtSymbolBlock is the PyTorch implementation of SymbolBlock.
You can create a PtSymbolBlock using Model.load(java.nio.file.Path, String).
-
Field Summary
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version -
Constructor Summary
ConstructorsConstructorDescriptionPtSymbolBlock(PtNDManager manager) Constructs an EmptyPtSymbolBlock.PtSymbolBlock(PtNDManager manager, long handle) Constructs aPtSymbolBlock. -
Method Summary
Modifier and TypeMethodDescriptionvoidclose()ai.djl.util.PairList<String,ai.djl.ndarray.types.Shape> ai.djl.util.PairList<String,ai.djl.ndarray.types.Shape> Runs the forward of this PyTorch module.protected ai.djl.ndarray.NDListforwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) ai.djl.nn.ParameterListGet the native PyTorch model pointer.ai.djl.ndarray.types.Shape[]getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes) ai.djl.ndarray.types.Shape[]getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes, ai.djl.ndarray.types.DataType[] dataTypes) voidloadParameters(ai.djl.ndarray.NDManager manager, DataInputStream is) voidMethods inherited from class ai.djl.nn.AbstractSymbolBlock
getChildrenMethods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, prepare, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toStringMethods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitMethods inherited from interface ai.djl.nn.Block
cast, clear, forward, forward, forward, freezeParameters, freezeParameters, getCustomMetadata, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, setInitializer, setInitializer, setInitializerMethods inherited from interface ai.djl.nn.SymbolBlock
removeLastBlock
-
Constructor Details
-
PtSymbolBlock
Constructs aPtSymbolBlock.You can create a
PtSymbolBlockusingModel.load(java.nio.file.Path, String).- Parameters:
manager- the manager to use for the blockhandle- the module handle
-
PtSymbolBlock
Constructs an EmptyPtSymbolBlock.- Parameters:
manager- the manager to use for the block
-
-
Method Details
-
close
public void close()- Specified by:
closein interfaceAutoCloseable
-
forward
Runs the forward of this PyTorch module. -
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) - Specified by:
forwardInternalin classai.djl.nn.AbstractBaseBlock
-
describeInput
- Specified by:
describeInputin interfaceai.djl.nn.Block- Overrides:
describeInputin classai.djl.nn.AbstractBaseBlock
-
getDirectParameters
public ai.djl.nn.ParameterList getDirectParameters()- Specified by:
getDirectParametersin interfaceai.djl.nn.Block
-
describeOutput
- Specified by:
describeOutputin interfaceai.djl.nn.SymbolBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes) - Specified by:
getOutputShapesin interfaceai.djl.nn.Block- Overrides:
getOutputShapesin classai.djl.nn.AbstractSymbolBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes, ai.djl.ndarray.types.DataType[] dataTypes) - Specified by:
getOutputShapesin interfaceai.djl.nn.Block
-
saveParameters
- Specified by:
saveParametersin interfaceai.djl.nn.Block- Overrides:
saveParametersin classai.djl.nn.AbstractBaseBlock- Throws:
IOException
-
loadParameters
public void loadParameters(ai.djl.ndarray.NDManager manager, DataInputStream is) throws IOException, ai.djl.MalformedModelException - Specified by:
loadParametersin interfaceai.djl.nn.Block- Overrides:
loadParametersin classai.djl.nn.AbstractBaseBlock- Throws:
IOExceptionai.djl.MalformedModelException
-
getHandle
Get the native PyTorch model pointer.- Returns:
- the pointer
-