Package ai.djl.nn.transformer
Class ScaledDotProductAttentionBlock.Builder
java.lang.Object
ai.djl.nn.transformer.ScaledDotProductAttentionBlock.Builder
- Enclosing class:
- ScaledDotProductAttentionBlock
A builder for
ScaledDotProductAttentionBlock
s.-
Method Summary
Modifier and TypeMethodDescriptionbuild()
Creates a newScaledDotProductAttentionBlock
with the current configuration.optAttentionProbsDropoutProb
(float attentionProbsDropoutProb) Sets the probability of applying dropout to the attention probability distribution.setEmbeddingSize
(int embeddingSize) Sets the embedding Size to be used for the internal token representation.setHeadCount
(int headCount) Sets the number of attention Heads, must divide the embedding size without rest.
-
Method Details
-
setEmbeddingSize
Sets the embedding Size to be used for the internal token representation.- Parameters:
embeddingSize
- the embedding Size to be used for the internal token representation.- Returns:
- this builder
-
setHeadCount
Sets the number of attention Heads, must divide the embedding size without rest. I.e. if embeddingSize = 10, a headCount of 3 would not be valid, a headCount of 1, 2 or 5 would be.- Parameters:
headCount
- the number of attention Heads- Returns:
- this builder
-
optAttentionProbsDropoutProb
public ScaledDotProductAttentionBlock.Builder optAttentionProbsDropoutProb(float attentionProbsDropoutProb) Sets the probability of applying dropout to the attention probability distribution. This dropout can randomly remove a complete token from the result at a position.- Parameters:
attentionProbsDropoutProb
- the probability of applying dropout to the attention probability distribution- Returns:
- this builder
-
build
Creates a newScaledDotProductAttentionBlock
with the current configuration.- Returns:
- a new
ScaledDotProductAttentionBlock
with the current configuration.
-