Package ai.djl.modality.nlp.generate
Class BatchTensorList
- java.lang.Object
-
- ai.djl.modality.nlp.generate.BatchTensorList
-
public abstract class BatchTensorList extends java.lang.Object
BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration of the autoregressive loop.It is a struct consisting of NDArrays, whose first dimension is batch, and also contains sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher batch operations will operate on these two dimensions.
-
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description abstract BatchTensorList
fromList(NDList inputList, long[] seqDimOrder)
Constructs a newBatchTensorList
instance from the serialized version of the batch tensors.abstract NDList
getList()
Returns the serialized version of the BatchTensorList.NDArray
getPastAttentionMask()
Returns the value of the pastAttentionMask.NDList
getPastKeyValues()
Returns the value of the pastKeyValues.NDArray
getPastOutputIds()
Returns the value of the pastOutputIds.long[]
getSeqDimOrder()
Returns the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.void
setPastAttentionMask(NDArray pastAttentionMask)
Sets the attention mask.void
setPastKeyValues(NDList pastKeyValues)
Sets the kv cache.void
setPastOutputIds(NDArray pastOutputIds)
Sets the past output token ids.void
setSeqDimOrder(long[] seqDimOrder)
Sets the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.
-
-
-
Method Detail
-
fromList
public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder)
Constructs a newBatchTensorList
instance from the serialized version of the batch tensors.The pastOutputIds has to be the first in the output list.
- Parameters:
inputList
- the serialized version of the batch tensorsseqDimOrder
- the sequence dimension order that specifies where the sequence dimension is in a tensor's shape- Returns:
- BatchTensorList
-
getList
public abstract NDList getList()
Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first in the output list.- Returns:
- the NDList that contains the serialized BatchTensorList
-
getSeqDimOrder
public long[] getSeqDimOrder()
Returns the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.- Returns:
- the sequence dimension order which specifies where the sequence dimension is in a tensor's shape
-
getPastOutputIds
public NDArray getPastOutputIds()
Returns the value of the pastOutputIds.- Returns:
- the value of pastOutputIds
-
setPastOutputIds
public void setPastOutputIds(NDArray pastOutputIds)
Sets the past output token ids.- Parameters:
pastOutputIds
- the past output token ids
-
getPastAttentionMask
public NDArray getPastAttentionMask()
Returns the value of the pastAttentionMask.- Returns:
- the value of pastAttentionMask
-
setPastAttentionMask
public void setPastAttentionMask(NDArray pastAttentionMask)
Sets the attention mask.- Parameters:
pastAttentionMask
- the attention mask
-
getPastKeyValues
public NDList getPastKeyValues()
Returns the value of the pastKeyValues.- Returns:
- the value of pastKeyValues
-
setPastKeyValues
public void setPastKeyValues(NDList pastKeyValues)
Sets the kv cache.- Parameters:
pastKeyValues
- the kv cache
-
setSeqDimOrder
public void setSeqDimOrder(long[] seqDimOrder)
Sets the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.- Parameters:
seqDimOrder
- the sequence dimension order which specifies where the sequence dimension is in a tensor's shape
-
-