Class EarlyStoppingListener

java.lang.Object
ai.djl.training.listener.EarlyStoppingListener
All Implemented Interfaces:
TrainingListener

public final class EarlyStoppingListener extends Object implements TrainingListener
Listener that allows the training to be stopped early if the validation loss is not improving, or if time has expired.

Usage: Add this listener to the training config, and add it as the last one.

  new DefaultTrainingConfig(...)
        .addTrainingListeners(EarlyStoppingListener.builder()
                .setEpochPatience(1)
                .setEarlyStopPctImprovement(1)
                .setMaxDuration(Duration.ofMinutes(42))
                .setMinEpochs(1)
                .build()
        );
 

Then surround the fit with a try catch that catches the EarlyStoppingListener.EarlyStoppedException.
Example:

 try {
   EasyTrain.fit(trainer, 5, trainDataset, testDataset);
 } catch (EarlyStoppingListener.EarlyStoppedException e) {
   // handle early stopping
   log.info("Stopped early at epoch {} because: {}", e.getEpoch(), e.getMessage());
 }
 

Note: Ensure that Metrics are set on the trainer.
  • Method Details

    • onEpoch

      public void onEpoch(Trainer trainer)
      Listens to the end of an epoch during training.
      Specified by:
      onEpoch in interface TrainingListener
      Parameters:
      trainer - the trainer the listener is attached to
    • onTrainingBatch

      public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData)
      Listens to the end of training one batch of data during training.
      Specified by:
      onTrainingBatch in interface TrainingListener
      Parameters:
      trainer - the trainer the listener is attached to
      batchData - the data from the batch
    • onValidationBatch

      public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData)
      Listens to the end of validating one batch of data during validation.
      Specified by:
      onValidationBatch in interface TrainingListener
      Parameters:
      trainer - the trainer the listener is attached to
      batchData - the data from the batch
    • onTrainingBegin

      public void onTrainingBegin(Trainer trainer)
      Listens to the beginning of training.
      Specified by:
      onTrainingBegin in interface TrainingListener
      Parameters:
      trainer - the trainer the listener is attached to
    • onTrainingEnd

      public void onTrainingEnd(Trainer trainer)
      Listens to the end of training.
      Specified by:
      onTrainingEnd in interface TrainingListener
      Parameters:
      trainer - the trainer the listener is attached to
    • builder

      public static EarlyStoppingListener.Builder builder()
      Creates a builder to build a EarlyStoppingListener.
      Returns:
      a new builder