Package ai.djl.training
Class DefaultTrainingConfig
- java.lang.Object
-
- ai.djl.training.DefaultTrainingConfig
-
- All Implemented Interfaces:
TrainingConfig
public class DefaultTrainingConfig extends java.lang.Object implements TrainingConfig
DefaultTrainingConfig
is an implementation of theTrainingConfig
interface.
-
-
Constructor Summary
Constructors Constructor Description DefaultTrainingConfig(Loss loss)
Creates an instance ofDefaultTrainingConfig
with the givenLoss
.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description DefaultTrainingConfig
addEvaluator(Evaluator evaluator)
Adds anEvaluator
that needs to be computed during training.<T extends Evaluator>
DefaultTrainingConfigaddEvaluators(java.util.Collection<T> evaluators)
Adds multipleEvaluator
s that needs to be computed during training.DefaultTrainingConfig
addTrainingListeners(TrainingListener... listeners)
AddsTrainingListener
s for training.Device[]
getDevices()
Gets theDevice
that are available for computation.java.util.List<Evaluator>
getEvaluators()
Returns the list ofEvaluator
s that should be computed during training.java.util.concurrent.ExecutorService
getExecutorService()
Gets theExecutorService
for parallelization.ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>>
getInitializers()
Gets a list ofInitializer
and Predicate to initialize the parameters of the model.Loss
getLossFunction()
Gets theLoss
function to compute the loss against.Optimizer
getOptimizer()
Gets theOptimizer
to use during training.java.util.List<TrainingListener>
getTrainingListeners()
Returns the list ofTrainingListener
s that should be used during training.DefaultTrainingConfig
optDevices(Device[] devices)
Sets the array ofDevice
available for training.DefaultTrainingConfig
optExecutorService()
Sets theExecutorService
with the globalForkJoinPool.commonPool()
.DefaultTrainingConfig
optExecutorService(java.util.concurrent.ExecutorService executorService)
Sets theExecutorService
to train with multiple threads.DefaultTrainingConfig
optInitializer(Initializer initializer, Parameter.Type type)
Sets theInitializer
to use for the parameters (default from paper).DefaultTrainingConfig
optInitializer(Initializer initializer, java.lang.String name)
Sets theInitializer
to use for the parameters (default from paper).DefaultTrainingConfig
optInitializer(Initializer initializer, java.util.function.Predicate<Parameter> predicate)
Sets theInitializer
to use for the parameters (default from paper).DefaultTrainingConfig
optOptimizer(Optimizer optimizer)
-
-
-
Constructor Detail
-
DefaultTrainingConfig
public DefaultTrainingConfig(Loss loss)
Creates an instance ofDefaultTrainingConfig
with the givenLoss
.DefaultTrainingConfig
creates a defaultTrainingConfig
,Adam
as optimiser, and the givenLoss
. The evaluators and listeners are left to the user's discretion.- Parameters:
loss
- the loss to use for training
-
-
Method Detail
-
optInitializer
public DefaultTrainingConfig optInitializer(Initializer initializer, Parameter.Type type)
Sets theInitializer
to use for the parameters (default from paper).- Parameters:
initializer
- the initialer to use for the parameterstype
- theParameter.Type
of the parameters- Returns:
- this
DefaultTrainingConfig
-
optInitializer
public DefaultTrainingConfig optInitializer(Initializer initializer, java.lang.String name)
Sets theInitializer
to use for the parameters (default from paper).- Parameters:
initializer
- the initialer to use for the parametersname
- the name of the parameter- Returns:
- this
DefaultTrainingConfig
-
optInitializer
public DefaultTrainingConfig optInitializer(Initializer initializer, java.util.function.Predicate<Parameter> predicate)
Sets theInitializer
to use for the parameters (default from paper).- Parameters:
initializer
- the initialer to use for the parameterspredicate
- the predicate to identify parameter- Returns:
- this
DefaultTrainingConfig
-
optDevices
public DefaultTrainingConfig optDevices(Device[] devices)
Sets the array ofDevice
available for training.- Parameters:
devices
- an array of devices to be set- Returns:
- this
DefaultTrainingConfig
-
optOptimizer
public DefaultTrainingConfig optOptimizer(Optimizer optimizer)
- Parameters:
optimizer
- the optimizer to be set- Returns:
- this
DefaultTrainingConfig
-
optExecutorService
public DefaultTrainingConfig optExecutorService()
Sets theExecutorService
with the globalForkJoinPool.commonPool()
.- Returns:
- this
DefaultTrainingConfig
-
optExecutorService
public DefaultTrainingConfig optExecutorService(java.util.concurrent.ExecutorService executorService)
Sets theExecutorService
to train with multiple threads.- Parameters:
executorService
- the executor service- Returns:
- this
DefaultTrainingConfig
-
addEvaluators
public <T extends Evaluator> DefaultTrainingConfig addEvaluators(java.util.Collection<T> evaluators)
Adds multipleEvaluator
s that needs to be computed during training.- Type Parameters:
T
- the type of evaluator to be added- Parameters:
evaluators
- the evaluators to be added- Returns:
- this
DefaultTrainingConfig
-
addEvaluator
public DefaultTrainingConfig addEvaluator(Evaluator evaluator)
Adds anEvaluator
that needs to be computed during training.- Parameters:
evaluator
- the evaluator to be added- Returns:
- this
DefaultTrainingConfig
-
addTrainingListeners
public DefaultTrainingConfig addTrainingListeners(TrainingListener... listeners)
AddsTrainingListener
s for training.- Parameters:
listeners
- theTrainingListener
s to add- Returns:
- this
DefaultTrainingConfig
-
getDevices
public Device[] getDevices()
Gets theDevice
that are available for computation.This is necessary for a
Trainer
as it needs to know what kind of device it is running on, and how many devices it is running on.- Specified by:
getDevices
in interfaceTrainingConfig
- Returns:
- an array of
Device
-
getInitializers
public ai.djl.util.PairList<Initializer,java.util.function.Predicate<Parameter>> getInitializers()
Gets a list ofInitializer
and Predicate to initialize the parameters of the model.- Specified by:
getInitializers
in interfaceTrainingConfig
- Returns:
- an
Initializer
-
getOptimizer
public Optimizer getOptimizer()
Gets theOptimizer
to use during training.- Specified by:
getOptimizer
in interfaceTrainingConfig
- Returns:
- an
Optimizer
-
getLossFunction
public Loss getLossFunction()
Gets theLoss
function to compute the loss against.- Specified by:
getLossFunction
in interfaceTrainingConfig
- Returns:
- a
Loss
function
-
getExecutorService
public java.util.concurrent.ExecutorService getExecutorService()
Gets theExecutorService
for parallelization.- Specified by:
getExecutorService
in interfaceTrainingConfig
- Returns:
- an
ExecutorService
-
getEvaluators
public java.util.List<Evaluator> getEvaluators()
Returns the list ofEvaluator
s that should be computed during training.- Specified by:
getEvaluators
in interfaceTrainingConfig
- Returns:
- a list of
Evaluator
s
-
getTrainingListeners
public java.util.List<TrainingListener> getTrainingListeners()
Returns the list ofTrainingListener
s that should be used during training.- Specified by:
getTrainingListeners
in interfaceTrainingConfig
- Returns:
- a list of
TrainingListener
s
-
-