Class RmsProp
- java.lang.Object
-
- ai.djl.training.optimizer.Optimizer
-
- ai.djl.training.optimizer.RmsProp
-
public class RmsProp extends Optimizer
TheRMSProp
Optimizer
.Two versions of RMSProp are implemented.
If `centered = False`, the algorithm described in http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf by Tieleman and Hinton, 2012 is used.
If `centered = True`, the algorithm described in http://arxiv.org/pdf/1308.0850v5.pdf (38)-(45) by Alex Graves, 2013 is used instead.
Default version is `centered = False`.
If `centered = False`:
RMSProp updates the weights using:
\( var = rho * var + (1 - rho) * grad^2 \)
\( weight -= learning_rate * (sqrt(v) + epsilon) \)
If `centered = True`: \( mean = rho * mean + (1 - rho) * grad \)
\( var = rho * var + (1 - rho) * grad^2 \)
\( mom = mom^2 - lr * grad / sqrt(var - mean^2) + epsilon \)
\( weight = mean / (sqrt(var) + epsilon) \)
Grad represents the gradient, mean and var are the 1st and 2nd order moment estimates (mean and variance), and mom is the momentum.- See Also:
- The D2L chapter on RMSProp
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
RmsProp.Builder
The Builder to construct anRmsProp
object.-
Nested classes/interfaces inherited from class ai.djl.training.optimizer.Optimizer
Optimizer.OptimizerBuilder<T extends Optimizer.OptimizerBuilder>
-
-
Field Summary
-
Fields inherited from class ai.djl.training.optimizer.Optimizer
clipGrad, rescaleGrad
-
-
Constructor Summary
Constructors Modifier Constructor Description protected
RmsProp(RmsProp.Builder builder)
Creates a new instance ofRMSProp
optimizer.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static RmsProp.Builder
builder()
Creates a builder to build aRMSProp
.void
update(java.lang.String parameterId, NDArray weight, NDArray grad)
Updates the parameters according to the gradients.-
Methods inherited from class ai.djl.training.optimizer.Optimizer
adadelta, adagrad, adam, adamW, getWeightDecay, nag, rmsprop, sgd, updateCount, withDefaultState
-
-
-
-
Constructor Detail
-
RmsProp
protected RmsProp(RmsProp.Builder builder)
Creates a new instance ofRMSProp
optimizer.- Parameters:
builder
- the builder to create a new instance ofAdam
optimizer
-
-
Method Detail
-
update
public void update(java.lang.String parameterId, NDArray weight, NDArray grad)
Updates the parameters according to the gradients.
-
builder
public static RmsProp.Builder builder()
Creates a builder to build aRMSProp
.- Returns:
- a new builder
-
-