Package ai.djl.training
Class Trainer
java.lang.Object
ai.djl.training.Trainer
- All Implemented Interfaces:
AutoCloseable
The
Trainer
interface provides a session for model training.
Trainer
provides an easy, and manageable interface for training. Trainer
is
not thread-safe.
See the tutorials on:
- See Also:
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionvoid
Helper to add a metric for a time difference.void
close()
Evaluates function of the model once on the given inputNDList
.protected void
finalize()
Applies the forward function of the model once on the given inputNDList
.Applies the forward function of the model once with both data and labels.Device[]
Returns the devices used for training.Gets allEvaluator
s.Returns theExecutorService
.getLoss()
Gets the trainingLoss
function of the trainer.Gets theNDManager
from the model.Returns the Metrics param used for benchmarking.getModel()
Returns the model used to create this trainer.Returns theTrainingResult
.void
initialize
(Shape... shapes) Initializes theModel
that theTrainer
is going to train.iterateDataset
(Dataset dataset) Fetches an iterator that can iterate through the givenDataset
.Returns a new instance ofGradientCollector
.final void
notifyListeners
(Consumer<TrainingListener> listenerConsumer) Executes a method on each of theTrainingListener
s.void
setMetrics
(Metrics metrics) Attaches a Metrics param to use for benchmarking.void
step()
Updates all of the parameters of the model once.
-
Constructor Details
-
Trainer
- Parameters:
model
- the model the trainer will train ontrainingConfig
- the configuration used by the trainer
-
-
Method Details
-
initialize
Initializes theModel
that theTrainer
is going to train.- Parameters:
shapes
- an array ofShape
of the inputs
-
iterateDataset
Fetches an iterator that can iterate through the givenDataset
.- Parameters:
dataset
- the dataset to iterate through- Returns:
- an
Iterable
ofBatch
that contains batches of data from the dataset - Throws:
IOException
- for various exceptions depending on the datasetTranslateException
- if there is an error while processing input
-
newGradientCollector
Returns a new instance ofGradientCollector
.- Returns:
- a new instance of
GradientCollector
-
forward
Applies the forward function of the model once on the given inputNDList
.- Parameters:
input
- the inputNDList
- Returns:
- the output of the forward function
-
forward
Applies the forward function of the model once with both data and labels. -
evaluate
Evaluates function of the model once on the given inputNDList
.- Parameters:
input
- the inputNDList
- Returns:
- the output of the predict function
-
step
public void step()Updates all of the parameters of the model once. -
getMetrics
Returns the Metrics param used for benchmarking.- Returns:
- the the Metrics param used for benchmarking
-
setMetrics
Attaches a Metrics param to use for benchmarking.- Parameters:
metrics
- the Metrics class
-
getDevices
Returns the devices used for training.- Returns:
- the devices used for training
-
getLoss
Gets the trainingLoss
function of the trainer.- Returns:
- the
Loss
function
-
getModel
Returns the model used to create this trainer.- Returns:
- the model associated with this trainer
-
getExecutorService
Returns theExecutorService
.- Returns:
- the
ExecutorService
-
getEvaluators
Gets allEvaluator
s.- Returns:
- the evaluators used during training
-
notifyListeners
Executes a method on each of theTrainingListener
s.- Parameters:
listenerConsumer
- a consumer that executes the method
-
getTrainingResult
Returns theTrainingResult
.- Returns:
- the
TrainingResult
-
getManager
Gets theNDManager
from the model.- Returns:
- the
NDManager
-
finalize
-
close
public void close()- Specified by:
close
in interfaceAutoCloseable
-
addMetric
Helper to add a metric for a time difference.- Parameters:
metricName
- the metric namebegin
- the time difference start (this method is called at the time difference end)
-