public final class EasyTrain
extends java.lang.Object
Modifier and Type | Method and Description |
---|---|
static void |
evaluateDataset(Trainer trainer,
Dataset testDataset)
Evaluates the test dataset.
|
static void |
fit(Trainer trainer,
int numEpoch,
Dataset trainingDataset,
Dataset validateDataset)
Runs a basic epoch training experience with a given trainer.
|
static void |
trainBatch(Trainer trainer,
Batch batch)
Trains the model with one iteration of the given
Batch of data. |
static void |
validateBatch(Trainer trainer,
Batch batch)
Validates the given batch of data.
|
public static void fit(Trainer trainer, int numEpoch, Dataset trainingDataset, Dataset validateDataset) throws java.io.IOException, TranslateException
trainer
- the trainer to train fornumEpoch
- the number of epochs to traintrainingDataset
- the dataset to train onvalidateDataset
- the dataset to validate against. Can be null for no validationjava.io.IOException
- for various exceptions depending on the datasetTranslateException
- if there is an error while processing inputpublic static void trainBatch(Trainer trainer, Batch batch)
Batch
of data.trainer
- the trainer to validate the batch withbatch
- a Batch
that contains data, and its respective labelsjava.lang.IllegalArgumentException
- if the batch engine does not match the trainer enginepublic static void validateBatch(Trainer trainer, Batch batch)
During validation, the evaluators and losses are computed, but gradients aren't computed, and parameters aren't updated.
trainer
- the trainer to validate the batch withbatch
- a Batch
of datajava.lang.IllegalArgumentException
- if the batch engine does not match the trainer enginepublic static void evaluateDataset(Trainer trainer, Dataset testDataset) throws java.io.IOException, TranslateException
trainer
- the trainer to evaluate ontestDataset
- the test dataset to evaluatejava.io.IOException
- for various exceptions depending on the datasetTranslateException
- if there is an error while processing input