Package ai.djl.nn.recurrent
Class LSTM
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.recurrent.RecurrentBlock
-
- ai.djl.nn.recurrent.LSTM
-
- All Implemented Interfaces:
Block
public class LSTM extends RecurrentBlock
LSTM
is an implementation of recurrent neural networks which applies Long Short-Term Memory recurrent layer to input.Reference paper - LONG SHORT-TERM MEMORY - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf
The LSTM operator is formulated as below:
$$ \begin{split}\begin{array}{ll} i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ c_t = f_t * c_{(t-1)} + i_t * g_t \\ h_t = o_t * \tanh(c_t) \end{array}\end{split} $$
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
LSTM.Builder
-
Nested classes/interfaces inherited from class ai.djl.nn.recurrent.RecurrentBlock
RecurrentBlock.BaseBuilder<T extends RecurrentBlock.BaseBuilder>
-
-
Field Summary
-
Fields inherited from class ai.djl.nn.recurrent.RecurrentBlock
batchFirst, bidirectional, dropRate, gates, hasBiases, numLayers, returnState, stateSize
-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static LSTM.Builder
builder()
Creates a builder to build aLSTM
.protected NDList
forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.-
Methods inherited from class ai.djl.nn.recurrent.RecurrentBlock
beforeInitialize, getNumDirections, getOutputShapes, loadMetadata, prepare
-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
-
-
-
Method Detail
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
builder
public static LSTM.Builder builder()
Creates a builder to build aLSTM
.- Returns:
- a new builder
-
-