Package ai.djl.nn

Interface Block

All Known Subinterfaces:
StreamingBlock, SymbolBlock
All Known Implementing Classes:
AbstractBaseBlock, AbstractBlock, AbstractSymbolBlock, BatchNorm, BertBlock, BertMaskedLanguageModelBlock, BertNextSentenceBlock, BertPretrainingBlock, ConstantEmbedding, Conv1d, Conv1dTranspose, Conv2d, Conv2dTranspose, Conv3d, Convolution, Decoder, Deconvolution, Dropout, Embedding, Encoder, EncoderDecoder, GhostBatchNorm, GRU, IdEmbedding, LambdaBlock, LayerNorm, Linear, LinearCollection, LSTM, Multiplication, ParallelBlock, PointwiseFeedForwardBlock, Prelu, RecurrentBlock, RNN, ScaledDotProductAttentionBlock, SequentialBlock, SparseMax, TrainableTextEmbedding, TrainableWordEmbedding, TransformerEncoderBlock

public interface Block
A Block is a composable function that forms a neural network.

Blocks serve a purpose similar to functions that convert an input NDList to an output NDList. They can represent single operations, parts of a neural network, and even the whole neural network. What makes blocks special is that they contain a number of parameters that are used in their function and are trained during deep learning. As these parameters are trained, the functions represented by the blocks get more and more accurate. Each block consists of the following components:

  • Forward function
  • Parameters
  • Child blocks

The core purpose of a Block is to perform an operation on the inputs, and return an output. It is defined in the forward method. The forward function could be defined explicitly in terms of parameters or implicitly and could be a combination of the functions of the child blocks.

The parameters of a Block are instances of Parameter which are required for the operation in the forward function. For example, in a Conv2d block, the parameters are weight and bias. During training, these parameters are updated to reflect the training data, and that forms the crux of learning.

When building these block functions, the easiest way is to use composition. Similar to how functions are built by calling other functions, blocks can be built by combining other blocks. We refer to the containing block as the parent and the sub-blocks as the children.

We provide helpers for creating two common structures of blocks. For blocks that call children in a chain, use SequentialBlock. If a blocks calls all of the children in parallel and then combines their results, use ParallelBlock. For blocks that do not fit these structures, you should directly extend the AbstractBlock class.

A block does not necessarily have to have children and parameters. For example, SequentialBlock, and ParallelBlock don't have any parameters, but do have child blocks. Similarly, Conv2d does not have children, but has parameters. There can be special cases where blocks have neither parameters nor children. One such example is LambdaBlock. LambdaBlock takes in a function, and applies that function to its input in the forward method.

Now that we understand the components of the block, we can explore what the block really represents. A block combined with the recursive, hierarchical structure of its children forms a network. It takes in the input to the network, performs its operation, and returns the output of the network. When a block is added as a child of another block, it becomes a sub-network of that block.

The life-cycle of a block has 3 stages:

  • Construction
  • Initialization
  • Training

Construction is the process of building the network. During this stage, blocks are created with appropriate arguments and the desired network is built by adding creating a hierarchy of parent and child blocks. At this stage, it is a bare-bones network. The parameter values are not created and the shapes of the inputs are not known. The block is ready for initialization.

Initialization is the process of initializing all the parameters of the block and its children, according to the inputs expected. It involves setting an Initializer, deciding the DataType, and the shapes of the input. The parameter arrays are NDArray that are initialized according to the Initializer set. At this stage, the block is expecting a specific type of input, and is ready to be trained.

Training is when we are starting feeding the training data as input to the block, get the output, and try to update parameters to learn. For more information about training, please refer the javadoc at Trainer. At the end of training, a block represents a fully-trained model.

It is also possible to freeze parameters and blocks to avoid them being trained. When loading models or building blocks with preTrained data, they default to being frozen. If you wish to further refine these elements, use freezeParameters(boolean) to unfreeze them.

