Package ai.djl.nn

Interface Block

  • All Known Subinterfaces:
    SymbolBlock
    All Known Implementing Classes:
    AbstractBaseBlock, AbstractBlock, AbstractSymbolBlock, BatchNorm, BertBlock, BertMaskedLanguageModelBlock, BertNextSentenceBlock, BertPretrainingBlock, ConstantEmbedding, Conv1d, Conv1dTranspose, Conv2d, Conv2dTranspose, Conv3d, Convolution, Decoder, Deconvolution, Dropout, Embedding, Encoder, EncoderDecoder, GhostBatchNorm, GRU, IdEmbedding, LambdaBlock, LayerNorm, Linear, LinearCollection, LSTM, ParallelBlock, PointwiseFeedForwardBlock, Prelu, RecurrentBlock, RNN, ScaledDotProductAttentionBlock, SequentialBlock, TrainableTextEmbedding, TrainableWordEmbedding, TransformerEncoderBlock

    public interface Block
    A Block is a composable function that forms a neural network.

    Blocks serve a purpose similar to functions that convert an input NDList to an output NDList. They can represent single operations, parts of a neural network, and even the whole neural network. What makes blocks special is that they contain a number of parameters that are used in their function and are trained during deep learning. As these parameters are trained, the functions represented by the blocks get more and more accurate. Each block consists of the following components:

    • Forward function
    • Parameters
    • Child blocks

    The core purpose of a Block is to perform an operation on the inputs, and return an output. It is defined in the forward method. The forward function could be defined explicitly in terms of parameters or implicitly and could be a combination of the functions of the child blocks.

    The parameters of a Block are instances of Parameter which are required for the operation in the forward function. For example, in a Conv2d block, the parameters are weight and bias. During training, these parameters are updated to reflect the training data, and that forms the crux of learning.

    When building these block functions, the easiest way is to use composition. Similar to how functions are built by calling other functions, blocks can be built by combining other blocks. We refer to the containing block as the parent and the sub-blocks as the children.

    We provide helpers for creating two common structures of blocks. For blocks that call children in a chain, use SequentialBlock. If a blocks calls all of the children in parallel and then combines their results, use ParallelBlock. For blocks that do not fit these strcutures, you should directly extend the AbstractBlock class.

    A block does not necessarily have to have children and parameters. For example, SequentialBlock, and ParallelBlock don't have any parameters, but do have child blocks. Similarly, Conv2d does not have children, but has parameters. There can be special cases where blocks have neither parameters nor children. One such example is LambdaBlock. LambdaBlock takes in a function, and applies that function to its input in the forward method.

    Now that we understand the components of the block, we can explore what the block really represents. A block combined with the recursive, hierarchical structure of its children forms a network. It takes in the input to the network, performs its operation, and returns the output of the network. When a block is added as a child of another block, it becomes a sub-network of that block.

    The life-cycle of a block has 3 stages:

    • Construction
    • Initialization
    • Training

    Construction is the process of building the network. During this stage, blocks are created with appropriate arguments and the desired network is built by adding creating a hierarchy of parent and child blocks. At this stage, it is a bare-bones network. The parameter values are not created and the shapes of the inputs are not known. The block is ready for initialization.

    Initialization is the process of initializing all the parameters of the block and its children, according to the inputs expected. It involves setting an Initializer, deciding the DataType, and the shapes of the input. The parameter arrays are NDArray that are initialized according to the Initializer set. At this stage, the block is expecting a specific type of input, and is ready to be trained.

    Training is when we starting feeding the training data as input to the block, get the output, and try to update parameters to learn. For more information about training, please refer the javadoc at Trainer. At the end of training, a block represents a fully-trained model.

    See Also:
    this tutorial on creating your first network, The D2L chapter on blocks and blocks with direct parameters
    • Method Detail

      • forward

        default NDList forward​(ParameterStore parameterStore,
                               NDList inputs,
                               boolean training)
        Applies the operating function of the block once. This method should be called only on blocks that are initialized.
        Parameters:
        parameterStore - the parameter store
        inputs - the input NDList
        training - true for a training forward pass
        Returns:
        the output of the forward pass
      • forward

        NDList forward​(ParameterStore parameterStore,
                       NDList inputs,
                       boolean training,
                       ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
        Applies the operating function of the block once. This method should be called only on blocks that are initialized.
        Parameters:
        parameterStore - the parameter store
        inputs - the input NDList
        training - true for a training forward pass
        params - optional parameters
        Returns:
        the output of the forward pass
      • forward

        default NDList forward​(ParameterStore parameterStore,
                               NDList data,
                               NDList labels,
                               ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
        A forward call using both training data and labels.

        Within this forward call, it can be assumed that training is true.

        Parameters:
        parameterStore - the parameter store
        data - the input data NDList
        labels - the input labels NDList
        params - optional parameters
        Returns:
        the output of the forward pass
        See Also:
        forward(ParameterStore, NDList, boolean, PairList)
      • setInitializer

        void setInitializer​(Initializer initializer,
                            Parameter.Type type)
        Sets an Initializer to all the parameters that match parameter type in the block.
        Parameters:
        initializer - the initializer to set
        type - the Parameter Type we want to setInitializer
      • setInitializer

        void setInitializer​(Initializer initializer,
                            java.lang.String paramName)
        Sets an Initializer to the specified direct parameter of the block, overriding the initializer of the parameter, if already set.
        Parameters:
        initializer - the initializer to be set
        paramName - the name of the parameter
      • setInitializer

        void setInitializer​(Initializer initializer,
                            java.util.function.Predicate<Parameter> predicate)
        Sets an Initializer to all the parameters that match Predicate in the block.
        Parameters:
        initializer - the initializer to be set
        predicate - predicate function to indicate parameters you want to set
      • initialize

        void initialize​(NDManager manager,
                        DataType dataType,
                        Shape... inputShapes)
        Initializes the parameters of the block. This method must be called before calling `forward`.
        Parameters:
        manager - the NDManager to initialize the parameters
        dataType - the datatype of the parameters
        inputShapes - the shapes of the inputs to the block
      • isInitialized

        boolean isInitialized()
        Returns a boolean whether the block is initialized.
        Returns:
        whether the block is initialized
      • cast

        void cast​(DataType dataType)
        Guaranteed to throw an exception. Not yet implemented
        Parameters:
        dataType - the data type to cast to
        Throws:
        java.lang.UnsupportedOperationException - always
      • clear

        void clear()
        Closes all the parameters of the block. All the updates made during training will be lost.
      • describeInput

        ai.djl.util.PairList<java.lang.String,​Shape> describeInput()
        Returns a PairList of input names, and shapes.
        Returns:
        the PairList of input names, and shapes
      • getChildren

        BlockList getChildren()
        Returns a list of all the children of the block.
        Returns:
        the list of child blocks
      • getDirectParameters

        ParameterList getDirectParameters()
        Returns a list of all the direct parameters of the block.
        Returns:
        the list of Parameter
      • getParameters

        ParameterList getParameters()
        Returns a list of all the parameters of the block, including the parameters of its children fetched recursively.
        Returns:
        the list of all parameters of the block
      • getOutputShapes

        Shape[] getOutputShapes​(Shape[] inputShapes)
        Returns the expected output shapes of the block for the specified input shapes.
        Parameters:
        inputShapes - the shapes of the inputs
        Returns:
        the expected output shapes of the block
      • getInputShapes

        Shape[] getInputShapes()
        Returns the input shapes of the block. The input shapes are only available after the block is initialized, otherwise an IllegalStateException is thrown.
        Returns:
        the input shapes of the block
      • saveParameters

        void saveParameters​(java.io.DataOutputStream os)
                     throws java.io.IOException
        Writes the parameters of the block to the given outputStream.
        Parameters:
        os - the outputstream to save the parameters to
        Throws:
        java.io.IOException - if an I/O error occurs
      • loadParameters

        void loadParameters​(NDManager manager,
                            java.io.DataInputStream is)
                     throws java.io.IOException,
                            MalformedModelException
        Loads the parameters from the given input stream.
        Parameters:
        manager - an NDManager to create the parameter arrays
        is - the inputstream that stream the parameter values
        Throws:
        java.io.IOException - if an I/O error occurs
        MalformedModelException - if the model file is corrupted or unsupported
      • freezeParameters

        default void freezeParameters​(boolean freeze)
        Freezes or unfreezes all parameters inside the block for training.
        Parameters:
        freeze - true if the parameter should be frozen
        See Also:
        Parameter.freeze(boolean)
      • validateLayout

        static void validateLayout​(LayoutType[] expectedLayout,
                                   LayoutType[] actualLayout)
        Validates that actual layout matches the expected layout.
        Parameters:
        expectedLayout - the expected layout
        actualLayout - the actual Layout
        Throws:
        java.lang.UnsupportedOperationException - if the actual layout does not match the expected layout