public interface Block
Block
is a composable function that forms a neural network.
Blocks serve a purpose similar to functions that convert an input NDList to an output NDList. They can represent single operations, parts of a neural network, and even the whole neural network. What makes blocks special is that they contain a number of parameters that are used in their function and are trained during deep learning. As these parameters are trained, the functions represented by the blocks get more and more accurate. Each block consists of the following components:
The core purpose of a Block
is to perform an operation on the inputs, and return an
output. It is defined in the forward
method. The forward
function could be defined explicitly in terms of parameters or implicitly and could be a
combination of the functions of the child blocks.
The parameters of a Block
are instances of Parameter
which are required for
the operation in the forward function. For example, in a Conv2D
block, the parameters are weight
and bias
. During training, these parameters are
updated to reflect the training data, and that forms the crux of learning.
When building these block functions, the easiest way is to use composition. Similar to how functions are built by calling other functions, blocks can be built by combining other blocks. We refer to the containing block as the parent and the sub-blocks as the children.
We provide helpers for creating two common structures of blocks. For blocks that call children
in a chain, use SequentialBlock
. If a blocks calls all of the children in parallel and
then combines their results, use ParallelBlock
. For blocks that do not fit these
strcutures, you should directly extend the AbstractBlock
class.
A block does not necessarily have to have children and parameters. For example, SequentialBlock
, and ParallelBlock
don't have any parameters, but do have child blocks.
Similarly, Conv2D
does not have children, but has parameters. We
recommend extending ParameterBlock
to create blocks that don't have children. There can
be special cases where blocks have neither parameters nor children. One such example is LambdaBlock
. LambdaBlock
takes in a function, and applies that function to its input in
the forward
method.
Now that we understand the components of the block, we can explore what the block really represents. A block combined with the recursive, hierarchical structure of its children forms a network. It takes in the input to the network, performs its operation, and returns the output of the network. When a block is added as a child of another block, it becomes a sub-network of that block.
The life-cycle of a block has 3 stages:
Construction is the process of building the network. During this stage, blocks are created with appropriate arguments and the desired network is built by adding creating a hierarchy of parent and child blocks. At this stage, it is a bare-bones network. The parameter values are not created and the shapes of the inputs are not known. The block is ready for initialization.
Initialization is the process of initializing all the parameters of the block and its
children, according to the inputs expected. It involves setting an Initializer
, deciding
the DataType
, and the shapes of the input. The parameter arrays are NDArray
that are initialized according to the Initializer
set. At this
stage, the block is expecting a specific type of input, and is ready to be trained.
Training is when we starting feeding the training data as input to the block, get the output,
and try to update parameters to learn. For more information about training, please refer the
javadoc at Trainer
. At the end of training, a block represents a
fully-trained model.
See this tutorial on creating your first network.
Modifier and Type | Method and Description |
---|---|
void |
cast(DataType dataType)
Guaranteed to throw an exception.
|
void |
clear()
Closes all the parameters of the block.
|
ai.djl.util.PairList<java.lang.String,Shape> |
describeInput()
Returns a
PairList of input names, and shapes. |
default NDList |
forward(ParameterStore parameterStore,
NDList inputs)
Applies the operating function of the block once.
|
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.
|
BlockList |
getChildren()
Returns a list of all the children of the block.
|
java.util.List<Parameter> |
getDirectParameters()
Returns a list of all the direct parameters of the block.
|
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.
|
ParameterList |
getParameters()
Returns a list of all the parameters of the block, including the parameters of its children
fetched recursively.
|
Shape |
getParameterShape(java.lang.String name,
Shape[] inputShapes)
Returns the shape of the specified direct parameter of this block given the shapes of the
input to the block.
|
Shape[] |
initialize(NDManager manager,
DataType dataType,
Shape... inputShapes)
Initializes the parameters of the block.
|
boolean |
isInitialized()
Returns a boolean whether the block is initialized.
|
void |
loadParameters(NDManager manager,
java.io.DataInputStream is)
Loads the parameters from the given input stream.
|
void |
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.
|
void |
setInitializer(Initializer initializer)
Sets an
Initializer to the block. |
void |
setInitializer(Initializer initializer,
java.lang.String paramName)
Sets an
Initializer to the specified direct parameter of the block, overriding the
initializer of the parameter, if already set. |
static void |
validateLayout(LayoutType[] expectedLayout,
LayoutType[] actualLayout)
Validates that actual layout matches the expected layout.
|
default NDList forward(ParameterStore parameterStore, NDList inputs)
parameterStore
- the parameter storeinputs
- the input NDListNDList forward(ParameterStore parameterStore, NDList inputs, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
parameterStore
- the parameter storeinputs
- the input NDListparams
- optional parametersvoid setInitializer(Initializer initializer)
Initializer
to the block.initializer
- the initializer to setvoid setInitializer(Initializer initializer, java.lang.String paramName)
Initializer
to the specified direct parameter of the block, overriding the
initializer of the parameter, if already set.initializer
- the initializer to be setparamName
- the name of the parameterShape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes)
manager
- the NDManager to initialize the parametersdataType
- the datatype of the parametersinputShapes
- the shapes of the inputs to the blockboolean isInitialized()
void cast(DataType dataType)
dataType
- the data type to cast tojava.lang.UnsupportedOperationException
- alwaysvoid clear()
ai.djl.util.PairList<java.lang.String,Shape> describeInput()
PairList
of input names, and shapes.PairList
of input names, and shapesBlockList getChildren()
java.util.List<Parameter> getDirectParameters()
Parameter
ParameterList getParameters()
Shape getParameterShape(java.lang.String name, Shape[] inputShapes)
name
- the name of the parameterinputShapes
- the shapes of the input to the blockjava.lang.IllegalArgumentException
- if the parameter name specified is invalidShape[] getOutputShapes(NDManager manager, Shape[] inputShapes)
manager
- an NDManagerinputShapes
- the shapes of the inputsvoid saveParameters(java.io.DataOutputStream os) throws java.io.IOException
os
- the outputstream to save the parameters tojava.io.IOException
- if an I/O error occursvoid loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter valuesjava.io.IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupportedstatic void validateLayout(LayoutType[] expectedLayout, LayoutType[] actualLayout)
expectedLayout
- the expected layoutactualLayout
- the actual Layoutjava.lang.UnsupportedOperationException
- if the actual layout does not match the expected layout