Class BatchTensorList


  • public abstract class BatchTensorList
    extends java.lang.Object
    BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration of the autoregressive loop.

    It is a struct consisting of NDArrays, whose first dimension is batch, and also contains sequence dimension (whose position in tensor's shape is specified by seqDimOrder). The SeqBatcher batch operations will operate on these two dimensions.

    • Method Detail

      • fromList

        public abstract BatchTensorList fromList​(NDList inputList,
                                                 long[] seqDimOrder)
        Constructs a new BatchTensorList instance from the serialized version of the batch tensors.

        The pastOutputIds has to be the first in the output list.

        Parameters:
        inputList - the serialized version of the batch tensors
        seqDimOrder - the sequence dimension order that specifies where the sequence dimension is in a tensor's shape
        Returns:
        BatchTensorList
      • getList

        public abstract NDList getList()
        Returns the serialized version of the BatchTensorList. The pastOutputIds has to be the first in the output list.
        Returns:
        the NDList that contains the serialized BatchTensorList
      • getSeqDimOrder

        public long[] getSeqDimOrder()
        Returns the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.
        Returns:
        the sequence dimension order which specifies where the sequence dimension is in a tensor's shape
      • getPastOutputIds

        public NDArray getPastOutputIds()
        Returns the value of the pastOutputIds.
        Returns:
        the value of pastOutputIds
      • setPastOutputIds

        public void setPastOutputIds​(NDArray pastOutputIds)
        Sets the past output token ids.
        Parameters:
        pastOutputIds - the past output token ids
      • getPastAttentionMask

        public NDArray getPastAttentionMask()
        Returns the value of the pastAttentionMask.
        Returns:
        the value of pastAttentionMask
      • setPastAttentionMask

        public void setPastAttentionMask​(NDArray pastAttentionMask)
        Sets the attention mask.
        Parameters:
        pastAttentionMask - the attention mask
      • getPastKeyValues

        public NDList getPastKeyValues()
        Returns the value of the pastKeyValues.
        Returns:
        the value of pastKeyValues
      • setPastKeyValues

        public void setPastKeyValues​(NDList pastKeyValues)
        Sets the kv cache.
        Parameters:
        pastKeyValues - the kv cache
      • setSeqDimOrder

        public void setSeqDimOrder​(long[] seqDimOrder)
        Sets the sequence dimension order which specifies where the sequence dimension is in a tensor's shape.
        Parameters:
        seqDimOrder - the sequence dimension order which specifies where the sequence dimension is in a tensor's shape