Package ai.djl.modality.nlp.generate
Class TextGenerator
- java.lang.Object
-
- ai.djl.modality.nlp.generate.TextGenerator
-
public class TextGenerator extends java.lang.Object
TextGenerator
is an LMSearch (language model search) which contains multiple autoregressive search methods.It has a Predictor from NDList to CausalLMOutput, which is called inside an autoregressive inference loop.
-
-
Constructor Summary
Constructors Constructor Description TextGenerator(Predictor<NDList,CausalLMOutput> predictor, java.lang.String searchName, SearchConfig searchConfig)
Constructs a newTextGenerator
instance.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArray
beamSearch(NDArray inputIds)
Generates text using beam search.NDArray
contrastiveSearch(NDArray inputIds)
Generates text using contrastive search.NDArray
forward(NDArray inputIds)
Forward function call to generate text.NDArray
greedySearch(NDArray inputIds)
Executes greedy search.
-
-
-
Constructor Detail
-
TextGenerator
public TextGenerator(Predictor<NDList,CausalLMOutput> predictor, java.lang.String searchName, SearchConfig searchConfig)
Constructs a newTextGenerator
instance.- Parameters:
predictor
- the language modelsearchName
- the autoregressive search namesearchConfig
- the autoregressive search configuration
-
-
Method Detail
-
greedySearch
public NDArray greedySearch(NDArray inputIds) throws TranslateException
Executes greedy search.- Parameters:
inputIds
- the input token ids.- Returns:
- the output token ids stored as NDArray
- Throws:
TranslateException
- if forward fails
-
beamSearch
public NDArray beamSearch(NDArray inputIds) throws TranslateException
Generates text using beam search.- Parameters:
inputIds
- input tokens ids- Returns:
- output tensor
- Throws:
TranslateException
- if failed run forward- See Also:
- Beam Search
-
contrastiveSearch
public NDArray contrastiveSearch(NDArray inputIds) throws TranslateException
Generates text using contrastive search.- Parameters:
inputIds
- input token ids- Returns:
- the generated
NDArray
- Throws:
TranslateException
- if forward failed- See Also:
- Contrastive Search
-
forward
public NDArray forward(NDArray inputIds) throws TranslateException
Forward function call to generate text.- Parameters:
inputIds
- the input token ids- Returns:
- generated token ids
- Throws:
TranslateException
- if prediction fails
-
-