Interface Block
-
- All Known Subinterfaces:
SymbolBlock
- All Known Implementing Classes:
AbstractBaseBlock
,AbstractBlock
,AbstractSymbolBlock
,BatchNorm
,BertBlock
,BertMaskedLanguageModelBlock
,BertNextSentenceBlock
,BertPretrainingBlock
,ConstantEmbedding
,Conv1d
,Conv1dTranspose
,Conv2d
,Conv2dTranspose
,Conv3d
,Convolution
,Decoder
,Deconvolution
,Dropout
,Embedding
,Encoder
,EncoderDecoder
,GhostBatchNorm
,GRU
,IdEmbedding
,LambdaBlock
,LayerNorm
,Linear
,LinearCollection
,LSTM
,Multiplication
,ParallelBlock
,PointwiseFeedForwardBlock
,Prelu
,RecurrentBlock
,RNN
,ScaledDotProductAttentionBlock
,SequentialBlock
,SparseMax
,TrainableTextEmbedding
,TrainableWordEmbedding
,TransformerEncoderBlock
public interface Block
ABlock
is a composable function that forms a neural network.Blocks serve a purpose similar to functions that convert an input NDList to an output NDList. They can represent single operations, parts of a neural network, and even the whole neural network. What makes blocks special is that they contain a number of parameters that are used in their function and are trained during deep learning. As these parameters are trained, the functions represented by the blocks get more and more accurate. Each block consists of the following components:
- Forward function
- Parameters
- Child blocks
The core purpose of a
Block
is to perform an operation on the inputs, and return an output. It is defined in theforward
method. The forward function could be defined explicitly in terms of parameters or implicitly and could be a combination of the functions of the child blocks.The parameters of a
Block
are instances ofParameter
which are required for the operation in the forward function. For example, in aConv2d
block, the parameters areweight
andbias
. During training, these parameters are updated to reflect the training data, and that forms the crux of learning.When building these block functions, the easiest way is to use composition. Similar to how functions are built by calling other functions, blocks can be built by combining other blocks. We refer to the containing block as the parent and the sub-blocks as the children.
We provide helpers for creating two common structures of blocks. For blocks that call children in a chain, use
SequentialBlock
. If a blocks calls all of the children in parallel and then combines their results, useParallelBlock
. For blocks that do not fit these strcutures, you should directly extend theAbstractBlock
class.A block does not necessarily have to have children and parameters. For example,
SequentialBlock
, andParallelBlock
don't have any parameters, but do have child blocks. Similarly,Conv2d
does not have children, but has parameters. There can be special cases where blocks have neither parameters nor children. One such example isLambdaBlock
.LambdaBlock
takes in a function, and applies that function to its input in theforward
method.Now that we understand the components of the block, we can explore what the block really represents. A block combined with the recursive, hierarchical structure of its children forms a network. It takes in the input to the network, performs its operation, and returns the output of the network. When a block is added as a child of another block, it becomes a sub-network of that block.
The life-cycle of a block has 3 stages:
- Construction
- Initialization
- Training
Construction is the process of building the network. During this stage, blocks are created with appropriate arguments and the desired network is built by adding creating a hierarchy of parent and child blocks. At this stage, it is a bare-bones network. The parameter values are not created and the shapes of the inputs are not known. The block is ready for initialization.
Initialization is the process of initializing all the parameters of the block and its children, according to the inputs expected. It involves setting an
Initializer
, deciding theDataType
, and the shapes of the input. The parameter arrays areNDArray
that are initialized according to theInitializer
set. At this stage, the block is expecting a specific type of input, and is ready to be trained.Training is when we starting feeding the training data as input to the block, get the output, and try to update parameters to learn. For more information about training, please refer the javadoc at
Trainer
. At the end of training, a block represents a fully-trained model.
-
-
Method Summary
All Methods Static Methods Instance Methods Abstract Methods Default Methods Modifier and Type Method Description void
cast(DataType dataType)
Guaranteed to throw an exception.void
clear()
Closes all the parameters of the block.ai.djl.util.PairList<java.lang.String,Shape>
describeInput()
Returns aPairList
of input names, and shapes.default NDList
forward(ParameterStore parameterStore, NDList inputs, boolean training)
Applies the operating function of the block once.NDList
forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.default NDList
forward(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A forward call using both training data and labels.default void
freezeParameters(boolean freeze)
Freezes or unfreezes all parameters inside the block for training.BlockList
getChildren()
Returns a list of all the children of the block.ParameterList
getDirectParameters()
Returns a list of all the direct parameters of the block.Shape[]
getInputShapes()
Returns the input shapes of the block.Shape[]
getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.ParameterList
getParameters()
Returns a list of all the parameters of the block, including the parameters of its children fetched recursively.void
initialize(NDManager manager, DataType dataType, Shape... inputShapes)
Initializes the parameters of the block.boolean
isInitialized()
Returns a boolean whether the block is initialized (block has inputShape and params have nonNull array).void
loadParameters(NDManager manager, java.io.DataInputStream is)
Loads the parameters from the given input stream.void
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.void
setInitializer(Initializer initializer, Parameter.Type type)
Sets anInitializer
to all the parameters that match parameter type in the block.void
setInitializer(Initializer initializer, java.lang.String paramName)
Sets anInitializer
to the specified direct parameter of the block, overriding the initializer of the parameter, if already set.void
setInitializer(Initializer initializer, java.util.function.Predicate<Parameter> predicate)
Sets anInitializer
to all the parameters that match Predicate in the block.static void
validateLayout(LayoutType[] expectedLayout, LayoutType[] actualLayout)
Validates that actual layout matches the expected layout.
-
-
-
Method Detail
-
forward
default NDList forward(ParameterStore parameterStore, NDList inputs, boolean training)
Applies the operating function of the block once. This method should be called only on blocks that are initialized.- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward pass- Returns:
- the output of the forward pass
-
forward
NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once. This method should be called only on blocks that are initialized.- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
forward
default NDList forward(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A forward call using both training data and labels.Within this forward call, it can be assumed that training is true.
- Parameters:
parameterStore
- the parameter storedata
- the input data NDListlabels
- the input labels NDListparams
- optional parameters- Returns:
- the output of the forward pass
- See Also:
forward(ParameterStore, NDList, boolean, PairList)
-
setInitializer
void setInitializer(Initializer initializer, Parameter.Type type)
Sets anInitializer
to all the parameters that match parameter type in the block.- Parameters:
initializer
- the initializer to settype
- the Parameter Type we want to setInitializer
-
setInitializer
void setInitializer(Initializer initializer, java.lang.String paramName)
Sets anInitializer
to the specified direct parameter of the block, overriding the initializer of the parameter, if already set.- Parameters:
initializer
- the initializer to be setparamName
- the name of the parameter
-
setInitializer
void setInitializer(Initializer initializer, java.util.function.Predicate<Parameter> predicate)
Sets anInitializer
to all the parameters that match Predicate in the block.- Parameters:
initializer
- the initializer to be setpredicate
- predicate function to indicate parameters you want to set
-
initialize
void initialize(NDManager manager, DataType dataType, Shape... inputShapes)
Initializes the parameters of the block. This method must be called before calling `forward`.- Parameters:
manager
- the NDManager to initialize the parametersdataType
- the datatype of the parametersinputShapes
- the shapes of the inputs to the block
-
isInitialized
boolean isInitialized()
Returns a boolean whether the block is initialized (block has inputShape and params have nonNull array).- Returns:
- whether the block is initialized
-
cast
void cast(DataType dataType)
Guaranteed to throw an exception. Not yet implemented- Parameters:
dataType
- the data type to cast to- Throws:
java.lang.UnsupportedOperationException
- always
-
clear
void clear()
Closes all the parameters of the block. All the updates made during training will be lost.
-
describeInput
ai.djl.util.PairList<java.lang.String,Shape> describeInput()
Returns aPairList
of input names, and shapes.- Returns:
- the
PairList
of input names, and shapes
-
getChildren
BlockList getChildren()
Returns a list of all the children of the block.- Returns:
- the list of child blocks
-
getDirectParameters
ParameterList getDirectParameters()
Returns a list of all the direct parameters of the block.- Returns:
- the list of
Parameter
-
getParameters
ParameterList getParameters()
Returns a list of all the parameters of the block, including the parameters of its children fetched recursively.- Returns:
- the list of all parameters of the block
-
getOutputShapes
Shape[] getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.- Parameters:
inputShapes
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
getInputShapes
Shape[] getInputShapes()
Returns the input shapes of the block. The input shapes are only available after the block is initialized, otherwise anIllegalStateException
is thrown.- Returns:
- the input shapes of the block
-
saveParameters
void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
Writes the parameters of the block to the given outputStream.- Parameters:
os
- the outputstream to save the parameters to- Throws:
java.io.IOException
- if an I/O error occurs
-
loadParameters
void loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
Loads the parameters from the given input stream.- Parameters:
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter values- Throws:
java.io.IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupported
-
freezeParameters
default void freezeParameters(boolean freeze)
Freezes or unfreezes all parameters inside the block for training.- Parameters:
freeze
- true if the parameter should be frozen- See Also:
Parameter.freeze(boolean)
-
validateLayout
static void validateLayout(LayoutType[] expectedLayout, LayoutType[] actualLayout)
Validates that actual layout matches the expected layout.- Parameters:
expectedLayout
- the expected layoutactualLayout
- the actual Layout- Throws:
java.lang.UnsupportedOperationException
- if the actual layout does not match the expected layout
-
-