Package ai.djl.modality.nlp.generate
Class TextGenerator
- java.lang.Object
-
- ai.djl.modality.nlp.generate.TextGenerator
-
public class TextGenerator extends java.lang.Object
TextGenerator
is an LMSearch (language model search) which contains multiple autoregressive search methods.It has a Predictor from NDList to CausalLMOutput, which is called inside an autoregressive inference loop.
-
-
Constructor Summary
Constructors Constructor Description TextGenerator(Predictor<NDList,CausalLMOutput> predictor, java.lang.String searchName, SearchConfig searchConfig)
Constructs a newTextGenerator
instance.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArray
beamSearch(NDArray inputIds)
Generates text using beam search.NDArray
contrastiveSearch(NDArray inputIds)
Generates text using contrastive search.NDArray
generate(NDArray inputIds)
Generate function call to generate text.long[]
getEndPosition()
Returns the end position of each sentence induced by EOS tokenId or reaching maxSeqLength.NDArray
getPositionOffset()
Returns the value of the positionOffset.NDArray
greedySearch(NDArray inputIds)
Executes greedy search.
-
-
-
Constructor Detail
-
TextGenerator
public TextGenerator(Predictor<NDList,CausalLMOutput> predictor, java.lang.String searchName, SearchConfig searchConfig)
Constructs a newTextGenerator
instance.- Parameters:
predictor
- the language modelsearchName
- the autoregressive search namesearchConfig
- the autoregressive search configuration
-
-
Method Detail
-
greedySearch
public NDArray greedySearch(NDArray inputIds) throws TranslateException
Executes greedy search.- Parameters:
inputIds
- the input token ids.- Returns:
- the output token ids stored as NDArray and the endPosition of each sentence
- Throws:
TranslateException
- if forward fails
-
beamSearch
public NDArray beamSearch(NDArray inputIds) throws TranslateException
Generates text using beam search.- Parameters:
inputIds
- input tokens ids- Returns:
- the output token ids stored as NDArray and the endPosition of each sentence
- Throws:
TranslateException
- if failed run forward- See Also:
- Beam Search
-
contrastiveSearch
public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException
Generates text using contrastive search.- Parameters:
inputIds
- input token ids- Returns:
- the output token ids stored as NDArray
- Throws:
TranslateException
- if forward failed- See Also:
- Contrastive Search
-
generate
public NDArray generate(NDArray inputIds) throws TranslateException
Generate function call to generate text.- Parameters:
inputIds
- the input token ids- Returns:
- generated token ids
- Throws:
TranslateException
- if prediction fails
-
getPositionOffset
public NDArray getPositionOffset()
Returns the value of the positionOffset.- Returns:
- the value of positionOffset
-
getEndPosition
public long[] getEndPosition()
Returns the end position of each sentence induced by EOS tokenId or reaching maxSeqLength.- Returns:
- the end position of each sentence
-
-