Package ai.djl.nn.norm
Class GhostBatchNorm
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.norm.BatchNorm
-
- ai.djl.nn.norm.GhostBatchNorm
-
- All Implemented Interfaces:
Block
public class GhostBatchNorm extends BatchNorm
GhostBatchNorm
is similar toBatchNorm
except that it splits a batch into a smaller sub-batches aka ghost batches, and normalize them individually to have a mean of 0 and variance of 1 and finally concatenate them again to a single batch. Each of the mini-batches contains a virtualBatchSize samples.- See Also:
- Ghost Normalization Paper
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
GhostBatchNorm.Builder
The Builder to construct aGhostBatchNorm
.-
Nested classes/interfaces inherited from class ai.djl.nn.norm.BatchNorm
BatchNorm.BaseBuilder<T extends BatchNorm.BaseBuilder<T>>
-
-
Field Summary
-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, version
-
-
Constructor Summary
Constructors Modifier Constructor Description protected
GhostBatchNorm(GhostBatchNorm.Builder builder)
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description protected NDList
batchify(NDList[] subBatches)
Converts an array ofNDList
into an NDList usingStackBatchifier
and squeezes the first dimension created by it.static GhostBatchNorm.Builder
builder()
Creates a builder to build aGhostBatchNorm
.protected NDList
forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.protected NDList[]
split(NDList list)
Splits anNDList
into the given size of sub-batch.protected NDList
squeezeExtraDimensions(NDList batch)
Squeezes first axes ofNDList
.-
Methods inherited from class ai.djl.nn.norm.BatchNorm
batchNorm, batchNorm, batchNorm, batchNorm, beforeInitialize, getOutputShapes, loadMetadata, prepare, saveMetadata
-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters
-
-
-
-
Constructor Detail
-
GhostBatchNorm
protected GhostBatchNorm(GhostBatchNorm.Builder builder)
-
-
Method Detail
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Overrides:
forwardInternal
in classBatchNorm
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
split
protected NDList[] split(NDList list)
Splits anNDList
into the given size of sub-batch.This function unbatchifies the input
NDList
into mini-batches, each with the size of virtualBatchSize. If the batch size is divisible by the virtual batch size, all returned sub-batches will be the same size. If the batch size is not divisible by virtual batch size, all returned sub-batches will be the same size, except the last one.
-
batchify
protected NDList batchify(NDList[] subBatches)
Converts an array ofNDList
into an NDList usingStackBatchifier
and squeezes the first dimension created by it. This makes the finalNDArray
same size as the splitted one.
-
squeezeExtraDimensions
protected NDList squeezeExtraDimensions(NDList batch)
Squeezes first axes ofNDList
.
-
builder
public static GhostBatchNorm.Builder builder()
Creates a builder to build aGhostBatchNorm
.- Returns:
- a new builder
-
-