Package ai.djl

Interface Model

All Superinterfaces:
AutoCloseable
All Known Implementing Classes:
BaseModel, ZooModel

public interface Model extends AutoCloseable
A model is a collection of artifacts that is created by the training process.

A deep learning model usually contains the following parts:

  • the Block of operations to run
  • the Parameters that are trained
  • Input/Output information: input and output parameter names, shape, etc.
  • Other artifacts such as a synset for classification that would be used during pre-processing and post-processing

For loading a pre-trained model, see load(Path, String)

For training a model, see Trainer.

For running inference with a model, see Predictor.

  • Method Details

    • newInstance

      static Model newInstance(String name)
      Creates an empty model instance.
      Parameters:
      name - the model name
      Returns:
      a new Model instance
    • newInstance

      static Model newInstance(String name, Device device)
      Creates an empty model instance on the specified Device.
      Parameters:
      name - the model name
      device - the device to load the model onto
      Returns:
      a new model instance
    • newInstance

      static Model newInstance(String name, String engineName)
      Creates an empty model instance on the specified Device and engine.
      Parameters:
      name - the model name
      engineName - the name of the engine
      Returns:
      a new model instance
    • newInstance

      static Model newInstance(String name, Device device, String engineName)
      Creates an empty model instance on the specified Device and engine.
      Parameters:
      name - the model name
      device - the device to load the model onto
      engineName - the name of the engine
      Returns:
      a new model instance
    • load

      default void load(Path modelPath) throws IOException, MalformedModelException
      Loads the model from the modelPath.
      Parameters:
      modelPath - the directory or file path of the model location
      Throws:
      IOException - when IO operation fails in loading a resource
      MalformedModelException - if model file is corrupted
    • load

      default void load(Path modelPath, String prefix) throws IOException, MalformedModelException
      Loads the model from the modelPath and the given name.
      Parameters:
      modelPath - the directory or file path of the model location
      prefix - the model file name or path prefix
      Throws:
      IOException - when IO operation fails in loading a resource
      MalformedModelException - if model file is corrupted
    • load

      void load(Path modelPath, String prefix, Map<String,?> options) throws IOException, MalformedModelException
      Loads the model from the modelPath with the name and options provided.
      Parameters:
      modelPath - the directory or file path of the model location
      prefix - the model file name or path prefix
      options - engine specific load model options, see documentation for each engine
      Throws:
      IOException - when IO operation fails in loading a resource
      MalformedModelException - if model file is corrupted
    • load

      default void load(InputStream is) throws IOException, MalformedModelException
      Loads the model from the InputStream.
      Parameters:
      is - the InputStream to load the model from
      Throws:
      IOException - when IO operation fails in loading a resource
      MalformedModelException - if model file is corrupted
    • load

      void load(InputStream is, Map<String,?> options) throws IOException, MalformedModelException
      Loads the model from the InputStream with the options provided.
      Parameters:
      is - the InputStream to load the model from
      options - engine specific load model options, see documentation for each engine
      Throws:
      IOException - when IO operation fails in loading a resource
      MalformedModelException - if model file is corrupted
    • save

      void save(Path modelPath, String newModelName) throws IOException
      Saves the model to the specified modelPath with the name provided.
      Parameters:
      modelPath - the directory or file path of the model location
      newModelName - the new model name to be saved, use null to keep original model name
      Throws:
      IOException - when IO operation fails in loading a resource
    • getModelPath

      Path getModelPath()
      Returns the directory from where the model is loaded.
      Returns:
      the directory of the model location
    • getBlock

      Block getBlock()
      Gets the block from the Model.
      Returns:
      the Block
    • setBlock

      void setBlock(Block block)
      Sets the block for the Model for training and inference.
      Parameters:
      block - the Block used in Model
    • getName

      String getName()
      Gets the model name.
      Returns:
      name of the model
    • getProperties

      Map<String,String> getProperties()
      Returns the model's properties.
      Returns:
      the model's properties
    • getProperty

      String getProperty(String key)
      Returns the property of the model based on property name.
      Parameters:
      key - the name of the property
      Returns:
      the value of the property
    • getProperty

      default String getProperty(String key, String defValue)
      Returns the property of the model based on property name.
      Parameters:
      key - the name of the property
      defValue - the default value if key not found
      Returns:
      the value of the property
    • setProperty

      void setProperty(String key, String value)
      Sets a property to the model.

      properties will be saved/loaded with model, user can store some information about the model in here.

      Parameters:
      key - the name of the property
      value - the value of the property
    • intProperty

      default int intProperty(String key, int defValue)
      Returns the property of the model based on property name.
      Parameters:
      key - the name of the property
      defValue - the default value if key not found
      Returns:
      the value of the property
    • longProperty

      default long longProperty(String key, long defValue)
      Returns the property of the model based on property name.
      Parameters:
      key - the name of the property
      defValue - the default value if key not found
      Returns:
      the value of the property
    • getNDManager

      NDManager getNDManager()
      Gets the NDManager from the model.
      Returns:
      the NDManager
    • newTrainer

      Trainer newTrainer(TrainingConfig trainingConfig)
      Creates a new Trainer instance for a Model.
      Parameters:
      trainingConfig - training configuration settings
      Returns:
      the Trainer instance
    • newPredictor

      default <I, O> Predictor<I,O> newPredictor(Translator<I,O> translator)
      Creates a new Predictor based on the model on the current device.
      Type Parameters:
      I - the input object for pre-processing
      O - the output object from postprocessing
      Parameters:
      translator - the object used for pre-processing and postprocessing
      Returns:
      an instance of Predictor
    • newPredictor

      <I, O> Predictor<I,O> newPredictor(Translator<I,O> translator, Device device)
      Creates a new Predictor based on the model.
      Type Parameters:
      I - the input object for pre-processing
      O - the output object from postprocessing
      Parameters:
      translator - the object used for pre-processing and postprocessing
      device - the device to use for prediction
      Returns:
      an instance of Predictor
    • describeInput

      ai.djl.util.PairList<String,Shape> describeInput()
      Returns the input descriptor of the model.

      It contains the information that can be extracted from the model, usually name, shape, layout and DataType.

      Returns:
      a PairList of String and Shape
    • describeOutput

      ai.djl.util.PairList<String,Shape> describeOutput()
      Returns the output descriptor of the model.

      It contains the output information that can be obtained from the model.

      Returns:
      a PairList of String and Shape
    • getArtifactNames

      String[] getArtifactNames()
      Returns the artifact names associated with the model.
      Returns:
      an array of artifact names
    • getArtifact

      <T> T getArtifact(String name, Function<InputStream,T> function) throws IOException
      Attempts to load the artifact using the given function and cache it if the specified artifact is not already cached.

      Model will cache loaded artifact, so the user doesn't need to keep tracking it.

      
       String synset = model.getArtifact("synset.txt", k -> IOUtils.toString(k)));
       
      Type Parameters:
      T - the type of the returned artifact object
      Parameters:
      name - the name of the desired artifact
      function - the function to load the artifact
      Returns:
      the current (existing or computed) artifact associated with the specified name, or null if the computed value is null
      Throws:
      IOException - when IO operation fails in loading a resource
      ClassCastException - if the cached artifact cannot be cast to the target class
    • getArtifact

      URL getArtifact(String name) throws IOException
      Finds an artifact resource with a given name in the model.
      Parameters:
      name - the name of the desired artifact
      Returns:
      a URL object or null if no artifact with this name is found
      Throws:
      IOException - when IO operation fails in loading a resource
    • getArtifactAsStream

      InputStream getArtifactAsStream(String name) throws IOException
      Finds an artifact resource with a given name in the model.
      Parameters:
      name - the name of the desired artifact
      Returns:
      a InputStream object or null if no resource with this name is found
      Throws:
      IOException - when IO operation fails in loading a resource
    • setDataType

      void setDataType(DataType dataType)
      Sets the standard data type used within the model.
      Parameters:
      dataType - the standard data type to use
    • getDataType

      DataType getDataType()
      Returns the standard data type used within the model.
      Returns:
      the standard data type used within the model
    • cast

      default void cast(DataType dataType)
      Casts the model to support a different precision level.

      For example, you can cast the precision from Float to Int

      Parameters:
      dataType - the target dataType you would like to cast to
    • quantize

      default void quantize()
      Converts the model to use a lower precision quantized network.

      Quantization converts the network to use int8 data type where possible for smaller model size and faster computation without too large a drop in accuracy. See original paper.

    • close

      void close()
      Specified by:
      close in interface AutoCloseable