Package ai.djl.nn.transformer
Class ScaledDotProductAttentionBlock
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.transformer.ScaledDotProductAttentionBlock
- All Implemented Interfaces:
Block
A Block implementing scaled product attention according to Vaswani et. al..
Abbreviations used:
- E = embedding size
- B = batch size
- N = number of attention heads
- F = "from" sequence length (key/value sequence), the input sequence
- T = "to" sequence length (query sequence), e.g. the length of the output sequence
- S = a sequence length, either F or T
- H = Attention head size (= E / N)
In many use cases F=T. For self attention, the input is equal to the output.
This block can process input in four forms:
- Input size one: [Values] = [(B, F, E)], only input is used as key, query and value (unmasked self attention), e.g. BERT
- Input size two: [Values, Mask] = [(B, F, E), (B, F, F)], first input is used as key, query and value, masked self attention
- Input size three: [Keys, Queries, Values] = [(B, F, E), (B, T, E), (B, F, E)], inputs are interpreted as keys, queries and values, unmasked attention
- Input size four: [Keys, Queries, Values, Mask] = [(B, F, E), (B, T, E), (B, F, E), (B, T, F)], inputs are interpreted as keys, queries, values and an attention mask, full masked attention.
Attention masks must contain a 1 for positions to keep and a 0 for positions to mask.
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic final class
A builder forScaledDotProductAttentionBlock
s. -
Field Summary
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Method Summary
Modifier and TypeMethodDescriptionbuilder()
Creates a new Builder to build an Attention Block with.protected NDList
forwardInternal
(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.Pointwise Linear projection of the keys.Shape[]
getOutputShapes
(Shape[] inputShapes) Returns the expected output shapes of the block for the specified input shapes.Pointwise Linear projection of the queries.Pointwise Linear projection of the results.Pointwise Linear projection of the values.void
initializeChildBlocks
(NDManager manager, DataType dataType, Shape... inputShapes) Initializes the Child blocks of this block.Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getCustomMetadata, getOutputShapes
-
Method Details
-
getKeyProjection
Pointwise Linear projection of the keys.- Returns:
- Pointwise Linear projection of the keys.
-
getQueryProjection
Pointwise Linear projection of the queries.- Returns:
- Pointwise Linear projection of the queries.
-
getValueProjection
Pointwise Linear projection of the values.- Returns:
- Pointwise Linear projection of the values.
-
getResultProjection
Pointwise Linear projection of the results.- Returns:
- Pointwise Linear projection of the results.
-
getOutputShapes
Returns the expected output shapes of the block for the specified input shapes.- Parameters:
inputShapes
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
initializeChildBlocks
Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.- Overrides:
initializeChildBlocks
in classAbstractBaseBlock
- Parameters:
manager
- the manager to use for initializationdataType
- the requested data typeinputShapes
- the expected input shapes for this block
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
builder
Creates a new Builder to build an Attention Block with.- Returns:
- a new Builder to build an Attention Block with.
-