Package ai.djl.pytorch.engine
Class PtModel
java.lang.Object
ai.djl.BaseModel
ai.djl.pytorch.engine.PtModel
- All Implemented Interfaces:
ai.djl.Model
,AutoCloseable
public class PtModel
extends ai.djl.BaseModel
PtModel
is the PyTorch implementation of Model
.
PtModel contains all the methods in Model to load and process a model. In addition, it provides PyTorch Specific functionality
-
Field Summary
Fields inherited from class ai.djl.BaseModel
artifacts, block, dataType, inputData, manager, modelDir, modelName, properties, wasLoaded
-
Method Summary
Modifier and TypeMethodDescriptionString[]
void
load
(InputStream modelStream, boolean mapLocation) Load PyTorch model fromInputStream
.void
load
(InputStream modelStream, Map<String, ?> options) void
ai.djl.training.Trainer
newTrainer
(ai.djl.training.TrainingConfig trainingConfig) Methods inherited from class ai.djl.BaseModel
close, describeInput, describeOutput, finalize, getArtifact, getArtifact, getArtifactAsStream, getBlock, getDataType, getModelPath, getName, getNDManager, getProperties, getProperty, newPredictor, paramPathResolver, readParameters, save, setBlock, setDataType, setModelDir, setProperty, toString
Methods inherited from class java.lang.Object
clone, equals, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.Model
cast, getProperty, load, load, load, newPredictor, quantize
-
Method Details
-
load
public void load(Path modelPath, String prefix, Map<String, ?> options) throws IOException, ai.djl.MalformedModelException- Throws:
IOException
ai.djl.MalformedModelException
-
load
- Specified by:
load
in interfaceai.djl.Model
- Overrides:
load
in classai.djl.BaseModel
- Throws:
IOException
-
load
Load PyTorch model fromInputStream
.- Parameters:
modelStream
- the stream of the model filemapLocation
- force load to specified device if true- Throws:
IOException
- model loading error
-
newTrainer
public ai.djl.training.Trainer newTrainer(ai.djl.training.TrainingConfig trainingConfig) - Specified by:
newTrainer
in interfaceai.djl.Model
- Overrides:
newTrainer
in classai.djl.BaseModel
-
getArtifactNames
- Specified by:
getArtifactNames
in interfaceai.djl.Model
- Overrides:
getArtifactNames
in classai.djl.BaseModel
-