Package ai.djl.training
Class DefaultTrainingConfig
java.lang.Object
ai.djl.training.DefaultTrainingConfig
- All Implemented Interfaces:
TrainingConfig
DefaultTrainingConfig
is an implementation of the TrainingConfig
interface.-
Constructor Summary
ConstructorsConstructorDescriptionDefaultTrainingConfig
(Loss loss) Creates an instance ofDefaultTrainingConfig
with the givenLoss
. -
Method Summary
Modifier and TypeMethodDescriptionaddEvaluator
(Evaluator evaluator) Adds anEvaluator
that needs to be computed during training.<T extends Evaluator>
DefaultTrainingConfigaddEvaluators
(Collection<T> evaluators) Adds multipleEvaluator
s that needs to be computed during training.addTrainingListeners
(TrainingListener... listeners) AddsTrainingListener
s for training.Device[]
Gets theDevice
that are available for computation.Returns the list ofEvaluator
s that should be computed during training.Gets theExecutorService
for parallelization.ai.djl.util.PairList<Initializer,
Predicate<Parameter>> Gets a list ofInitializer
and Predicate to initialize the parameters of the model.Gets theLoss
function to compute the loss against.Gets theOptimizer
to use during training.Returns the list ofTrainingListener
s that should be used during training.optDevices
(Device[] devices) Sets the array ofDevice
available for training.Sets theExecutorService
with the globalForkJoinPool.commonPool()
.optExecutorService
(ExecutorService executorService) Sets theExecutorService
to train with multiple threads.optInitializer
(Initializer initializer, Parameter.Type type) Sets theInitializer
to use for the parameters (default from paper).optInitializer
(Initializer initializer, String name) Sets theInitializer
to use for the parameters (default from paper).optInitializer
(Initializer initializer, Predicate<Parameter> predicate) Sets theInitializer
to use for the parameters (default from paper).optOptimizer
(Optimizer optimizer)
-
Constructor Details
-
DefaultTrainingConfig
Creates an instance ofDefaultTrainingConfig
with the givenLoss
.DefaultTrainingConfig
creates a defaultTrainingConfig
,Adam
as optimiser, and the givenLoss
. The evaluators and listeners are left to the user's discretion.- Parameters:
loss
- the loss to use for training
-
-
Method Details
-
optInitializer
Sets theInitializer
to use for the parameters (default from paper).- Parameters:
initializer
- the initialer to use for the parameterstype
- theParameter.Type
of the parameters- Returns:
- this
DefaultTrainingConfig
-
optInitializer
Sets theInitializer
to use for the parameters (default from paper).- Parameters:
initializer
- the initialer to use for the parametersname
- the name of the parameter- Returns:
- this
DefaultTrainingConfig
-
optInitializer
public DefaultTrainingConfig optInitializer(Initializer initializer, Predicate<Parameter> predicate) Sets theInitializer
to use for the parameters (default from paper).- Parameters:
initializer
- the initialer to use for the parameterspredicate
- the predicate to identify parameter- Returns:
- this
DefaultTrainingConfig
-
optDevices
Sets the array ofDevice
available for training.- Parameters:
devices
- an array of devices to be set- Returns:
- this
DefaultTrainingConfig
-
optOptimizer
- Parameters:
optimizer
- the optimizer to be set- Returns:
- this
DefaultTrainingConfig
-
optExecutorService
Sets theExecutorService
with the globalForkJoinPool.commonPool()
.- Returns:
- this
DefaultTrainingConfig
-
optExecutorService
Sets theExecutorService
to train with multiple threads.- Parameters:
executorService
- the executor service- Returns:
- this
DefaultTrainingConfig
-
addEvaluators
Adds multipleEvaluator
s that needs to be computed during training.- Type Parameters:
T
- the type of evaluator to be added- Parameters:
evaluators
- the evaluators to be added- Returns:
- this
DefaultTrainingConfig
-
addEvaluator
Adds anEvaluator
that needs to be computed during training.- Parameters:
evaluator
- the evaluator to be added- Returns:
- this
DefaultTrainingConfig
-
addTrainingListeners
AddsTrainingListener
s for training.- Parameters:
listeners
- theTrainingListener
s to add- Returns:
- this
DefaultTrainingConfig
-
getDevices
Gets theDevice
that are available for computation.This is necessary for a
Trainer
as it needs to know what kind of device it is running on, and how many devices it is running on.- Specified by:
getDevices
in interfaceTrainingConfig
- Returns:
- an array of
Device
-
getInitializers
Gets a list ofInitializer
and Predicate to initialize the parameters of the model.- Specified by:
getInitializers
in interfaceTrainingConfig
- Returns:
- an
Initializer
-
getOptimizer
Gets theOptimizer
to use during training.- Specified by:
getOptimizer
in interfaceTrainingConfig
- Returns:
- an
Optimizer
-
getLossFunction
Gets theLoss
function to compute the loss against.- Specified by:
getLossFunction
in interfaceTrainingConfig
- Returns:
- a
Loss
function
-
getExecutorService
Gets theExecutorService
for parallelization.- Specified by:
getExecutorService
in interfaceTrainingConfig
- Returns:
- an
ExecutorService
-
getEvaluators
Returns the list ofEvaluator
s that should be computed during training.- Specified by:
getEvaluators
in interfaceTrainingConfig
- Returns:
- a list of
Evaluator
s
-
getTrainingListeners
Returns the list ofTrainingListener
s that should be used during training.- Specified by:
getTrainingListeners
in interfaceTrainingConfig
- Returns:
- a list of
TrainingListener
s
-