public abstract class RecurrentBlock extends AbstractBlock
RecurrentBlock
is an abstract implementation of recurrent neural networks.
Recurrent neural networks are neural networks with hidden states. They are very popular for natural language processing tasks, and other tasks which involve sequential data.
This [article](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) written by Andrej Karpathy provides a detailed explanation of recurrent neural networks.
Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support.
Modifier and Type | Class and Description |
---|---|
static class |
RecurrentBlock.BaseBuilder<T extends RecurrentBlock.BaseBuilder>
The Builder to construct a
RecurrentBlock type of Block . |
Modifier and Type | Field and Description |
---|---|
protected boolean |
batchFirst |
protected boolean |
bidirectional |
protected float |
dropRate |
protected int |
gates |
protected boolean |
hasBiases |
protected int |
numLayers |
protected boolean |
returnState |
protected long |
stateSize |
children, inputNames, inputShapes, parameters, parameterShapeCallbacks, version
Constructor and Description |
---|
RecurrentBlock(RecurrentBlock.BaseBuilder<?> builder)
Creates a
RecurrentBlock object. |
Modifier and Type | Method and Description |
---|---|
void |
beforeInitialize(Shape[] inputs)
Performs any action necessary before initialization.
|
protected int |
getNumDirections() |
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputs)
Returns the expected output shapes of the block for the specified input shapes.
|
Shape |
getParameterShape(java.lang.String name,
Shape[] inputShapes)
Returns the shape of the specified direct parameter of this block given the shapes of the
input to the block.
|
void |
loadMetadata(byte version,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
addChildBlock, addParameter, addParameter, addParameter, cast, clear, describeInput, forward, forwardInternal, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, toString
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, forward, validateLayout
protected long stateSize
protected float dropRate
protected int numLayers
protected int gates
protected boolean batchFirst
protected boolean hasBiases
protected boolean bidirectional
protected boolean returnState
public RecurrentBlock(RecurrentBlock.BaseBuilder<?> builder)
RecurrentBlock
object.builder
- the Builder
that has the necessary configurationspublic Shape[] getOutputShapes(NDManager manager, Shape[] inputs)
manager
- an NDManagerinputs
- the shapes of the inputspublic void beforeInitialize(Shape[] inputs)
beforeInitialize
in class AbstractBlock
inputs
- the expected shapes of the inputpublic Shape getParameterShape(java.lang.String name, Shape[] inputShapes)
getParameterShape
in interface Block
getParameterShape
in class AbstractBlock
name
- the name of the parameterinputShapes
- the shapes of the input to the blockpublic void loadMetadata(byte version, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
If you overwrite AbstractBlock.saveMetadata(DataOutputStream)
or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException
. After that it restores the input shapes.
loadMetadata
in class AbstractBlock
version
- the version used for loading this metadata.is
- the input stream we are loading fromjava.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong formatprotected int getNumDirections()