Class Trainer

java.lang.Object
ai.djl.training.Trainer
All Implemented Interfaces:
AutoCloseable

public class Trainer extends Object implements AutoCloseable
The Trainer interface provides a session for model training.

Trainer provides an easy, and manageable interface for training. Trainer is not thread-safe.

See the tutorials on:

See Also:
  • Constructor Details

    • Trainer

      public Trainer(Model model, TrainingConfig trainingConfig)
      Creates an instance of Trainer with the given Model and TrainingConfig.
      Parameters:
      model - the model the trainer will train on
      trainingConfig - the configuration used by the trainer
  • Method Details

    • initialize

      public void initialize(Shape... shapes)
      Initializes the Model that the Trainer is going to train.
      Parameters:
      shapes - an array of Shape of the inputs
    • iterateDataset

      public Iterable<Batch> iterateDataset(Dataset dataset) throws IOException, TranslateException
      Fetches an iterator that can iterate through the given Dataset.
      Parameters:
      dataset - the dataset to iterate through
      Returns:
      an Iterable of Batch that contains batches of data from the dataset
      Throws:
      IOException - for various exceptions depending on the dataset
      TranslateException - if there is an error while processing input
    • newGradientCollector

      public GradientCollector newGradientCollector()
      Returns a new instance of GradientCollector.
      Returns:
      a new instance of GradientCollector
    • forward

      public NDList forward(NDList input)
      Applies the forward function of the model once on the given input NDList.
      Parameters:
      input - the input NDList
      Returns:
      the output of the forward function
    • forward

      public NDList forward(NDList data, NDList labels)
      Applies the forward function of the model once with both data and labels.
      Parameters:
      data - the input data NDList
      labels - the input labels NDList
      Returns:
      the output of the forward function
    • evaluate

      public NDList evaluate(NDList input)
      Evaluates function of the model once on the given input NDList.
      Parameters:
      input - the input NDList
      Returns:
      the output of the predict function
    • step

      public void step()
      Updates all of the parameters of the model once.
    • getMetrics

      public Metrics getMetrics()
      Returns the Metrics param used for benchmarking.
      Returns:
      the the Metrics param used for benchmarking
    • setMetrics

      public void setMetrics(Metrics metrics)
      Attaches a Metrics param to use for benchmarking.
      Parameters:
      metrics - the Metrics class
    • getDevices

      public Device[] getDevices()
      Returns the devices used for training.
      Returns:
      the devices used for training
    • getLoss

      public Loss getLoss()
      Gets the training Loss function of the trainer.
      Returns:
      the Loss function
    • getModel

      public Model getModel()
      Returns the model used to create this trainer.
      Returns:
      the model associated with this trainer
    • getExecutorService

      public Optional<ExecutorService> getExecutorService()
      Returns the ExecutorService.
      Returns:
      the ExecutorService
    • getEvaluators

      public List<Evaluator> getEvaluators()
      Gets all Evaluators.
      Returns:
      the evaluators used during training
    • notifyListeners

      public final void notifyListeners(Consumer<TrainingListener> listenerConsumer)
      Executes a method on each of the TrainingListeners.
      Parameters:
      listenerConsumer - a consumer that executes the method
    • getTrainingResult

      public TrainingResult getTrainingResult()
      Returns the TrainingResult.
      Returns:
      the TrainingResult
    • getManager

      public NDManager getManager()
      Gets the NDManager from the model.
      Returns:
      the NDManager
    • finalize

      protected void finalize() throws Throwable
      Overrides:
      finalize in class Object
      Throws:
      Throwable
    • close

      public void close()
      Specified by:
      close in interface AutoCloseable
    • addMetric

      public void addMetric(String metricName, long begin)
      Helper to add a metric for a time difference.
      Parameters:
      metricName - the metric name
      begin - the time difference start (this method is called at the time difference end)