Class ScaledDotProductAttentionBlock

java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.transformer.ScaledDotProductAttentionBlock
All Implemented Interfaces:
Block

public final class ScaledDotProductAttentionBlock extends AbstractBlock
A Block implementing scaled product attention according to Vaswani et. al..

Abbreviations used:

  • E = embedding size
  • B = batch size
  • N = number of attention heads
  • F = "from" sequence length (key/value sequence), the input sequence
  • T = "to" sequence length (query sequence), e.g. the length of the output sequence
  • S = a sequence length, either F or T
  • H = Attention head size (= E / N)

In many use cases F=T. For self attention, the input is equal to the output.

This block can process input in four forms:

  • Input size one: [Values] = [(B, F, E)], only input is used as key, query and value (unmasked self attention), e.g. BERT
  • Input size two: [Values, Mask] = [(B, F, E), (B, F, F)], first input is used as key, query and value, masked self attention
  • Input size three: [Keys, Queries, Values] = [(B, F, E), (B, T, E), (B, F, E)], inputs are interpreted as keys, queries and values, unmasked attention
  • Input size four: [Keys, Queries, Values, Mask] = [(B, F, E), (B, T, E), (B, F, E), (B, T, F)], inputs are interpreted as keys, queries, values and an attention mask, full masked attention.

Attention masks must contain a 1 for positions to keep and a 0 for positions to mask.

  • Method Details

    • getKeyProjection

      public Linear getKeyProjection()
      Pointwise Linear projection of the keys.
      Returns:
      Pointwise Linear projection of the keys.
    • getQueryProjection

      public Linear getQueryProjection()
      Pointwise Linear projection of the queries.
      Returns:
      Pointwise Linear projection of the queries.
    • getValueProjection

      public Linear getValueProjection()
      Pointwise Linear projection of the values.
      Returns:
      Pointwise Linear projection of the values.
    • getResultProjection

      public Linear getResultProjection()
      Pointwise Linear projection of the results.
      Returns:
      Pointwise Linear projection of the results.
    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputShapes)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputShapes - the shapes of the inputs
      Returns:
      the expected output shapes of the block
    • initializeChildBlocks

      public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
      Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
      Overrides:
      initializeChildBlocks in class AbstractBaseBlock
      Parameters:
      manager - the manager to use for initialization
      dataType - the requested data type
      inputShapes - the expected input shapes for this block
    • forwardInternal

      protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • builder

      public static ScaledDotProductAttentionBlock.Builder builder()
      Creates a new Builder to build an Attention Block with.
      Returns:
      a new Builder to build an Attention Block with.