public interface TrainingConfig
Trainer
.
A trainer requires different information to facilitate the training process. This information is passed by using this configuration.
The required options for the configuration are:
Loss
- A loss function is used to measure how well a model matches
the dataset. Because the lower value of the function is better, it is called the "loss"
function. This is the only required configuration.
Evaluator
- An evaluator is used to measure how well a model matches the dataset.
Unlike the loss, they are only there for people to look at and are not used for
optimization. Since many losses are not as intuitive, adding other evaluators can help to
understand how the model is doing. We recommend adding as many as possible.
Device
- The device is what hardware should be used to train your model on.
Typically, this is either GPU or GPU. The default is to use a single GPU if it is available
or CPU if not.
Initializer
- The initializer is used to set the initial values of the model's
parameters before training. This can usually be left as the default initializer.
Optimizer
- The optimizer is the algorithm that updates the model parameters to
minimize the loss function. There are a variety of optimizers, most of which are variants
of stochastic gradient descent. When you are just starting, you can use the default
optimizer. Later on, customizing the optimizer can result in faster training.
ExecutorService
- The executorService is used for parallelization when training
batches on multiple GPUs or loading data from the dataset. If none is provided, all
operations with be sequential.
TrainingListener
- The training listeners add additional functionality to the
training process through a listener interface. This can include showing training progress,
stopping early if the training fails, or recording performance metrics. We offer several
easy sets of TrainingListener.Defaults
.
Modifier and Type | Method and Description |
---|---|
Device[] |
getDevices()
Gets the
Device that are available for computation. |
java.util.List<Evaluator> |
getEvaluators()
Returns the list of
Evaluator s that should be computed during training. |
java.util.concurrent.ExecutorService |
getExecutorService()
Gets the
ExecutorService for parallelization. |
ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>> |
getInitializers()
Gets a list of
Initializer and Predicate to initialize the parameters of the model. |
Loss |
getLossFunction()
Gets the
Loss function to compute the loss against. |
Optimizer |
getOptimizer()
Gets the
Optimizer to use during training. |
java.util.List<TrainingListener> |
getTrainingListeners()
Returns the list of
TrainingListener s that should be used during training. |
Device[] getDevices()
Device
that are available for computation.
This is necessary for a Trainer
as it needs to know what kind of device it is
running on, and how many devices it is running on.
Device
ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>> getInitializers()
Initializer
and Predicate to initialize the parameters of the model.Initializer
Optimizer getOptimizer()
Optimizer
to use during training.Optimizer
Loss getLossFunction()
Loss
function to compute the loss against.Loss
functionjava.util.concurrent.ExecutorService getExecutorService()
ExecutorService
for parallelization.ExecutorService
java.util.List<Evaluator> getEvaluators()
Evaluator
s that should be computed during training.Evaluator
sjava.util.List<TrainingListener> getTrainingListeners()
TrainingListener
s that should be used during training.TrainingListener
s