Package ai.djl.training.loss
Class L1WeightDecay
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- ai.djl.training.loss.L1WeightDecay
-
public class L1WeightDecay extends Loss
L1WeightDecay
calculates L1 penalty of a set of parameters. Used for regularization.L1 loss is defined as \(L1 = \lambda \sum_i \vert W_i\vert\).
-
-
Field Summary
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description L1WeightDecay(NDList parameters)
Calculates L1 weight decay for regularization.L1WeightDecay(java.lang.String name, NDList parameters)
Calculates L1 weight decay for regularization.L1WeightDecay(java.lang.String name, NDList parameters, float lambda)
Calculates L1 weight decay for regularization.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArray
evaluate(NDList label, NDList prediction)
Calculates the evaluation between the labels and the predictions.-
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator
-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-
-
-
Constructor Detail
-
L1WeightDecay
public L1WeightDecay(NDList parameters)
Calculates L1 weight decay for regularization.- Parameters:
parameters
- holds the model weights that will be penalized
-
L1WeightDecay
public L1WeightDecay(java.lang.String name, NDList parameters)
Calculates L1 weight decay for regularization.- Parameters:
name
- the name of the penaltyparameters
- holds the model weights that will be penalized
-
L1WeightDecay
public L1WeightDecay(java.lang.String name, NDList parameters, float lambda)
Calculates L1 weight decay for regularization.- Parameters:
name
- the name of the penaltyparameters
- holds the model weights that will be penalizedlambda
- the weight to apply to the penalty value, default 1
-
-