public class PtModel
extends ai.djl.BaseModel
PtModel
is the PyTorch implementation of Model
.
PtModel contains all the methods in Model to load and process a model. In addition, it provides PyTorch Specific functionality
Modifier and Type | Method and Description |
---|---|
void |
cast(ai.djl.ndarray.types.DataType dataType) |
void |
close() |
java.lang.String[] |
getArtifactNames() |
void |
load(java.nio.file.Path modelPath,
java.lang.String prefix,
java.util.Map<java.lang.String,java.lang.Object> options) |
<I,O> ai.djl.inference.Predictor<I,O> |
newPredictor(ai.djl.translate.Translator<I,O> translator) |
ai.djl.training.Trainer |
newTrainer(ai.djl.training.TrainingConfig trainingConfig) |
describeInput, describeOutput, finalize, getArtifact, getArtifact, getArtifactAsStream, getBlock, getDataType, getName, getNDManager, getProperty, paramPathResolver, readParameters, save, setBlock, setDataType, setModelDir, setProperty
public void load(java.nio.file.Path modelPath, java.lang.String prefix, java.util.Map<java.lang.String,java.lang.Object> options) throws java.io.IOException, ai.djl.MalformedModelException
java.io.IOException
ai.djl.MalformedModelException
public ai.djl.training.Trainer newTrainer(ai.djl.training.TrainingConfig trainingConfig)
public <I,O> ai.djl.inference.Predictor<I,O> newPredictor(ai.djl.translate.Translator<I,O> translator)
public java.lang.String[] getArtifactNames()
public void cast(ai.djl.ndarray.types.DataType dataType)
public void close()