Package ai.djl.training
Class EasyTrain
java.lang.Object
ai.djl.training.EasyTrain
Helper for easy training of a whole model, a trainining batch, or a validation batch.
-
Method Summary
Modifier and TypeMethodDescriptionstatic void
evaluateDataset
(Trainer trainer, Dataset testDataset) Evaluates the test dataset.static void
Runs a basic epoch training experience with a given trainer.static void
trainBatch
(Trainer trainer, Batch batch) Trains the model with one iteration of the givenBatch
of data.static void
validateBatch
(Trainer trainer, Batch batch) Validates the given batch of data.
-
Method Details
-
fit
public static void fit(Trainer trainer, int numEpoch, Dataset trainingDataset, Dataset validateDataset) throws IOException, TranslateException Runs a basic epoch training experience with a given trainer.- Parameters:
trainer
- the trainer to train fornumEpoch
- the number of epochs to traintrainingDataset
- the dataset to train onvalidateDataset
- the dataset to validate against. Can be null for no validation- Throws:
IOException
- for various exceptions depending on the datasetTranslateException
- if there is an error while processing input
-
trainBatch
Trains the model with one iteration of the givenBatch
of data.- Parameters:
trainer
- the trainer to validate the batch withbatch
- aBatch
that contains data, and its respective labels- Throws:
IllegalArgumentException
- if the batch engine does not match the trainer engine
-
validateBatch
Validates the given batch of data.During validation, the evaluators and losses are computed, but gradients aren't computed, and parameters aren't updated.
- Parameters:
trainer
- the trainer to validate the batch withbatch
- aBatch
of data- Throws:
IllegalArgumentException
- if the batch engine does not match the trainer engine
-
evaluateDataset
public static void evaluateDataset(Trainer trainer, Dataset testDataset) throws IOException, TranslateException Evaluates the test dataset.- Parameters:
trainer
- the trainer to evaluate ontestDataset
- the test dataset to evaluate- Throws:
IOException
- for various exceptions depending on the datasetTranslateException
- if there is an error while processing input
-