public class PtSymbolBlock
extends ai.djl.nn.AbstractSymbolBlock
implements java.lang.AutoCloseable
PtSymbolBlock
is the PyTorch implementation of SymbolBlock
.
You can create a PtSymbolBlock
using Model.load(java.nio.file.Path,
String)
.
Constructor and Description |
---|
PtSymbolBlock(PtNDManager manager)
Constructs an Empty
PtSymbolBlock . |
PtSymbolBlock(PtNDManager manager,
long handle)
Constructs a
PtSymbolBlock . |
Modifier and Type | Method and Description |
---|---|
void |
close() |
ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> |
describeInput() |
ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> |
describeOutput() |
IValue |
forward(IValue... inputs)
Runs the forward of this PyTorch module.
|
protected ai.djl.ndarray.NDList |
forwardInternal(ai.djl.training.ParameterStore parameterStore,
ai.djl.ndarray.NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params) |
java.lang.Long |
getHandle()
Get the native PyTorch model pointer.
|
ai.djl.ndarray.types.Shape[] |
getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes) |
void |
loadParameters(ai.djl.ndarray.NDManager manager,
java.io.DataInputStream is) |
void |
saveParameters(java.io.DataOutputStream os) |
addChildBlock, addParameter, beforeInitialize, cast, clear, forward, forward, forwardInternal, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, prepare, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
public PtSymbolBlock(PtNDManager manager, long handle)
PtSymbolBlock
.
You can create a PtSymbolBlock
using Model.load(java.nio.file.Path,
String)
.
manager
- the manager to use for the blockhandle
- the module handlepublic PtSymbolBlock(PtNDManager manager)
PtSymbolBlock
.manager
- the manager to use for the blockpublic void close()
close
in interface java.lang.AutoCloseable
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
forwardInternal
in class ai.djl.nn.AbstractBlock
public ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> describeInput()
describeInput
in interface ai.djl.nn.Block
describeInput
in class ai.djl.nn.AbstractBlock
public ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> describeOutput()
describeOutput
in interface ai.djl.nn.SymbolBlock
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
getOutputShapes
in interface ai.djl.nn.Block
getOutputShapes
in class ai.djl.nn.AbstractSymbolBlock
public void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
saveParameters
in interface ai.djl.nn.Block
saveParameters
in class ai.djl.nn.AbstractBlock
java.io.IOException
public void loadParameters(ai.djl.ndarray.NDManager manager, java.io.DataInputStream is) throws java.io.IOException, ai.djl.MalformedModelException
loadParameters
in interface ai.djl.nn.Block
loadParameters
in class ai.djl.nn.AbstractBlock
java.io.IOException
ai.djl.MalformedModelException
public java.lang.Long getHandle()