public class LSTM extends RecurrentCell
Reference paper - LONG SHORT-TERM MEMORY - Hochreiter, 1997. http://www.bioinf.jku.at/publications/older/2604.pdf
$$ \begin{split}\begin{array}{ll} i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\ f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\ g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\ o_t = \mathrm{sigmoid}(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\ c_t = f_t * c_{(t-1)} + i_t * g_t \\ h_t = o_t * \tanh(c_t) \end{array}\end{split} $$
Modifier and Type | Class and Description |
---|---|
static class |
LSTM.Builder
|
RecurrentCell.BaseBuilder<T extends RecurrentCell.BaseBuilder>
dropRate, gates, mode, numDirections, numStackedLayers, parameters, stateOutputs, stateShape, stateSize, useSequenceLength
inputNames, inputShapes
Modifier and Type | Method and Description |
---|---|
static LSTM.Builder |
builder()
Creates a builder to build a
LSTM . |
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.
|
void |
loadParameters(NDManager manager,
java.io.DataInputStream is)
Loads the parameters from the given input stream.
|
protected NDList |
opInputs(ParameterStore parameterStore,
NDList inputs) |
void |
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.
|
beforeInitialize, getDirectParameters, getOutputShapes, getParameterShape, isBidirectional, updateInputLayoutToTNC, validateInputSize
getChildren, initialize, toString
cast, clear, describeInput, getParameters, isInitialized, readInputShapes, saveInputShapes, setInitializer, setInitializer
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, validateLayout
public NDList forward(ParameterStore parameterStore, NDList inputs, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
forward
in interface Block
forward
in class RecurrentCell
parameterStore
- the parameter storeinputs
- the input NDListparams
- optional parametersprotected NDList opInputs(ParameterStore parameterStore, NDList inputs)
opInputs
in class RecurrentCell
public void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
saveParameters
in interface Block
saveParameters
in class RecurrentCell
os
- the outputstream to save the parameters tojava.io.IOException
- if an I/O error occurspublic void loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
loadParameters
in interface Block
loadParameters
in class RecurrentCell
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter valuesjava.io.IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupportedpublic static LSTM.Builder builder()
LSTM
.