Class BertMaskedLanguageModelBlock

java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.transformer.BertMaskedLanguageModelBlock
All Implemented Interfaces:
Block

public class BertMaskedLanguageModelBlock extends AbstractBlock
Block for the bert masked language task.
  • Constructor Details

    • BertMaskedLanguageModelBlock

      public BertMaskedLanguageModelBlock(BertBlock bertBlock, Function<NDArray,NDArray> hiddenActivation)
      Creates a new block that applies the masked language task.
      Parameters:
      bertBlock - the bert block to create the task for
      hiddenActivation - the activation to use for the hidden layer
  • Method Details

    • gatherFromIndices

      public static NDArray gatherFromIndices(NDArray sequences, NDArray indices)
      Given a 3D array of shape (B, S, E) and a 2D array of shape (B, I) returns the flattened lookup result of shape (B * I * E).
      Parameters:
      sequences - Sequences of embeddings
      indices - Indices into the sequences. The indices are relative within each sequence, i.e. [[0, 1],[0, 1]] would return the first two elements of two sequences.
      Returns:
      The flattened result of gathering elements from the sequences
    • initializeChildBlocks

      public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
      Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
      Overrides:
      initializeChildBlocks in class AbstractBaseBlock
      Parameters:
      manager - the manager to use for initialization
      dataType - the requested data type
      inputShapes - the expected input shapes for this block
    • forwardInternal

      protected NDList forwardInternal(ParameterStore ps, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      ps - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputShapes)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputShapes - the shapes of the inputs
      Returns:
      the expected output shapes of the block