Class PtGptTranslator

  • All Implemented Interfaces:
    ai.djl.translate.NoBatchifyTranslator<ai.djl.ndarray.NDList,​ai.djl.modality.nlp.generate.CausalLMOutput>, ai.djl.translate.PostProcessor<ai.djl.modality.nlp.generate.CausalLMOutput>, ai.djl.translate.PreProcessor<ai.djl.ndarray.NDList>, ai.djl.translate.Translator<ai.djl.ndarray.NDList,​ai.djl.modality.nlp.generate.CausalLMOutput>

    public class PtGptTranslator
    extends java.lang.Object
    implements ai.djl.translate.NoBatchifyTranslator<ai.djl.ndarray.NDList,​ai.djl.modality.nlp.generate.CausalLMOutput>
    The Translator for PyTorch GPT2 model.
    • Constructor Summary

      Constructors 
      Constructor Description
      PtGptTranslator​(long kvDim, int numAttentionHeads, int numLayers)
      Constructs a new instance of PtGptTranslator.
    • Method Summary

      All Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      ai.djl.ndarray.NDList processInput​(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList input)
      ai.djl.modality.nlp.generate.CausalLMOutput processOutput​(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList output)
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
      • Methods inherited from interface ai.djl.translate.NoBatchifyTranslator

        getBatchifier
      • Methods inherited from interface ai.djl.translate.Translator

        getExpansions, prepare, toBatchTranslator, toBatchTranslator
    • Constructor Detail

      • PtGptTranslator

        public PtGptTranslator​(long kvDim,
                               int numAttentionHeads,
                               int numLayers)
        Constructs a new instance of PtGptTranslator.
        Parameters:
        kvDim - the kv dimension
        numAttentionHeads - the number of attention heads
        numLayers - the number of layers
    • Method Detail

      • processInput

        public ai.djl.ndarray.NDList processInput​(ai.djl.translate.TranslatorContext ctx,
                                                  ai.djl.ndarray.NDList input)
                                           throws java.lang.Exception
        Specified by:
        processInput in interface ai.djl.translate.PreProcessor<ai.djl.ndarray.NDList>
        Throws:
        java.lang.Exception
      • processOutput

        public ai.djl.modality.nlp.generate.CausalLMOutput processOutput​(ai.djl.translate.TranslatorContext ctx,
                                                                         ai.djl.ndarray.NDList output)
                                                                  throws java.lang.Exception
        Specified by:
        processOutput in interface ai.djl.translate.PostProcessor<ai.djl.modality.nlp.generate.CausalLMOutput>
        Throws:
        java.lang.Exception