Package ai.djl.training
Interface TrainingConfig
-
- All Known Implementing Classes:
DefaultTrainingConfig
public interface TrainingConfig
An interface that is responsible for holding the configuration required byTrainer
.A trainer requires different information to facilitate the training process. This information is passed by using this configuration.
The required options for the configuration are:
- Required
Loss
- A loss function is used to measure how well a model matches the dataset. Because the lower value of the function is better, it is called the "loss" function. This is the only required configuration. Evaluator
- An evaluator is used to measure how well a model matches the dataset. Unlike the loss, they are only there for people to look at and are not used for optimization. Since many losses are not as intuitive, adding other evaluators can help to understand how the model is doing. We recommend adding as many as possible.Device
- The device is what hardware should be used to train your model on. Typically, this is either GPU or GPU. The default is to use a single GPU if it is available or CPU if not.Initializer
- The initializer is used to set the initial values of the model's parameters before training. This can usually be left as the default initializer.Optimizer
- The optimizer is the algorithm that updates the model parameters to minimize the loss function. There are a variety of optimizers, most of which are variants of stochastic gradient descent. When you are just starting, you can use the default optimizer. Later on, customizing the optimizer can result in faster training.ExecutorService
- The executorService is used for parallelization when training batches on multiple GPUs or loading data from the dataset. If none is provided, all operations with be sequential.TrainingListener
- The training listeners add additional functionality to the training process through a listener interface. This can include showing training progress, stopping early if the training fails, or recording performance metrics. We offer several easy sets ofTrainingListener.Defaults
.
-
-
Method Summary
All Methods Instance Methods Abstract Methods Modifier and Type Method Description Device[]
getDevices()
Gets theDevice
that are available for computation.java.util.List<Evaluator>
getEvaluators()
Returns the list ofEvaluator
s that should be computed during training.java.util.concurrent.ExecutorService
getExecutorService()
Gets theExecutorService
for parallelization.ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>>
getInitializers()
Gets a list ofInitializer
and Predicate to initialize the parameters of the model.Loss
getLossFunction()
Gets theLoss
function to compute the loss against.Optimizer
getOptimizer()
Gets theOptimizer
to use during training.java.util.List<TrainingListener>
getTrainingListeners()
Returns the list ofTrainingListener
s that should be used during training.
-
-
-
Method Detail
-
getDevices
Device[] getDevices()
Gets theDevice
that are available for computation.This is necessary for a
Trainer
as it needs to know what kind of device it is running on, and how many devices it is running on.- Returns:
- an array of
Device
-
getInitializers
ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>> getInitializers()
Gets a list ofInitializer
and Predicate to initialize the parameters of the model.- Returns:
- an
Initializer
-
getOptimizer
Optimizer getOptimizer()
Gets theOptimizer
to use during training.- Returns:
- an
Optimizer
-
getLossFunction
Loss getLossFunction()
Gets theLoss
function to compute the loss against.- Returns:
- a
Loss
function
-
getExecutorService
java.util.concurrent.ExecutorService getExecutorService()
Gets theExecutorService
for parallelization.- Returns:
- an
ExecutorService
-
getEvaluators
java.util.List<Evaluator> getEvaluators()
Returns the list ofEvaluator
s that should be computed during training.- Returns:
- a list of
Evaluator
s
-
getTrainingListeners
java.util.List<TrainingListener> getTrainingListeners()
Returns the list ofTrainingListener
s that should be used during training.- Returns:
- a list of
TrainingListener
s
-
-