Class Embedding<T>

    • Field Detail

      • numEmbeddings

        protected int numEmbeddings
      • embeddingSize

        protected int embeddingSize
    • Constructor Detail

      • Embedding

        public Embedding​(NDArray embedding)
        Constructs a pretrained embedding.
        Parameters:
        embedding - the embedding array
      • Embedding

        public Embedding​(NDArray embedding,
                         SparseFormat format)
        Constructs a pretrained embedding.
        Parameters:
        embedding - the embedding array
        format - whether to compute row sparse gradient in the backward calculation
    • Method Detail

      • prepare

        public void prepare​(Shape[] inputShapes)
        Sets the shape of Parameters.
        Overrides:
        prepare in class AbstractBlock
        Parameters:
        inputShapes - the shapes of inputs
      • getOutputShapes

        public Shape[] getOutputShapes​(Shape[] inputShapes)
        Returns the expected output shapes of the block for the specified input shapes.
        Specified by:
        getOutputShapes in interface Block
        Parameters:
        inputShapes - the shapes of the inputs
        Returns:
        the expected output shapes of the block
      • saveParameters

        public void saveParameters​(java.io.DataOutputStream os)
                            throws java.io.IOException
        Writes the parameters of the block to the given outputStream.
        Specified by:
        saveParameters in interface Block
        Overrides:
        saveParameters in class AbstractBlock
        Parameters:
        os - the outputstream to save the parameters to
        Throws:
        java.io.IOException - if an I/O error occurs
      • loadParameters

        public void loadParameters​(NDManager manager,
                                   java.io.DataInputStream is)
                            throws java.io.IOException,
                                   MalformedModelException
        Loads the parameters from the given input stream.
        Specified by:
        loadParameters in interface Block
        Overrides:
        loadParameters in class AbstractBlock
        Parameters:
        manager - an NDManager to create the parameter arrays
        is - the inputstream that stream the parameter values
        Throws:
        java.io.IOException - if an I/O error occurs
        MalformedModelException - if the model file is corrupted or unsupported
      • embed

        public NDArray embed​(NDManager manager,
                             T[] items)
        Embeds an array of items.
        Specified by:
        embed in interface AbstractEmbedding<T>
        Parameters:
        manager - the manager for the new embeddings
        items - the items to embed
        Returns:
        the embedding NDArray of Shape(items.length, embeddingSize)
      • embedding

        public static NDList embedding​(NDArray input,
                                       NDArray weight,
                                       SparseFormat sparse)
        A simple lookup table that looks up embeddings in a fixed dictionary and size.
        Parameters:
        input - NDArray containing indices into the embedding matrix
        weight - The embedding matrix with number of rows equal to the maximum possible index + 1, and number of columns equal to the embedding size
        sparse - SparseFormat of the gradient
        Returns:
        output NDArray