Class ScaledDotProductAttentionBlock.Builder

java.lang.Object
ai.djl.nn.transformer.ScaledDotProductAttentionBlock.Builder
Enclosing class:
ScaledDotProductAttentionBlock

public static final class ScaledDotProductAttentionBlock.Builder extends Object
  • Method Details

    • setEmbeddingSize

      public ScaledDotProductAttentionBlock.Builder setEmbeddingSize(int embeddingSize)
      Sets the embedding Size to be used for the internal token representation.
      Parameters:
      embeddingSize - the embedding Size to be used for the internal token representation.
      Returns:
      this builder
    • setHeadCount

      public ScaledDotProductAttentionBlock.Builder setHeadCount(int headCount)
      Sets the number of attention Heads, must divide the embedding size without rest. I.e. if embeddingSize = 10, a headCount of 3 would not be valid, a headCount of 1, 2 or 5 would be.
      Parameters:
      headCount - the number of attention Heads
      Returns:
      this builder
    • optAttentionProbsDropoutProb

      public ScaledDotProductAttentionBlock.Builder optAttentionProbsDropoutProb(float attentionProbsDropoutProb)
      Sets the probability of applying dropout to the attention probability distribution. This dropout can randomly remove a complete token from the result at a position.
      Parameters:
      attentionProbsDropoutProb - the probability of applying dropout to the attention probability distribution
      Returns:
      this builder
    • build

      Creates a new ScaledDotProductAttentionBlock with the current configuration.
      Returns:
      a new ScaledDotProductAttentionBlock with the current configuration.