Class RmsProp

java.lang.Object
ai.djl.training.optimizer.Optimizer
ai.djl.training.optimizer.RmsProp

public class RmsProp extends Optimizer
The RMSProp Optimizer.

Two versions of RMSProp are implemented.

If `centered = False`, the algorithm described in http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf by Tieleman and Hinton, 2012 is used.

If `centered = True`, the algorithm described in http://arxiv.org/pdf/1308.0850v5.pdf (38)-(45) by Alex Graves, 2013 is used instead.

Default version is `centered = False`.

If `centered = False`:

RMSProp updates the weights using:

\( var = rho * var + (1 - rho) * grad^2 \)
\( weight -= learning_rate * (sqrt(v) + epsilon) \)

If `centered = True`: \( mean = rho * mean + (1 - rho) * grad \)
\( var = rho * var + (1 - rho) * grad^2 \)
\( mom = mom^2 - lr * grad / sqrt(var - mean^2) + epsilon \)
\( weight = mean / (sqrt(var) + epsilon) \)

Grad represents the gradient, mean and var are the 1st and 2nd order moment estimates (mean and variance), and mom is the momentum.

See Also:
  • Constructor Details

    • RmsProp

      protected RmsProp(RmsProp.Builder builder)
      Creates a new instance of RMSProp optimizer.
      Parameters:
      builder - the builder to create a new instance of Adam optimizer
  • Method Details

    • update

      public void update(String parameterId, NDArray weight, NDArray grad)
      Updates the parameters according to the gradients.
      Specified by:
      update in class Optimizer
      Parameters:
      parameterId - the parameter to be updated
      weight - the weights of the parameter
      grad - the gradients
    • builder

      public static RmsProp.Builder builder()
      Creates a builder to build a RMSProp.
      Returns:
      a new builder