Class TrainableWordEmbedding

    • Constructor Detail

      • TrainableWordEmbedding

        public TrainableWordEmbedding​(Vocabulary vocabulary,
                                      int embeddingSize)
        Constructs a new instance of TrainableWordEmbedding from a DefaultVocabulary and a given embedding size.
        Parameters:
        vocabulary - a Vocabulary to get tokens from
        embeddingSize - the required embedding size
      • TrainableWordEmbedding

        public TrainableWordEmbedding​(NDArray embedding,
                                      java.util.List<java.lang.String> items)
        Constructs a pretrained embedding.
        Parameters:
        embedding - the embedding array
        items - the items in the embedding (in matching order to the embedding array)
      • TrainableWordEmbedding

        public TrainableWordEmbedding​(NDArray embedding,
                                      java.util.List<java.lang.String> items,
                                      SparseFormat sparseFormat)
        Constructs a pretrained embedding.
        Parameters:
        embedding - the embedding array
        items - the items in the embedding (in matching order to the embedding array)
        sparseFormat - whether to compute row sparse gradient in the backward calculation
    • Method Detail

      • vocabularyContains

        public boolean vocabularyContains​(java.lang.String word)
        Returns whether an embedding exists for a word.
        Specified by:
        vocabularyContains in interface WordEmbedding
        Parameters:
        word - the word to check
        Returns:
        true if an embedding exists
      • preprocessWordToEmbed

        public long preprocessWordToEmbed​(java.lang.String word)
        Pre-processes the word to embed into an array to pass into the model.

        Make sure to call WordEmbedding.embedWord(NDManager, long) after this.

        Specified by:
        preprocessWordToEmbed in interface WordEmbedding
        Parameters:
        word - the word to embed
        Returns:
        the word that is ready to embed
      • unembedWord

        public java.lang.String unembedWord​(NDArray word)
        Returns the closest matching word for the given index.
        Specified by:
        unembedWord in interface WordEmbedding
        Parameters:
        word - the word embedding to find the matching string word for.
        Returns:
        a word similar to the passed in embedding
      • encode

        public byte[] encode​(java.lang.String input)
        Encodes an object of input type into a byte array. This is used in saving and loading the Embedding objects.
        Specified by:
        encode in interface AbstractIndexedEmbedding<java.lang.String>
        Parameters:
        input - the input object to be encoded
        Returns:
        the encoded byte array.
      • decode

        public java.lang.String decode​(byte[] byteArray)
        Decodes the given byte array into an object of input parameter type.
        Specified by:
        decode in interface AbstractIndexedEmbedding<java.lang.String>
        Parameters:
        byteArray - the byte array to be decoded
        Returns:
        the decode object of input parameter type
      • embed

        public long embed​(java.lang.String item)
        Embeds an item.
        Specified by:
        embed in interface AbstractIndexedEmbedding<java.lang.String>
        Parameters:
        item - the item to embed
        Returns:
        the index of the item in the embedding
      • unembed

        public java.util.Optional<java.lang.String> unembed​(long index)
        Returns the item corresponding to the given index.
        Specified by:
        unembed in interface AbstractIndexedEmbedding<java.lang.String>
        Parameters:
        index - the index
        Returns:
        the item corresponding to the given index
      • hasItem

        public boolean hasItem​(java.lang.String item)
        Returns whether an item is in the embedding.
        Specified by:
        hasItem in interface AbstractEmbedding<java.lang.String>
        Parameters:
        item - the item to test
        Returns:
        true if the item is in the embedding