Class OrtGptTranslator
- java.lang.Object
-
- ai.djl.onnxruntime.zoo.nlp.textgeneration.OrtGptTranslator
-
- All Implemented Interfaces:
ai.djl.translate.NoBatchifyTranslator<ai.djl.ndarray.NDList,ai.djl.modality.nlp.generate.CausalLMOutput>
,ai.djl.translate.PostProcessor<ai.djl.modality.nlp.generate.CausalLMOutput>
,ai.djl.translate.PreProcessor<ai.djl.ndarray.NDList>
,ai.djl.translate.Translator<ai.djl.ndarray.NDList,ai.djl.modality.nlp.generate.CausalLMOutput>
public class OrtGptTranslator extends java.lang.Object implements ai.djl.translate.NoBatchifyTranslator<ai.djl.ndarray.NDList,ai.djl.modality.nlp.generate.CausalLMOutput>
TheTranslator
for PyTorch GPT2 model.
-
-
Constructor Summary
Constructors Constructor Description OrtGptTranslator(long kvDim, int numAttentionHeads, int numLayers)
Constructs a new instance ofPtGptTranslator
.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description ai.djl.ndarray.NDList
processInput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList input)
ai.djl.modality.nlp.generate.CausalLMOutput
processOutput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList output)
-
-
-
Method Detail
-
processInput
public ai.djl.ndarray.NDList processInput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList input) throws java.lang.Exception
- Specified by:
processInput
in interfaceai.djl.translate.PreProcessor<ai.djl.ndarray.NDList>
- Throws:
java.lang.Exception
-
processOutput
public ai.djl.modality.nlp.generate.CausalLMOutput processOutput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList output) throws java.lang.Exception
- Specified by:
processOutput
in interfaceai.djl.translate.PostProcessor<ai.djl.modality.nlp.generate.CausalLMOutput>
- Throws:
java.lang.Exception
-
-