Class BatchTensorList

java.lang.Object
ai.djl.modality.nlp.generate.BatchTensorList

public abstract class BatchTensorList extends Object
BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration of the autoregressive loop.

It is a struct consisting of NDArrays, whose first dimension is batch, and also contains sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher batch operations will operate on these two dimensions.

  • Method Details

    • fromList

      public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder)
      Constructs a new BatchTensorList instance from the serialized version of the batch tensors.

      The pastOutputIds has to be the first in the output list.

      Parameters:
      inputList - the serialized version of the batch tensors
      seqDimOrder - the sequence dimension order that specifies where the sequence dimension is in a tensor's shape
      Returns:
      BatchTensorList
    • getList

      public abstract NDList getList()
      Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first in the output list.
      Returns:
      the NDList that contains the serialized BatchTensorList
    • getSeqDimOrder

      public long[] getSeqDimOrder()
      Returns the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.
      Returns:
      the sequence dimension order which specifies where the sequence dimension is in a tensor's shape
    • getPastOutputIds

      public NDArray getPastOutputIds()
      Returns the value of the pastOutputIds.
      Returns:
      the value of pastOutputIds
    • setPastOutputIds

      public void setPastOutputIds(NDArray pastOutputIds)
      Sets the past output token ids.
      Parameters:
      pastOutputIds - the past output token ids
    • getPastAttentionMask

      public NDArray getPastAttentionMask()
      Returns the value of the pastAttentionMask.
      Returns:
      the value of pastAttentionMask
    • setPastAttentionMask

      public void setPastAttentionMask(NDArray pastAttentionMask)
      Sets the attention mask.
      Parameters:
      pastAttentionMask - the attention mask
    • getPastKeyValues

      public NDList getPastKeyValues()
      Returns the value of the pastKeyValues.
      Returns:
      the value of pastKeyValues
    • setPastKeyValues

      public void setPastKeyValues(NDList pastKeyValues)
      Sets the kv cache.
      Parameters:
      pastKeyValues - the kv cache
    • setSeqDimOrder

      public void setSeqDimOrder(long[] seqDimOrder)
      Sets the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.
      Parameters:
      seqDimOrder - the sequence dimension order which specifies where the sequence dimension is in a tensor's shape