Package ai.djl.nn
Class AbstractBlock
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- All Implemented Interfaces:
Block
- Direct Known Subclasses:
BatchNorm
,BertBlock
,BertMaskedLanguageModelBlock
,BertNextSentenceBlock
,BertPretrainingBlock
,ConstantEmbedding
,Convolution
,Decoder
,Deconvolution
,Dropout
,Embedding
,Encoder
,EncoderDecoder
,IdEmbedding
,LambdaBlock
,LayerNorm
,Linear
,LinearCollection
,Multiplication
,ParallelBlock
,Prelu
,RecurrentBlock
,ScaledDotProductAttentionBlock
,SequentialBlock
,SparseMax
,TrainableTextEmbedding
,TransformerEncoderBlock
public abstract class AbstractBlock extends AbstractBaseBlock
AbstractBlock
is an abstract implementation ofBlock
.It is recommended that all
Block
classes that have children extend theAbstractBlock
.To create your own blocks, you need to do the following:
- Define a version for serializing parameter and metadata and pass it to the parent constructor
- Use
addParameter(Parameter)
to add parameters to your block in the constructor if necessary. - Use
addChildBlock(String, Block)
to add child blocks if necessary. - Override
Block.getOutputShapes(Shape[])
to determine the shape of your custom block's output based on the input it will receive. - Override
AbstractBaseBlock.initializeChildBlocks(NDManager, DataType, Shape...)
if you added child blocks to initialize them based on the input shape your block will receive. You can skip this if your block does not contain child blocks - Override
AbstractBaseBlock.forward(ParameterStore, NDList, boolean, PairList)
to implement the computation of your block - IFF you need to save data apart from the parameter values of your block, you need to
override
AbstractBaseBlock.saveMetadata(DataOutputStream)
andAbstractBaseBlock.loadMetadata(byte, java.io.DataInputStream)
. If you do not need to save or load any state other than parameters in your block, you can skip this.
If you use
addParameter(Parameter)
to add parameters, you have to take care of parameter initialization yourself. In this case, you need to setShape to your parameters if you know the shape of Parameter or you can implement prepare to setShape when you see the input shape.
-
-
Field Summary
Fields Modifier and Type Field Description protected BlockList
children
All direct children of this Block.protected java.util.LinkedHashMap<java.lang.String,Parameter>
parameters
All direct parameters of this Block.-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, version
-
-
Constructor Summary
Constructors Constructor Description AbstractBlock()
Constructs a newAbstractBlock
instance.AbstractBlock(byte version)
Builds an empty block with the given version for parameter serialization.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description protected <B extends Block>
BaddChildBlock(java.lang.String name, B block)
Use this to add a child block to this block.protected LambdaBlock
addChildBlock(java.lang.String name, java.util.function.Function<NDList,NDList> f)
Adds aLambdaBlock
as a child block to this block.protected LambdaBlock
addChildBlockSingleton(java.lang.String name, java.util.function.Function<NDArray,NDArray> f)
Adds aLambdaBlock.singleton(Function)
as a child block to this block.protected <P extends Parameter>
PaddParameter(P parameter)
Adds a parameter to this block.BlockList
getChildren()
Returns a list of all the children of the block.ParameterList
getDirectParameters()
Returns a list of all the direct parameters of the block.-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, forwardInternal, getInputShapes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, getOutputShapes
-
-
-
-
Field Detail
-
children
protected BlockList children
All direct children of this Block. Keys are names of the blocks.Use the
addChildBlock(String, Block)
method to add children. All children in this map are automagically loaded / saved.
-
parameters
protected java.util.LinkedHashMap<java.lang.String,Parameter> parameters
All direct parameters of this Block. Keys are name of the parameters.Use the
addParameter(Parameter)
method to add children. All parameters in this map are automatically loaded / saved.
-
-
Method Detail
-
addChildBlock
protected final <B extends Block> B addChildBlock(java.lang.String name, B block)
Use this to add a child block to this block.- Type Parameters:
B
- The type of block- Parameters:
name
- Name of the block, must not be null.block
- The block, must not be null.- Returns:
- the block given as a parameter - that way the block can be created and reassigned to a member variable more easily.
-
addChildBlock
protected LambdaBlock addChildBlock(java.lang.String name, java.util.function.Function<NDList,NDList> f)
Adds aLambdaBlock
as a child block to this block.- Parameters:
name
- Name of the block, must not be null.f
- the function forms theLambdaBlock
- Returns:
- the child block
-
addChildBlockSingleton
protected final LambdaBlock addChildBlockSingleton(java.lang.String name, java.util.function.Function<NDArray,NDArray> f)
Adds aLambdaBlock.singleton(Function)
as a child block to this block.- Parameters:
name
- Name of the block, must not be null.f
- the function forms theLambdaBlock
- Returns:
- the child block
- See Also:
LambdaBlock.singleton(Function)
-
addParameter
protected final <P extends Parameter> P addParameter(P parameter)
Adds a parameter to this block. If parameters are added with this method, intialization of the parameter works out of the box- Type Parameters:
P
- the specific parameter subclass- Parameters:
parameter
- the parameter to add, not null- Returns:
- the parameter passed as arguments to make it easier to create and assign parameters in one line
-
getChildren
public BlockList getChildren()
Returns a list of all the children of the block.- Returns:
- the list of child blocks
-
getDirectParameters
public ParameterList getDirectParameters()
Returns a list of all the direct parameters of the block.- Returns:
- the list of
Parameter
-
-