Package ai.djl.modality.nlp.generate
Class SeqBatchScheduler
- java.lang.Object
-
- ai.djl.modality.nlp.generate.SeqBatchScheduler
-
- Direct Known Subclasses:
ContrastiveSeqBatchScheduler
public abstract class SeqBatchScheduler extends java.lang.Object
This is a scheduler, serving as an API to the consumer of the system, allowing for three major actions: initForward, addBatch, fastForward, collectResults. An optimal control sequence should be solved, after considering the time consumption of each action, the batch size and sequence length of queueing requests. Such optimal control solver needs additional effort. Primitive policy is setting several thresholds.
-
-
Constructor Summary
Constructors Constructor Description SeqBatchScheduler(Predictor<NDList,CausalLMOutput> lmBlock, SearchConfig config)
Constructs a newSeqBatchScheduler
instance.
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description void
addRequest(NDArray inputIds, NDArray batchUids)
Adds new batch.java.util.Map<java.lang.Long,NDArray>
collectResults()
Collects finished results.boolean
incrementForward(int count)
Executes forward for a given number of iterations.abstract SeqBatcher
initForward(NDArray inputIds, NDArray batchUids)
Initializes the iteration and SeqBatcher.
-
-
-
Constructor Detail
-
SeqBatchScheduler
public SeqBatchScheduler(Predictor<NDList,CausalLMOutput> lmBlock, SearchConfig config)
Constructs a newSeqBatchScheduler
instance.- Parameters:
lmBlock
- the predictor that contconfig
- the search parameter configuration
-
-
Method Detail
-
initForward
public abstract SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException
Initializes the iteration and SeqBatcher.- Parameters:
inputIds
- the input token ids.batchUids
- the request uid identifying a sequence- Returns:
- SeqBatcher Stores the search state and operate on the BatchTensorList
- Throws:
TranslateException
- if forward fails
-
incrementForward
public boolean incrementForward(int count) throws TranslateException
Executes forward for a given number of iterations.- Parameters:
count
- the time of forward calls- Returns:
- boolean Indicate whether the Batch is empty
- Throws:
TranslateException
- if forward fails
-
addRequest
public void addRequest(NDArray inputIds, NDArray batchUids) throws TranslateException
Adds new batch.- Parameters:
inputIds
- the input token ids.batchUids
- the request uid identifying a sequence- Throws:
TranslateException
- if forward fails
-
collectResults
public java.util.Map<java.lang.Long,NDArray> collectResults()
Collects finished results.- Returns:
- the outputs stored as a map from requestUid to output token ids
-
-