Package ai.djl.modality.nlp.generate
Class ContrastiveSeqBatchScheduler
- java.lang.Object
-
- ai.djl.modality.nlp.generate.SeqBatchScheduler
-
- ai.djl.modality.nlp.generate.ContrastiveSeqBatchScheduler
-
public class ContrastiveSeqBatchScheduler extends SeqBatchScheduler
ContrastiveSeqBatchScheduler
is a class which implements the contrastive search algorithm used in SeqBatchScheduler.
-
-
Constructor Summary
Constructors Constructor Description ContrastiveSeqBatchScheduler(Predictor<NDList,CausalLMOutput> lmBlock, SearchConfig config)
Constructs a newContrastiveSeqBatchScheduler
instance.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArray
inferenceCall()
SeqBatcher
initForward(NDArray inputIds, NDArray batchUids)
Initializes the iteration and SeqBatcher.-
Methods inherited from class ai.djl.modality.nlp.generate.SeqBatchScheduler
addRequest, collectResults, incrementForward
-
-
-
-
Constructor Detail
-
ContrastiveSeqBatchScheduler
public ContrastiveSeqBatchScheduler(Predictor<NDList,CausalLMOutput> lmBlock, SearchConfig config)
Constructs a newContrastiveSeqBatchScheduler
instance.- Parameters:
lmBlock
- the predictor containing language modelconfig
- the autoregressive search configuration
-
-
Method Detail
-
initForward
public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException
Initializes the iteration and SeqBatcher.- Specified by:
initForward
in classSeqBatchScheduler
- Parameters:
inputIds
- the input token ids.batchUids
- the request uid identifying a sequence- Returns:
- SeqBatcher Stores the search state and operate on the BatchTensorList
- Throws:
TranslateException
- if forward fails
-
inferenceCall
public NDArray inferenceCall() throws TranslateException
- Throws:
TranslateException
-
-