public abstract class RecurrentBlock extends AbstractBlock
RecurrentBlock
is an abstract implementation of recurrent neural networks.
Recurrent neural networks are neural networks with hidden states. They are very popular for natural language processing tasks, and other tasks which involve sequential data.
This [article](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) written by Andrej Karpathy provides a detailed explanation of recurrent neural networks.
Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support.
Modifier and Type | Class and Description |
---|---|
static class |
RecurrentBlock.BaseBuilder<T extends RecurrentBlock.BaseBuilder>
The Builder to construct a
RecurrentBlock type of Block . |
Modifier and Type | Field and Description |
---|---|
protected boolean |
batchFirst |
protected boolean |
bidirectional |
protected float |
dropRate |
protected int |
gates |
protected boolean |
hasBiases |
protected int |
numLayers |
protected boolean |
returnState |
protected long |
stateSize |
children, inputNames, inputShapes, parameters, version
Constructor and Description |
---|
RecurrentBlock(RecurrentBlock.BaseBuilder<?> builder)
Creates a
RecurrentBlock object. |
Modifier and Type | Method and Description |
---|---|
protected void |
beforeInitialize(Shape... inputShapes)
Performs any action necessary before initialization.
|
protected int |
getNumDirections() |
Shape[] |
getOutputShapes(Shape[] inputs)
Returns the expected output shapes of the block for the specified input shapes.
|
void |
loadMetadata(byte version,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
void |
prepare(Shape[] inputs)
Sets the shape of
Parameter s. |
addChildBlock, addParameter, cast, clear, describeInput, forward, forward, forwardInternal, forwardInternal, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, validateLayout
protected long stateSize
protected float dropRate
protected int numLayers
protected int gates
protected boolean batchFirst
protected boolean hasBiases
protected boolean bidirectional
protected boolean returnState
public RecurrentBlock(RecurrentBlock.BaseBuilder<?> builder)
RecurrentBlock
object.builder
- the Builder
that has the necessary configurationspublic Shape[] getOutputShapes(Shape[] inputs)
inputs
- the shapes of the inputsprotected void beforeInitialize(Shape... inputShapes)
beforeInitialize
in class AbstractBlock
inputShapes
- the expected shapes of the inputpublic void prepare(Shape[] inputs)
Parameter
s.prepare
in class AbstractBlock
inputs
- the shapes of inputspublic void loadMetadata(byte version, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
If you overwrite AbstractBlock.saveMetadata(DataOutputStream)
or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException
. After that it restores the input shapes.
loadMetadata
in class AbstractBlock
version
- the version used for loading this metadata.is
- the input stream we are loading fromjava.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong formatprotected int getNumDirections()