Package ai.djl.modality.nlp.generate
Class StepGeneration
java.lang.Object
ai.djl.modality.nlp.generate.StepGeneration
StepGeneration is a utility class containing the step generation utility functions used
in autoregressive search.-
Method Summary
Modifier and TypeMethodDescriptionstatic NDListbeamStepGeneration(NDArray lastProbs, NDArray logits, long numBatch, long numBeam) Generates the output token id and selecting indices used in beam search.static NDListconstrastiveStepGeneration(NDArray topKIds, NDArray logits, NDArray contextHiddenStates, NDArray topkHiddenStates, NDArray offSets, float alpha) Generate the output token id and selecting indices used in contrastive search.static NDArraygreedyStepGen(NDArray logits) Generates the output token id for greedy search.
-
Method Details
-
constrastiveStepGeneration
public static NDList constrastiveStepGeneration(NDArray topKIds, NDArray logits, NDArray contextHiddenStates, NDArray topkHiddenStates, NDArray offSets, float alpha) Generate the output token id and selecting indices used in contrastive search.- Parameters:
topKIds- the topk candidate token idslogits- the logits from the language modelcontextHiddenStates- the embedding of the past generated token idstopkHiddenStates- the embedding of the topk candidate token idsoffSets- the offsetsalpha- the repetition penalty- Returns:
- the output token ids and selecting indices
-
greedyStepGen
Generates the output token id for greedy search.- Parameters:
logits- the logits from the language model- Returns:
- the output token ids
-
beamStepGeneration
public static NDList beamStepGeneration(NDArray lastProbs, NDArray logits, long numBatch, long numBeam) Generates the output token id and selecting indices used in beam search.- Parameters:
lastProbs- the probabilities of the past prefix sequenceslogits- the logitsnumBatch- number of batchnumBeam- number of beam- Returns:
- the output token ids and selecting indices
-