public class PtSymbolBlock
extends ai.djl.util.NativeResource<java.lang.Long>
implements ai.djl.nn.SymbolBlock
PtSymbolBlock
is the PyTorch implementation of SymbolBlock
.
You can create a PtSymbolBlock
using Model.load(java.nio.file.Path,
String)
.
Constructor and Description |
---|
PtSymbolBlock(PtNDManager manager,
long handle)
Constructs a
PtSymbolBlock . |
Modifier and Type | Method and Description |
---|---|
void |
cast(ai.djl.ndarray.types.DataType dataType) |
void |
clear() |
void |
close() |
ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> |
describeInput() |
ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> |
describeOutput() |
ai.djl.ndarray.NDList |
forward(ai.djl.training.ParameterStore parameterStore,
ai.djl.ndarray.NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params) |
ai.djl.nn.BlockList |
getChildren() |
ai.djl.nn.ParameterList |
getDirectParameters() |
ai.djl.ndarray.types.Shape[] |
getOutputShapes(ai.djl.ndarray.NDManager manager,
ai.djl.ndarray.types.Shape[] inputShapes) |
ai.djl.nn.ParameterList |
getParameters() |
ai.djl.ndarray.types.Shape |
getParameterShape(java.lang.String name,
ai.djl.ndarray.types.Shape[] inputShapes) |
ai.djl.ndarray.types.Shape[] |
initialize(ai.djl.ndarray.NDManager manager,
ai.djl.ndarray.types.DataType dataType,
ai.djl.ndarray.types.Shape... inputShapes) |
boolean |
isInitialized() |
void |
loadParameters(ai.djl.ndarray.NDManager manager,
java.io.DataInputStream is) |
void |
removeLastBlock() |
void |
saveParameters(java.io.DataOutputStream os) |
void |
setInitializer(ai.djl.training.initializer.Initializer initializer) |
void |
setInitializer(ai.djl.training.initializer.Initializer initializer,
java.lang.String paramName) |
public PtSymbolBlock(PtNDManager manager, long handle)
PtSymbolBlock
.
You can create a PtSymbolBlock
using Model.load(java.nio.file.Path,
String)
.
manager
- the manager to use for the blockhandle
- the module handlepublic void close()
close
in interface java.lang.AutoCloseable
close
in class ai.djl.util.NativeResource<java.lang.Long>
public void removeLastBlock()
removeLastBlock
in interface ai.djl.nn.SymbolBlock
public ai.djl.ndarray.NDList forward(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
forward
in interface ai.djl.nn.Block
public void setInitializer(ai.djl.training.initializer.Initializer initializer)
setInitializer
in interface ai.djl.nn.Block
public void setInitializer(ai.djl.training.initializer.Initializer initializer, java.lang.String paramName)
setInitializer
in interface ai.djl.nn.Block
public ai.djl.ndarray.types.Shape[] initialize(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
initialize
in interface ai.djl.nn.Block
public boolean isInitialized()
isInitialized
in interface ai.djl.nn.Block
public void cast(ai.djl.ndarray.types.DataType dataType)
cast
in interface ai.djl.nn.Block
public void clear()
clear
in interface ai.djl.nn.Block
public ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> describeInput()
describeInput
in interface ai.djl.nn.Block
public ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> describeOutput()
describeOutput
in interface ai.djl.nn.SymbolBlock
public ai.djl.nn.BlockList getChildren()
getChildren
in interface ai.djl.nn.Block
public ai.djl.nn.ParameterList getDirectParameters()
getDirectParameters
in interface ai.djl.nn.Block
public ai.djl.nn.ParameterList getParameters()
getParameters
in interface ai.djl.nn.Block
public ai.djl.ndarray.types.Shape getParameterShape(java.lang.String name, ai.djl.ndarray.types.Shape[] inputShapes)
getParameterShape
in interface ai.djl.nn.Block
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.Shape[] inputShapes)
getOutputShapes
in interface ai.djl.nn.Block
public void saveParameters(java.io.DataOutputStream os)
saveParameters
in interface ai.djl.nn.Block
public void loadParameters(ai.djl.ndarray.NDManager manager, java.io.DataInputStream is)
loadParameters
in interface ai.djl.nn.Block