Package ai.djl.pytorch.engine
Class PtSymbolBlock
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractSymbolBlock
-
- ai.djl.pytorch.engine.PtSymbolBlock
-
- All Implemented Interfaces:
ai.djl.nn.Block,ai.djl.nn.SymbolBlock,java.lang.AutoCloseable
public class PtSymbolBlock extends ai.djl.nn.AbstractSymbolBlock implements java.lang.AutoCloseablePtSymbolBlockis the PyTorch implementation ofSymbolBlock.You can create a
PtSymbolBlockusingModel.load(java.nio.file.Path, String).
-
-
Constructor Summary
Constructors Constructor Description PtSymbolBlock(PtNDManager manager)Constructs an EmptyPtSymbolBlock.PtSymbolBlock(PtNDManager manager, long handle)Constructs aPtSymbolBlock.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description voidclose()ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape>describeInput()ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape>describeOutput()IValueforward(IValue... inputs)Runs the forward of this PyTorch module.protected ai.djl.ndarray.NDListforwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)ai.djl.nn.ParameterListgetDirectParameters()java.lang.LonggetHandle()Get the native PyTorch model pointer.ai.djl.ndarray.types.Shape[]getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)ai.djl.ndarray.types.Shape[]getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes, ai.djl.ndarray.types.DataType[] dataTypes)voidloadParameters(ai.djl.ndarray.NDManager manager, java.io.DataInputStream is)voidsaveParameters(java.io.DataOutputStream os)-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, prepare, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
-
-
-
Constructor Detail
-
PtSymbolBlock
public PtSymbolBlock(PtNDManager manager, long handle)
Constructs aPtSymbolBlock.You can create a
PtSymbolBlockusingModel.load(java.nio.file.Path, String).- Parameters:
manager- the manager to use for the blockhandle- the module handle
-
PtSymbolBlock
public PtSymbolBlock(PtNDManager manager)
Constructs an EmptyPtSymbolBlock.- Parameters:
manager- the manager to use for the block
-
-
Method Detail
-
close
public void close()
- Specified by:
closein interfacejava.lang.AutoCloseable
-
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)- Specified by:
forwardInternalin classai.djl.nn.AbstractBaseBlock
-
describeInput
public ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> describeInput()
- Specified by:
describeInputin interfaceai.djl.nn.Block- Overrides:
describeInputin classai.djl.nn.AbstractBaseBlock
-
getDirectParameters
public ai.djl.nn.ParameterList getDirectParameters()
- Specified by:
getDirectParametersin interfaceai.djl.nn.Block
-
describeOutput
public ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> describeOutput()
- Specified by:
describeOutputin interfaceai.djl.nn.SymbolBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
- Specified by:
getOutputShapesin interfaceai.djl.nn.Block- Overrides:
getOutputShapesin classai.djl.nn.AbstractSymbolBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes, ai.djl.ndarray.types.DataType[] dataTypes)- Specified by:
getOutputShapesin interfaceai.djl.nn.Block
-
saveParameters
public void saveParameters(java.io.DataOutputStream os) throws java.io.IOException- Specified by:
saveParametersin interfaceai.djl.nn.Block- Overrides:
saveParametersin classai.djl.nn.AbstractBaseBlock- Throws:
java.io.IOException
-
loadParameters
public void loadParameters(ai.djl.ndarray.NDManager manager, java.io.DataInputStream is) throws java.io.IOException, ai.djl.MalformedModelException- Specified by:
loadParametersin interfaceai.djl.nn.Block- Overrides:
loadParametersin classai.djl.nn.AbstractBaseBlock- Throws:
java.io.IOExceptionai.djl.MalformedModelException
-
getHandle
public java.lang.Long getHandle()
Get the native PyTorch model pointer.- Returns:
- the pointer
-
-