Class TrainableTextEmbedding

java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.modality.nlp.embedding.TrainableTextEmbedding
All Implemented Interfaces:
TextEmbedding, Block

public class TrainableTextEmbedding extends AbstractBlock implements TextEmbedding
TrainableTextEmbedding is an implementation of TextEmbedding based on TrainableWordEmbedding block. This TextEmbedding is ideal when there are no pre-trained embeddings available, or when the pre-trained embedding needs to be further trained.
  • Constructor Details

  • Method Details

    • preprocessTextToEmbed

      public long[] preprocessTextToEmbed(List<String> text)
      Preprocesses the text to embed into an array to pass into the model.

      Make sure to call TextEmbedding.embedText(NDManager, long[]) after this.

      Specified by:
      preprocessTextToEmbed in interface TextEmbedding
      Parameters:
      text - the text to embed
      Returns:
      the indices of text that is ready to embed
    • embedText

      public NDArray embedText(NDArray textIndices) throws EmbeddingException
      Embeds the text after preprocessed using TextEmbedding.preprocessTextToEmbed(List).
      Specified by:
      embedText in interface TextEmbedding
      Parameters:
      textIndices - the indices of text to embed
      Returns:
      the embedded text
      Throws:
      EmbeddingException - if there is an error while trying to embed
    • unembedText

      public List<String> unembedText(NDArray textEmbedding)
      Returns the closest matching text for a given embedding.
      Specified by:
      unembedText in interface TextEmbedding
      Parameters:
      textEmbedding - the text embedding to find the matching string text for.
      Returns:
      text similar to the passed in embedding
    • forwardInternal

      protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • initializeChildBlocks

      public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
      Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
      Overrides:
      initializeChildBlocks in class AbstractBaseBlock
      Parameters:
      manager - the manager to use for initialization
      dataType - the requested data type
      inputShapes - the expected input shapes for this block
    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputShapes)
      Returns the expected output shapes of the block for the specified input shapes.
      Specified by:
      getOutputShapes in interface Block
      Parameters:
      inputShapes - the shapes of the inputs
      Returns:
      the expected output shapes of the block