public abstract class RecurrentBlock extends AbstractBlock
RecurrentBlock
is an abstract implementation of recurrent neural networks.
Recurrent neural networks are neural networks with hidden states. They are very popular for natural language processing tasks, and other tasks which involve sequential data.
This [article](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) written by Andrej Karpathy provides a detailed explanation of recurrent neural networks.
Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support.
Modifier and Type | Class and Description |
---|---|
static class |
RecurrentBlock.BaseBuilder<T extends RecurrentBlock.BaseBuilder>
The Builder to construct a
RecurrentBlock type of Block . |
Modifier and Type | Field and Description |
---|---|
protected NDArray |
beginState |
protected float |
dropRate |
protected int |
gates |
protected java.lang.String |
mode |
protected int |
numDirections |
protected int |
numStackedLayers |
protected boolean |
stateOutputs |
protected long |
stateSize |
protected boolean |
useSequenceLength |
children, inputNames, inputShapes, parameters, parameterShapeCallbacks, version
Constructor and Description |
---|
RecurrentBlock(RecurrentBlock.BaseBuilder<?> builder)
Creates a
RecurrentBlock object. |
Modifier and Type | Method and Description |
---|---|
void |
beforeInitialize(Shape[] inputs)
Performs any action necessary before initialization.
|
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.
|
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputs)
Returns the expected output shapes of the block for the specified input shapes.
|
Shape |
getParameterShape(java.lang.String name,
Shape[] inputShapes)
Returns the shape of the specified direct parameter of this block given the shapes of the
input to the block.
|
protected boolean |
isBidirectional() |
void |
loadMetadata(byte version,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
protected NDList |
opInputs(ParameterStore parameterStore,
NDList inputs) |
protected void |
resetBeginStates() |
void |
setBeginStates(NDList beginStates)
Sets the initial
NDArray value for the hidden states. |
void |
setStateOutputs(boolean stateOutputs)
Sets the parameter that indicates whether the output must include the hidden states.
|
protected NDList |
updateInputLayoutToTNC(NDList inputs) |
protected void |
validateInputSize(NDList inputs) |
addChildBlock, addParameter, addParameter, addParameter, cast, clear, describeInput, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, toString
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, forward, validateLayout
protected long stateSize
protected float dropRate
protected int numStackedLayers
protected java.lang.String mode
protected boolean useSequenceLength
protected int numDirections
protected int gates
protected boolean stateOutputs
protected NDArray beginState
public RecurrentBlock(RecurrentBlock.BaseBuilder<?> builder)
RecurrentBlock
object.builder
- the Builder
that has the necessary configurationsprotected void validateInputSize(NDList inputs)
public final void setStateOutputs(boolean stateOutputs)
stateOutputs
- whether the output must include the hidden states.public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameterspublic void setBeginStates(NDList beginStates)
NDArray
value for the hidden states.beginStates
- the NDArray
value for the hidden statesprotected void resetBeginStates()
public Shape[] getOutputShapes(NDManager manager, Shape[] inputs)
manager
- an NDManagerinputs
- the shapes of the inputspublic void beforeInitialize(Shape[] inputs)
beforeInitialize
in class AbstractBlock
inputs
- the expected shapes of the inputpublic Shape getParameterShape(java.lang.String name, Shape[] inputShapes)
getParameterShape
in interface Block
getParameterShape
in class AbstractBlock
name
- the name of the parameterinputShapes
- the shapes of the input to the blockpublic void loadMetadata(byte version, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
If you overwrite AbstractBlock.saveMetadata(DataOutputStream)
or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException
. After that it restores the input shapes.
loadMetadata
in class AbstractBlock
version
- the version used for loading this metadata.is
- the input stream we are loading fromjava.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong formatprotected boolean isBidirectional()
protected NDList opInputs(ParameterStore parameterStore, NDList inputs)