Package ai.djl.modality.nlp.generate
Class StepGeneration
- java.lang.Object
-
- ai.djl.modality.nlp.generate.StepGeneration
-
public final class StepGeneration extends java.lang.ObjectStepGenerationis a utility class containing the step generation utility functions used in autoregressive search.
-
-
Method Summary
All Methods Static Methods Concrete Methods Modifier and Type Method Description static NDListbeamStepGeneration(NDArray lastProbs, NDArray logits, long numBatch, long numBeam)Generates the output token id and selecting indices used in beam search.static NDListconstrastiveStepGeneration(NDArray topKIds, NDArray logits, NDArray contextHiddenStates, NDArray topkHiddenStates, NDArray offSets, float alpha)Generate the output token id and selecting indices used in contrastive search.static NDArraygreedyStepGen(NDArray logits)Generates the output token id for greedy search.
-
-
-
Method Detail
-
constrastiveStepGeneration
public static NDList constrastiveStepGeneration(NDArray topKIds, NDArray logits, NDArray contextHiddenStates, NDArray topkHiddenStates, NDArray offSets, float alpha)
Generate the output token id and selecting indices used in contrastive search.- Parameters:
topKIds- the topk candidate token idslogits- the logits from the language modelcontextHiddenStates- the embedding of the past generated token idstopkHiddenStates- the embedding of the topk candidate token idsoffSets- the offsetsalpha- the repetition penalty- Returns:
- the output token ids and selecting indices
-
greedyStepGen
public static NDArray greedyStepGen(NDArray logits)
Generates the output token id for greedy search.- Parameters:
logits- the logits from the language model- Returns:
- the output token ids
-
beamStepGeneration
public static NDList beamStepGeneration(NDArray lastProbs, NDArray logits, long numBatch, long numBeam)
Generates the output token id and selecting indices used in beam search.- Parameters:
lastProbs- the probabilities of the past prefix sequenceslogits- the logitsnumBatch- number of batchnumBeam- number of beam- Returns:
- the output token ids and selecting indices
-
-