Class EasyTrain

java.lang.Object
ai.djl.training.EasyTrain

public final class EasyTrain extends Object
Helper for easy training of a whole model, a trainining batch, or a validation batch.
  • Method Details

    • fit

      public static void fit(Trainer trainer, int numEpoch, Dataset trainingDataset, Dataset validateDataset) throws IOException, TranslateException
      Runs a basic epoch training experience with a given trainer.
      Parameters:
      trainer - the trainer to train for
      numEpoch - the number of epochs to train
      trainingDataset - the dataset to train on
      validateDataset - the dataset to validate against. Can be null for no validation
      Throws:
      IOException - for various exceptions depending on the dataset
      TranslateException - if there is an error while processing input
    • trainBatch

      public static void trainBatch(Trainer trainer, Batch batch)
      Trains the model with one iteration of the given Batch of data.
      Parameters:
      trainer - the trainer to validate the batch with
      batch - a Batch that contains data, and its respective labels
      Throws:
      IllegalArgumentException - if the batch engine does not match the trainer engine
    • validateBatch

      public static void validateBatch(Trainer trainer, Batch batch)
      Validates the given batch of data.

      During validation, the evaluators and losses are computed, but gradients aren't computed, and parameters aren't updated.

      Parameters:
      trainer - the trainer to validate the batch with
      batch - a Batch of data
      Throws:
      IllegalArgumentException - if the batch engine does not match the trainer engine
    • evaluateDataset

      public static void evaluateDataset(Trainer trainer, Dataset testDataset) throws IOException, TranslateException
      Evaluates the test dataset.
      Parameters:
      trainer - the trainer to evaluate on
      testDataset - the test dataset to evaluate
      Throws:
      IOException - for various exceptions depending on the dataset
      TranslateException - if there is an error while processing input