Class EncoderDecoder

  • All Implemented Interfaces:
    Block

    public class EncoderDecoder
    extends AbstractBlock
    EncoderDecoder is a general implementation of the very popular encoder-decoder architecture. This class depends on implementations of Encoder and Decoder to provide encoder-decoder architecture for different tasks and inputs such as machine translation(text-text), image captioning(image-text) etc.
    • Constructor Detail

      • EncoderDecoder

        public EncoderDecoder​(Encoder encoder,
                              Decoder decoder)
        Constructs a new instance of EncoderDecoder class with the given Encoder and Decoder.
        Parameters:
        encoder - the Encoder
        decoder - the Decoder
    • Method Detail

      • describeInput

        public ai.djl.util.PairList<java.lang.String,​Shape> describeInput()
        Returns a PairList of input names, and shapes.
        Specified by:
        describeInput in interface Block
        Overrides:
        describeInput in class AbstractBaseBlock
        Returns:
        the PairList of input names, and shapes
      • forward

        public NDList forward​(ParameterStore parameterStore,
                              NDList data,
                              NDList labels,
                              ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
        A forward call using both training data and labels.

        Within this forward call, it can be assumed that training is true.

        Specified by:
        forward in interface Block
        Overrides:
        forward in class AbstractBaseBlock
        Parameters:
        parameterStore - the parameter store
        data - the input data NDList
        labels - the input labels NDList
        params - optional parameters
        Returns:
        the output of the forward pass
        See Also:
        Block.forward(ParameterStore, NDList, boolean, PairList)
      • initialize

        public void initialize​(NDManager manager,
                               DataType dataType,
                               Shape... inputShapes)
        Initializes the parameters of the block. This method must be called before calling `forward`.

        This method assumes that inputShapes contains encoder and decoder inputs in index 0 and 1 respectively.

        Specified by:
        initialize in interface Block
        Overrides:
        initialize in class AbstractBaseBlock
        Parameters:
        manager - the NDManager to initialize the parameters
        dataType - the datatype of the parameters
        inputShapes - the shapes of the inputs to the block
      • getOutputShapes

        public Shape[] getOutputShapes​(Shape[] inputShapes)
        Returns the expected output shapes of the block for the specified input shapes.
        Parameters:
        inputShapes - the shapes of the inputs
        Returns:
        the expected output shapes of the block
      • saveParameters

        public void saveParameters​(java.io.DataOutputStream os)
                            throws java.io.IOException
        Writes the parameters of the block to the given outputStream.
        Specified by:
        saveParameters in interface Block
        Overrides:
        saveParameters in class AbstractBaseBlock
        Parameters:
        os - the outputstream to save the parameters to
        Throws:
        java.io.IOException - if an I/O error occurs
      • loadParameters

        public void loadParameters​(NDManager manager,
                                   java.io.DataInputStream is)
                            throws java.io.IOException,
                                   MalformedModelException
        Loads the parameters from the given input stream.
        Specified by:
        loadParameters in interface Block
        Overrides:
        loadParameters in class AbstractBaseBlock
        Parameters:
        manager - an NDManager to create the parameter arrays
        is - the inputstream that stream the parameter values
        Throws:
        java.io.IOException - if an I/O error occurs
        MalformedModelException - if the model file is corrupted or unsupported