Class BertBlock

All Implemented Interfaces:
Block

public final class BertBlock extends AbstractBlock
Implements the core bert model (without next sentence and masked language task) of bert.

This closely follows the original Devlin et. al. paper and its reference implementation.

  • Method Details

    • getTokenEmbedding

      public IdEmbedding getTokenEmbedding()
      Returns the token embedding used by this Bert model.
      Returns:
      the token embedding used by this Bert model
    • getEmbeddingSize

      public int getEmbeddingSize()
      Returns the embedding size used for tokens.
      Returns:
      the embedding size used for tokens
    • getTokenDictionarySize

      public int getTokenDictionarySize()
      Returns the size of the token dictionary.
      Returns:
      the size of the token dictionary
    • getTypeDictionarySize

      public int getTypeDictionarySize()
      Returns the size of the type dictionary.
      Returns:
      the size of the type dictionary
    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputShapes)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputShapes - the shapes of the inputs
      Returns:
      the expected output shapes of the block
    • initializeChildBlocks

      public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
      Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
      Overrides:
      initializeChildBlocks in class AbstractBaseBlock
      Parameters:
      manager - the manager to use for initialization
      dataType - the requested data type
      inputShapes - the expected input shapes for this block
    • createAttentionMaskFromInputMask

      public static NDArray createAttentionMaskFromInputMask(NDArray ids, NDArray mask)
      Creates a 3D attention mask from a 2D tensor mask.
      Parameters:
      ids - 2D Tensor of shape (B, F)
      mask - 2D Tensor of shape (B, T)
      Returns:
      float tensor of shape (B, F, T)
    • forwardInternal

      protected NDList forwardInternal(ParameterStore ps, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      ps - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • builder

      public static BertBlock.Builder builder()
      Returns a new BertBlock builder.
      Returns:
      a new BertBlock builder.