Package ai.djl.modality.nlp.generate
Class StepGeneration
java.lang.Object
ai.djl.modality.nlp.generate.StepGeneration
StepGeneration
is a utility class containing the step generation utility functions used
in autoregressive search.-
Method Summary
Modifier and TypeMethodDescriptionstatic NDList
beamStepGeneration
(NDArray lastProbs, NDArray logits, long numBatch, long numBeam) Generates the output token id and selecting indices used in beam search.static NDList
constrastiveStepGeneration
(NDArray topKIds, NDArray logits, NDArray contextHiddenStates, NDArray topkHiddenStates, NDArray offSets, float alpha) Generate the output token id and selecting indices used in contrastive search.static NDArray
greedyStepGen
(NDArray logits) Generates the output token id for greedy search.
-
Method Details
-
constrastiveStepGeneration
public static NDList constrastiveStepGeneration(NDArray topKIds, NDArray logits, NDArray contextHiddenStates, NDArray topkHiddenStates, NDArray offSets, float alpha) Generate the output token id and selecting indices used in contrastive search.- Parameters:
topKIds
- the topk candidate token idslogits
- the logits from the language modelcontextHiddenStates
- the embedding of the past generated token idstopkHiddenStates
- the embedding of the topk candidate token idsoffSets
- the offsetsalpha
- the repetition penalty- Returns:
- the output token ids and selecting indices
-
greedyStepGen
Generates the output token id for greedy search.- Parameters:
logits
- the logits from the language model- Returns:
- the output token ids
-
beamStepGeneration
public static NDList beamStepGeneration(NDArray lastProbs, NDArray logits, long numBatch, long numBeam) Generates the output token id and selecting indices used in beam search.- Parameters:
lastProbs
- the probabilities of the past prefix sequenceslogits
- the logitsnumBatch
- number of batchnumBeam
- number of beam- Returns:
- the output token ids and selecting indices
-