public class WeightDecay extends Object implements Regularization
L = loss + coeff * 0.5 * sum_i w[i]^2
L2Regularization
For all cases, w -= update
applyLR == true
, we have:
update = updater(gradient) + lr * coeff * w
lr
is the learning rate for the current iteration/epoch (accounting for LR schedules if present).applyLR == false
, we have:update = updater(gradient) + coeff * w
Regularization.ApplyStep
Modifier and Type | Field and Description |
---|---|
protected boolean |
applyLR |
protected ISchedule |
coeff |
Constructor and Description |
---|
WeightDecay(double coeff,
boolean applyLR) |
WeightDecay(@NonNull ISchedule coeff,
boolean applyLR) |
Modifier and Type | Method and Description |
---|---|
void |
apply(INDArray param,
INDArray gradView,
double lr,
int iteration,
int epoch)
Apply the regularization by modifying the gradient array in-place
|
Regularization.ApplyStep |
applyStep() |
Regularization |
clone() |
double |
score(INDArray param,
int iteration,
int epoch)
Calculate the loss function score component for the regularization.
For example, in L2 regularization, this would return L = 0.5 * sum_i param[i]^2 For regularization types that don't have a score component, this method can return 0. |
protected final ISchedule coeff
protected final boolean applyLR
public WeightDecay(double coeff, boolean applyLR)
coeff
- Weight decay regularization coefficientapplyLR
- If true, multiply the regularization coefficient by the current learning rate. If false, do not multiply by LR.public WeightDecay(@NonNull @NonNull ISchedule coeff, boolean applyLR)
coeff
- Weight decay regularization coefficient (schedule)applyLR
- If true, multiply the regularization coefficient by the current learning rate. If false, do not multiply by LR.public Regularization.ApplyStep applyStep()
applyStep
in interface Regularization
Regularization.ApplyStep
public void apply(INDArray param, INDArray gradView, double lr, int iteration, int epoch)
Regularization
apply
in interface Regularization
param
- Input array (usually parameters)gradView
- Gradient view array (should be modified/updated). Same shape and type as the input array.lr
- Current learning rateiteration
- Current network training iterationepoch
- Current network training epochpublic double score(INDArray param, int iteration, int epoch)
Regularization
L = 0.5 * sum_i param[i]^2
score
in interface Regularization
param
- Input array (usually parameters)iteration
- Current network training iterationepoch
- Current network training epochpublic Regularization clone()
clone
in interface Regularization
clone
in class Object
Copyright © 2020. All rights reserved.