Class RecurrentBlock
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.recurrent.RecurrentBlock
-
- All Implemented Interfaces:
Block
public abstract class RecurrentBlock extends AbstractBlock
RecurrentBlock
is an abstract implementation of recurrent neural networks.Recurrent neural networks are neural networks with hidden states. They are very popular for natural language processing tasks, and other tasks which involve sequential data.
This [article](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) written by Andrej Karpathy provides a detailed explanation of recurrent neural networks.
Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support.
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
RecurrentBlock.BaseBuilder<T extends RecurrentBlock.BaseBuilder>
The Builder to construct aRecurrentBlock
type ofBlock
.
-
Field Summary
Fields Modifier and Type Field Description protected boolean
batchFirst
protected boolean
bidirectional
protected float
dropRate
protected int
gates
protected boolean
hasBiases
protected int
numLayers
protected boolean
returnState
protected long
stateSize
-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, version
-
-
Constructor Summary
Constructors Constructor Description RecurrentBlock(RecurrentBlock.BaseBuilder<?> builder)
Creates aRecurrentBlock
object.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description protected void
beforeInitialize(Shape... inputShapes)
Performs any action necessary before initialization.protected int
getNumDirections()
Shape[]
getOutputShapes(Shape[] inputs)
Returns the expected output shapes of the block for the specified input shapes.void
loadMetadata(byte loadVersion, java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.void
prepare(Shape[] inputs)
Sets the shape ofParameter
s.-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
cast, clear, describeInput, forward, forward, forwardInternal, forwardInternal, getInputShapes, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters
-
-
-
-
Field Detail
-
stateSize
protected long stateSize
-
dropRate
protected float dropRate
-
numLayers
protected int numLayers
-
gates
protected int gates
-
batchFirst
protected boolean batchFirst
-
hasBiases
protected boolean hasBiases
-
bidirectional
protected boolean bidirectional
-
returnState
protected boolean returnState
-
-
Constructor Detail
-
RecurrentBlock
public RecurrentBlock(RecurrentBlock.BaseBuilder<?> builder)
Creates aRecurrentBlock
object.- Parameters:
builder
- theBuilder
that has the necessary configurations
-
-
Method Detail
-
getOutputShapes
public Shape[] getOutputShapes(Shape[] inputs)
Returns the expected output shapes of the block for the specified input shapes.- Parameters:
inputs
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
beforeInitialize
protected void beforeInitialize(Shape... inputShapes)
Performs any action necessary before initialization. For example, keep the input information or verify the layout.- Overrides:
beforeInitialize
in classAbstractBaseBlock
- Parameters:
inputShapes
- the expected shapes of the input
-
prepare
public void prepare(Shape[] inputs)
Sets the shape ofParameter
s.- Overrides:
prepare
in classAbstractBaseBlock
- Parameters:
inputs
- the shapes of inputs
-
loadMetadata
public void loadMetadata(byte loadVersion, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
Overwrite this to load additional metadata with the parameter values.If you overwrite
AbstractBaseBlock.saveMetadata(DataOutputStream)
or need to provide backward compatibility to older binary formats, you prabably need to overwrite this. This default implementation checks if the version number fits, if not it throws anMalformedModelException
. After that it restores the input shapes.- Overrides:
loadMetadata
in classAbstractBaseBlock
- Parameters:
loadVersion
- the version used for loading this metadata.is
- the input stream we are loading from- Throws:
java.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong format
-
getNumDirections
protected int getNumDirections()
-
-