public interface Model
extends java.lang.AutoCloseable
A deep learning model usually contains the following parts:
Block
of operations to run
Parameter
s that are trained
For loading a pre-trained model, see load(Path, String)
For training a model, see Trainer
.
For running inference with a model, see Predictor
.
Modifier and Type | Method and Description |
---|---|
void |
cast(DataType dataType)
Casts the model to support a different precision level.
|
void |
close() |
ai.djl.util.PairList<java.lang.String,Shape> |
describeInput()
Returns the input descriptor of the model.
|
ai.djl.util.PairList<java.lang.String,Shape> |
describeOutput()
Returns the output descriptor of the model.
|
java.net.URL |
getArtifact(java.lang.String name)
Finds an artifact resource with a given name in the model.
|
<T> T |
getArtifact(java.lang.String name,
java.util.function.Function<java.io.InputStream,T> function)
Attempts to load the artifact using the given function and cache it if the specified artifact
is not already cached.
|
java.io.InputStream |
getArtifactAsStream(java.lang.String name)
Finds an artifact resource with a given name in the model.
|
java.lang.String[] |
getArtifactNames()
Returns the artifact names associated with the model.
|
Block |
getBlock()
Gets the block from the Model.
|
DataType |
getDataType()
Returns the standard data type used within the model.
|
java.lang.String |
getName()
Gets the model name.
|
NDManager |
getNDManager()
Gets the
NDManager from the model. |
java.lang.String |
getProperty(java.lang.String key)
Gets the property of the model based on property name.
|
default void |
load(java.nio.file.Path modelPath)
Loads the model from the
modelPath . |
default void |
load(java.nio.file.Path modelPath,
java.lang.String modelName)
Loads the model from the
modelPath and the given name. |
void |
load(java.nio.file.Path modelPath,
java.lang.String modelName,
java.util.Map<java.lang.String,java.lang.String> options)
Loads the model from the
modelPath with the name and options provided. |
static Model |
newInstance()
Creates an empty model instance.
|
static Model |
newInstance(Device device)
Creates an empty model instance on the specified
Device . |
static Model |
newInstance(Device device,
java.lang.String engineName)
Creates an empty model instance on the specified
Device and engine. |
<I,O> Predictor<I,O> |
newPredictor(Translator<I,O> translator)
Creates a new Predictor based on the model.
|
Trainer |
newTrainer(TrainingConfig trainingConfig)
Creates a new
Trainer instance for a Model. |
default void |
quantize()
Converts the model to use a lower precision quantized network.
|
void |
save(java.nio.file.Path modelPath,
java.lang.String modelName)
Saves the model to the specified
modelPath with the name provided. |
void |
setBlock(Block block)
Sets the block for the Model for training and inference.
|
void |
setDataType(DataType dataType)
Sets the standard data type used within the model.
|
void |
setProperty(java.lang.String key,
java.lang.String value)
Sets a property to the model.
|
static Model newInstance()
static Model newInstance(Device device)
Device
.device
- the device to load the model ontostatic Model newInstance(Device device, java.lang.String engineName)
Device
and engine.device
- the device to load the model ontoengineName
- the name of the enginedefault void load(java.nio.file.Path modelPath) throws java.io.IOException, MalformedModelException
modelPath
.modelPath
- the directory or file path of the model locationjava.io.IOException
- when IO operation fails in loading a resourceMalformedModelException
- if model file is corrupteddefault void load(java.nio.file.Path modelPath, java.lang.String modelName) throws java.io.IOException, MalformedModelException
modelPath
and the given name.modelPath
- the directory or file path of the model locationmodelName
- the model file namejava.io.IOException
- when IO operation fails in loading a resourceMalformedModelException
- if model file is corruptedvoid load(java.nio.file.Path modelPath, java.lang.String modelName, java.util.Map<java.lang.String,java.lang.String> options) throws java.io.IOException, MalformedModelException
modelPath
with the name and options provided.modelPath
- the directory or file path of the model locationmodelName
- the model file nameoptions
- engine specific load model options, see documentation for each enginejava.io.IOException
- when IO operation fails in loading a resourceMalformedModelException
- if model file is corruptedvoid save(java.nio.file.Path modelPath, java.lang.String modelName) throws java.io.IOException
modelPath
with the name provided.modelPath
- the directory or file path of the model locationmodelName
- the model file namejava.io.IOException
- when IO operation fails in loading a resourcevoid setBlock(Block block)
block
- the Block
used in Modeljava.lang.String getName()
java.lang.String getProperty(java.lang.String key)
key
- the name of the propertyvoid setProperty(java.lang.String key, java.lang.String value)
properties will be saved/loaded with model, user can store some information about the model in here.
key
- the name of the propertyvalue
- the value of the propertyTrainer newTrainer(TrainingConfig trainingConfig)
Trainer
instance for a Model.trainingConfig
- training configuration settingsTrainer
instance<I,O> Predictor<I,O> newPredictor(Translator<I,O> translator)
I
- the input object for pre-processingO
- the output object from postprocessingtranslator
- the object used for pre-processing and postprocessingPredictor
ai.djl.util.PairList<java.lang.String,Shape> describeInput()
It contains the information that can be extracted from the model, usually name, shape, layout and DataType.
ai.djl.util.PairList<java.lang.String,Shape> describeOutput()
It contains the output information that can be obtained from the model.
java.lang.String[] getArtifactNames()
<T> T getArtifact(java.lang.String name, java.util.function.Function<java.io.InputStream,T> function) throws java.io.IOException
Model will cache loaded artifact, so the user doesn't need to keep tracking it.
String synset = model.getArtifact("synset.txt", k -> IOUtils.toString(k)));
T
- the type of the returned artifact objectname
- the name of the desired artifactfunction
- the function to load the artifactjava.io.IOException
- when IO operation fails in loading a resourcejava.lang.ClassCastException
- if the cached artifact cannot be cast to the target classjava.net.URL getArtifact(java.lang.String name) throws java.io.IOException
name
- the name of the desired artifactURL
object or null
if no artifact with this name is foundjava.io.IOException
- when IO operation fails in loading a resourcejava.io.InputStream getArtifactAsStream(java.lang.String name) throws java.io.IOException
name
- the name of the desired artifactInputStream
object or null
if no resource with this name is
foundjava.io.IOException
- when IO operation fails in loading a resourcevoid setDataType(DataType dataType)
dataType
- the standard data type to useDataType getDataType()
void cast(DataType dataType)
For example, you can cast the precision from Float to Int
dataType
- the target dataType you would like to cast todefault void quantize()
Quantization converts the network to use int8 data type where possible for smaller model size and faster computation without too large a drop in accuracy. See original paper.
void close()
close
in interface java.lang.AutoCloseable