Package ai.djl.modality.nlp.generate
Class SeqBatchScheduler
- java.lang.Object
-
- ai.djl.modality.nlp.generate.SeqBatchScheduler
-
- Direct Known Subclasses:
ContrastiveSeqBatchScheduler
public abstract class SeqBatchScheduler extends java.lang.ObjectThis is a scheduler, serving as an API to the consumer of the system, allowing for three major actions: initForward, addBatch, fastForward, collectResults. An optimal control sequence should be solved, after considering the time consumption of each action, the batch size and sequence length of queueing requests. Such optimal control solver needs additional effort. Primitive policy is setting several thresholds.
-
-
Constructor Summary
Constructors Constructor Description SeqBatchScheduler(Predictor<NDList,CausalLMOutput> lmBlock, SearchConfig config)Constructs a newSeqBatchSchedulerinstance.
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description voidaddRequest(NDArray inputIds, NDArray batchUids)Adds new batch.java.util.Map<java.lang.Long,NDArray>collectResults()Collects finished results.booleanincrementForward(int count)Executes forward for a given number of iterations.abstract SeqBatcherinitForward(NDArray inputIds, NDArray batchUids)Initializes the iteration and SeqBatcher.
-
-
-
Constructor Detail
-
SeqBatchScheduler
public SeqBatchScheduler(Predictor<NDList,CausalLMOutput> lmBlock, SearchConfig config)
Constructs a newSeqBatchSchedulerinstance.- Parameters:
lmBlock- the predictor that contconfig- the search parameter configuration
-
-
Method Detail
-
initForward
public abstract SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException
Initializes the iteration and SeqBatcher.- Parameters:
inputIds- the input token ids.batchUids- the request uid identifying a sequence- Returns:
- SeqBatcher Stores the search state and operate on the BatchTensorList
- Throws:
TranslateException- if forward fails
-
incrementForward
public boolean incrementForward(int count) throws TranslateExceptionExecutes forward for a given number of iterations.- Parameters:
count- the time of forward calls- Returns:
- boolean Indicate whether the Batch is empty
- Throws:
TranslateException- if forward fails
-
addRequest
public void addRequest(NDArray inputIds, NDArray batchUids) throws TranslateException
Adds new batch.- Parameters:
inputIds- the input token ids.batchUids- the request uid identifying a sequence- Throws:
TranslateException- if forward fails
-
collectResults
public java.util.Map<java.lang.Long,NDArray> collectResults()
Collects finished results.- Returns:
- the outputs stored as a map from requestUid to output token ids
-
-