public abstract class Optimizer
extends java.lang.Object
Optimizer updates the weight parameters to minimize the loss function. Optimizer is an abstract class that provides the base implementation for optimizers.| Modifier and Type | Class and Description |
|---|---|
static class |
Optimizer.OptimizerBuilder<T extends Optimizer.OptimizerBuilder>
The Builder to construct an
Optimizer. |
| Modifier and Type | Field and Description |
|---|---|
protected float |
clipGrad |
protected float |
rescaleGrad |
| Constructor and Description |
|---|
Optimizer(Optimizer.OptimizerBuilder<?> builder)
Creates a new instance of
Optimizer. |
| Modifier and Type | Method and Description |
|---|---|
static Adadelta.Builder |
adadelta()
Returns a new instance of
Adadelta.Builder that can build
an Adadelta optimizer. |
static Adagrad.Builder |
adagrad()
Returns a new instance of
Adagrad.Builder that can build an
Adagrad optimizer. |
static Adam.Builder |
adam()
Returns a new instance of
Adam.Builder that can build an
Adam optimizer. |
protected float |
getWeightDecay()
Gets the value of weight decay.
|
static Nag.Builder |
nag()
Returns a new instance of
Nag.Builder that can build an
Nag optimizer. |
static RmsProp.Builder |
rmsprop()
Returns a new instance of
RmsProp.Builder that can build an RmsProp
optimizer. |
static Sgd.Builder |
sgd()
Returns a new instance of
Sgd.Builder that can build an
Sgd optimizer. |
abstract void |
update(java.lang.String parameterId,
NDArray weight,
NDArray grad)
Updates the parameters according to the gradients.
|
protected int |
updateCount(java.lang.String parameterId) |
protected NDArray |
withDefaultState(java.util.Map<java.lang.String,java.util.Map<Device,NDArray>> state,
java.lang.String key,
Device device,
java.util.function.Function<java.lang.String,NDArray> defaultFunction) |
public Optimizer(Optimizer.OptimizerBuilder<?> builder)
Optimizer.builder - the builder used to create an instance of Optimizerpublic static Sgd.Builder sgd()
Sgd.Builder that can build an
Sgd optimizer.Sgd Sgd.Builderpublic static Nag.Builder nag()
Nag.Builder that can build an
Nag optimizer.Nag Nag.Builderpublic static Adam.Builder adam()
Adam.Builder that can build an
Adam optimizer.Adam Adam.Builderpublic static RmsProp.Builder rmsprop()
RmsProp.Builder that can build an RmsProp
optimizer.RmsProp RmsProp.Builderpublic static Adagrad.Builder adagrad()
Adagrad.Builder that can build an
Adagrad optimizer.Adagrad Adagrad.Builderpublic static Adadelta.Builder adadelta()
Adadelta.Builder that can build
an Adadelta optimizer.Adadelta Adadelta.Builderprotected float getWeightDecay()
protected int updateCount(java.lang.String parameterId)
public abstract void update(java.lang.String parameterId,
NDArray weight,
NDArray grad)
parameterId - the parameter to be updatedweight - the weights of the parametergrad - the gradients