Package ai.djl.modality.nlp.generate
Class BatchTensorList
java.lang.Object
ai.djl.modality.nlp.generate.BatchTensorList
BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
of the autoregressive loop.
It is a struct consisting of NDArrays, whose first dimension is batch, and also contains sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher batch operations will operate on these two dimensions.
-
Method Summary
Modifier and TypeMethodDescriptionabstract BatchTensorList
Constructs a newBatchTensorList
instance from the serialized version of the batch tensors.abstract NDList
getList()
Returns the serialized version of the BatchTensorList.Returns the value of the pastAttentionMask.Returns the value of the pastKeyValues.Returns the value of the pastOutputIds.long[]
Returns the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.void
setPastAttentionMask
(NDArray pastAttentionMask) Sets the attention mask.void
setPastKeyValues
(NDList pastKeyValues) Sets the kv cache.void
setPastOutputIds
(NDArray pastOutputIds) Sets the past output token ids.void
setSeqDimOrder
(long[] seqDimOrder) Sets the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.
-
Method Details
-
fromList
Constructs a newBatchTensorList
instance from the serialized version of the batch tensors.The pastOutputIds has to be the first in the output list.
- Parameters:
inputList
- the serialized version of the batch tensorsseqDimOrder
- the sequence dimension order that specifies where the sequence dimension is in a tensor's shape- Returns:
- BatchTensorList
-
getList
Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first in the output list.- Returns:
- the NDList that contains the serialized BatchTensorList
-
getSeqDimOrder
public long[] getSeqDimOrder()Returns the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.- Returns:
- the sequence dimension order which specifies where the sequence dimension is in a tensor's shape
-
getPastOutputIds
Returns the value of the pastOutputIds.- Returns:
- the value of pastOutputIds
-
setPastOutputIds
Sets the past output token ids.- Parameters:
pastOutputIds
- the past output token ids
-
getPastAttentionMask
Returns the value of the pastAttentionMask.- Returns:
- the value of pastAttentionMask
-
setPastAttentionMask
Sets the attention mask.- Parameters:
pastAttentionMask
- the attention mask
-
getPastKeyValues
Returns the value of the pastKeyValues.- Returns:
- the value of pastKeyValues
-
setPastKeyValues
Sets the kv cache.- Parameters:
pastKeyValues
- the kv cache
-
setSeqDimOrder
public void setSeqDimOrder(long[] seqDimOrder) Sets the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.- Parameters:
seqDimOrder
- the sequence dimension order which specifies where the sequence dimension is in a tensor's shape
-