Package ai.djl.nn.recurrent
Class LSTM
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.recurrent.RecurrentBlock
ai.djl.nn.recurrent.LSTM
- All Implemented Interfaces:
Block
LSTM
is an implementation of recurrent neural networks which applies Long Short-Term
Memory recurrent layer to input.
Reference paper - LONG SHORT-TERM MEMORY - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf
The LSTM operator is formulated as below:
$$ \begin{split}\begin{array}{ll} i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ c_t = f_t * c_{(t-1)} + i_t * g_t \\ h_t = o_t * \tanh(c_t) \end{array}\end{split} $$
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic final class
Nested classes/interfaces inherited from class ai.djl.nn.recurrent.RecurrentBlock
RecurrentBlock.BaseBuilder<T extends RecurrentBlock.BaseBuilder>
-
Field Summary
Fields inherited from class ai.djl.nn.recurrent.RecurrentBlock
batchFirst, bidirectional, dropRate, gates, hasBiases, numLayers, returnState, stateSize
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Method Summary
Modifier and TypeMethodDescriptionstatic LSTM.Builder
builder()
Creates a builder to build aLSTM
.protected NDList
forwardInternal
(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.Methods inherited from class ai.djl.nn.recurrent.RecurrentBlock
beforeInitialize, getNumDirections, getOutputShapes, loadMetadata, prepare
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
Methods inherited from class ai.djl.nn.AbstractBaseBlock
cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Method Details
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
builder
Creates a builder to build aLSTM
.- Returns:
- a new builder
-