public class GRU extends RecurrentBlock
GRU
is an abstract implementation of recurrent neural networks which applies GRU (Gated
Recurrent Unit) recurrent layer to input.
Current implementation refers the [paper](http://arxiv.org/abs/1406.1078) - Gated Recurrent Unit. The definition of GRU here is slightly different from the paper but compatible with CUDNN.
The GRU operator is formulated as below:
$$ \begin{split}\begin{array}{ll} r_t = \mathrm{sigmoid}(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ z_t = \mathrm{sigmoid}(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \\ \end{array}\end{split} $$
Modifier and Type | Class and Description |
---|---|
static class |
GRU.Builder
|
RecurrentBlock.BaseBuilder<T extends RecurrentBlock.BaseBuilder>
batchFirst, bidirectional, dropRate, gates, hasBiases, numLayers, returnState, stateSize
children, inputNames, inputShapes, parameters, version
Modifier and Type | Method and Description |
---|---|
static GRU.Builder |
builder()
Creates a builder to build a
GRU . |
protected NDList |
forwardInternal(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper for
Block.forward(ParameterStore, NDList, boolean, PairList) after
initialization. |
beforeInitialize, getNumDirections, getOutputShapes, loadMetadata, prepare
addChildBlock, addParameter, cast, clear, describeInput, forward, forward, forwardInternal, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, validateLayout
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Block.forward(ParameterStore, NDList, boolean, PairList)
after
initialization.forwardInternal
in class AbstractBlock
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameterspublic static GRU.Builder builder()
GRU
.