Package ai.djl.training
Class Trainer
- java.lang.Object
-
- ai.djl.training.Trainer
-
- All Implemented Interfaces:
java.lang.AutoCloseable
public class Trainer extends java.lang.Object implements java.lang.AutoCloseable
TheTrainer
interface provides a session for model training.Trainer
provides an easy, and manageable interface for training.Trainer
is not thread-safe.See the tutorials on:
- See Also:
- The guide on memory management
-
-
Constructor Summary
Constructors Constructor Description Trainer(Model model, TrainingConfig trainingConfig)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description void
addMetric(java.lang.String metricName, long begin)
Helper to add a metric for a time difference.void
close()
NDList
evaluate(NDList input)
Evaluates function of the model once on the given inputNDList
.protected void
finalize()
NDList
forward(NDList input)
Applies the forward function of the model once on the given inputNDList
.NDList
forward(NDList data, NDList labels)
Applies the forward function of the model once with both data and labels.Device[]
getDevices()
Returns the devices used for training.java.util.List<Evaluator>
getEvaluators()
Gets allEvaluator
s.java.util.Optional<java.util.concurrent.ExecutorService>
getExecutorService()
Returns theExecutorService
.Loss
getLoss()
Gets the trainingLoss
function of the trainer.NDManager
getManager()
Gets theNDManager
from the model.Metrics
getMetrics()
Returns the Metrics param used for benchmarking.Model
getModel()
Returns the model used to create this trainer.TrainingResult
getTrainingResult()
Returns theTrainingResult
.void
initialize(Shape... shapes)
Initializes theModel
that theTrainer
is going to train.java.lang.Iterable<Batch>
iterateDataset(Dataset dataset)
Fetches an iterator that can iterate through the givenDataset
.GradientCollector
newGradientCollector()
Returns a new instance ofGradientCollector
.void
notifyListeners(java.util.function.Consumer<TrainingListener> listenerConsumer)
Executes a method on each of theTrainingListener
s.void
setMetrics(Metrics metrics)
Attaches a Metrics param to use for benchmarking.void
step()
Updates all of the parameters of the model once.
-
-
-
Constructor Detail
-
Trainer
public Trainer(Model model, TrainingConfig trainingConfig)
- Parameters:
model
- the model the trainer will train ontrainingConfig
- the configuration used by the trainer
-
-
Method Detail
-
initialize
public void initialize(Shape... shapes)
Initializes theModel
that theTrainer
is going to train.- Parameters:
shapes
- an array ofShape
of the inputs
-
iterateDataset
public java.lang.Iterable<Batch> iterateDataset(Dataset dataset) throws java.io.IOException, TranslateException
Fetches an iterator that can iterate through the givenDataset
.- Parameters:
dataset
- the dataset to iterate through- Returns:
- an
Iterable
ofBatch
that contains batches of data from the dataset - Throws:
java.io.IOException
- for various exceptions depending on the datasetTranslateException
- if there is an error while processing input
-
newGradientCollector
public GradientCollector newGradientCollector()
Returns a new instance ofGradientCollector
.- Returns:
- a new instance of
GradientCollector
-
forward
public NDList forward(NDList input)
Applies the forward function of the model once on the given inputNDList
.- Parameters:
input
- the inputNDList
- Returns:
- the output of the forward function
-
forward
public NDList forward(NDList data, NDList labels)
Applies the forward function of the model once with both data and labels.
-
evaluate
public NDList evaluate(NDList input)
Evaluates function of the model once on the given inputNDList
.- Parameters:
input
- the inputNDList
- Returns:
- the output of the predict function
-
step
public void step()
Updates all of the parameters of the model once.
-
getMetrics
public Metrics getMetrics()
Returns the Metrics param used for benchmarking.- Returns:
- the the Metrics param used for benchmarking
-
setMetrics
public void setMetrics(Metrics metrics)
Attaches a Metrics param to use for benchmarking.- Parameters:
metrics
- the Metrics class
-
getDevices
public Device[] getDevices()
Returns the devices used for training.- Returns:
- the devices used for training
-
getLoss
public Loss getLoss()
Gets the trainingLoss
function of the trainer.- Returns:
- the
Loss
function
-
getModel
public Model getModel()
Returns the model used to create this trainer.- Returns:
- the model associated with this trainer
-
getExecutorService
public java.util.Optional<java.util.concurrent.ExecutorService> getExecutorService()
Returns theExecutorService
.- Returns:
- the
ExecutorService
-
getEvaluators
public java.util.List<Evaluator> getEvaluators()
Gets allEvaluator
s.- Returns:
- the evaluators used during training
-
notifyListeners
public void notifyListeners(java.util.function.Consumer<TrainingListener> listenerConsumer)
Executes a method on each of theTrainingListener
s.- Parameters:
listenerConsumer
- a consumer that executes the method
-
getTrainingResult
public TrainingResult getTrainingResult()
Returns theTrainingResult
.- Returns:
- the
TrainingResult
-
finalize
protected void finalize() throws java.lang.Throwable
- Overrides:
finalize
in classjava.lang.Object
- Throws:
java.lang.Throwable
-
close
public void close()
- Specified by:
close
in interfacejava.lang.AutoCloseable
-
addMetric
public void addMetric(java.lang.String metricName, long begin)
Helper to add a metric for a time difference.- Parameters:
metricName
- the metric namebegin
- the time difference start (this method is called at the time difference end)
-
-