Package ai.djl.nn.transformer
Class ScaledDotProductAttentionBlock
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.transformer.ScaledDotProductAttentionBlock
-
- All Implemented Interfaces:
Block
public final class ScaledDotProductAttentionBlock extends AbstractBlock
A Block implementing scaled product attention according to Vaswani et. al..Abbreviations used:
- E = embedding size
- B = batch size
- N = number of attention heads
- F = "from" sequence length (key/value sequence), the input sequence
- T = "to" sequence length (query sequence), e.g. the length of the output sequence
- S = a sequence length, either F or T
- H = Attention head size (= E / N)
In many use cases F=T. For self attention, the input is equal to the output.
This block can process input in four forms:
- Input size one: [Values] = [(B, F, E)], only input is used as key, query and value (unmasked self attention), e.g. BERT
- Input size two: [Values, Mask] = [(B, F, E), (B, F, F)], first input is used as key, query and value, masked self attention
- Input size three: [Keys, Queries, Values] = [(B, F, E), (B, T, E), (B, F, E)], inputs are interpreted as keys, queries and values, unmasked attention
- Input size four: [Keys, Queries, Values, Mask] = [(B, F, E), (B, T, E), (B, F, E), (B, T, F)], inputs are interpreted as keys, queries, values and an attention mask, full masked attention.
Attention masks must contain a 1 for positions to keep and a 0 for positions to mask.
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
ScaledDotProductAttentionBlock.Builder
A builder forScaledDotProductAttentionBlock
s.
-
Field Summary
-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, version
-
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static ScaledDotProductAttentionBlock.Builder
builder()
Creates a new Builder to build an Attention Block with.protected NDList
forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.Linear
getKeyProjection()
Pointwise Linear projection of the keys.Shape[]
getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.Linear
getQueryProjection()
Pointwise Linear projection of the queries.Linear
getResultProjection()
Pointwise Linear projection of the results.Linear
getValueProjection()
Pointwise Linear projection of the values.void
initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
Initializes the Child blocks of this block.-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters
-
-
-
-
Method Detail
-
getKeyProjection
public Linear getKeyProjection()
Pointwise Linear projection of the keys.- Returns:
- Pointwise Linear projection of the keys.
-
getQueryProjection
public Linear getQueryProjection()
Pointwise Linear projection of the queries.- Returns:
- Pointwise Linear projection of the queries.
-
getValueProjection
public Linear getValueProjection()
Pointwise Linear projection of the values.- Returns:
- Pointwise Linear projection of the values.
-
getResultProjection
public Linear getResultProjection()
Pointwise Linear projection of the results.- Returns:
- Pointwise Linear projection of the results.
-
getOutputShapes
public Shape[] getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.- Parameters:
inputShapes
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
initializeChildBlocks
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.- Overrides:
initializeChildBlocks
in classAbstractBaseBlock
- Parameters:
manager
- the manager to use for initializationdataType
- the requested data typeinputShapes
- the expected input shapes for this block
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
builder
public static ScaledDotProductAttentionBlock.Builder builder()
Creates a new Builder to build an Attention Block with.- Returns:
- a new Builder to build an Attention Block with.
-
-