Class GhostBatchNorm

All Implemented Interfaces:
Block

public class GhostBatchNorm extends BatchNorm
GhostBatchNorm is similar to BatchNorm except that it splits a batch into a smaller sub-batches aka ghost batches, and normalize them individually to have a mean of 0 and variance of 1 and finally concatenate them again to a single batch. Each of the mini-batches contains a virtualBatchSize samples.
See Also:
  • Constructor Details

  • Method Details

    • forwardInternal

      protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Overrides:
      forwardInternal in class BatchNorm
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • split

      protected NDList[] split(NDList list)
      Splits an NDList into the given size of sub-batch.

      This function unbatchifies the input NDList into mini-batches, each with the size of virtualBatchSize. If the batch size is divisible by the virtual batch size, all returned sub-batches will be the same size. If the batch size is not divisible by virtual batch size, all returned sub-batches will be the same size, except the last one.

      Parameters:
      list - the NDList that needs to be split
      Returns:
      an array of NDList that contains all the mini-batches
    • batchify

      protected NDList batchify(NDList[] subBatches)
      Converts an array of NDList into an NDList using StackBatchifier and squeezes the first dimension created by it. This makes the final NDArray same size as the splitted one.
      Parameters:
      subBatches - the input array of NDList
      Returns:
      the batchified NDList
    • squeezeExtraDimensions

      protected NDList squeezeExtraDimensions(NDList batch)
      Squeezes first axes of NDList.
      Parameters:
      batch - input array of NDList
      Returns:
      the squeezed NDList
    • builder

      public static GhostBatchNorm.Builder builder()
      Creates a builder to build a GhostBatchNorm.
      Returns:
      a new builder