public abstract class RecurrentCell extends ParameterBlock
Modifier and Type | Class and Description |
---|---|
static class |
RecurrentCell.BaseBuilder<T extends RecurrentCell.BaseBuilder>
The Builder to construct a
RecurrentCell type of Block . |
Modifier and Type | Field and Description |
---|---|
protected float |
dropRate |
protected int |
gates |
protected java.lang.String |
mode |
protected int |
numDirections |
protected int |
numStackedLayers |
protected java.util.List<Parameter> |
parameters |
protected boolean |
stateOutputs |
protected Shape |
stateShape |
protected long |
stateSize |
protected boolean |
useSequenceLength |
inputNames, inputShapes
Constructor and Description |
---|
RecurrentCell(RecurrentCell.BaseBuilder<?> builder)
Creates a
RecurrentCell object. |
Modifier and Type | Method and Description |
---|---|
void |
beforeInitialize(Shape[] inputs)
Performs any action necessary before initialization.
|
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.
|
java.util.List<Parameter> |
getDirectParameters()
Returns a list of all the direct parameters of the block.
|
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputs)
Returns the expected output shapes of the block for the specified input shapes.
|
Shape |
getParameterShape(java.lang.String name,
Shape[] inputShapes)
Returns the shape of the specified direct parameter of this block given the shapes of the
input to the block.
|
protected boolean |
isBidirectional() |
void |
loadParameters(NDManager manager,
java.io.DataInputStream is)
Loads the parameters from the given input stream.
|
protected NDList |
opInputs(ParameterStore parameterStore,
NDList inputs) |
void |
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.
|
protected NDList |
updateInputLayoutToTNC(NDList inputs) |
protected void |
validateInputSize(NDList inputs) |
getChildren, initialize, toString
cast, clear, describeInput, getParameters, isInitialized, readInputShapes, saveInputShapes, setInitializer, setInitializer
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, validateLayout
protected long stateSize
protected float dropRate
protected int numStackedLayers
protected java.lang.String mode
protected boolean useSequenceLength
protected int numDirections
protected int gates
protected boolean stateOutputs
protected Shape stateShape
protected java.util.List<Parameter> parameters
public RecurrentCell(RecurrentCell.BaseBuilder<?> builder)
RecurrentCell
object.builder
- the Builder
that has the necessary configurationsprotected void validateInputSize(NDList inputs)
public NDList forward(ParameterStore parameterStore, NDList inputs, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
parameterStore
- the parameter storeinputs
- the input NDListparams
- optional parameterspublic Shape[] getOutputShapes(NDManager manager, Shape[] inputs)
manager
- an NDManagerinputs
- the shapes of the inputspublic java.util.List<Parameter> getDirectParameters()
Parameter
public void beforeInitialize(Shape[] inputs)
beforeInitialize
in class AbstractBlock
inputs
- the expected shapes of the inputpublic Shape getParameterShape(java.lang.String name, Shape[] inputShapes)
name
- the name of the parameterinputShapes
- the shapes of the input to the blockpublic void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
os
- the outputstream to save the parameters tojava.io.IOException
- if an I/O error occurspublic void loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter valuesjava.io.IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupportedprotected boolean isBidirectional()
protected NDList opInputs(ParameterStore parameterStore, NDList inputs)