public class LayerNorm extends AbstractBlock
Citing the abstract of the paper: "Training state-of-the-art, deep neural networks is computationally expensive. One way to reduce the training time is to normalize the activities of the neurons. A recently introduced technique called batch normalization uses the distribution of the summed input to a neuron over a mini-batch of training cases to compute a mean and variance which are then used to normalize the summed input to that neuron on each training case. This significantly reduces the training time in feed-forward neural networks. However, the effect of batch normalization is dependent on the mini-batch size and it is not obvious how to apply it to recurrent neural networks. In this paper, we transpose batch normalization into layer normalization by computing the mean and variance used for normalization from all of the summed inputs to the neurons in a layer on a single training case. Like batch normalization, we also give each neuron its own adaptive bias and gain which are applied after the normalization but before the non-linearity. Unlike batch normalization, layer normalization performs exactly the same computation at training and test times. It is also straightforward to apply to recurrent neural networks by computing the normalization statistics separately at each time step. Layer normalization is very effective at stabilizing the hidden state dynamics in recurrent networks. Empirically, we show that layer normalization can substantially reduce the training time compared with previously published techniques."
Modifier and Type | Class and Description |
---|---|
static class |
LayerNorm.Builder
The Builder to construct a
LayerNorm . |
children, inputNames, inputShapes, parameters, version
Modifier and Type | Method and Description |
---|---|
protected void |
beforeInitialize(Shape... inputShapes)
Performs any action necessary before initialization.
|
static LayerNorm.Builder |
builder()
Creates a builder to build a
LayerNorm . |
protected NDList |
forwardInternal(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper for
Block.forward(ParameterStore, NDList, boolean, PairList) after
initialization. |
Shape[] |
getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.
|
static NDList |
layerNorm(NDArray input,
Shape normalizedShape,
NDArray gamma,
NDArray beta,
float eps)
Applies Layer Normalization with average and variance for each input sample across the axis
dimensions.
|
void |
loadMetadata(byte loadVersion,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
void |
prepare(Shape[] inputShapes)
Sets the shape of
Parameter s. |
protected void |
saveMetadata(java.io.DataOutputStream os)
Override this method to save additional data apart from parameter values.
|
addChildBlock, addParameter, cast, clear, describeInput, forward, forward, forwardInternal, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveParameters, setInitializer, setInitializer, setInitializer, toString
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, validateLayout
public static NDList layerNorm(NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps)
input
- the input NDArray
of shape (batchSize, inputChannel, *), * could be
empty, width, (height, width), (depth, height, width)normalizedShape
- dimensions to calculate average and variance fromgamma
- gamma weight NDArray
beta
- beta weight NDArray
eps
- a value added to the denominator for numerical stabilityNDArray
of shape (batchSize, inputChannel, *), * could be empty,
width, (height, width), (depth, height, width)public static LayerNorm.Builder builder()
LayerNorm
.protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Block.forward(ParameterStore, NDList, boolean, PairList)
after
initialization.forwardInternal
in class AbstractBlock
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameterspublic Shape[] getOutputShapes(Shape[] inputShapes)
inputShapes
- the shapes of the inputsprotected void beforeInitialize(Shape... inputShapes)
beforeInitialize
in class AbstractBlock
inputShapes
- the expected shapes of the inputpublic void prepare(Shape[] inputShapes)
Parameter
s.prepare
in class AbstractBlock
inputShapes
- the shapes of inputsprotected void saveMetadata(java.io.DataOutputStream os) throws java.io.IOException
This default implementation saves the currently set input shapes.
saveMetadata
in class AbstractBlock
os
- the non-null output stream the parameter values and metadata are written tojava.io.IOException
- saving failedpublic void loadMetadata(byte loadVersion, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
If you overwrite AbstractBlock.saveMetadata(DataOutputStream)
or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException
. After that it restores the input shapes.
loadMetadata
in class AbstractBlock
loadVersion
- the version used for loading this metadata.is
- the input stream we are loading fromjava.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong format