Package ai.djl.modality.nlp.generate
Class TextGenerator
- java.lang.Object
-
- ai.djl.modality.nlp.generate.TextGenerator
-
public class TextGenerator extends java.lang.ObjectTextGeneratoris an LMSearch (language model search) which contains multiple autoregressive search methods.It has a Predictor from NDList to CausalLMOutput, which is called inside an autoregressive inference loop.
-
-
Constructor Summary
Constructors Constructor Description TextGenerator(Predictor<NDList,CausalLMOutput> predictor, java.lang.String searchName, SearchConfig searchConfig)Constructs a newTextGeneratorinstance.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArraybeamSearch(NDArray inputIds)Generates text using beam search.NDArraycontrastiveSearch(NDArray inputIds)Generates text using contrastive search.NDArraygenerate(NDArray inputIds)Generate function call to generate text.long[]getEndPosition()Returns the end position of each sentence induced by EOS tokenId or reaching maxSeqLength.NDArraygetPositionOffset()Returns the value of the positionOffset.NDArraygreedySearch(NDArray inputIds)Executes greedy search.
-
-
-
Constructor Detail
-
TextGenerator
public TextGenerator(Predictor<NDList,CausalLMOutput> predictor, java.lang.String searchName, SearchConfig searchConfig)
Constructs a newTextGeneratorinstance.- Parameters:
predictor- the language modelsearchName- the autoregressive search namesearchConfig- the autoregressive search configuration
-
-
Method Detail
-
greedySearch
public NDArray greedySearch(NDArray inputIds) throws TranslateException
Executes greedy search.- Parameters:
inputIds- the input token ids.- Returns:
- the output token ids stored as NDArray and the endPosition of each sentence
- Throws:
TranslateException- if forward fails
-
beamSearch
public NDArray beamSearch(NDArray inputIds) throws TranslateException
Generates text using beam search.- Parameters:
inputIds- input tokens ids- Returns:
- the output token ids stored as NDArray and the endPosition of each sentence
- Throws:
TranslateException- if failed run forward- See Also:
- Beam Search
-
contrastiveSearch
public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException
Generates text using contrastive search.- Parameters:
inputIds- input token ids- Returns:
- the output token ids stored as NDArray
- Throws:
TranslateException- if forward failed- See Also:
- Contrastive Search
-
generate
public NDArray generate(NDArray inputIds) throws TranslateException
Generate function call to generate text.- Parameters:
inputIds- the input token ids- Returns:
- generated token ids
- Throws:
TranslateException- if prediction fails
-
getPositionOffset
public NDArray getPositionOffset()
Returns the value of the positionOffset.- Returns:
- the value of positionOffset
-
getEndPosition
public long[] getEndPosition()
Returns the end position of each sentence induced by EOS tokenId or reaching maxSeqLength.- Returns:
- the end position of each sentence
-
-