Class CheckpointListener
- java.lang.Object
-
- org.deeplearning4j.optimize.api.BaseTrainingListener
-
- org.deeplearning4j.optimize.listeners.CheckpointListener
-
- All Implemented Interfaces:
Serializable
,TrainingListener
public class CheckpointListener extends BaseTrainingListener implements Serializable
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
CheckpointListener.Builder
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description List<Checkpoint>
availableCheckpoints()
List all available checkpoints.static List<Checkpoint>
availableCheckpoints(File directory)
List all available checkpoints.protected static int
getEpoch(Model model)
File
getFileForCheckpoint(int checkpointNum)
Get the model file for the given checkpoint number.static File
getFileForCheckpoint(File rootDir, int checkpointNum)
File
getFileForCheckpoint(Checkpoint checkpoint)
Get the model file for the given checkpoint.protected static int
getIter(Model model)
protected static String
getModelType(Model model)
void
iterationDone(Model model, int iteration, int epoch)
Event listener for each iteration.Checkpoint
lastCheckpoint()
Return the most recent checkpoint, if one exists - otherwise returns nullstatic Checkpoint
lastCheckpoint(File rootDir)
Return the most recent checkpoint, if one exists - otherwise returns nullComputationGraph
loadCheckpointCG(int checkpointNum)
Load a ComputationGraph for the given checkpointstatic ComputationGraph
loadCheckpointCG(File rootDir, int checkpointNum)
Load a ComputationGraph for the given checkpoint that resides in the specified root directorystatic ComputationGraph
loadCheckpointCG(File rootDir, Checkpoint checkpoint)
Load a ComputationGraph for the given checkpoint from the specified root direcotryComputationGraph
loadCheckpointCG(Checkpoint checkpoint)
Load a ComputationGraph for the given checkpointMultiLayerNetwork
loadCheckpointMLN(int checkpointNum)
Load a MultiLayerNetwork for the given checkpoint numberstatic MultiLayerNetwork
loadCheckpointMLN(File rootDir, int checkpointNum)
Load a MultiLayerNetwork for the given checkpoint numberstatic MultiLayerNetwork
loadCheckpointMLN(File rootDir, Checkpoint checkpoint)
Load a MultiLayerNetwork for the given checkpoint that resides in the specified root directoryMultiLayerNetwork
loadCheckpointMLN(Checkpoint checkpoint)
Load a MultiLayerNetwork for the given checkpointstatic ComputationGraph
loadLastCheckpointCG(File rootDir)
Load the last (most recent) checkpoint from the specified root directorystatic MultiLayerNetwork
loadLastCheckpointMLN(File rootDir)
Load the last (most recent) checkpoint from the specified root directoryvoid
onEpochEnd(Model model)
Called once at the end of each epoch, when using methods such asMultiLayerNetwork.fit(DataSetIterator)
,ComputationGraph.fit(DataSetIterator)
orComputationGraph.fit(MultiDataSetIterator)
-
Methods inherited from class org.deeplearning4j.optimize.api.BaseTrainingListener
onBackwardPass, onEpochStart, onForwardPass, onForwardPass, onGradientCalculation
-
-
-
-
Method Detail
-
onEpochEnd
public void onEpochEnd(Model model)
Description copied from interface:TrainingListener
Called once at the end of each epoch, when using methods such asMultiLayerNetwork.fit(DataSetIterator)
,ComputationGraph.fit(DataSetIterator)
orComputationGraph.fit(MultiDataSetIterator)
- Specified by:
onEpochEnd
in interfaceTrainingListener
- Overrides:
onEpochEnd
in classBaseTrainingListener
-
iterationDone
public void iterationDone(Model model, int iteration, int epoch)
Description copied from interface:TrainingListener
Event listener for each iteration. Called once, after each parameter update has ocurred while training the network- Specified by:
iterationDone
in interfaceTrainingListener
- Overrides:
iterationDone
in classBaseTrainingListener
- Parameters:
model
- the model iteratingiteration
- the iteration
-
getIter
protected static int getIter(Model model)
-
getEpoch
protected static int getEpoch(Model model)
-
availableCheckpoints
public List<Checkpoint> availableCheckpoints()
List all available checkpoints. A checkpoint is 'available' if the file can be loaded. Any checkpoint files that have been automatically deleted (given the configuration) will not be returned here.- Returns:
- List of checkpoint files that can be loaded
-
availableCheckpoints
public static List<Checkpoint> availableCheckpoints(File directory)
List all available checkpoints. A checkpoint is 'available' if the file can be loaded. Any checkpoint files that have been automatically deleted (given the configuration) will not be returned here. Note that the checkpointInfo.txt file must exist, as this stores checkpoint information- Returns:
- List of checkpoint files that can be loaded from the specified directory
-
lastCheckpoint
public Checkpoint lastCheckpoint()
Return the most recent checkpoint, if one exists - otherwise returns null- Returns:
- Checkpoint
-
lastCheckpoint
public static Checkpoint lastCheckpoint(File rootDir)
Return the most recent checkpoint, if one exists - otherwise returns null- Parameters:
rootDir
- Root direcotry for the checkpoint files- Returns:
- Checkpoint
-
getFileForCheckpoint
public File getFileForCheckpoint(Checkpoint checkpoint)
Get the model file for the given checkpoint. Checkpoint model file must exist- Parameters:
checkpoint
- Checkpoint to get the model file for- Returns:
- Model file for the checkpoint
-
getFileForCheckpoint
public File getFileForCheckpoint(int checkpointNum)
Get the model file for the given checkpoint number. Checkpoint model file must exist- Parameters:
checkpointNum
- Checkpoint number to get the model file for- Returns:
- Model file for the checkpoint
-
loadCheckpointMLN
public MultiLayerNetwork loadCheckpointMLN(Checkpoint checkpoint)
Load a MultiLayerNetwork for the given checkpoint- Parameters:
checkpoint
- Checkpoint model to load- Returns:
- The loaded model
-
loadCheckpointMLN
public MultiLayerNetwork loadCheckpointMLN(int checkpointNum)
Load a MultiLayerNetwork for the given checkpoint number- Parameters:
checkpointNum
- Checkpoint model to load- Returns:
- The loaded model
-
loadCheckpointMLN
public static MultiLayerNetwork loadCheckpointMLN(File rootDir, Checkpoint checkpoint)
Load a MultiLayerNetwork for the given checkpoint that resides in the specified root directory- Parameters:
rootDir
- Root directory for the checkpointcheckpoint
- Checkpoint model to load- Returns:
- The loaded model
-
loadCheckpointMLN
public static MultiLayerNetwork loadCheckpointMLN(File rootDir, int checkpointNum)
Load a MultiLayerNetwork for the given checkpoint number- Parameters:
rootDir
- The directory that the checkpoint resides incheckpointNum
- Checkpoint model to load- Returns:
- The loaded model
-
loadLastCheckpointMLN
public static MultiLayerNetwork loadLastCheckpointMLN(File rootDir)
Load the last (most recent) checkpoint from the specified root directory- Parameters:
rootDir
- Root directory to load checpoint from- Returns:
- MultiLayerNetwork for last checkpoint
-
loadCheckpointCG
public ComputationGraph loadCheckpointCG(Checkpoint checkpoint)
Load a ComputationGraph for the given checkpoint- Parameters:
checkpoint
- Checkpoint model to load- Returns:
- The loaded model
-
loadCheckpointCG
public static ComputationGraph loadCheckpointCG(File rootDir, Checkpoint checkpoint)
Load a ComputationGraph for the given checkpoint from the specified root direcotry- Parameters:
checkpoint
- Checkpoint model to load- Returns:
- The loaded model
-
loadCheckpointCG
public ComputationGraph loadCheckpointCG(int checkpointNum)
Load a ComputationGraph for the given checkpoint- Parameters:
checkpointNum
- Checkpoint model number to load- Returns:
- The loaded model
-
loadCheckpointCG
public static ComputationGraph loadCheckpointCG(File rootDir, int checkpointNum)
Load a ComputationGraph for the given checkpoint that resides in the specified root directory- Parameters:
rootDir
- Directory that the checkpoint resides incheckpointNum
- Checkpoint model number to load- Returns:
- The loaded model
-
loadLastCheckpointCG
public static ComputationGraph loadLastCheckpointCG(File rootDir)
Load the last (most recent) checkpoint from the specified root directory- Parameters:
rootDir
- Root directory to load checpoint from- Returns:
- ComputationGraph for last checkpoint
-
-