Package ai.djl.training.loss
Class L2WeightDecay
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- ai.djl.training.loss.L2WeightDecay
-
public class L2WeightDecay extends Loss
L2WeightDecay
calculates L2 penalty of a set of parameters. Used for regularization.L2 loss is defined by \(L2 = \lambda \sum_i {W_i}^2\).
-
-
Field Summary
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description L2WeightDecay(NDList parameters)
Calculates L2 weight decay for regularization.L2WeightDecay(java.lang.String name, NDList parameters)
Calculates L2 weight decay for regularization.L2WeightDecay(java.lang.String name, NDList parameters, float lambda)
Calculates L2 weight decay for regularization.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArray
evaluate(NDList label, NDList prediction)
Calculates the evaluation between the labels and the predictions.-
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator
-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-
-
-
Constructor Detail
-
L2WeightDecay
public L2WeightDecay(NDList parameters)
Calculates L2 weight decay for regularization.- Parameters:
parameters
- holds the model weights that will be penalized
-
L2WeightDecay
public L2WeightDecay(java.lang.String name, NDList parameters)
Calculates L2 weight decay for regularization.- Parameters:
name
- the name of the penaltyparameters
- holds the model weights that will be penalized
-
L2WeightDecay
public L2WeightDecay(java.lang.String name, NDList parameters, float lambda)
Calculates L2 weight decay for regularization.- Parameters:
name
- the name of the penaltyparameters
- holds the model weights that will be penalizedlambda
- the weight to apply to the penalty value, default 1
-
-