Package ai.djl.pytorch.engine
Class PtModel
- java.lang.Object
-
- ai.djl.BaseModel
-
- ai.djl.pytorch.engine.PtModel
-
- All Implemented Interfaces:
ai.djl.Model,java.lang.AutoCloseable
public class PtModel extends ai.djl.BaseModelPtModelis the PyTorch implementation ofModel.PtModel contains all the methods in Model to load and process a model. In addition, it provides PyTorch Specific functionality
-
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description java.lang.String[]getArtifactNames()voidload(java.io.InputStream modelStream, boolean mapLocation)Load PyTorch model fromInputStream.voidload(java.io.InputStream modelStream, java.util.Map<java.lang.String,?> options)voidload(java.nio.file.Path modelPath, java.lang.String prefix, java.util.Map<java.lang.String,?> options)ai.djl.training.TrainernewTrainer(ai.djl.training.TrainingConfig trainingConfig)-
Methods inherited from class ai.djl.BaseModel
close, describeInput, describeOutput, finalize, getArtifact, getArtifact, getArtifactAsStream, getBlock, getDataType, getModelPath, getName, getNDManager, getProperty, newPredictor, paramPathResolver, readParameters, save, setBlock, setDataType, setModelDir, setProperty, toString
-
-
-
-
Method Detail
-
load
public void load(java.nio.file.Path modelPath, java.lang.String prefix, java.util.Map<java.lang.String,?> options) throws java.io.IOException, ai.djl.MalformedModelException- Throws:
java.io.IOExceptionai.djl.MalformedModelException
-
load
public void load(java.io.InputStream modelStream, java.util.Map<java.lang.String,?> options) throws java.io.IOException- Specified by:
loadin interfaceai.djl.Model- Overrides:
loadin classai.djl.BaseModel- Throws:
java.io.IOException
-
load
public void load(java.io.InputStream modelStream, boolean mapLocation) throws java.io.IOExceptionLoad PyTorch model fromInputStream.- Parameters:
modelStream- the stream of the model filemapLocation- force load to specified device if true- Throws:
java.io.IOException- model loading error
-
newTrainer
public ai.djl.training.Trainer newTrainer(ai.djl.training.TrainingConfig trainingConfig)
- Specified by:
newTrainerin interfaceai.djl.Model- Overrides:
newTrainerin classai.djl.BaseModel
-
getArtifactNames
public java.lang.String[] getArtifactNames()
- Specified by:
getArtifactNamesin interfaceai.djl.Model- Overrides:
getArtifactNamesin classai.djl.BaseModel
-
-