Package ai.djl.modality.nlp.generate
Class StepGeneration
- java.lang.Object
-
- ai.djl.modality.nlp.generate.StepGeneration
-
public final class StepGeneration extends java.lang.Object
StepGeneration
is a utility class containing the step generation utility functions used in autoregressive search.
-
-
Method Summary
All Methods Static Methods Concrete Methods Modifier and Type Method Description static NDList
beamStepGeneration(NDArray lastProbs, NDArray logits, long numBatch, long numBeam)
Generates the output token id and selecting indices used in beam search.static NDList
constrastiveStepGeneration(NDArray topKIds, NDArray logits, NDArray contextHiddenStates, NDArray topkHiddenStates, NDArray offSets, float alpha)
Generate the output token id and selecting indices used in contrastive search.static NDArray
greedyStepGen(NDArray logits)
Generates the output token id for greedy search.
-
-
-
Method Detail
-
constrastiveStepGeneration
public static NDList constrastiveStepGeneration(NDArray topKIds, NDArray logits, NDArray contextHiddenStates, NDArray topkHiddenStates, NDArray offSets, float alpha)
Generate the output token id and selecting indices used in contrastive search.- Parameters:
topKIds
- the topk candidate token idslogits
- the logits from the language modelcontextHiddenStates
- the embedding of the past generated token idstopkHiddenStates
- the embedding of the topk candidate token idsoffSets
- the offsetsalpha
- the repetition penalty- Returns:
- the output token ids and selecting indices
-
greedyStepGen
public static NDArray greedyStepGen(NDArray logits)
Generates the output token id for greedy search.- Parameters:
logits
- the logits from the language model- Returns:
- the output token ids
-
beamStepGeneration
public static NDList beamStepGeneration(NDArray lastProbs, NDArray logits, long numBatch, long numBeam)
Generates the output token id and selecting indices used in beam search.- Parameters:
lastProbs
- the probabilities of the past prefix sequenceslogits
- the logitsnumBatch
- number of batchnumBeam
- number of beam- Returns:
- the output token ids and selecting indices
-
-