Class StackBatchifier

java.lang.Object
ai.djl.translate.StackBatchifier
All Implemented Interfaces:
Batchifier, Serializable

public class StackBatchifier extends Object implements Batchifier
StackBatchifier is used to merge a list of samples to form a mini-batch of NDArray(s). The is default Batchifier for data loading.
See Also:
  • Constructor Details

    • StackBatchifier

      public StackBatchifier()
  • Method Details

    • batchify

      public NDList batchify(NDList[] inputs)
      Converts an array of Record NDLists into a combined Batch NDList.

      The size of the input array is the batch size. The data in each of the NDList are assumed to be the same, and are batched together to form one batched NDList.

      Specified by:
      batchify in interface Batchifier
      Parameters:
      inputs - the input array of NDList where each element is a
      Returns:
      the batchified NDList
    • unbatchify

      public NDList[] unbatchify(NDList inputs)
      Splits a combined Batch NDList into it's constituent Record NDLists.

      This reverses the batchify operation.

      Specified by:
      unbatchify in interface Batchifier
      Parameters:
      inputs - the NDList that needs to be 'unbatchified'
      Returns:
      an array of NDLists, of size equal to batch size, where each NDList is one element from the batch of inputs
    • split

      public NDList[] split(NDList list, int numOfSlices, boolean evenSplit)
      Partitions the given Batch NDList into multiple Batch lists with smaller batch size.

      As an example, this function might be used for multi-GPU training where it takes the main batch and splits it into sub-batches that can be run on each GPU.

      This function unbatchifies the input NDList, redistributes them into the given number of slices, and then batchify each of the slices to form an array of NDList.

      Specified by:
      split in interface Batchifier
      Parameters:
      list - the NDList that needs to be split
      numOfSlices - the number of slices the list must be sliced into
      evenSplit - whether each slice must have the same shape
      Returns:
      an array of NDList that contains all the slices