Package ai.djl.training.listener
Class SaveModelTrainingListener
- java.lang.Object
-
- ai.djl.training.listener.TrainingListenerAdapter
-
- ai.djl.training.listener.SaveModelTrainingListener
-
- All Implemented Interfaces:
TrainingListener
public class SaveModelTrainingListener extends TrainingListenerAdapter
ATrainingListener
that saves a model and can save checkpoints.
-
-
Nested Class Summary
-
Nested classes/interfaces inherited from interface ai.djl.training.listener.TrainingListener
TrainingListener.BatchData, TrainingListener.Defaults
-
-
Constructor Summary
Constructors Constructor Description SaveModelTrainingListener(java.lang.String outputDir)
Constructs aSaveModelTrainingListener
using the model's name.SaveModelTrainingListener(java.lang.String outputDir, java.lang.String overrideModelName)
Constructs aSaveModelTrainingListener
.SaveModelTrainingListener(java.lang.String outputDir, java.lang.String overrideModelName, int checkpoint)
Constructs aSaveModelTrainingListener
.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description int
getCheckpoint()
Returns the checkpoint frequency (or -1 for no checkpointing) inSaveModelTrainingListener
.java.lang.String
getOverrideModelName()
Returns the override model name to save checkpoints with.void
onEpoch(Trainer trainer)
Listens to the end of an epoch during training.void
onTrainingEnd(Trainer trainer)
Listens to the end of training.protected void
saveModel(Trainer trainer)
void
setCheckpoint(int checkpoint)
Sets the checkpoint frequency inSaveModelTrainingListener
.void
setOverrideModelName(java.lang.String overrideModelName)
Sets the override model name to save checkpoints with.void
setSaveModelCallback(java.util.function.Consumer<Trainer> onSaveModel)
Sets the callback function on model saving.-
Methods inherited from class ai.djl.training.listener.TrainingListenerAdapter
onTrainingBatch, onTrainingBegin, onValidationBatch
-
-
-
-
Constructor Detail
-
SaveModelTrainingListener
public SaveModelTrainingListener(java.lang.String outputDir)
Constructs aSaveModelTrainingListener
using the model's name.- Parameters:
outputDir
- the directory to output the checkpointed models in
-
SaveModelTrainingListener
public SaveModelTrainingListener(java.lang.String outputDir, java.lang.String overrideModelName)
Constructs aSaveModelTrainingListener
.- Parameters:
overrideModelName
- an override model name to save checkpoints withoutputDir
- the directory to output the checkpointed models in
-
SaveModelTrainingListener
public SaveModelTrainingListener(java.lang.String outputDir, java.lang.String overrideModelName, int checkpoint)
Constructs aSaveModelTrainingListener
.- Parameters:
overrideModelName
- an override model name to save checkpoints withoutputDir
- the directory to output the checkpointed models incheckpoint
- adds a checkpoint every n epochs
-
-
Method Detail
-
onEpoch
public void onEpoch(Trainer trainer)
Listens to the end of an epoch during training.- Specified by:
onEpoch
in interfaceTrainingListener
- Overrides:
onEpoch
in classTrainingListenerAdapter
- Parameters:
trainer
- the trainer the listener is attached to
-
onTrainingEnd
public void onTrainingEnd(Trainer trainer)
Listens to the end of training.- Specified by:
onTrainingEnd
in interfaceTrainingListener
- Overrides:
onTrainingEnd
in classTrainingListenerAdapter
- Parameters:
trainer
- the trainer the listener is attached to
-
getOverrideModelName
public java.lang.String getOverrideModelName()
Returns the override model name to save checkpoints with.- Returns:
- the override model name to save checkpoints with
-
setOverrideModelName
public void setOverrideModelName(java.lang.String overrideModelName)
Sets the override model name to save checkpoints with.- Parameters:
overrideModelName
- the override model name to save checkpoints with
-
getCheckpoint
public int getCheckpoint()
Returns the checkpoint frequency (or -1 for no checkpointing) inSaveModelTrainingListener
.- Returns:
- the checkpoint frequency (or -1 for no checkpointing)
-
setCheckpoint
public void setCheckpoint(int checkpoint)
Sets the checkpoint frequency inSaveModelTrainingListener
.- Parameters:
checkpoint
- how many epochs between checkpoints (or -1 for no checkpoints)
-
setSaveModelCallback
public void setSaveModelCallback(java.util.function.Consumer<Trainer> onSaveModel)
Sets the callback function on model saving.This allows user to set custom properties to model metadata.
- Parameters:
onSaveModel
- the callback function on model saving
-
saveModel
protected void saveModel(Trainer trainer)
-
-