Package ai.djl.modality.nlp.generate
Class ContrastiveSeqBatchScheduler
java.lang.Object
ai.djl.modality.nlp.generate.SeqBatchScheduler
ai.djl.modality.nlp.generate.ContrastiveSeqBatchScheduler
ContrastiveSeqBatchScheduler
is a class which implements the contrastive search algorithm
used in SeqBatchScheduler.-
Constructor Summary
ConstructorsConstructorDescriptionContrastiveSeqBatchScheduler
(Predictor<NDList, CausalLMOutput> lmBlock, SearchConfig config) Constructs a newContrastiveSeqBatchScheduler
instance. -
Method Summary
Modifier and TypeMethodDescriptionAn inference call in an iteration.initForward
(NDArray inputIds, NDArray batchUids) Initializes the iteration and SeqBatcher.Methods inherited from class ai.djl.modality.nlp.generate.SeqBatchScheduler
addRequest, collectResults, incrementForward
-
Constructor Details
-
ContrastiveSeqBatchScheduler
Constructs a newContrastiveSeqBatchScheduler
instance.- Parameters:
lmBlock
- the predictor containing language modelconfig
- the autoregressive search configuration
-
-
Method Details
-
initForward
Initializes the iteration and SeqBatcher.- Specified by:
initForward
in classSeqBatchScheduler
- Parameters:
inputIds
- the input token ids.batchUids
- the request uid identifying a sequence- Returns:
- SeqBatcher Stores the search state and operate on the BatchTensorList
- Throws:
TranslateException
- if forward fails
-
inferenceCall
An inference call in an iteration.- Specified by:
inferenceCall
in classSeqBatchScheduler
- Returns:
- the output token ids
- Throws:
TranslateException
- if forward fails
-