Package ai.djl.modality.nlp.generate
Class SeqBatchScheduler
java.lang.Object
ai.djl.modality.nlp.generate.SeqBatchScheduler
- Direct Known Subclasses:
ContrastiveSeqBatchScheduler
This is a scheduler, serving as an API to the consumer of the system, allowing for three major
actions: initForward, addBatch, fastForward, collectResults. An optimal control sequence should
be solved, after considering the time consumption of each action, the batch size and sequence
length of queueing requests. Such optimal control solver needs additional effort. Primitive
policy is setting several thresholds.
-
Constructor Summary
ConstructorsConstructorDescriptionSeqBatchScheduler
(Predictor<NDList, CausalLMOutput> lmBlock, SearchConfig config) Constructs a newSeqBatchScheduler
instance. -
Method Summary
Modifier and TypeMethodDescriptionvoid
addRequest
(NDArray inputIds, NDArray batchUids) Adds new batch.Collects finished results.boolean
incrementForward
(int count) Executes forward for a given number of iterations.protected abstract NDArray
An inference call in an iteration.abstract SeqBatcher
initForward
(NDArray inputIds, NDArray batchUids) Initializes the iteration and SeqBatcher.
-
Constructor Details
-
SeqBatchScheduler
Constructs a newSeqBatchScheduler
instance.- Parameters:
lmBlock
- the predictor that contconfig
- the search parameter configuration
-
-
Method Details
-
initForward
public abstract SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException Initializes the iteration and SeqBatcher.- Parameters:
inputIds
- the input token ids.batchUids
- the request uid identifying a sequence- Returns:
- SeqBatcher Stores the search state and operate on the BatchTensorList
- Throws:
TranslateException
- if forward fails
-
incrementForward
Executes forward for a given number of iterations.- Parameters:
count
- the time of forward calls- Returns:
- boolean Indicate whether the Batch is empty
- Throws:
TranslateException
- if forward fails
-
inferenceCall
An inference call in an iteration.- Returns:
- the output token ids
- Throws:
TranslateException
- if forward fails
-
addRequest
Adds new batch.- Parameters:
inputIds
- the input token ids.batchUids
- the request uid identifying a sequence- Throws:
TranslateException
- if forward fails
-
collectResults
Collects finished results.- Returns:
- the outputs stored as a map from requestUid to output token ids
-