Class ScaledDotProductAttentionBlock

  • All Implemented Interfaces:
    Block

    public final class ScaledDotProductAttentionBlock
    extends AbstractBlock
    A Block implementing scaled product attention according to Vaswani et. al..

    Abbreviations used:

    • E = embedding size
    • B = batch size
    • N = number of attention heads
    • F = "from" sequence length (key/value sequence), the input sequence
    • T = "to" sequence length (query sequence), e.g. the length of the output sequence
    • S = a sequence length, either F or T
    • H = Attention head size (= E / N)

    In many use cases F=T. For self attention, the input is equal to the output.

    This block can process input in four forms:

    • Input size one: [Values] = [(B, F, E)], only input is used as key, query and value (unmasked self attention), e.g. BERT
    • Input size two: [Values, Mask] = [(B, F, E), (B, F, F)], first input is used as key, query and value, masked self attention
    • Input size three: [Keys, Queries, Values] = [(B, F, E), (B, T, E), (B, F, E)], inputs are interpreted as keys, queries and values, unmasked attention
    • Input size four: [Keys, Queries, Values, Mask] = [(B, F, E), (B, T, E), (B, F, E), (B, T, F)], inputs are interpreted as keys, queries, values and an attention mask, full masked attention.

    Attention masks must contain a 1 for positions to keep and a 0 for positions to mask.

    • Method Detail

      • getKeyProjection

        public Linear getKeyProjection()
        Pointwise Linear projection of the keys.
        Returns:
        Pointwise Linear projection of the keys.
      • getQueryProjection

        public Linear getQueryProjection()
        Pointwise Linear projection of the queries.
        Returns:
        Pointwise Linear projection of the queries.
      • getValueProjection

        public Linear getValueProjection()
        Pointwise Linear projection of the values.
        Returns:
        Pointwise Linear projection of the values.
      • getResultProjection

        public Linear getResultProjection()
        Pointwise Linear projection of the results.
        Returns:
        Pointwise Linear projection of the results.
      • getOutputShapes

        public Shape[] getOutputShapes​(Shape[] inputShapes)
        Returns the expected output shapes of the block for the specified input shapes.
        Parameters:
        inputShapes - the shapes of the inputs
        Returns:
        the expected output shapes of the block
      • initializeChildBlocks

        public void initializeChildBlocks​(NDManager manager,
                                          DataType dataType,
                                          Shape... inputShapes)
        Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
        Overrides:
        initializeChildBlocks in class AbstractBaseBlock
        Parameters:
        manager - the manager to use for initialization
        dataType - the requested data type
        inputShapes - the expected input shapes for this block