Package ai.djl.nn.recurrent
Class GRU
- java.lang.Object
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.recurrent.RecurrentBlock
-
- ai.djl.nn.recurrent.GRU
-
- All Implemented Interfaces:
Block
public class GRU extends RecurrentBlock
GRU
is an abstract implementation of recurrent neural networks which applies GRU (Gated Recurrent Unit) recurrent layer to input.Current implementation refers the [paper](http://arxiv.org/abs/1406.1078) - Gated Recurrent Unit. The definition of GRU here is slightly different from the paper but compatible with CUDNN.
The GRU operator is formulated as below:
$$ \begin{split}\begin{array}{ll} r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ \end{array}\end{split} $$
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
GRU.Builder
-
Nested classes/interfaces inherited from class ai.djl.nn.recurrent.RecurrentBlock
RecurrentBlock.BaseBuilder<T extends RecurrentBlock.BaseBuilder>
-
-
Field Summary
-
Fields inherited from class ai.djl.nn.recurrent.RecurrentBlock
batchFirst, bidirectional, dropRate, gates, hasBiases, numLayers, returnState, stateSize
-
Fields inherited from class ai.djl.nn.AbstractBlock
children, inputNames, inputShapes, parameters, version
-
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static GRU.Builder
builder()
Creates a builder to build aGRU
.protected NDList
forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.-
Methods inherited from class ai.djl.nn.recurrent.RecurrentBlock
beforeInitialize, getNumDirections, getOutputShapes, loadMetadata, prepare
-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addParameter, cast, clear, describeInput, forward, forward, forwardInternal, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
-
-
-
Method Detail
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
builder
public static GRU.Builder builder()
Creates a builder to build aGRU
.- Returns:
- a new builder
-
-