Class TextGenerator

java.lang.Object
ai.djl.modality.nlp.generate.TextGenerator

public class TextGenerator extends Object
TextGenerator is an LMSearch (language model search) which contains multiple autoregressive search methods.

It has a Predictor from NDList to CausalLMOutput, which is called inside an autoregressive inference loop.

  • Constructor Details

    • TextGenerator

      public TextGenerator(Predictor<NDList,CausalLMOutput> predictor, String searchName, SearchConfig searchConfig)
      Constructs a new TextGenerator instance.
      Parameters:
      predictor - the language model
      searchName - the autoregressive search name
      searchConfig - the autoregressive search configuration
  • Method Details

    • greedySearch

      public NDArray greedySearch(NDArray inputIds) throws TranslateException
      Executes greedy search.
      Parameters:
      inputIds - the input token ids.
      Returns:
      the output token ids stored as NDArray and the endPosition of each sentence
      Throws:
      TranslateException - if forward fails
    • beamSearch

      public NDArray beamSearch(NDArray inputIds) throws TranslateException
      Generates text using beam search.
      Parameters:
      inputIds - input tokens ids
      Returns:
      the output token ids stored as NDArray and the endPosition of each sentence
      Throws:
      TranslateException - if failed run forward
      See Also:
    • contrastiveSearch

      public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException
      Generates text using contrastive search.
      Parameters:
      inputIds - input token ids
      Returns:
      the output token ids stored as NDArray
      Throws:
      TranslateException - if forward failed
      See Also:
    • generate

      public NDArray generate(NDArray inputIds) throws TranslateException
      Generate function call to generate text.
      Parameters:
      inputIds - the input token ids
      Returns:
      generated token ids
      Throws:
      TranslateException - if prediction fails
    • getPositionOffset

      public NDArray getPositionOffset()
      Returns the value of the positionOffset.
      Returns:
      the value of positionOffset
    • getEndPosition

      public long[] getEndPosition()
      Returns the end position of each sentence induced by EOS tokenId or reaching maxSeqLength.
      Returns:
      the end position of each sentence