public abstract class BaseTrainingListener extends Object implements TrainingListener
TrainingListener
to be used as a starting point for custom training callbacks.
Extend this and selectively override the methods you will actually use.Constructor and Description |
---|
BaseTrainingListener() |
Modifier and Type | Method and Description |
---|---|
void |
iterationDone(Model model,
int iteration,
int epoch)
Event listener for each iteration.
|
void |
onBackwardPass(Model model)
Called once per iteration (backward pass) after gradients have been calculated, and updated
Gradients are available via
Model.gradient() . |
void |
onEpochEnd(Model model)
Called once at the end of each epoch, when using methods such as
MultiLayerNetwork.fit(DataSetIterator) ,
ComputationGraph.fit(DataSetIterator) or ComputationGraph.fit(MultiDataSetIterator) |
void |
onEpochStart(Model model)
Called once at the start of each epoch, when using methods such as
MultiLayerNetwork.fit(DataSetIterator) ,
ComputationGraph.fit(DataSetIterator) or ComputationGraph.fit(MultiDataSetIterator) |
void |
onForwardPass(Model model,
List<INDArray> activations)
Called once per iteration (forward pass) for activations (usually for a
MultiLayerNetwork ),
only at training time |
void |
onForwardPass(Model model,
Map<String,INDArray> activations)
Called once per iteration (forward pass) for activations (usually for a
ComputationGraph ),
only at training time |
void |
onGradientCalculation(Model model)
Called once per iteration (backward pass) before the gradients are updated
Gradients are available via
Model.gradient() . |
public void onEpochStart(Model model)
TrainingListener
MultiLayerNetwork.fit(DataSetIterator)
,
ComputationGraph.fit(DataSetIterator)
or ComputationGraph.fit(MultiDataSetIterator)
onEpochStart
in interface TrainingListener
public void onEpochEnd(Model model)
TrainingListener
MultiLayerNetwork.fit(DataSetIterator)
,
ComputationGraph.fit(DataSetIterator)
or ComputationGraph.fit(MultiDataSetIterator)
onEpochEnd
in interface TrainingListener
public void onForwardPass(Model model, List<INDArray> activations)
TrainingListener
MultiLayerNetwork
),
only at training timeonForwardPass
in interface TrainingListener
model
- Modelactivations
- Layer activations (including input)public void onForwardPass(Model model, Map<String,INDArray> activations)
TrainingListener
ComputationGraph
),
only at training timeonForwardPass
in interface TrainingListener
model
- Modelactivations
- Layer activations (including input)public void onGradientCalculation(Model model)
TrainingListener
Model.gradient()
.
Note that gradients will likely be updated in-place - thus they should be copied or processed synchronously
in this method.
For updates (gradients post learning rate/momentum/rmsprop etc) see TrainingListener.onBackwardPass(Model)
onGradientCalculation
in interface TrainingListener
model
- Modelpublic void onBackwardPass(Model model)
TrainingListener
Model.gradient()
.
Unlike TrainingListener.onGradientCalculation(Model)
the gradients at this point will be post-update, rather than
raw (pre-update) gradients at that method call.
onBackwardPass
in interface TrainingListener
model
- Modelpublic void iterationDone(Model model, int iteration, int epoch)
TrainingListener
iterationDone
in interface TrainingListener
model
- the model iteratingiteration
- the iterationCopyright © 2019. All rights reserved.