Package ai.djl.pytorch.engine
Class PtSymbolBlock
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractSymbolBlock
-
- ai.djl.pytorch.engine.PtSymbolBlock
-
- All Implemented Interfaces:
ai.djl.nn.Block
,ai.djl.nn.SymbolBlock
,java.lang.AutoCloseable
public class PtSymbolBlock extends ai.djl.nn.AbstractSymbolBlock implements java.lang.AutoCloseable
PtSymbolBlock
is the PyTorch implementation ofSymbolBlock
.You can create a
PtSymbolBlock
usingModel.load(java.nio.file.Path, String)
.
-
-
Constructor Summary
Constructors Constructor Description PtSymbolBlock(PtNDManager manager)
Constructs an EmptyPtSymbolBlock
.PtSymbolBlock(PtNDManager manager, long handle)
Constructs aPtSymbolBlock
.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description void
close()
ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape>
describeInput()
ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape>
describeOutput()
IValue
forward(IValue... inputs)
Runs the forward of this PyTorch module.protected ai.djl.ndarray.NDList
forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
ai.djl.nn.ParameterList
getDirectParameters()
java.lang.Long
getHandle()
Get the native PyTorch model pointer.ai.djl.ndarray.types.Shape[]
getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
ai.djl.ndarray.types.Shape[]
getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes, ai.djl.ndarray.types.DataType[] dataTypes)
void
loadParameters(ai.djl.ndarray.NDManager manager, java.io.DataInputStream is)
void
saveParameters(java.io.DataOutputStream os)
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, prepare, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
-
-
-
Constructor Detail
-
PtSymbolBlock
public PtSymbolBlock(PtNDManager manager, long handle)
Constructs aPtSymbolBlock
.You can create a
PtSymbolBlock
usingModel.load(java.nio.file.Path, String)
.- Parameters:
manager
- the manager to use for the blockhandle
- the module handle
-
PtSymbolBlock
public PtSymbolBlock(PtNDManager manager)
Constructs an EmptyPtSymbolBlock
.- Parameters:
manager
- the manager to use for the block
-
-
Method Detail
-
close
public void close()
- Specified by:
close
in interfacejava.lang.AutoCloseable
-
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
- Specified by:
forwardInternal
in classai.djl.nn.AbstractBaseBlock
-
describeInput
public ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> describeInput()
- Specified by:
describeInput
in interfaceai.djl.nn.Block
- Overrides:
describeInput
in classai.djl.nn.AbstractBaseBlock
-
getDirectParameters
public ai.djl.nn.ParameterList getDirectParameters()
- Specified by:
getDirectParameters
in interfaceai.djl.nn.Block
-
describeOutput
public ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> describeOutput()
- Specified by:
describeOutput
in interfaceai.djl.nn.SymbolBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
- Specified by:
getOutputShapes
in interfaceai.djl.nn.Block
- Overrides:
getOutputShapes
in classai.djl.nn.AbstractSymbolBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes, ai.djl.ndarray.types.DataType[] dataTypes)
- Specified by:
getOutputShapes
in interfaceai.djl.nn.Block
-
saveParameters
public void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
- Specified by:
saveParameters
in interfaceai.djl.nn.Block
- Overrides:
saveParameters
in classai.djl.nn.AbstractBaseBlock
- Throws:
java.io.IOException
-
loadParameters
public void loadParameters(ai.djl.ndarray.NDManager manager, java.io.DataInputStream is) throws java.io.IOException, ai.djl.MalformedModelException
- Specified by:
loadParameters
in interfaceai.djl.nn.Block
- Overrides:
loadParameters
in classai.djl.nn.AbstractBaseBlock
- Throws:
java.io.IOException
ai.djl.MalformedModelException
-
getHandle
public java.lang.Long getHandle()
Get the native PyTorch model pointer.- Returns:
- the pointer
-
-