Package ai.djl.modality.nlp.generate
Class ContrastiveSeqBatchScheduler
java.lang.Object
ai.djl.modality.nlp.generate.SeqBatchScheduler
ai.djl.modality.nlp.generate.ContrastiveSeqBatchScheduler
ContrastiveSeqBatchScheduler is a class which implements the contrastive search algorithm
used in SeqBatchScheduler.-
Constructor Summary
ConstructorsConstructorDescriptionContrastiveSeqBatchScheduler(Predictor<NDList, CausalLMOutput> lmBlock, SearchConfig config) Constructs a newContrastiveSeqBatchSchedulerinstance. -
Method Summary
Modifier and TypeMethodDescriptionAn inference call in an iteration.initForward(NDArray inputIds, NDArray batchUids) Initializes the iteration and SeqBatcher.Methods inherited from class ai.djl.modality.nlp.generate.SeqBatchScheduler
addRequest, collectResults, incrementForward
-
Constructor Details
-
ContrastiveSeqBatchScheduler
Constructs a newContrastiveSeqBatchSchedulerinstance.- Parameters:
lmBlock- the predictor containing language modelconfig- the autoregressive search configuration
-
-
Method Details
-
initForward
Initializes the iteration and SeqBatcher.- Specified by:
initForwardin classSeqBatchScheduler- Parameters:
inputIds- the input token ids.batchUids- the request uid identifying a sequence- Returns:
- SeqBatcher Stores the search state and operate on the BatchTensorList
- Throws:
TranslateException- if forward fails
-
inferenceCall
An inference call in an iteration.- Specified by:
inferenceCallin classSeqBatchScheduler- Returns:
- the output token ids
- Throws:
TranslateException- if forward fails
-