Class PtModel

  • All Implemented Interfaces:
    ai.djl.Model, java.lang.AutoCloseable

    public class PtModel
    extends ai.djl.BaseModel
    PtModel is the PyTorch implementation of Model.

    PtModel contains all the methods in Model to load and process a model. In addition, it provides PyTorch Specific functionality

    • Field Summary

      • Fields inherited from class ai.djl.BaseModel

        artifacts, block, dataType, inputData, manager, modelDir, modelName, properties, wasLoaded
    • Method Summary

      All Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      java.lang.String[] getArtifactNames()
      void load​(java.io.InputStream modelStream, boolean mapLocation)
      Load PyTorch model from InputStream.
      void load​(java.io.InputStream modelStream, java.util.Map<java.lang.String,​?> options)
      void load​(java.nio.file.Path modelPath, java.lang.String prefix, java.util.Map<java.lang.String,​?> options)
      ai.djl.training.Trainer newTrainer​(ai.djl.training.TrainingConfig trainingConfig)
      • Methods inherited from class ai.djl.BaseModel

        close, describeInput, describeOutput, finalize, getArtifact, getArtifact, getArtifactAsStream, getBlock, getDataType, getModelPath, getName, getNDManager, getProperties, getProperty, newPredictor, paramPathResolver, readParameters, save, setBlock, setDataType, setModelDir, setProperty, toString
      • Methods inherited from class java.lang.Object

        clone, equals, getClass, hashCode, notify, notifyAll, wait, wait, wait
      • Methods inherited from interface ai.djl.Model

        cast, getProperty, load, load, load, newPredictor, quantize
    • Method Detail

      • load

        public void load​(java.nio.file.Path modelPath,
                         java.lang.String prefix,
                         java.util.Map<java.lang.String,​?> options)
                  throws java.io.IOException,
                         ai.djl.MalformedModelException
        Throws:
        java.io.IOException
        ai.djl.MalformedModelException
      • load

        public void load​(java.io.InputStream modelStream,
                         java.util.Map<java.lang.String,​?> options)
                  throws java.io.IOException
        Specified by:
        load in interface ai.djl.Model
        Overrides:
        load in class ai.djl.BaseModel
        Throws:
        java.io.IOException
      • load

        public void load​(java.io.InputStream modelStream,
                         boolean mapLocation)
                  throws java.io.IOException
        Load PyTorch model from InputStream.
        Parameters:
        modelStream - the stream of the model file
        mapLocation - force load to specified device if true
        Throws:
        java.io.IOException - model loading error
      • newTrainer

        public ai.djl.training.Trainer newTrainer​(ai.djl.training.TrainingConfig trainingConfig)
        Specified by:
        newTrainer in interface ai.djl.Model
        Overrides:
        newTrainer in class ai.djl.BaseModel
      • getArtifactNames

        public java.lang.String[] getArtifactNames()
        Specified by:
        getArtifactNames in interface ai.djl.Model
        Overrides:
        getArtifactNames in class ai.djl.BaseModel