Class TrainableWordEmbedding

All Implemented Interfaces:
WordEmbedding, Block, AbstractEmbedding<String>, AbstractIndexedEmbedding<String>

public class TrainableWordEmbedding extends Embedding<String> implements WordEmbedding
TrainableWordEmbedding is an implementation of WordEmbedding and Embedding based on a DefaultVocabulary. This WordEmbedding is ideal when there are no pre-trained embeddings available.
  • Constructor Details

  • Method Details

    • fromPretrained

      public static TrainableWordEmbedding fromPretrained(NDArray embedding, List<String> items)
      Constructs a pretrained embedding.

      Because it is created with preTrained data, it is created as a frozen block. If you with to update it, call Block.freezeParameters(boolean).

      Parameters:
      embedding - the embedding array
      items - the items in the embedding (in matching order to the embedding array)
      Returns:
      the created embedding
    • fromPretrained

      public static TrainableWordEmbedding fromPretrained(NDArray embedding, List<String> items, SparseFormat sparseFormat)
      Constructs a pretrained embedding.

      Because it is created with preTrained data, it is created as a frozen block. If you with to update it, call Block.freezeParameters(boolean).

      Parameters:
      embedding - the embedding array
      items - the items in the embedding (in matching order to the embedding array)
      sparseFormat - whether to compute row sparse gradient in the backward calculation
      Returns:
      the created embedding
    • vocabularyContains

      public boolean vocabularyContains(String word)
      Returns whether an embedding exists for a word.
      Specified by:
      vocabularyContains in interface WordEmbedding
      Parameters:
      word - the word to check
      Returns:
      true if an embedding exists
    • preprocessWordToEmbed

      public long preprocessWordToEmbed(String word)
      Pre-processes the word to embed into an array to pass into the model.

      Make sure to call WordEmbedding.embedWord(NDManager, long) after this.

      Specified by:
      preprocessWordToEmbed in interface WordEmbedding
      Parameters:
      word - the word to embed
      Returns:
      the word that is ready to embed
    • embedWord

      public NDArray embedWord(NDArray index) throws EmbeddingException
      Embeds the word after preprocessed using WordEmbedding.preprocessWordToEmbed(String).
      Specified by:
      embedWord in interface WordEmbedding
      Parameters:
      index - the index of the word to embed
      Returns:
      the embedded word
      Throws:
      EmbeddingException - if there is an error while trying to embed
    • unembedWord

      public String unembedWord(NDArray word)
      Returns the closest matching word for the given index.
      Specified by:
      unembedWord in interface WordEmbedding
      Parameters:
      word - the word embedding to find the matching string word for.
      Returns:
      a word similar to the passed in embedding
    • encode

      public byte[] encode(String input)
      Encodes an object of input type into a byte array. This is used in saving and loading the Embedding objects.
      Specified by:
      encode in interface AbstractIndexedEmbedding<String>
      Parameters:
      input - the input object to be encoded
      Returns:
      the encoded byte array.
    • decode

      public String decode(byte[] byteArray)
      Decodes the given byte array into an object of input parameter type.
      Specified by:
      decode in interface AbstractIndexedEmbedding<String>
      Parameters:
      byteArray - the byte array to be decoded
      Returns:
      the decode object of input parameter type
    • embed

      public long embed(String item)
      Embeds an item.
      Specified by:
      embed in interface AbstractIndexedEmbedding<String>
      Parameters:
      item - the item to embed
      Returns:
      the index of the item in the embedding
    • unembed

      public Optional<String> unembed(long index)
      Returns the item corresponding to the given index.
      Specified by:
      unembed in interface AbstractIndexedEmbedding<String>
      Parameters:
      index - the index
      Returns:
      the item corresponding to the given index
    • builder

      public static TrainableWordEmbedding.Builder builder()
      Creates a builder to build an Embedding.
      Returns:
      a new builder
    • hasItem

      public boolean hasItem(String item)
      Returns whether an item is in the embedding.
      Specified by:
      hasItem in interface AbstractEmbedding<String>
      Parameters:
      item - the item to test
      Returns:
      true if the item is in the embedding