Class PtEngine

java.lang.Object
ai.djl.engine.Engine
ai.djl.pytorch.engine.PtEngine

public final class PtEngine extends ai.djl.engine.Engine
The PtEngine is an implementation of the Engine based on the PyTorch Deep Learning Framework.

To get an instance of the PtEngine when it is not the default Engine, call Engine.getEngine(String) with the Engine name "PyTorch".

  • Field Details

  • Method Details

    • getAlternativeEngine

      public ai.djl.engine.Engine getAlternativeEngine()
      Specified by:
      getAlternativeEngine in class ai.djl.engine.Engine
    • getEngineName

      public String getEngineName()
      Specified by:
      getEngineName in class ai.djl.engine.Engine
    • getRank

      public int getRank()
      Specified by:
      getRank in class ai.djl.engine.Engine
    • getVersion

      public String getVersion()
      Specified by:
      getVersion in class ai.djl.engine.Engine
    • hasCapability

      public boolean hasCapability(String capability)
      Specified by:
      hasCapability in class ai.djl.engine.Engine
    • newSymbolBlock

      public ai.djl.nn.SymbolBlock newSymbolBlock(ai.djl.ndarray.NDManager manager)
      Overrides:
      newSymbolBlock in class ai.djl.engine.Engine
    • newModel

      public ai.djl.Model newModel(String name, ai.djl.Device device)
      Specified by:
      newModel in class ai.djl.engine.Engine
    • newBaseManager

      public ai.djl.ndarray.NDManager newBaseManager()
      Specified by:
      newBaseManager in class ai.djl.engine.Engine
    • newBaseManager

      public ai.djl.ndarray.NDManager newBaseManager(ai.djl.Device device)
      Specified by:
      newBaseManager in class ai.djl.engine.Engine
    • newGradientCollector

      public ai.djl.training.GradientCollector newGradientCollector()
      Overrides:
      newGradientCollector in class ai.djl.engine.Engine
    • setRandomSeed

      public void setRandomSeed(int seed)
      Overrides:
      setRandomSeed in class ai.djl.engine.Engine
    • toString

      public String toString()
      Overrides:
      toString in class ai.djl.engine.Engine