public abstract class AbstractBlock extends java.lang.Object implements Block
AbstractBlock
is an abstract implementation of Block
.
It is recommended that all Block
classes that have children extend the AbstractBlock
.
To create your own blocks, you need to do the following:
addParameter(Parameter)
to add parameters to your block in the
constructor if necessary.
addChildBlock(String, Block)
to add child blocks if necessary.
Block.getOutputShapes(Shape[])
to determine the shape of your custom
block's output based on the input it will receive.
initializeChildBlocks(NDManager, DataType, Shape...)
if you
added child blocks to initialize them based on the input shape your block will receive. You
can skip this if your block does not contain child blocks
forward(ParameterStore, NDList, boolean, PairList)
to
implement the computation of your block
saveMetadata(DataOutputStream)
and loadMetadata(byte, DataInputStream)
. If you do not need to save or load any
state other than parameters in your block, you can skip this.
If you use addParameter(Parameter)
to add parameters, you have to take
care of parameter initialization yourself. In this case, you need to setShape to your parameters
if you know the shape of Parameter or you can implement prepare to setShape when you see the
input shape.
Modifier and Type | Field and Description |
---|---|
protected BlockList |
children
All direct children of this Block.
|
protected java.util.List<java.lang.String> |
inputNames
List of names for the input, named inputs should be manually set in sub class.
|
protected Shape[] |
inputShapes
The shape of the input for this block, set by the initialization process.
|
protected java.util.LinkedHashMap<java.lang.String,Parameter> |
parameters
All direct parameters of this Block.
|
protected byte |
version
The model version of this block, used for checking if parameters are still valid during
parameter loading.
|
Constructor and Description |
---|
AbstractBlock()
Constructs a new
AbstractBlock instance. |
AbstractBlock(byte version)
Builds an empty block with the given version for parameter serialization.
|
Modifier and Type | Method and Description |
---|---|
protected <B extends Block> |
addChildBlock(java.lang.String name,
B block)
Use this to add a child block to this block.
|
protected <P extends Parameter> |
addParameter(P parameter)
Adds a parameter to this block.
|
protected void |
beforeInitialize(Shape... inputShapes)
Performs any action necessary before initialization.
|
void |
cast(DataType dataType)
Guaranteed to throw an exception.
|
void |
clear()
Closes all the parameters of the block.
|
ai.djl.util.PairList<java.lang.String,Shape> |
describeInput()
Returns a
PairList of input names, and shapes. |
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.
|
NDList |
forward(ParameterStore parameterStore,
NDList data,
NDList labels,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A forward call using both training data and labels.
|
protected abstract NDList |
forwardInternal(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper for
Block.forward(ParameterStore, NDList, boolean, PairList) after
initialization. |
protected NDList |
forwardInternal(ParameterStore parameterStore,
NDList data,
NDList labels,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper for
Block.forward(ParameterStore, NDList, NDList, PairList) after
initialization. |
BlockList |
getChildren()
Returns a list of all the children of the block.
|
ParameterList |
getDirectParameters()
Returns a list of all the direct parameters of the block.
|
ParameterList |
getParameters()
Returns a list of all the parameters of the block, including the parameters of its children
fetched recursively.
|
void |
initialize(NDManager manager,
DataType dataType,
Shape... inputShapes)
Initializes the parameters of the block.
|
protected void |
initializeChildBlocks(NDManager manager,
DataType dataType,
Shape... inputShapes)
Initializes the Child blocks of this block.
|
boolean |
isInitialized()
Returns a boolean whether the block is initialized.
|
protected void |
loadMetadata(byte loadVersion,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
void |
loadParameters(NDManager manager,
java.io.DataInputStream is)
Loads the parameters from the given input stream.
|
protected void |
prepare(Shape[] inputShapes)
Sets the shape of
Parameter s. |
protected void |
readInputShapes(java.io.DataInputStream is) |
protected void |
saveInputShapes(java.io.DataOutputStream os) |
protected void |
saveMetadata(java.io.DataOutputStream os)
Override this method to save additional data apart from parameter values.
|
void |
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.
|
void |
setInitializer(Initializer initializer,
Parameter.Type params)
Sets an
Initializer to all the parameters that match parameter type in the block. |
void |
setInitializer(Initializer initializer,
java.util.function.Predicate<Parameter> predicate)
Sets an
Initializer to all the parameters that match Predicate in the block. |
void |
setInitializer(Initializer initializer,
java.lang.String paramName)
Sets an
Initializer to the specified direct parameter of the block, overriding the
initializer of the parameter, if already set. |
java.lang.String |
toString() |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, getOutputShapes, validateLayout
protected Shape[] inputShapes
protected java.util.List<java.lang.String> inputNames
protected byte version
protected BlockList children
Use the addChildBlock(String, Block)
method to add children. All
children in this map are automagically loaded / saved.
protected java.util.LinkedHashMap<java.lang.String,Parameter> parameters
Use the addParameter(Parameter)
method to add children. All
parameters in this map are automatically loaded / saved.
public AbstractBlock()
AbstractBlock
instance.public AbstractBlock(byte version)
version
- the version to use for parameter serialization.public final NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
public NDList forward(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Within this forward call, it can be assumed that training is true.
forward
in interface Block
parameterStore
- the parameter storedata
- the input data NDListlabels
- the input labels NDListparams
- optional parametersBlock.forward(ParameterStore, NDList, boolean, PairList)
protected abstract NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Block.forward(ParameterStore, NDList, boolean, PairList)
after
initialization.parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parametersprotected NDList forwardInternal(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Block.forward(ParameterStore, NDList, NDList, PairList)
after
initialization.parameterStore
- the parameter storedata
- the input data NDListlabels
- the input labels NDListparams
- optional parametersforward(ParameterStore, NDList, boolean, PairList)
protected final <B extends Block> B addChildBlock(java.lang.String name, B block)
B
- The type of blockname
- Name of the block, must be unique or otherwise existing children with this name
are removed, must not be null.block
- The block, must not be null.protected final <P extends Parameter> P addParameter(P parameter)
P
- the specific parameter subclassparameter
- the parameter to add, not nullpublic BlockList getChildren()
getChildren
in interface Block
public ai.djl.util.PairList<java.lang.String,Shape> describeInput()
PairList
of input names, and shapes.describeInput
in interface Block
PairList
of input names, and shapespublic void setInitializer(Initializer initializer, Parameter.Type params)
Initializer
to all the parameters that match parameter type in the block.setInitializer
in interface Block
initializer
- the initializer to setparams
- the Parameter Type we want to setInitializerpublic void setInitializer(Initializer initializer, java.lang.String paramName)
Initializer
to the specified direct parameter of the block, overriding the
initializer of the parameter, if already set.setInitializer
in interface Block
initializer
- the initializer to be setparamName
- the name of the parameterpublic void setInitializer(Initializer initializer, java.util.function.Predicate<Parameter> predicate)
Initializer
to all the parameters that match Predicate in the block.setInitializer
in interface Block
initializer
- the initializer to be setpredicate
- predicate function to indicate parameters you want to setpublic void initialize(NDManager manager, DataType dataType, Shape... inputShapes)
initialize
in interface Block
manager
- the NDManager to initialize the parametersdataType
- the datatype of the parametersinputShapes
- the shapes of the inputs to the blockprotected void beforeInitialize(Shape... inputShapes)
inputShapes
- the expected shapes of the inputprotected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
manager
- the manager to use for initializationdataType
- the requested data typeinputShapes
- the expected input shapes for this blockpublic ParameterList getParameters()
getParameters
in interface Block
public ParameterList getDirectParameters()
getDirectParameters
in interface Block
Parameter
protected void prepare(Shape[] inputShapes)
Parameter
s.inputShapes
- the shapes of inputspublic boolean isInitialized()
isInitialized
in interface Block
public void clear()
public void cast(DataType dataType)
public void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
saveParameters
in interface Block
os
- the outputstream to save the parameters tojava.io.IOException
- if an I/O error occurspublic void loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
loadParameters
in interface Block
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter valuesjava.io.IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupportedprotected void saveMetadata(java.io.DataOutputStream os) throws java.io.IOException
This default implementation saves the currently set input shapes.
os
- the non-null output stream the parameter values and metadata are written tojava.io.IOException
- saving failedprotected void loadMetadata(byte loadVersion, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
If you overwrite saveMetadata(DataOutputStream)
or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException
. After that it restores the input shapes.
loadVersion
- the version used for loading this metadata.is
- the input stream we are loading fromjava.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong formatprotected void saveInputShapes(java.io.DataOutputStream os) throws java.io.IOException
java.io.IOException
protected void readInputShapes(java.io.DataInputStream is) throws java.io.IOException
java.io.IOException
public java.lang.String toString()
toString
in class java.lang.Object