Class PtModel

java.lang.Object
ai.djl.BaseModel
ai.djl.pytorch.engine.PtModel
All Implemented Interfaces:
ai.djl.Model, AutoCloseable

public class PtModel extends ai.djl.BaseModel
PtModel is the PyTorch implementation of Model.

PtModel contains all the methods in Model to load and process a model. In addition, it provides PyTorch Specific functionality

  • Field Summary

    Fields inherited from class ai.djl.BaseModel

    artifacts, block, dataType, inputData, manager, modelDir, modelName, properties, wasLoaded
  • Method Summary

    Modifier and Type
    Method
    Description
    void
    load(InputStream modelStream, boolean mapLocation)
    Load PyTorch model from InputStream.
    void
    load(InputStream modelStream, Map<String,?> options)
    void
    load(Path modelPath, String prefix, Map<String,?> options)
    ai.djl.training.Trainer
    newTrainer(ai.djl.training.TrainingConfig trainingConfig)

    Methods inherited from class ai.djl.BaseModel

    close, describeInput, describeOutput, finalize, getArtifact, getArtifact, getArtifactAsStream, getBlock, getDataType, getModelPath, getName, getNDManager, getProperties, getProperty, newPredictor, paramPathResolver, readParameters, readParameters, save, setBlock, setDataType, setModelDir, setProperty, toString

    Methods inherited from class java.lang.Object

    clone, equals, getClass, hashCode, notify, notifyAll, wait, wait, wait

    Methods inherited from interface ai.djl.Model

    cast, getProperty, intProperty, load, load, load, longProperty, newPredictor, quantize
  • Method Details

    • load

      public void load(Path modelPath, String prefix, Map<String,?> options) throws IOException, ai.djl.MalformedModelException
      Throws:
      IOException
      ai.djl.MalformedModelException
    • load

      public void load(InputStream modelStream, Map<String,?> options) throws IOException, ai.djl.MalformedModelException
      Specified by:
      load in interface ai.djl.Model
      Overrides:
      load in class ai.djl.BaseModel
      Throws:
      IOException
      ai.djl.MalformedModelException
    • load

      public void load(InputStream modelStream, boolean mapLocation) throws IOException, ai.djl.MalformedModelException
      Load PyTorch model from InputStream.
      Parameters:
      modelStream - the stream of the model file
      mapLocation - force load to specified device if true
      Throws:
      IOException - model loading error
      ai.djl.MalformedModelException - if model file is corrupted
    • newTrainer

      public ai.djl.training.Trainer newTrainer(ai.djl.training.TrainingConfig trainingConfig)
      Specified by:
      newTrainer in interface ai.djl.Model
      Overrides:
      newTrainer in class ai.djl.BaseModel
    • getArtifactNames

      public String[] getArtifactNames()
      Specified by:
      getArtifactNames in interface ai.djl.Model
      Overrides:
      getArtifactNames in class ai.djl.BaseModel