Class LayerNorm
- All Implemented Interfaces:
Block
Citing the abstract of the paper: "Training state-of-the-art, deep neural networks is computationally expensive. One way to reduce the training time is to normalize the activities of the neurons. A recently introduced technique called batch normalization uses the distribution of the summed input to a neuron over a mini-batch of training cases to compute a mean and variance which are then used to normalize the summed input to that neuron on each training case. This significantly reduces the training time in feed-forward neural networks. However, the effect of batch normalization is dependent on the mini-batch size and it is not obvious how to apply it to recurrent neural networks. In this paper, we transpose batch normalization into layer normalization by computing the mean and variance used for normalization from all of the summed inputs to the neurons in a layer on a single training case. Like batch normalization, we also give each neuron its own adaptive bias and gain which are applied after the normalization but before the non-linearity. Unlike batch normalization, layer normalization performs exactly the same computation at training and test times. It is also straightforward to apply to recurrent neural networks by computing the normalization statistics separately at each time step. Layer normalization is very effective at stabilizing the hidden state dynamics in recurrent networks. Empirically, we show that layer normalization can substantially reduce the training time compared with previously published techniques."
-
Nested Class Summary
Nested Classes -
Field Summary
FieldsModifier and TypeFieldDescriptionprotected int[]
protected Parameter
protected boolean
protected float
protected Parameter
protected Shape
protected boolean
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionprotected void
beforeInitialize
(Shape... inputShapes) Performs any action necessary before initialization.static LayerNorm.Builder
builder()
Creates a builder to build aLayerNorm
.protected NDList
forwardInternal
(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.Shape[]
getOutputShapes
(Shape[] inputShapes) Returns the expected output shapes of the block for the specified input shapes.static NDList
Applies Layer Normalization with average and variance for each input sample across the axis dimensions.void
loadMetadata
(byte loadVersion, DataInputStream is) Overwrite this to load additional metadata with the parameter values.void
Sets the shape ofParameter
s.protected void
Override this method to save additional data apart from parameter values.Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
Methods inherited from class ai.djl.nn.AbstractBaseBlock
cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveParameters, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Field Details
-
epsilon
protected float epsilon -
normalizedShape
-
center
protected boolean center -
scale
protected boolean scale -
axis
protected int[] axis -
gamma
-
beta
-
-
Constructor Details
-
LayerNorm
-
-
Method Details
-
layerNorm
public static NDList layerNorm(NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) Applies Layer Normalization with average and variance for each input sample across the axis dimensions.- Parameters:
input
- the inputNDArray
of shape (batchSize, inputChannel, *), * could be empty, width, (height, width), (depth, height, width)normalizedShape
- dimensions to calculate average and variance fromgamma
- gamma weightNDArray
beta
- beta weightNDArray
eps
- a value added to the denominator for numerical stability- Returns:
- the output
NDArray
of shape (batchSize, inputChannel, *), * could be empty, width, (height, width), (depth, height, width)
-
builder
Creates a builder to build aLayerNorm
.- Returns:
- a new builder
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
getOutputShapes
Returns the expected output shapes of the block for the specified input shapes.- Parameters:
inputShapes
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
beforeInitialize
Performs any action necessary before initialization. For example, keep the input information or verify the layout.- Overrides:
beforeInitialize
in classAbstractBaseBlock
- Parameters:
inputShapes
- the expected shapes of the input
-
prepare
Sets the shape ofParameter
s.- Overrides:
prepare
in classAbstractBaseBlock
- Parameters:
inputShapes
- the shapes of inputs
-
saveMetadata
Override this method to save additional data apart from parameter values.This default implementation saves the currently set input shapes.
- Overrides:
saveMetadata
in classAbstractBaseBlock
- Parameters:
os
- the non-null output stream the parameter values and metadata are written to- Throws:
IOException
- saving failed
-
loadMetadata
public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException Overwrite this to load additional metadata with the parameter values.If you overwrite
AbstractBaseBlock.saveMetadata(DataOutputStream)
or need to provide backward compatibility to older binary formats, you probably need to overwrite this. This default implementation checks if the version number fits, if not it throws anMalformedModelException
. After that it restores the input shapes.- Overrides:
loadMetadata
in classAbstractBaseBlock
- Parameters:
loadVersion
- the version used for loading this metadata.is
- the input stream we are loading from- Throws:
IOException
- loading failedMalformedModelException
- data can be loaded but has wrong format
-