Class BertBlock

  • All Implemented Interfaces:
    Block

    public final class BertBlock
    extends AbstractBlock
    Implements the core bert model (without next sentence and masked language task) of bert.

    This closely follows the original Devlin et. al. paper and its reference implementation.

    • Method Detail

      • getTokenEmbedding

        public IdEmbedding getTokenEmbedding()
        Returns the token embedding used by this Bert model.
        Returns:
        the token embedding used by this Bert model
      • getEmbeddingSize

        public int getEmbeddingSize()
        Returns the embedding size used for tokens.
        Returns:
        the embedding size used for tokens
      • getTokenDictionarySize

        public int getTokenDictionarySize()
        Returns the size of the token dictionary.
        Returns:
        the size of the token dictionary
      • getTypeDictionarySize

        public int getTypeDictionarySize()
        Returns the size of the type dictionary.
        Returns:
        the size of the type dictionary
      • getOutputShapes

        public Shape[] getOutputShapes​(Shape[] inputShapes)
        Returns the expected output shapes of the block for the specified input shapes.
        Parameters:
        inputShapes - the shapes of the inputs
        Returns:
        the expected output shapes of the block
      • initializeChildBlocks

        public void initializeChildBlocks​(NDManager manager,
                                          DataType dataType,
                                          Shape... inputShapes)
        Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
        Overrides:
        initializeChildBlocks in class AbstractBaseBlock
        Parameters:
        manager - the manager to use for initialization
        dataType - the requested data type
        inputShapes - the expected input shapes for this block
      • createAttentionMaskFromInputMask

        public static NDArray createAttentionMaskFromInputMask​(NDArray ids,
                                                               NDArray mask)
        Creates a 3D attention mask from a 2D tensor mask.
        Parameters:
        ids - 2D Tensor of shape (B, F)
        mask - 2D Tensor of shape (B, T)
        Returns:
        float tensor of shape (B, F, T)
      • builder

        public static BertBlock.Builder builder()
        Returns a new BertBlock builder.
        Returns:
        a new BertBlock builder.