public class PtSymbolBlock extends NativeResource implements ai.djl.nn.SymbolBlock
PtSymbolBlock
is the PyTorch implementation of SymbolBlock
.
You can create a PtSymbolBlock
using Model.load(java.nio.file.Path,
String)
.
handle
Constructor and Description |
---|
PtSymbolBlock(PtNDManager manager,
Pointer handle)
Constructs a
PtSymbolBlock . |
Modifier and Type | Method and Description |
---|---|
void |
cast(ai.djl.ndarray.types.DataType dataType) |
void |
clear() |
void |
close() |
ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> |
describeInput() |
ai.djl.ndarray.NDList |
forward(ai.djl.training.ParameterStore parameterStore,
ai.djl.ndarray.NDList inputs,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params) |
ai.djl.nn.BlockList |
getChildren() |
java.util.List<ai.djl.nn.Parameter> |
getDirectParameters() |
ai.djl.ndarray.types.Shape[] |
getOutputShapes(ai.djl.ndarray.NDManager manager,
ai.djl.ndarray.types.Shape[] inputShapes) |
ai.djl.nn.ParameterList |
getParameters() |
ai.djl.ndarray.types.Shape |
getParameterShape(java.lang.String name,
ai.djl.ndarray.types.Shape[] inputShapes) |
ai.djl.ndarray.types.Shape[] |
initialize(ai.djl.ndarray.NDManager manager,
ai.djl.ndarray.types.DataType dataType,
ai.djl.ndarray.types.Shape... inputShapes) |
boolean |
isInitialized() |
void |
loadParameters(ai.djl.ndarray.NDManager manager,
java.io.DataInputStream is) |
void |
removeLastBlock() |
void |
saveParameters(java.io.DataOutputStream os) |
void |
setInitializer(ai.djl.training.initializer.Initializer initializer) |
void |
setInitializer(ai.djl.training.initializer.Initializer initializer,
java.lang.String paramName) |
finalize, getHandle, getUid, isReleased
public PtSymbolBlock(PtNDManager manager, Pointer handle)
PtSymbolBlock
.
You can create a PtSymbolBlock
using Model.load(java.nio.file.Path,
String)
.
manager
- the manager to use for the blockhandle
- the module handlepublic void close()
NativeResource
close
in interface java.lang.AutoCloseable
close
in class NativeResource
public void removeLastBlock()
removeLastBlock
in interface ai.djl.nn.SymbolBlock
public ai.djl.ndarray.NDList forward(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
forward
in interface ai.djl.nn.Block
public void setInitializer(ai.djl.training.initializer.Initializer initializer)
setInitializer
in interface ai.djl.nn.Block
public void setInitializer(ai.djl.training.initializer.Initializer initializer, java.lang.String paramName)
setInitializer
in interface ai.djl.nn.Block
public ai.djl.ndarray.types.Shape[] initialize(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
initialize
in interface ai.djl.nn.Block
public boolean isInitialized()
isInitialized
in interface ai.djl.nn.Block
public void cast(ai.djl.ndarray.types.DataType dataType)
cast
in interface ai.djl.nn.Block
public void clear()
clear
in interface ai.djl.nn.Block
public ai.djl.util.PairList<java.lang.String,ai.djl.ndarray.types.Shape> describeInput()
describeInput
in interface ai.djl.nn.Block
public ai.djl.nn.BlockList getChildren()
getChildren
in interface ai.djl.nn.Block
public java.util.List<ai.djl.nn.Parameter> getDirectParameters()
getDirectParameters
in interface ai.djl.nn.Block
public ai.djl.nn.ParameterList getParameters()
getParameters
in interface ai.djl.nn.Block
public ai.djl.ndarray.types.Shape getParameterShape(java.lang.String name, ai.djl.ndarray.types.Shape[] inputShapes)
getParameterShape
in interface ai.djl.nn.Block
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.Shape[] inputShapes)
getOutputShapes
in interface ai.djl.nn.Block
public void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
saveParameters
in interface ai.djl.nn.Block
java.io.IOException
public void loadParameters(ai.djl.ndarray.NDManager manager, java.io.DataInputStream is) throws java.io.IOException, ai.djl.MalformedModelException
loadParameters
in interface ai.djl.nn.Block
java.io.IOException
ai.djl.MalformedModelException