public abstract class Optimizer
extends java.lang.Object
Optimizer
updates the weight parameters to minimize the loss function. Optimizer
is an abstract class that provides the base implementation for optimizers.Modifier and Type | Class and Description |
---|---|
static class |
Optimizer.OptimizerBuilder<T extends Optimizer.OptimizerBuilder>
The Builder to construct an
Optimizer . |
Modifier and Type | Field and Description |
---|---|
protected float |
clipGrad |
protected float |
rescaleGrad |
Constructor and Description |
---|
Optimizer(Optimizer.OptimizerBuilder<?> builder)
Creates a new instance of
Optimizer . |
Modifier and Type | Method and Description |
---|---|
static Adadelta.Builder |
adadelta()
Returns a new instance of
Adadelta.Builder that can build
an Adadelta optimizer. |
static Adagrad.Builder |
adagrad()
Returns a new instance of
Adagrad.Builder that can build an
Adagrad optimizer. |
static Adam.Builder |
adam()
Returns a new instance of
Adam.Builder that can build an
Adam optimizer. |
protected float |
getWeightDecay()
Gets the value of weight decay.
|
static Nag.Builder |
nag()
Returns a new instance of
Nag.Builder that can build an
Nag optimizer. |
static RmsProp.Builder |
rmsprop()
Returns a new instance of
RmsProp.Builder that can build an RmsProp
optimizer. |
static Sgd.Builder |
sgd()
Returns a new instance of
Sgd.Builder that can build an
Sgd optimizer. |
abstract void |
update(java.lang.String parameterId,
NDArray weight,
NDArray grad)
Updates the parameters according to the gradients.
|
protected int |
updateCount(java.lang.String parameterId) |
protected NDArray |
withDefaultState(java.util.Map<java.lang.String,java.util.Map<Device,NDArray>> state,
java.lang.String key,
Device device,
java.util.function.Function<java.lang.String,NDArray> defaultFunction) |
public Optimizer(Optimizer.OptimizerBuilder<?> builder)
Optimizer
.builder
- the builder used to create an instance of Optimizer
public static Sgd.Builder sgd()
Sgd.Builder
that can build an
Sgd
optimizer.Sgd
Sgd.Builder
public static Nag.Builder nag()
Nag.Builder
that can build an
Nag
optimizer.Nag
Nag.Builder
public static Adam.Builder adam()
Adam.Builder
that can build an
Adam
optimizer.Adam
Adam.Builder
public static RmsProp.Builder rmsprop()
RmsProp.Builder
that can build an RmsProp
optimizer.RmsProp
RmsProp.Builder
public static Adagrad.Builder adagrad()
Adagrad.Builder
that can build an
Adagrad
optimizer.Adagrad
Adagrad.Builder
public static Adadelta.Builder adadelta()
Adadelta.Builder
that can build
an Adadelta
optimizer.Adadelta
Adadelta.Builder
protected float getWeightDecay()
protected int updateCount(java.lang.String parameterId)
public abstract void update(java.lang.String parameterId, NDArray weight, NDArray grad)
parameterId
- the parameter to be updatedweight
- the weights of the parametergrad
- the gradients