Class TransformerEncoderBlock

java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.transformer.TransformerEncoderBlock
All Implemented Interfaces:
Block

public class TransformerEncoderBlock extends AbstractBlock
Self-Attention based transformer encoder block.
  • Constructor Details

    • TransformerEncoderBlock

      public TransformerEncoderBlock(int embeddingSize, int headCount, int hiddenSize, float dropoutProbability, Function<NDList,NDList> activationFunction)
      Creates a transformer encoder block.
      Parameters:
      embeddingSize - the embedding size for tokens
      headCount - number of attention blocks
      hiddenSize - the hidden size for fully connected networks
      dropoutProbability - dropout probability
      activationFunction - activation function
  • Method Details

    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputShapes)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputShapes - the shapes of the inputs
      Returns:
      the expected output shapes of the block
    • initializeChildBlocks

      public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
      Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
      Overrides:
      initializeChildBlocks in class AbstractBaseBlock
      Parameters:
      manager - the manager to use for initialization
      dataType - the requested data type
      inputShapes - the expected input shapes for this block
    • forwardInternal

      protected NDList forwardInternal(ParameterStore ps, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      ps - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass