Class PaddingStackBatchifier

  • All Implemented Interfaces:
    Batchifier, java.io.Serializable

    public final class PaddingStackBatchifier
    extends java.lang.Object
    implements Batchifier
    The padding stack batchifier is a StackBatchifier that also pads elements to reach the same length.
    See Also:
    Serialized Form
    • Method Detail

      • batchify

        public NDList batchify​(NDList[] inputs)
        Converts an array of Record NDLists into a combined Batch NDList.

        The size of the input array is the batch size. The data in each of the NDList are assumed to be the same, and are batched together to form one batched NDList.

        Specified by:
        batchify in interface Batchifier
        Parameters:
        inputs - the input array of NDList where each element is a
        Returns:
        the batchified NDList
      • unbatchify

        public NDList[] unbatchify​(NDList inputs)
        Splits a combined Batch NDList into it's constituent Record NDLists.

        This reverses the batchify operation.

        Specified by:
        unbatchify in interface Batchifier
        Parameters:
        inputs - the NDList that needs to be 'unbatchified'
        Returns:
        an array of NDLists, of size equal to batch size, where each NDList is one element from the batch of inputs
      • split

        public NDList[] split​(NDList list,
                              int numOfSlices,
                              boolean evenSplit)
        Partitions the given Batch NDList into multiple Batch lists with smaller batch size.

        As an example, this function might be used for multi-GPU training where it takes the main batch and splits it into sub-batches that can be run on each GPU.

        This function unbatchifies the input NDList, redistributes them into the given number of slices, and then batchify each of the slices to form an array of NDList.

        Specified by:
        split in interface Batchifier
        Parameters:
        list - the NDList that needs to be split
        numOfSlices - the number of slices the list must be sliced into
        evenSplit - whether each slice must have the same shape
        Returns:
        an array of NDList that contains all the slices
      • findMaxSize

        public static long findMaxSize​(NDList[] inputs,
                                       int arrayIndex,
                                       int dimIndex)
        Finds the maximum size for a particular array/dimension in a batch of inputs (which can be padded to equalize their sizes).
        Parameters:
        inputs - the batch of inputs
        arrayIndex - the array (for each NDList in the batch)
        dimIndex - for the array in each NDList in the batch
        Returns:
        the maximum size
      • padArrays

        public static long[] padArrays​(NDList[] inputs,
                                       int arrayIndex,
                                       int dimIndex,
                                       NDArray padding,
                                       long maxSize)
        Pads the arrays at a particular dimension to all have the same size (updating inputs in place).
        Parameters:
        inputs - the batch of inputs
        arrayIndex - the array (for each NDList in the batch)
        dimIndex - for the array in each NDList in the batch
        padding - the padding to use. Say you have a batch of arrays of Shape(10, ?, 3) and you are padding the "?" dimension. There are two padding modes:
        • If you give padding of Shape(1, 3) (same dimensionality as required), it will be repeated with NDArray.repeat(long) as necessary
        • If you give padding of Shape(3) or Shape(0) (smaller dimensionality as required), it will be broadcasted with NDArray.broadcast(Shape) to reach the full required Shape(?, 3)
        maxSize - the size that each array will be padded to in that dimension. In the example above, the padding to be applied to the "?" dimension.
        Returns:
        the original valid length for each dimension in the batch (same length as inputs.length). The inputs will be updated in place.