Package ai.djl.modality.nlp.generate
Class TextGenerator
java.lang.Object
ai.djl.modality.nlp.generate.TextGenerator
TextGenerator
is an LMSearch (language model search) which contains multiple
autoregressive search methods.
It has a Predictor from NDList to CausalLMOutput, which is called inside an autoregressive inference loop.
-
Constructor Summary
ConstructorsConstructorDescriptionTextGenerator
(Predictor<NDList, CausalLMOutput> predictor, String searchName, SearchConfig searchConfig) Constructs a newTextGenerator
instance. -
Method Summary
Modifier and TypeMethodDescriptionbeamSearch
(NDArray inputIds) Generates text using beam search.contrastiveSearch
(NDArray inputIds) Generates text using contrastive search.Generate function call to generate text.long[]
Returns the end position of each sentence induced by EOS tokenId or reaching maxSeqLength.Returns the value of the positionOffset.greedySearch
(NDArray inputIds) Executes greedy search.
-
Constructor Details
-
TextGenerator
public TextGenerator(Predictor<NDList, CausalLMOutput> predictor, String searchName, SearchConfig searchConfig) Constructs a newTextGenerator
instance.- Parameters:
predictor
- the language modelsearchName
- the autoregressive search namesearchConfig
- the autoregressive search configuration
-
-
Method Details
-
greedySearch
Executes greedy search.- Parameters:
inputIds
- the input token ids.- Returns:
- the output token ids stored as NDArray and the endPosition of each sentence
- Throws:
TranslateException
- if forward fails
-
beamSearch
Generates text using beam search.- Parameters:
inputIds
- input tokens ids- Returns:
- the output token ids stored as NDArray and the endPosition of each sentence
- Throws:
TranslateException
- if failed run forward- See Also:
-
contrastiveSearch
Generates text using contrastive search.- Parameters:
inputIds
- input token ids- Returns:
- the output token ids stored as NDArray
- Throws:
TranslateException
- if forward failed- See Also:
-
generate
Generate function call to generate text.- Parameters:
inputIds
- the input token ids- Returns:
- generated token ids
- Throws:
TranslateException
- if prediction fails
-
getPositionOffset
Returns the value of the positionOffset.- Returns:
- the value of positionOffset
-
getEndPosition
public long[] getEndPosition()Returns the end position of each sentence induced by EOS tokenId or reaching maxSeqLength.- Returns:
- the end position of each sentence
-