Package ai.djl.training.loss
Class L2Loss
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- ai.djl.training.loss.L2Loss
-
public class L2Loss extends Loss
Calculates L2Loss between label and prediction, a.k.a. MSE(Mean Square Error).L2 loss is defined by \(L = \frac{1}{2} \sum_i \vert {label}_i - {prediction}_i \vert^2\)
-
-
Field Summary
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description L2Loss()
Calculate L2Loss between the label and prediction, a.k.a.L2Loss(java.lang.String name)
Calculate L2Loss between the label and prediction, a.k.a.L2Loss(java.lang.String name, float weight)
Calculates L2Loss between the label and prediction, a.k.a.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArray
evaluate(NDList label, NDList prediction)
Calculates the evaluation between the labels and the predictions.-
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator
-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-
-
-
Constructor Detail
-
L2Loss
public L2Loss()
Calculate L2Loss between the label and prediction, a.k.a. MSE(Mean Square Error).
-
L2Loss
public L2Loss(java.lang.String name)
Calculate L2Loss between the label and prediction, a.k.a. MSE(Mean Square Error).- Parameters:
name
- the name of the loss
-
L2Loss
public L2Loss(java.lang.String name, float weight)
Calculates L2Loss between the label and prediction, a.k.a. MSE(Mean Square Error).- Parameters:
name
- the name of the lossweight
- the weight to apply on loss value, default 1/2
-
-