Class SeqBatchScheduler

  • Direct Known Subclasses:
    ContrastiveSeqBatchScheduler

    public abstract class SeqBatchScheduler
    extends java.lang.Object
    This is a scheduler, serving as an API to the consumer of the system, allowing for three major actions: initForward, addBatch, fastForward, collectResults. An optimal control sequence should be solved, after considering the time consumption of each action, the batch size and sequence length of queueing requests. Such optimal control solver needs additional effort. Primitive policy is setting several thresholds.
    • Constructor Detail

      • SeqBatchScheduler

        public SeqBatchScheduler​(Predictor<NDList,​CausalLMOutput> lmBlock,
                                 SearchConfig config)
        Constructs a new SeqBatchScheduler instance.
        Parameters:
        lmBlock - the predictor that cont
        config - the search parameter configuration
    • Method Detail

      • initForward

        public abstract SeqBatcher initForward​(NDArray inputIds,
                                               NDArray batchUids)
                                        throws TranslateException
        Initializes the iteration and SeqBatcher.
        Parameters:
        inputIds - the input token ids.
        batchUids - the request uid identifying a sequence
        Returns:
        SeqBatcher Stores the search state and operate on the BatchTensorList
        Throws:
        TranslateException - if forward fails
      • incrementForward

        public boolean incrementForward​(int count)
                                 throws TranslateException
        Executes forward for a given number of iterations.
        Parameters:
        count - the time of forward calls
        Returns:
        boolean Indicate whether the Batch is empty
        Throws:
        TranslateException - if forward fails
      • addRequest

        public void addRequest​(NDArray inputIds,
                               NDArray batchUids)
                        throws TranslateException
        Adds new batch.
        Parameters:
        inputIds - the input token ids.
        batchUids - the request uid identifying a sequence
        Throws:
        TranslateException - if forward fails
      • collectResults

        public java.util.Map<java.lang.Long,​NDArray> collectResults()
        Collects finished results.
        Returns:
        the outputs stored as a map from requestUid to output token ids