public class BatchNorm extends ParameterBlock
The problem of varying distribution of input data requires the training process of a deep network to compensate for each different data distribution per batch, hence changing parameters' values as new batch data is processed and changes distribution of the network's (and each of its layers) activations. This condition is termed as internal covariate shift, and such occurrence prevents the network to learn faster and generalize better to unseen data.
With batch normalization, one benefits by having faster learning process as batch normalization allows larger learning rate without causing gradient problems on backpropagation as all inputs are normalized and hence reducing the scale of weight update impact on backpropagation. In some cases, the utilization of batch normalization layer regularizes the network and reduces, even eliminates, the need for dropout, which in turn results in even faster training process since dropout slows down training by 2-3 times. However, it was reported that batch normalization may not be beneficial when small batch sizes are used.
Formally, batch normalization is represented below:
\(\hat{x} \:=\: \frac{x \:-\: \mu_{batch}}{\sqrt{\sigma^2_{batch} \:+\: \epsilon}}\),
where \(\hat{x}\) is the normalized input, \(\mu_{batch}\) and \(\sigma^2_{batch}\) respectively
denote the mean and variance of a batch, and \(\epsilon\) (epsilon) is a constant for numerical
stability. The scale and shift operation can be formally defined as follows:
\(y \:=\: \gamma\hat{x} \:+\: \beta\),
where \(\gamma\) is the scale factor and \(\beta\) is the shift factor.
Modifier and Type | Class and Description |
---|---|
static class |
BatchNorm.Builder
The Builder to construct a
BatchNorm . |
inputNames, inputShapes
Modifier and Type | Method and Description |
---|---|
void |
beforeInitialize(Shape[] inputShapes)
Performs any action necessary before initialization.
|
static BatchNorm.Builder |
builder()
Creates a builder to build a
BatchNorm . |
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.
|
java.util.List<Parameter> |
getDirectParameters()
Returns a list of all the direct parameters of the block.
|
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.
|
Shape |
getParameterShape(java.lang.String name,
Shape[] inputShapes)
Returns the shape of the specified direct parameter of this block given the shapes of the
input to the block.
|
void |
loadParameters(NDManager manager,
java.io.DataInputStream is)
Loads the parameters from the given input stream.
|
void |
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.
|
getChildren, initialize, toString
cast, clear, describeInput, getParameters, isInitialized, readInputShapes, saveInputShapes, setInitializer, setInitializer
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, validateLayout
public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameterspublic Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes)
manager
- an NDManagerinputShapes
- the shapes of the inputspublic java.util.List<Parameter> getDirectParameters()
Parameter
public void beforeInitialize(Shape[] inputShapes)
beforeInitialize
in class AbstractBlock
inputShapes
- the expected shapes of the inputpublic Shape getParameterShape(java.lang.String name, Shape[] inputShapes)
name
- the name of the parameterinputShapes
- the shapes of the input to the blockpublic void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
os
- the outputstream to save the parameters tojava.io.IOException
- if an I/O error occurspublic void loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter valuesjava.io.IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupportedpublic static BatchNorm.Builder builder()
BatchNorm
.