Package ai.djl.modality.nlp.generate
Class ContrastiveSeqBatchScheduler
- java.lang.Object
-
- ai.djl.modality.nlp.generate.SeqBatchScheduler
-
- ai.djl.modality.nlp.generate.ContrastiveSeqBatchScheduler
-
public class ContrastiveSeqBatchScheduler extends SeqBatchScheduler
ContrastiveSeqBatchScheduleris a class which implements the contrastive search algorithm used in SeqBatchScheduler.
-
-
Constructor Summary
Constructors Constructor Description ContrastiveSeqBatchScheduler(Predictor<NDList,CausalLMOutput> lmBlock, SearchConfig config)Constructs a newContrastiveSeqBatchSchedulerinstance.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArrayinferenceCall()SeqBatcherinitForward(NDArray inputIds, NDArray batchUids)Initializes the iteration and SeqBatcher.-
Methods inherited from class ai.djl.modality.nlp.generate.SeqBatchScheduler
addRequest, collectResults, incrementForward
-
-
-
-
Constructor Detail
-
ContrastiveSeqBatchScheduler
public ContrastiveSeqBatchScheduler(Predictor<NDList,CausalLMOutput> lmBlock, SearchConfig config)
Constructs a newContrastiveSeqBatchSchedulerinstance.- Parameters:
lmBlock- the predictor containing language modelconfig- the autoregressive search configuration
-
-
Method Detail
-
initForward
public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException
Initializes the iteration and SeqBatcher.- Specified by:
initForwardin classSeqBatchScheduler- Parameters:
inputIds- the input token ids.batchUids- the request uid identifying a sequence- Returns:
- SeqBatcher Stores the search state and operate on the BatchTensorList
- Throws:
TranslateException- if forward fails
-
inferenceCall
public NDArray inferenceCall() throws TranslateException
- Throws:
TranslateException
-
-