Class StepGeneration

java.lang.Object
ai.djl.modality.nlp.generate.StepGeneration

public final class StepGeneration extends Object
StepGeneration is a utility class containing the step generation utility functions used in autoregressive search.
  • Method Details

    • constrastiveStepGeneration

      public static NDList constrastiveStepGeneration(NDArray topKIds, NDArray logits, NDArray contextHiddenStates, NDArray topkHiddenStates, NDArray offSets, float alpha)
      Generate the output token id and selecting indices used in contrastive search.
      Parameters:
      topKIds - the topk candidate token ids
      logits - the logits from the language model
      contextHiddenStates - the embedding of the past generated token ids
      topkHiddenStates - the embedding of the topk candidate token ids
      offSets - the offsets
      alpha - the repetition penalty
      Returns:
      the output token ids and selecting indices
    • greedyStepGen

      public static NDArray greedyStepGen(NDArray logits)
      Generates the output token id for greedy search.
      Parameters:
      logits - the logits from the language model
      Returns:
      the output token ids
    • beamStepGeneration

      public static NDList beamStepGeneration(NDArray lastProbs, NDArray logits, long numBatch, long numBeam)
      Generates the output token id and selecting indices used in beam search.
      Parameters:
      lastProbs - the probabilities of the past prefix sequences
      logits - the logits
      numBatch - number of batch
      numBeam - number of beam
      Returns:
      the output token ids and selecting indices