See Also:
  • Method Details

    • forward

      default NDList forward(ParameterStore parameterStore, NDList inputs, boolean training)
      Applies the operating function of the block once. This method should be called only on blocks that are initialized.
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass (turn on dropout and layerNorm)
      Returns:
      the output of the forward pass
    • forward

      NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Applies the operating function of the block once. This method should be called only on blocks that are initialized.
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass (turn on dropout and layerNorm)
      params - optional parameters
      Returns:
      the output of the forward pass
    • forward

      default NDList forward(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<String,Object> params)
      A forward call using both training data and labels.

      Within this forward call, it can be assumed that training is true.

      Parameters:
      parameterStore - the parameter store
      data - the input data NDList
      labels - the input labels NDList
      params - optional parameters
      Returns:
      the output of the forward pass
      See Also:
    • setInitializer

      void setInitializer(Initializer initializer, Parameter.Type type)
      Sets an Initializer to all the parameters that match parameter type in the block.
      Parameters:
      initializer - the initializer to set
      type - the Parameter Type we want to setInitializer
    • setInitializer

      void setInitializer(Initializer initializer, String paramName)
      Sets an Initializer to the specified direct parameter of the block, overriding the initializer of the parameter, if already set.
      Parameters:
      initializer - the initializer to be set
      paramName - the name of the parameter
    • setInitializer

      void setInitializer(Initializer initializer, Predicate<Parameter> predicate)
      Sets an Initializer to all the parameters that match Predicate in the block.
      Parameters:
      initializer - the initializer to be set
      predicate - predicate function to indicate parameters you want to set
    • initialize

      void initialize(NDManager manager, DataType dataType, Shape... inputShapes)
      Initializes the parameters of the block, set require gradient if required and infer the block inputShape. This method must be called before calling `forward`.
      Parameters:
      manager - the NDManager to initialize the parameters
      dataType - the datatype of the parameters
      inputShapes - the shapes of the inputs to the block
    • isInitialized

      boolean isInitialized()
      Returns a boolean whether the block is initialized (block has inputShape and params have nonNull array).
      Returns:
      whether the block is initialized
    • cast

      void cast(DataType dataType)
      Guaranteed to throw an exception. Not yet implemented
      Parameters:
      dataType - the data type to cast to
      Throws:
      UnsupportedOperationException - always
    • clear

      void clear()
      Closes all the parameters of the block. All the updates made during training will be lost.
    • describeInput

      ai.djl.util.PairList<String,Shape> describeInput()
      Returns a PairList of input names, and shapes.
      Returns:
      the PairList of input names, and shapes
    • getChildren

      BlockList getChildren()
      Returns a list of all the children of the block.
      Returns:
      the list of child blocks
    • getDirectParameters

      ParameterList getDirectParameters()
      Returns a list of all the direct parameters of the block.
      Returns:
      the list of Parameter
    • getParameters

      ParameterList getParameters()
      Returns a list of all the parameters of the block, including the parameters of its children fetched recursively.
      Returns:
      the list of all parameters of the block
    • getOutputShapes

      Shape[] getOutputShapes(Shape[] inputShapes)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputShapes - the shapes of the inputs
      Returns:
      the expected output shapes of the block
    • getOutputShapes

      default Shape[] getOutputShapes(Shape[] inputShapes, DataType[] inputDataTypes)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputShapes - the shapes of the inputs
      inputDataTypes - the datatypes of the inputs
      Returns:
      the expected output shapes of the block
    • getInputShapes

      Shape[] getInputShapes()
      Returns the input shapes of the block. The input shapes are only available after the block is initialized, otherwise an IllegalStateException is thrown.
      Returns:
      the input shapes of the block
    • getOutputDataTypes

      DataType[] getOutputDataTypes()
      Returns the input dataTypes of the block.
      Returns:
      the input dataTypes of the block
    • saveParameters

      void saveParameters(DataOutputStream os) throws IOException
      Writes the parameters of the block to the given outputStream.
      Parameters:
      os - the outputstream to save the parameters to
      Throws:
      IOException - if an I/O error occurs
    • loadParameters

      void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException
      Loads the parameters from the given input stream.
      Parameters:
      manager - an NDManager to create the parameter arrays
      is - the inputstream that stream the parameter values
      Throws:
      IOException - if an I/O error occurs
      MalformedModelException - if the model file is corrupted or unsupported
    • freezeParameters

      default void freezeParameters(boolean freeze)
      Freezes or unfreezes all parameters inside the block for training.
      Parameters:
      freeze - true if the parameter should be frozen
      See Also:
    • freezeParameters

      default void freezeParameters(boolean freeze, Predicate<Parameter> pred)
      Freezes or unfreezes all parameters inside the block that pass the predicate.
      Parameters:
      freeze - true to mark as frozen rather than unfrozen
      pred - true tests if the parameter should be updated
      See Also:
    • validateLayout

      static void validateLayout(LayoutType[] expectedLayout, LayoutType[] actualLayout)
      Validates that actual layout matches the expected layout.
      Parameters:
      expectedLayout - the expected layout
      actualLayout - the actual Layout
      Throws:
      UnsupportedOperationException - if the actual layout does not match the expected layout