Class SeqBatchScheduler

java.lang.Object
ai.djl.modality.nlp.generate.SeqBatchScheduler
Direct Known Subclasses:
ContrastiveSeqBatchScheduler

public abstract class SeqBatchScheduler extends Object
This is a scheduler, serving as an API to the consumer of the system, allowing for three major actions: initForward, addBatch, fastForward, collectResults. An optimal control sequence should be solved, after considering the time consumption of each action, the batch size and sequence length of queueing requests. Such optimal control solver needs additional effort. Primitive policy is setting several thresholds.
  • Constructor Details

    • SeqBatchScheduler

      public SeqBatchScheduler(Predictor<NDList,CausalLMOutput> lmBlock, SearchConfig config)
      Constructs a new SeqBatchScheduler instance.
      Parameters:
      lmBlock - the predictor that cont
      config - the search parameter configuration
  • Method Details

    • initForward

      public abstract SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException
      Initializes the iteration and SeqBatcher.
      Parameters:
      inputIds - the input token ids.
      batchUids - the request uid identifying a sequence
      Returns:
      SeqBatcher Stores the search state and operate on the BatchTensorList
      Throws:
      TranslateException - if forward fails
    • incrementForward

      public boolean incrementForward(int count) throws TranslateException
      Executes forward for a given number of iterations.
      Parameters:
      count - the time of forward calls
      Returns:
      boolean Indicate whether the Batch is empty
      Throws:
      TranslateException - if forward fails
    • inferenceCall

      protected abstract NDArray inferenceCall() throws TranslateException
      An inference call in an iteration.
      Returns:
      the output token ids
      Throws:
      TranslateException - if forward fails
    • addRequest

      public void addRequest(NDArray inputIds, NDArray batchUids) throws TranslateException
      Adds new batch.
      Parameters:
      inputIds - the input token ids.
      batchUids - the request uid identifying a sequence
      Throws:
      TranslateException - if forward fails
    • collectResults

      public Map<Long,NDArray> collectResults()
      Collects finished results.
      Returns:
      the outputs stored as a map from requestUid to output token ids