Package ai.djl

Class BaseModel

  • All Implemented Interfaces:
    Model, java.lang.AutoCloseable

    public abstract class BaseModel
    extends java.lang.Object
    implements Model
    BaseModel is the basic implementation of Model.
    • Constructor Summary

      Constructors 
      Modifier Constructor Description
      protected BaseModel​(java.lang.String modelName)  
    • Method Summary

      All Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      void close()
      ai.djl.util.PairList<java.lang.String,​Shape> describeInput()
      Returns the input descriptor of the model.
      ai.djl.util.PairList<java.lang.String,​Shape> describeOutput()
      Returns the output descriptor of the model.
      protected void finalize()
      java.net.URL getArtifact​(java.lang.String artifactName)
      Finds an artifact resource with a given name in the model.
      <T> T getArtifact​(java.lang.String name, java.util.function.Function<java.io.InputStream,​T> function)
      Attempts to load the artifact using the given function and cache it if the specified artifact is not already cached.
      java.io.InputStream getArtifactAsStream​(java.lang.String name)
      Finds an artifact resource with a given name in the model.
      java.lang.String[] getArtifactNames()
      Returns the artifact names associated with the model.
      Block getBlock()
      Gets the block from the Model.
      DataType getDataType()
      Returns the standard data type used within the model.
      java.nio.file.Path getModelPath()
      Returns the directory from where the model is loaded.
      java.lang.String getName()
      Gets the model name.
      NDManager getNDManager()
      Gets the NDManager from the model.
      java.util.Map<java.lang.String,​java.lang.String> getProperties()
      Returns the model's properties.
      java.lang.String getProperty​(java.lang.String key)
      Returns the property of the model based on property name.
      void load​(java.io.InputStream is, java.util.Map<java.lang.String,​?> options)
      Loads the model from the InputStream with the options provided.
      <I,​O>
      Predictor<I,​O>
      newPredictor​(Translator<I,​O> translator, Device device)
      Creates a new Predictor based on the model.
      Trainer newTrainer​(TrainingConfig trainingConfig)
      Creates a new Trainer instance for a Model.
      protected java.nio.file.Path paramPathResolver​(java.lang.String prefix, java.util.Map<java.lang.String,​?> options)  
      protected boolean readParameters​(java.nio.file.Path paramFile, java.util.Map<java.lang.String,​?> options)  
      void save​(java.nio.file.Path modelPath, java.lang.String newModelName)
      Saves the model to the specified modelPath with the name provided.
      void setBlock​(Block block)
      Sets the block for the Model for training and inference.
      void setDataType​(DataType dataType)
      Sets the standard data type used within the model.
      protected void setModelDir​(java.nio.file.Path modelDir)  
      void setProperty​(java.lang.String key, java.lang.String value)
      Sets a property to the model.
      java.lang.String toString()
      • Methods inherited from class java.lang.Object

        clone, equals, getClass, hashCode, notify, notifyAll, wait, wait, wait
    • Field Detail

      • modelDir

        protected java.nio.file.Path modelDir
      • block

        protected Block block
      • modelName

        protected java.lang.String modelName
      • wasLoaded

        protected boolean wasLoaded
      • inputData

        protected ai.djl.util.PairList<java.lang.String,​Shape> inputData
      • artifacts

        protected java.util.Map<java.lang.String,​java.lang.Object> artifacts
      • properties

        protected java.util.Map<java.lang.String,​java.lang.String> properties
    • Constructor Detail

      • BaseModel

        protected BaseModel​(java.lang.String modelName)
    • Method Detail

      • getBlock

        public Block getBlock()
        Gets the block from the Model.
        Specified by:
        getBlock in interface Model
        Returns:
        the Block
      • setBlock

        public void setBlock​(Block block)
        Sets the block for the Model for training and inference.
        Specified by:
        setBlock in interface Model
        Parameters:
        block - the Block used in Model
      • getName

        public java.lang.String getName()
        Gets the model name.
        Specified by:
        getName in interface Model
        Returns:
        name of the model
      • newTrainer

        public Trainer newTrainer​(TrainingConfig trainingConfig)
        Creates a new Trainer instance for a Model.
        Specified by:
        newTrainer in interface Model
        Parameters:
        trainingConfig - training configuration settings
        Returns:
        the Trainer instance
      • newPredictor

        public <I,​O> Predictor<I,​O> newPredictor​(Translator<I,​O> translator,
                                                             Device device)
        Creates a new Predictor based on the model.
        Specified by:
        newPredictor in interface Model
        Type Parameters:
        I - the input object for pre-processing
        O - the output object from postprocessing
        Parameters:
        translator - the object used for pre-processing and postprocessing
        device - the device to use for prediction
        Returns:
        an instance of Predictor
      • setDataType

        public void setDataType​(DataType dataType)
        Sets the standard data type used within the model.
        Specified by:
        setDataType in interface Model
        Parameters:
        dataType - the standard data type to use
      • getDataType

        public DataType getDataType()
        Returns the standard data type used within the model.
        Specified by:
        getDataType in interface Model
        Returns:
        the standard data type used within the model
      • load

        public void load​(java.io.InputStream is,
                         java.util.Map<java.lang.String,​?> options)
                  throws java.io.IOException,
                         MalformedModelException
        Loads the model from the InputStream with the options provided.
        Specified by:
        load in interface Model
        Parameters:
        is - the InputStream to load the model from
        options - engine specific load model options, see documentation for each engine
        Throws:
        java.io.IOException - when IO operation fails in loading a resource
        MalformedModelException - if model file is corrupted
      • close

        public void close()
        Specified by:
        close in interface java.lang.AutoCloseable
        Specified by:
        close in interface Model
      • describeInput

        public ai.djl.util.PairList<java.lang.String,​Shape> describeInput()
        Returns the input descriptor of the model.

        It contains the information that can be extracted from the model, usually name, shape, layout and DataType.

        Specified by:
        describeInput in interface Model
        Returns:
        a PairList of String and Shape
      • describeOutput

        public ai.djl.util.PairList<java.lang.String,​Shape> describeOutput()
        Returns the output descriptor of the model.

        It contains the output information that can be obtained from the model.

        Specified by:
        describeOutput in interface Model
        Returns:
        a PairList of String and Shape
      • getArtifactNames

        public java.lang.String[] getArtifactNames()
        Returns the artifact names associated with the model.
        Specified by:
        getArtifactNames in interface Model
        Returns:
        an array of artifact names
      • getArtifact

        public <T> T getArtifact​(java.lang.String name,
                                 java.util.function.Function<java.io.InputStream,​T> function)
                          throws java.io.IOException
        Attempts to load the artifact using the given function and cache it if the specified artifact is not already cached.

        Model will cache loaded artifact, so the user doesn't need to keep tracking it.

        
         String synset = model.getArtifact("synset.txt", k -> IOUtils.toString(k)));
         
        Specified by:
        getArtifact in interface Model
        Type Parameters:
        T - the type of the returned artifact object
        Parameters:
        name - the name of the desired artifact
        function - the function to load the artifact
        Returns:
        the current (existing or computed) artifact associated with the specified name, or null if the computed value is null
        Throws:
        java.io.IOException - when IO operation fails in loading a resource
      • getArtifact

        public java.net.URL getArtifact​(java.lang.String artifactName)
                                 throws java.io.IOException
        Finds an artifact resource with a given name in the model.
        Specified by:
        getArtifact in interface Model
        Parameters:
        artifactName - the name of the desired artifact
        Returns:
        a URL object or null if no artifact with this name is found
        Throws:
        java.io.IOException - when IO operation fails in loading a resource
      • getArtifactAsStream

        public java.io.InputStream getArtifactAsStream​(java.lang.String name)
                                                throws java.io.IOException
        Finds an artifact resource with a given name in the model.
        Specified by:
        getArtifactAsStream in interface Model
        Parameters:
        name - the name of the desired artifact
        Returns:
        a InputStream object or null if no resource with this name is found
        Throws:
        java.io.IOException - when IO operation fails in loading a resource
      • setProperty

        public void setProperty​(java.lang.String key,
                                java.lang.String value)
        Sets a property to the model.

        properties will be saved/loaded with model, user can store some information about the model in here.

        Specified by:
        setProperty in interface Model
        Parameters:
        key - the name of the property
        value - the value of the property
      • getProperty

        public java.lang.String getProperty​(java.lang.String key)
        Returns the property of the model based on property name.
        Specified by:
        getProperty in interface Model
        Parameters:
        key - the name of the property
        Returns:
        the value of the property
      • getProperties

        public java.util.Map<java.lang.String,​java.lang.String> getProperties()
        Returns the model's properties.
        Specified by:
        getProperties in interface Model
        Returns:
        the model's properties
      • setModelDir

        protected void setModelDir​(java.nio.file.Path modelDir)
      • save

        public void save​(java.nio.file.Path modelPath,
                         java.lang.String newModelName)
                  throws java.io.IOException
        Saves the model to the specified modelPath with the name provided.
        Specified by:
        save in interface Model
        Parameters:
        modelPath - the directory or file path of the model location
        newModelName - the new model name to be saved, use null to keep original model name
        Throws:
        java.io.IOException - when IO operation fails in loading a resource
      • getModelPath

        public java.nio.file.Path getModelPath()
        Returns the directory from where the model is loaded.
        Specified by:
        getModelPath in interface Model
        Returns:
        the directory of the model location
      • toString

        public java.lang.String toString()
        Overrides:
        toString in class java.lang.Object
      • finalize

        protected void finalize()
                         throws java.lang.Throwable
        Overrides:
        finalize in class java.lang.Object
        Throws:
        java.lang.Throwable
      • paramPathResolver

        protected java.nio.file.Path paramPathResolver​(java.lang.String prefix,
                                                       java.util.Map<java.lang.String,​?> options)
                                                throws java.io.IOException
        Throws:
        java.io.IOException
      • readParameters

        protected boolean readParameters​(java.nio.file.Path paramFile,
                                         java.util.Map<java.lang.String,​?> options)
                                  throws java.io.IOException,
                                         MalformedModelException
        Throws:
        java.io.IOException
        MalformedModelException