public class DefaultTrainingConfig extends java.lang.Object implements TrainingConfig
DefaultTrainingConfig
is an implementation of the TrainingConfig
interface.Constructor and Description |
---|
DefaultTrainingConfig(Loss loss)
Creates an instance of
DefaultTrainingConfig with the given Loss . |
Modifier and Type | Method and Description |
---|---|
DefaultTrainingConfig |
addEvaluator(Evaluator evaluator)
Adds an
Evaluator that needs to be computed during training. |
DefaultTrainingConfig |
addTrainingListeners(TrainingListener... listeners)
Adds
TrainingListener s for training. |
Device[] |
getDevices()
Gets the
Device that are available for computation. |
java.util.List<Evaluator> |
getEvaluators()
Returns the list of
Evaluator s that should be computed during training. |
java.util.concurrent.ExecutorService |
getExecutorService()
Gets the
ExecutorService for parallelization. |
ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>> |
getInitializers()
Gets a list of
Initializer and Predicate to initialize the parameters of the model. |
Loss |
getLossFunction()
Gets the
Loss function to compute the loss against. |
Optimizer |
getOptimizer()
Gets the
Optimizer to use during training. |
java.util.List<TrainingListener> |
getTrainingListeners()
Returns the list of
TrainingListener s that should be used during training. |
DefaultTrainingConfig |
optDevices(Device[] devices)
Sets the array of
Device available for training. |
DefaultTrainingConfig |
optExecutorService()
Sets the
ExecutorService with the global ForkJoinPool.commonPool() . |
DefaultTrainingConfig |
optExecutorService(java.util.concurrent.ExecutorService executorService)
Sets the
ExecutorService to train with multiple threads. |
DefaultTrainingConfig |
optInitializer(Initializer initializer,
Parameter.Type type)
Sets the
Initializer to use for the parameters (default from paper). |
DefaultTrainingConfig |
optInitializer(Initializer initializer,
java.util.function.Predicate<Parameter> predicate)
Sets the
Initializer to use for the parameters (default from paper). |
DefaultTrainingConfig |
optInitializer(Initializer initializer,
java.lang.String name)
Sets the
Initializer to use for the parameters (default from paper). |
DefaultTrainingConfig |
optOptimizer(Optimizer optimizer)
|
public DefaultTrainingConfig(Loss loss)
DefaultTrainingConfig
with the given Loss
. DefaultTrainingConfig
creates a default TrainingConfig
, Adam
as optimiser,
and the given Loss
. The evaluators and listeners are left to the user's discretion.loss
- the loss to use for trainingpublic DefaultTrainingConfig optInitializer(Initializer initializer, Parameter.Type type)
Initializer
to use for the parameters (default from paper).initializer
- the initialer to use for the parameterstype
- the Parameter.Type
of the parametersDefaultTrainingConfig
public DefaultTrainingConfig optInitializer(Initializer initializer, java.lang.String name)
Initializer
to use for the parameters (default from paper).initializer
- the initialer to use for the parametersname
- the name of the parameterDefaultTrainingConfig
public DefaultTrainingConfig optInitializer(Initializer initializer, java.util.function.Predicate<Parameter> predicate)
Initializer
to use for the parameters (default from paper).initializer
- the initialer to use for the parameterspredicate
- the predicate to identify parameterDefaultTrainingConfig
public DefaultTrainingConfig optDevices(Device[] devices)
Device
available for training.devices
- an array of devices to be setDefaultTrainingConfig
public DefaultTrainingConfig optOptimizer(Optimizer optimizer)
optimizer
- the optimizer to be setDefaultTrainingConfig
public DefaultTrainingConfig optExecutorService()
ExecutorService
with the global ForkJoinPool.commonPool()
.DefaultTrainingConfig
public DefaultTrainingConfig optExecutorService(java.util.concurrent.ExecutorService executorService)
ExecutorService
to train with multiple threads.executorService
- the executor serviceDefaultTrainingConfig
public DefaultTrainingConfig addEvaluator(Evaluator evaluator)
Evaluator
that needs to be computed during training.evaluator
- the evaluator to be addedDefaultTrainingConfig
public DefaultTrainingConfig addTrainingListeners(TrainingListener... listeners)
TrainingListener
s for training.listeners
- the TrainingListener
s to addDefaultTrainingConfig
public Device[] getDevices()
Device
that are available for computation.
This is necessary for a Trainer
as it needs to know what kind of device it is
running on, and how many devices it is running on.
getDevices
in interface TrainingConfig
Device
public ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>> getInitializers()
Initializer
and Predicate to initialize the parameters of the model.getInitializers
in interface TrainingConfig
Initializer
public Optimizer getOptimizer()
Optimizer
to use during training.getOptimizer
in interface TrainingConfig
Optimizer
public Loss getLossFunction()
Loss
function to compute the loss against.getLossFunction
in interface TrainingConfig
Loss
functionpublic java.util.concurrent.ExecutorService getExecutorService()
TrainingConfig
ExecutorService
for parallelization.getExecutorService
in interface TrainingConfig
ExecutorService
public java.util.List<Evaluator> getEvaluators()
Evaluator
s that should be computed during training.getEvaluators
in interface TrainingConfig
Evaluator
spublic java.util.List<TrainingListener> getTrainingListeners()
TrainingListener
s that should be used during training.getTrainingListeners
in interface TrainingConfig
TrainingListener
s