Package ai.djl
Class BaseModel
- java.lang.Object
-
- ai.djl.BaseModel
-
-
Field Summary
Fields Modifier and Type Field Description protected java.util.Map<java.lang.String,java.lang.Object>
artifacts
protected Block
block
protected DataType
dataType
protected ai.djl.util.PairList<java.lang.String,Shape>
inputData
protected NDManager
manager
protected java.nio.file.Path
modelDir
protected java.lang.String
modelName
protected java.util.Map<java.lang.String,java.lang.String>
properties
protected boolean
wasLoaded
-
Constructor Summary
Constructors Modifier Constructor Description protected
BaseModel(java.lang.String modelName)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description void
close()
ai.djl.util.PairList<java.lang.String,Shape>
describeInput()
Returns the input descriptor of the model.ai.djl.util.PairList<java.lang.String,Shape>
describeOutput()
Returns the output descriptor of the model.protected void
finalize()
java.net.URL
getArtifact(java.lang.String artifactName)
Finds an artifact resource with a given name in the model.<T> T
getArtifact(java.lang.String name, java.util.function.Function<java.io.InputStream,T> function)
Attempts to load the artifact using the given function and cache it if the specified artifact is not already cached.java.io.InputStream
getArtifactAsStream(java.lang.String name)
Finds an artifact resource with a given name in the model.java.lang.String[]
getArtifactNames()
Returns the artifact names associated with the model.Block
getBlock()
Gets the block from the Model.DataType
getDataType()
Returns the standard data type used within the model.java.nio.file.Path
getModelPath()
Returns the directory from where the model is loaded.java.lang.String
getName()
Gets the model name.NDManager
getNDManager()
Gets theNDManager
from the model.java.util.Map<java.lang.String,java.lang.String>
getProperties()
Returns the model's properties.java.lang.String
getProperty(java.lang.String key)
Returns the property of the model based on property name.void
load(java.io.InputStream is, java.util.Map<java.lang.String,?> options)
Loads the model from theInputStream
with the options provided.<I,O>
Predictor<I,O>newPredictor(Translator<I,O> translator, Device device)
Creates a new Predictor based on the model.Trainer
newTrainer(TrainingConfig trainingConfig)
Creates a newTrainer
instance for a Model.protected java.nio.file.Path
paramPathResolver(java.lang.String prefix, java.util.Map<java.lang.String,?> options)
protected boolean
readParameters(java.nio.file.Path paramFile, java.util.Map<java.lang.String,?> options)
void
save(java.nio.file.Path modelPath, java.lang.String newModelName)
Saves the model to the specifiedmodelPath
with the name provided.void
setBlock(Block block)
Sets the block for the Model for training and inference.void
setDataType(DataType dataType)
Sets the standard data type used within the model.protected void
setModelDir(java.nio.file.Path modelDir)
void
setProperty(java.lang.String key, java.lang.String value)
Sets a property to the model.java.lang.String
toString()
-
Methods inherited from class java.lang.Object
clone, equals, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.Model
cast, getProperty, load, load, load, load, newPredictor, quantize
-
-
-
-
Field Detail
-
modelDir
protected java.nio.file.Path modelDir
-
block
protected Block block
-
modelName
protected java.lang.String modelName
-
manager
protected NDManager manager
-
dataType
protected DataType dataType
-
wasLoaded
protected boolean wasLoaded
-
inputData
protected ai.djl.util.PairList<java.lang.String,Shape> inputData
-
artifacts
protected java.util.Map<java.lang.String,java.lang.Object> artifacts
-
properties
protected java.util.Map<java.lang.String,java.lang.String> properties
-
-
Method Detail
-
getBlock
public Block getBlock()
Gets the block from the Model.
-
setBlock
public void setBlock(Block block)
Sets the block for the Model for training and inference.
-
getName
public java.lang.String getName()
Gets the model name.
-
getNDManager
public NDManager getNDManager()
Gets theNDManager
from the model.- Specified by:
getNDManager
in interfaceModel
- Returns:
- the
NDManager
-
newTrainer
public Trainer newTrainer(TrainingConfig trainingConfig)
Creates a newTrainer
instance for a Model.- Specified by:
newTrainer
in interfaceModel
- Parameters:
trainingConfig
- training configuration settings- Returns:
- the
Trainer
instance
-
newPredictor
public <I,O> Predictor<I,O> newPredictor(Translator<I,O> translator, Device device)
Creates a new Predictor based on the model.- Specified by:
newPredictor
in interfaceModel
- Type Parameters:
I
- the input object for pre-processingO
- the output object from postprocessing- Parameters:
translator
- the object used for pre-processing and postprocessingdevice
- the device to use for prediction- Returns:
- an instance of
Predictor
-
setDataType
public void setDataType(DataType dataType)
Sets the standard data type used within the model.- Specified by:
setDataType
in interfaceModel
- Parameters:
dataType
- the standard data type to use
-
getDataType
public DataType getDataType()
Returns the standard data type used within the model.- Specified by:
getDataType
in interfaceModel
- Returns:
- the standard data type used within the model
-
load
public void load(java.io.InputStream is, java.util.Map<java.lang.String,?> options) throws java.io.IOException, MalformedModelException
Loads the model from theInputStream
with the options provided.- Specified by:
load
in interfaceModel
- Parameters:
is
- theInputStream
to load the model fromoptions
- engine specific load model options, see documentation for each engine- Throws:
java.io.IOException
- when IO operation fails in loading a resourceMalformedModelException
- if model file is corrupted
-
close
public void close()
-
describeInput
public ai.djl.util.PairList<java.lang.String,Shape> describeInput()
Returns the input descriptor of the model.It contains the information that can be extracted from the model, usually name, shape, layout and DataType.
- Specified by:
describeInput
in interfaceModel
- Returns:
- a PairList of String and Shape
-
describeOutput
public ai.djl.util.PairList<java.lang.String,Shape> describeOutput()
Returns the output descriptor of the model.It contains the output information that can be obtained from the model.
- Specified by:
describeOutput
in interfaceModel
- Returns:
- a PairList of String and Shape
-
getArtifactNames
public java.lang.String[] getArtifactNames()
Returns the artifact names associated with the model.- Specified by:
getArtifactNames
in interfaceModel
- Returns:
- an array of artifact names
-
getArtifact
public <T> T getArtifact(java.lang.String name, java.util.function.Function<java.io.InputStream,T> function) throws java.io.IOException
Attempts to load the artifact using the given function and cache it if the specified artifact is not already cached.Model will cache loaded artifact, so the user doesn't need to keep tracking it.
String synset = model.getArtifact("synset.txt", k -> IOUtils.toString(k)));
- Specified by:
getArtifact
in interfaceModel
- Type Parameters:
T
- the type of the returned artifact object- Parameters:
name
- the name of the desired artifactfunction
- the function to load the artifact- Returns:
- the current (existing or computed) artifact associated with the specified name, or null if the computed value is null
- Throws:
java.io.IOException
- when IO operation fails in loading a resource
-
getArtifact
public java.net.URL getArtifact(java.lang.String artifactName) throws java.io.IOException
Finds an artifact resource with a given name in the model.- Specified by:
getArtifact
in interfaceModel
- Parameters:
artifactName
- the name of the desired artifact- Returns:
- a
URL
object ornull
if no artifact with this name is found - Throws:
java.io.IOException
- when IO operation fails in loading a resource
-
getArtifactAsStream
public java.io.InputStream getArtifactAsStream(java.lang.String name) throws java.io.IOException
Finds an artifact resource with a given name in the model.- Specified by:
getArtifactAsStream
in interfaceModel
- Parameters:
name
- the name of the desired artifact- Returns:
- a
InputStream
object ornull
if no resource with this name is found - Throws:
java.io.IOException
- when IO operation fails in loading a resource
-
setProperty
public void setProperty(java.lang.String key, java.lang.String value)
Sets a property to the model.properties will be saved/loaded with model, user can store some information about the model in here.
- Specified by:
setProperty
in interfaceModel
- Parameters:
key
- the name of the propertyvalue
- the value of the property
-
getProperty
public java.lang.String getProperty(java.lang.String key)
Returns the property of the model based on property name.- Specified by:
getProperty
in interfaceModel
- Parameters:
key
- the name of the property- Returns:
- the value of the property
-
getProperties
public java.util.Map<java.lang.String,java.lang.String> getProperties()
Returns the model's properties.- Specified by:
getProperties
in interfaceModel
- Returns:
- the model's properties
-
setModelDir
protected void setModelDir(java.nio.file.Path modelDir)
-
save
public void save(java.nio.file.Path modelPath, java.lang.String newModelName) throws java.io.IOException
Saves the model to the specifiedmodelPath
with the name provided.
-
getModelPath
public java.nio.file.Path getModelPath()
Returns the directory from where the model is loaded.- Specified by:
getModelPath
in interfaceModel
- Returns:
- the directory of the model location
-
toString
public java.lang.String toString()
- Overrides:
toString
in classjava.lang.Object
-
finalize
protected void finalize() throws java.lang.Throwable
- Overrides:
finalize
in classjava.lang.Object
- Throws:
java.lang.Throwable
-
paramPathResolver
protected java.nio.file.Path paramPathResolver(java.lang.String prefix, java.util.Map<java.lang.String,?> options) throws java.io.IOException
- Throws:
java.io.IOException
-
readParameters
protected boolean readParameters(java.nio.file.Path paramFile, java.util.Map<java.lang.String,?> options) throws java.io.IOException, MalformedModelException
- Throws:
java.io.IOException
MalformedModelException
-
-