Class StepGeneration


  • public final class StepGeneration
    extends java.lang.Object
    StepGeneration is a utility class containing the step generation utility functions used in autoregressive search.
    • Method Detail

      • constrastiveStepGeneration

        public static NDList constrastiveStepGeneration​(NDArray topKIds,
                                                        NDArray logits,
                                                        NDArray contextHiddenStates,
                                                        NDArray topkHiddenStates,
                                                        NDArray offSets,
                                                        float alpha)
        Generate the output token id and selecting indices used in contrastive search.
        Parameters:
        topKIds - the topk candidate token ids
        logits - the logits from the language model
        contextHiddenStates - the embedding of the past generated token ids
        topkHiddenStates - the embedding of the topk candidate token ids
        offSets - the offsets
        alpha - the repetition penalty
        Returns:
        the output token ids and selecting indices
      • greedyStepGen

        public static NDArray greedyStepGen​(NDArray logits)
        Generates the output token id for greedy search.
        Parameters:
        logits - the logits from the language model
        Returns:
        the output token ids
      • beamStepGeneration

        public static NDList beamStepGeneration​(NDArray lastProbs,
                                                NDArray logits,
                                                long numBatch,
                                                long numBeam)
        Generates the output token id and selecting indices used in beam search.
        Parameters:
        lastProbs - the probabilities of the past prefix sequences
        logits - the logits
        numBatch - number of batch
        numBeam - number of beam
        Returns:
        the output token ids and selecting indices