Package ai.djl.pytorch.engine
Class PtModel
- java.lang.Object
-
- ai.djl.BaseModel
-
- ai.djl.pytorch.engine.PtModel
-
- All Implemented Interfaces:
ai.djl.Model
,java.lang.AutoCloseable
public class PtModel extends ai.djl.BaseModel
PtModel
is the PyTorch implementation ofModel
.PtModel contains all the methods in Model to load and process a model. In addition, it provides PyTorch Specific functionality
-
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description java.lang.String[]
getArtifactNames()
void
load(java.io.InputStream modelStream, boolean mapLocation)
Load PyTorch model fromInputStream
.void
load(java.io.InputStream modelStream, java.util.Map<java.lang.String,?> options)
void
load(java.nio.file.Path modelPath, java.lang.String prefix, java.util.Map<java.lang.String,?> options)
ai.djl.training.Trainer
newTrainer(ai.djl.training.TrainingConfig trainingConfig)
-
Methods inherited from class ai.djl.BaseModel
close, describeInput, describeOutput, finalize, getArtifact, getArtifact, getArtifactAsStream, getBlock, getDataType, getModelPath, getName, getNDManager, getProperties, getProperty, newPredictor, paramPathResolver, readParameters, save, setBlock, setDataType, setModelDir, setProperty, toString
-
-
-
-
Method Detail
-
load
public void load(java.nio.file.Path modelPath, java.lang.String prefix, java.util.Map<java.lang.String,?> options) throws java.io.IOException, ai.djl.MalformedModelException
- Throws:
java.io.IOException
ai.djl.MalformedModelException
-
load
public void load(java.io.InputStream modelStream, java.util.Map<java.lang.String,?> options) throws java.io.IOException
- Specified by:
load
in interfaceai.djl.Model
- Overrides:
load
in classai.djl.BaseModel
- Throws:
java.io.IOException
-
load
public void load(java.io.InputStream modelStream, boolean mapLocation) throws java.io.IOException
Load PyTorch model fromInputStream
.- Parameters:
modelStream
- the stream of the model filemapLocation
- force load to specified device if true- Throws:
java.io.IOException
- model loading error
-
newTrainer
public ai.djl.training.Trainer newTrainer(ai.djl.training.TrainingConfig trainingConfig)
- Specified by:
newTrainer
in interfaceai.djl.Model
- Overrides:
newTrainer
in classai.djl.BaseModel
-
getArtifactNames
public java.lang.String[] getArtifactNames()
- Specified by:
getArtifactNames
in interfaceai.djl.Model
- Overrides:
getArtifactNames
in classai.djl.BaseModel
-
-