Package ai.djl.modality.nlp.generate
Class TextGenerator
java.lang.Object
ai.djl.modality.nlp.generate.TextGenerator
TextGenerator is an LMSearch (language model search) which contains multiple
autoregressive search methods.
It has a Predictor from NDList to CausalLMOutput, which is called inside an autoregressive inference loop.
-
Constructor Summary
ConstructorsConstructorDescriptionTextGenerator(Predictor<NDList, CausalLMOutput> predictor, String searchName, SearchConfig searchConfig) Constructs a newTextGeneratorinstance. -
Method Summary
Modifier and TypeMethodDescriptionbeamSearch(NDArray inputIds) Generates text using beam search.contrastiveSearch(NDArray inputIds) Generates text using contrastive search.Generate function call to generate text.long[]Returns the end position of each sentence induced by EOS tokenId or reaching maxSeqLength.Returns the value of the positionOffset.greedySearch(NDArray inputIds) Executes greedy search.
-
Constructor Details
-
TextGenerator
public TextGenerator(Predictor<NDList, CausalLMOutput> predictor, String searchName, SearchConfig searchConfig) Constructs a newTextGeneratorinstance.- Parameters:
predictor- the language modelsearchName- the autoregressive search namesearchConfig- the autoregressive search configuration
-
-
Method Details
-
greedySearch
Executes greedy search.- Parameters:
inputIds- the input token ids.- Returns:
- the output token ids stored as NDArray and the endPosition of each sentence
- Throws:
TranslateException- if forward fails
-
beamSearch
Generates text using beam search.- Parameters:
inputIds- input tokens ids- Returns:
- the output token ids stored as NDArray and the endPosition of each sentence
- Throws:
TranslateException- if failed run forward- See Also:
-
contrastiveSearch
Generates text using contrastive search.- Parameters:
inputIds- input token ids- Returns:
- the output token ids stored as NDArray
- Throws:
TranslateException- if forward failed- See Also:
-
generate
Generate function call to generate text.- Parameters:
inputIds- the input token ids- Returns:
- generated token ids
- Throws:
TranslateException- if prediction fails
-
getPositionOffset
Returns the value of the positionOffset.- Returns:
- the value of positionOffset
-
getEndPosition
public long[] getEndPosition()Returns the end position of each sentence induced by EOS tokenId or reaching maxSeqLength.- Returns:
- the end position of each sentence
-