public class EvaluatorTrainingListener extends TrainingListenerAdapter
TrainingListener
that records evaluator results.
Results are recorded for the following stages:
TRAIN_EPOCH
- This accumulates for the whole epoch and is recorded to a metric at
the end of the epoch
TRAIN_PROGRESS
- This accumulates for progressUpdateFrequency
batches and
is recorded to a metric at the end
TRAIN_ALL
- This does not accumulates and records every training batch to a metric
VALIDATE_EPOCH
- This accumulates for the whole validation epoch and is recorded
to a metric at the end of the epoch
The training and validation evaluators are saved as metrics with names that can be found using
metricName(Evaluator, String)
. The validation evaluators are
also saved as model properties with the evaluator name.
TrainingListener.BatchData, TrainingListener.Defaults
Modifier and Type | Field and Description |
---|---|
static java.lang.String |
TRAIN_ALL |
static java.lang.String |
TRAIN_EPOCH |
static java.lang.String |
TRAIN_PROGRESS |
static java.lang.String |
VALIDATE_EPOCH |
Constructor and Description |
---|
EvaluatorTrainingListener()
Constructs an
EvaluatorTrainingListener that updates the training progress the
default frequency. |
EvaluatorTrainingListener(int progressUpdateFrequency)
Constructs an
EvaluatorTrainingListener that updates the training progress the given
frequency. |
Modifier and Type | Method and Description |
---|---|
java.util.Map<java.lang.String,java.lang.Float> |
getLatestEvaluations()
Returns the latest evaluations.
|
static java.lang.String |
metricName(Evaluator evaluator,
java.lang.String stage)
Returns the metric created with the evaluator for the given stage.
|
void |
onEpoch(Trainer trainer)
Listens to the end of an epoch during training.
|
void |
onTrainingBatch(Trainer trainer,
TrainingListener.BatchData batchData)
Listens to the end of training one batch of data during training.
|
void |
onTrainingBegin(Trainer trainer)
Listens to the beginning of training.
|
void |
onValidationBatch(Trainer trainer,
TrainingListener.BatchData batchData)
Listens to the end of validating one batch of data during validation.
|
onTrainingEnd
public static final java.lang.String TRAIN_EPOCH
public static final java.lang.String TRAIN_PROGRESS
public static final java.lang.String TRAIN_ALL
public static final java.lang.String VALIDATE_EPOCH
public EvaluatorTrainingListener()
EvaluatorTrainingListener
that updates the training progress the
default frequency.
Current default frequency is every 5 batches.
public EvaluatorTrainingListener(int progressUpdateFrequency)
EvaluatorTrainingListener
that updates the training progress the given
frequency.progressUpdateFrequency
- the number of batches to accumulate an evaluator before it is
stable enough to outputpublic void onEpoch(Trainer trainer)
onEpoch
in interface TrainingListener
onEpoch
in class TrainingListenerAdapter
trainer
- the trainer the listener is attached topublic void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData)
onTrainingBatch
in interface TrainingListener
onTrainingBatch
in class TrainingListenerAdapter
trainer
- the trainer the listener is attached tobatchData
- the data from the batchpublic void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData)
onValidationBatch
in interface TrainingListener
onValidationBatch
in class TrainingListenerAdapter
trainer
- the trainer the listener is attached tobatchData
- the data from the batchpublic void onTrainingBegin(Trainer trainer)
onTrainingBegin
in interface TrainingListener
onTrainingBegin
in class TrainingListenerAdapter
trainer
- the trainer the listener is attached topublic static java.lang.String metricName(Evaluator evaluator, java.lang.String stage)
evaluator
- the evaluator to read the metric fromstage
- one of TRAIN_EPOCH
, TRAIN_PROGRESS
, or VALIDATE_EPOCH
public java.util.Map<java.lang.String,java.lang.Float> getLatestEvaluations()
The latest evaluations are updated on each epoch.