public class EncoderDecoder extends AbstractBlock
Modifier and Type | Field and Description |
---|---|
protected Decoder |
decoder |
protected Encoder |
encoder |
inputNames, inputShapes
Constructor and Description |
---|
EncoderDecoder(Encoder encoder,
Decoder decoder)
|
Modifier and Type | Method and Description |
---|---|
ai.djl.util.PairList<java.lang.String,Shape> |
describeInput()
Returns a
PairList of input names, and shapes. |
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
boolean training)
Applies the forward function of the encoder and the decoder.
|
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the forward function of the encoder and the decoder.
|
NDList |
forward(ParameterStore parameterStore,
NDList encoderInputs,
NDList decoderInputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the forward function of the encoder and the decoder.
|
BlockList |
getChildren()
Returns a list of all the children of the block.
|
java.util.List<Parameter> |
getDirectParameters()
Returns a list of all the direct parameters of the block.
|
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.
|
Shape |
getParameterShape(java.lang.String name,
Shape[] inputShapes)
Returns the shape of the specified direct parameter of this block given the shapes of the
input to the block.
|
Shape[] |
initialize(NDManager manager,
DataType dataType,
Shape... inputShapes)
Initializes the parameters of the block.
|
void |
loadParameters(NDManager manager,
java.io.DataInputStream is)
Loads the parameters from the given input stream.
|
void |
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.
|
beforeInitialize, cast, clear, getParameters, isInitialized, readInputShapes, saveInputShapes, setInitializer, setInitializer
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
validateLayout
public ai.djl.util.PairList<java.lang.String,Shape> describeInput()
PairList
of input names, and shapes.describeInput
in interface Block
describeInput
in class AbstractBlock
PairList
of input names, and shapespublic NDList forward(ParameterStore parameterStore, NDList encoderInputs, NDList decoderInputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
parameterStore
- the parameter storeencoderInputs
- the input for the encoderdecoderInputs
- the input for the decodertraining
- true to run a training forward passparams
- optional parameterspublic NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
This forward function in the EncoderDecoder
class assumes the input NDList
contains both the encoder and decoder inputs. Further, it assumes that the first index in the
input NDList
contains the encoder input and the second index contains the decoder
input.
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true if running a forward pass for trainingparams
- optional parameterspublic NDList forward(ParameterStore parameterStore, NDList inputs, boolean training)
This forward function in the EncoderDecoder
class assumes the input NDList
contains both the encoder and decoder inputs. Further, it assumes that the first index in the
input NDList
contains the encoder input and the second index contains the decoder
input.
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true if running a forward pass for trainingpublic Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes)
This method assumes that inputShapes contains encoder and decoder inputs in index 0 and 1 respectively.
manager
- the NDManager to initialize the parametersdataType
- the datatype of the parametersinputShapes
- the shapes of the inputs to the blockpublic BlockList getChildren()
public java.util.List<Parameter> getDirectParameters()
Parameter
public Shape getParameterShape(java.lang.String name, Shape[] inputShapes)
name
- the name of the parameterinputShapes
- the shapes of the input to the blockpublic Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes)
manager
- an NDManagerinputShapes
- the shapes of the inputspublic void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
os
- the outputstream to save the parameters tojava.io.IOException
- if an I/O error occurspublic void loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter valuesjava.io.IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupported