Package ai.djl.training.optimizer
Class Optimizer
- java.lang.Object
-
- ai.djl.training.optimizer.Optimizer
-
public abstract class Optimizer extends java.lang.Object
AnOptimizer
updates the weight parameters to minimize the loss function.Optimizer
is an abstract class that provides the base implementation for optimizers.
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
Optimizer.OptimizerBuilder<T extends Optimizer.OptimizerBuilder>
The Builder to construct anOptimizer
.
-
Field Summary
Fields Modifier and Type Field Description protected float
clipGrad
protected float
rescaleGrad
-
Constructor Summary
Constructors Constructor Description Optimizer(Optimizer.OptimizerBuilder<?> builder)
Creates a new instance ofOptimizer
.
-
Method Summary
All Methods Static Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description static Adadelta.Builder
adadelta()
Returns a new instance ofAdadelta.Builder
that can build anAdadelta
optimizer.static Adagrad.Builder
adagrad()
Returns a new instance ofAdagrad.Builder
that can build anAdagrad
optimizer.static Adam.Builder
adam()
Returns a new instance ofAdam.Builder
that can build anAdam
optimizer.static AdamW.Builder
adamW()
Returns a new instance ofAdamW.Builder
that can build anAdamW
optimizer.protected float
getWeightDecay()
Gets the value of weight decay.static Nag.Builder
nag()
Returns a new instance ofNag.Builder
that can build anNag
optimizer.static RmsProp.Builder
rmsprop()
Returns a new instance ofRmsProp.Builder
that can build anRmsProp
optimizer.static Sgd.Builder
sgd()
Returns a new instance ofSgd.Builder
that can build anSgd
optimizer.abstract void
update(java.lang.String parameterId, NDArray weight, NDArray grad)
Updates the parameters according to the gradients.protected int
updateCount(java.lang.String parameterId)
protected NDArray
withDefaultState(java.util.Map<java.lang.String,java.util.Map<Device,NDArray>> state, java.lang.String key, Device device, java.util.function.Function<java.lang.String,NDArray> defaultFunction)
-
-
-
Constructor Detail
-
Optimizer
public Optimizer(Optimizer.OptimizerBuilder<?> builder)
Creates a new instance ofOptimizer
.- Parameters:
builder
- the builder used to create an instance ofOptimizer
-
-
Method Detail
-
sgd
public static Sgd.Builder sgd()
Returns a new instance ofSgd.Builder
that can build anSgd
optimizer.- Returns:
- the
Sgd
Sgd.Builder
-
nag
public static Nag.Builder nag()
Returns a new instance ofNag.Builder
that can build anNag
optimizer.- Returns:
- the
Nag
Nag.Builder
-
adam
public static Adam.Builder adam()
Returns a new instance ofAdam.Builder
that can build anAdam
optimizer.- Returns:
- the
Adam
Adam.Builder
-
adamW
public static AdamW.Builder adamW()
Returns a new instance ofAdamW.Builder
that can build anAdamW
optimizer.- Returns:
- the
AdamW
AdamW.Builder
-
rmsprop
public static RmsProp.Builder rmsprop()
Returns a new instance ofRmsProp.Builder
that can build anRmsProp
optimizer.- Returns:
- the
RmsProp
RmsProp.Builder
-
adagrad
public static Adagrad.Builder adagrad()
Returns a new instance ofAdagrad.Builder
that can build anAdagrad
optimizer.- Returns:
- the
Adagrad
Adagrad.Builder
-
adadelta
public static Adadelta.Builder adadelta()
Returns a new instance ofAdadelta.Builder
that can build anAdadelta
optimizer.- Returns:
- the
Adadelta
Adadelta.Builder
-
getWeightDecay
protected float getWeightDecay()
Gets the value of weight decay.- Returns:
- the value of weight decay
-
updateCount
protected int updateCount(java.lang.String parameterId)
-
update
public abstract void update(java.lang.String parameterId, NDArray weight, NDArray grad)
Updates the parameters according to the gradients.- Parameters:
parameterId
- the parameter to be updatedweight
- the weights of the parametergrad
- the gradients
-
-