Class AbstractBlock
- java.lang.Object
-
- ai.djl.nn.AbstractBlock
-
- All Implemented Interfaces:
Block
- Direct Known Subclasses:
AbstractSymbolBlock
,BatchNorm
,BertBlock
,BertMaskedLanguageModelBlock
,BertNextSentenceBlock
,BertPretrainingBlock
,ConstantEmbedding
,Convolution
,Decoder
,Deconvolution
,Dropout
,Embedding
,Encoder
,EncoderDecoder
,IdEmbedding
,LambdaBlock
,LayerNorm
,Linear
,ParallelBlock
,Prelu
,RecurrentBlock
,ScaledDotProductAttentionBlock
,SequentialBlock
,TrainableTextEmbedding
,TransformerEncoderBlock
public abstract class AbstractBlock extends java.lang.Object implements Block
AbstractBlock
is an abstract implementation ofBlock
.It is recommended that all
Block
classes that have children extend theAbstractBlock
.To create your own blocks, you need to do the following:
- Define a version for serializing parameter and metadata and pass it to the parent constructor
- Use
addParameter(Parameter)
to add parameters to your block in the constructor if necessary. - Use
addChildBlock(String, Block)
to add child blocks if necessary. - Override
Block.getOutputShapes(Shape[])
to determine the shape of your custom block's output based on the input it will receive. - Override
initializeChildBlocks(NDManager, DataType, Shape...)
if you added child blocks to initialize them based on the input shape your block will receive. You can skip this if your block does not contain child blocks - Override
forward(ParameterStore, NDList, boolean, PairList)
to implement the computation of your block - IFF you need to save data apart from the parameter values of your block, you need to
override
saveMetadata(DataOutputStream)
andloadMetadata(byte, DataInputStream)
. If you do not need to save or load any state other than parameters in your block, you can skip this.
If you use
addParameter(Parameter)
to add parameters, you have to take care of parameter initialization yourself. In this case, you need to setShape to your parameters if you know the shape of Parameter or you can implement prepare to setShape when you see the input shape.
-
-
Field Summary
Fields Modifier and Type Field Description protected BlockList
children
All direct children of this Block.protected java.util.List<java.lang.String>
inputNames
List of names for the input, named inputs should be manually set in sub class.protected Shape[]
inputShapes
The shape of the input for this block, set by the initialization process.protected java.util.LinkedHashMap<java.lang.String,Parameter>
parameters
All direct parameters of this Block.protected byte
version
The model version of this block, used for checking if parameters are still valid during parameter loading.
-
Constructor Summary
Constructors Constructor Description AbstractBlock()
Constructs a newAbstractBlock
instance.AbstractBlock(byte version)
Builds an empty block with the given version for parameter serialization.
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description protected <B extends Block>
BaddChildBlock(java.lang.String name, B block)
Use this to add a child block to this block.protected <P extends Parameter>
PaddParameter(P parameter)
Adds a parameter to this block.protected void
beforeInitialize(Shape... inputShapes)
Performs any action necessary before initialization.void
cast(DataType dataType)
Guaranteed to throw an exception.void
clear()
Closes all the parameters of the block.ai.djl.util.PairList<java.lang.String,Shape>
describeInput()
Returns aPairList
of input names, and shapes.NDList
forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.NDList
forward(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A forward call using both training data and labels.protected abstract NDList
forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.protected NDList
forwardInternal(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, NDList, PairList)
after initialization.BlockList
getChildren()
Returns a list of all the children of the block.ParameterList
getDirectParameters()
Returns a list of all the direct parameters of the block.ParameterList
getParameters()
Returns a list of all the parameters of the block, including the parameters of its children fetched recursively.void
initialize(NDManager manager, DataType dataType, Shape... inputShapes)
Initializes the parameters of the block.protected void
initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
Initializes the Child blocks of this block.boolean
isInitialized()
Returns a boolean whether the block is initialized.protected void
loadMetadata(byte loadVersion, java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.void
loadParameters(NDManager manager, java.io.DataInputStream is)
Loads the parameters from the given input stream.protected void
prepare(Shape[] inputShapes)
Sets the shape ofParameter
s.protected void
readInputShapes(java.io.DataInputStream is)
protected void
saveInputShapes(java.io.DataOutputStream os)
protected void
saveMetadata(java.io.DataOutputStream os)
Override this method to save additional data apart from parameter values.void
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.void
setInitializer(Initializer initializer, Parameter.Type params)
Sets anInitializer
to all the parameters that match parameter type in the block.void
setInitializer(Initializer initializer, java.lang.String paramName)
Sets anInitializer
to the specified direct parameter of the block, overriding the initializer of the parameter, if already set.void
setInitializer(Initializer initializer, java.util.function.Predicate<Parameter> predicate)
Sets anInitializer
to all the parameters that match Predicate in the block.java.lang.String
toString()
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, getOutputShapes
-
-
-
-
Field Detail
-
inputShapes
protected Shape[] inputShapes
The shape of the input for this block, set by the initialization process.
-
inputNames
protected java.util.List<java.lang.String> inputNames
List of names for the input, named inputs should be manually set in sub class.
-
version
protected byte version
The model version of this block, used for checking if parameters are still valid during parameter loading.
-
children
protected BlockList children
All direct children of this Block. Keys are names of the blocks.Use the
addChildBlock(String, Block)
method to add children. All children in this map are automagically loaded / saved.
-
parameters
protected java.util.LinkedHashMap<java.lang.String,Parameter> parameters
All direct parameters of this Block. Keys are name of the parameters.Use the
addParameter(Parameter)
method to add children. All parameters in this map are automatically loaded / saved.
-
-
Method Detail
-
forward
public final NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once. This method should be called only on blocks that are initialized.
-
forward
public NDList forward(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A forward call using both training data and labels.Within this forward call, it can be assumed that training is true.
- Specified by:
forward
in interfaceBlock
- Parameters:
parameterStore
- the parameter storedata
- the input data NDListlabels
- the input labels NDListparams
- optional parameters- Returns:
- the output of the forward pass
- See Also:
Block.forward(ParameterStore, NDList, boolean, PairList)
-
forwardInternal
protected abstract NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, NDList, PairList)
after initialization.- Parameters:
parameterStore
- the parameter storedata
- the input data NDListlabels
- the input labels NDListparams
- optional parameters- Returns:
- the output of the forward pass
- See Also:
forward(ParameterStore, NDList, boolean, PairList)
-
addChildBlock
protected final <B extends Block> B addChildBlock(java.lang.String name, B block)
Use this to add a child block to this block.- Type Parameters:
B
- The type of block- Parameters:
name
- Name of the block, must be unique or otherwise existing children with this name are removed, must not be null.block
- The block, must not be null.- Returns:
- the block given as a parameter - that way the block can be created and reassigned to a member variable more easily.
-
addParameter
protected final <P extends Parameter> P addParameter(P parameter)
Adds a parameter to this block. If parameters are added with this method, intialization of the parameter works out of the box- Type Parameters:
P
- the specific parameter subclass- Parameters:
parameter
- the parameter to add, not null- Returns:
- the parameter passed as arguments to make it easier to create and assign parameters in one line
-
getChildren
public BlockList getChildren()
Returns a list of all the children of the block.- Specified by:
getChildren
in interfaceBlock
- Returns:
- the list of child blocks
-
describeInput
public ai.djl.util.PairList<java.lang.String,Shape> describeInput()
Returns aPairList
of input names, and shapes.- Specified by:
describeInput
in interfaceBlock
- Returns:
- the
PairList
of input names, and shapes
-
setInitializer
public void setInitializer(Initializer initializer, Parameter.Type params)
Sets anInitializer
to all the parameters that match parameter type in the block.- Specified by:
setInitializer
in interfaceBlock
- Parameters:
initializer
- the initializer to setparams
- the Parameter Type we want to setInitializer
-
setInitializer
public void setInitializer(Initializer initializer, java.lang.String paramName)
Sets anInitializer
to the specified direct parameter of the block, overriding the initializer of the parameter, if already set.- Specified by:
setInitializer
in interfaceBlock
- Parameters:
initializer
- the initializer to be setparamName
- the name of the parameter
-
setInitializer
public void setInitializer(Initializer initializer, java.util.function.Predicate<Parameter> predicate)
Sets anInitializer
to all the parameters that match Predicate in the block.- Specified by:
setInitializer
in interfaceBlock
- Parameters:
initializer
- the initializer to be setpredicate
- predicate function to indicate parameters you want to set
-
initialize
public void initialize(NDManager manager, DataType dataType, Shape... inputShapes)
Initializes the parameters of the block. This method must be called before calling `forward`.- Specified by:
initialize
in interfaceBlock
- Parameters:
manager
- the NDManager to initialize the parametersdataType
- the datatype of the parametersinputShapes
- the shapes of the inputs to the block
-
beforeInitialize
protected void beforeInitialize(Shape... inputShapes)
Performs any action necessary before initialization. For example, keep the input information or verify the layout.- Parameters:
inputShapes
- the expected shapes of the input
-
initializeChildBlocks
protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.- Parameters:
manager
- the manager to use for initializationdataType
- the requested data typeinputShapes
- the expected input shapes for this block
-
getParameters
public ParameterList getParameters()
Returns a list of all the parameters of the block, including the parameters of its children fetched recursively.- Specified by:
getParameters
in interfaceBlock
- Returns:
- the list of all parameters of the block
-
getDirectParameters
public ParameterList getDirectParameters()
Returns a list of all the direct parameters of the block.- Specified by:
getDirectParameters
in interfaceBlock
- Returns:
- the list of
Parameter
-
prepare
protected void prepare(Shape[] inputShapes)
Sets the shape ofParameter
s.- Parameters:
inputShapes
- the shapes of inputs
-
isInitialized
public boolean isInitialized()
Returns a boolean whether the block is initialized.- Specified by:
isInitialized
in interfaceBlock
- Returns:
- whether the block is initialized
-
clear
public void clear()
Closes all the parameters of the block. All the updates made during training will be lost.
-
cast
public void cast(DataType dataType)
Guaranteed to throw an exception. Not yet implemented
-
saveParameters
public void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
Writes the parameters of the block to the given outputStream.- Specified by:
saveParameters
in interfaceBlock
- Parameters:
os
- the outputstream to save the parameters to- Throws:
java.io.IOException
- if an I/O error occurs
-
loadParameters
public void loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
Loads the parameters from the given input stream.- Specified by:
loadParameters
in interfaceBlock
- Parameters:
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter values- Throws:
java.io.IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupported
-
saveMetadata
protected void saveMetadata(java.io.DataOutputStream os) throws java.io.IOException
Override this method to save additional data apart from parameter values.This default implementation saves the currently set input shapes.
- Parameters:
os
- the non-null output stream the parameter values and metadata are written to- Throws:
java.io.IOException
- saving failed
-
loadMetadata
protected void loadMetadata(byte loadVersion, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
Overwrite this to load additional metadata with the parameter values.If you overwrite
saveMetadata(DataOutputStream)
or need to provide backward compatibility to older binary formats, you prabably need to overwrite this. This default implementation checks if the version number fits, if not it throws anMalformedModelException
. After that it restores the input shapes.- Parameters:
loadVersion
- the version used for loading this metadata.is
- the input stream we are loading from- Throws:
java.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong format
-
saveInputShapes
protected void saveInputShapes(java.io.DataOutputStream os) throws java.io.IOException
- Throws:
java.io.IOException
-
readInputShapes
protected void readInputShapes(java.io.DataInputStream is) throws java.io.IOException
- Throws:
java.io.IOException
-
toString
public java.lang.String toString()
- Overrides:
toString
in classjava.lang.Object
-
-