Package ai.djl.training.loss
Class ElasticNetWeightDecay
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- ai.djl.training.loss.ElasticNetWeightDecay
-
public class ElasticNetWeightDecay extends Loss
ElasticWeightDecay
calculates L1+L2 penalty of a set of parameters. Used for regularization.L loss is defined as \(L = \lambda_1 \sum_i \vert W_i\vert + \lambda_2 \sum_i {W_i}^2\).
-
-
Field Summary
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description ElasticNetWeightDecay(NDList parameters)
Calculates Elastic Net weight decay for regularization.ElasticNetWeightDecay(java.lang.String name, NDList parameters)
Calculates Elastic Net weight decay for regularization.ElasticNetWeightDecay(java.lang.String name, NDList parameters, float lambda)
Calculates Elastic Net weight decay for regularization.ElasticNetWeightDecay(java.lang.String name, NDList parameters, float lambda1, float lambda2)
Calculates Elastic Net weight decay for regularization.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArray
evaluate(NDList label, NDList prediction)
Calculates the evaluation between the labels and the predictions.-
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator
-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-
-
-
Constructor Detail
-
ElasticNetWeightDecay
public ElasticNetWeightDecay(NDList parameters)
Calculates Elastic Net weight decay for regularization.- Parameters:
parameters
- holds the model weights that will be penalized
-
ElasticNetWeightDecay
public ElasticNetWeightDecay(java.lang.String name, NDList parameters)
Calculates Elastic Net weight decay for regularization.- Parameters:
name
- the name of the penaltyparameters
- holds the model weights that will be penalized
-
ElasticNetWeightDecay
public ElasticNetWeightDecay(java.lang.String name, NDList parameters, float lambda)
Calculates Elastic Net weight decay for regularization.- Parameters:
name
- the name of the penaltyparameters
- holds the model weights that will be penalizedlambda
- the weight to apply to the penalty value, default 1 (both L1 and L2)
-
ElasticNetWeightDecay
public ElasticNetWeightDecay(java.lang.String name, NDList parameters, float lambda1, float lambda2)
Calculates Elastic Net weight decay for regularization.- Parameters:
name
- the name of the penaltyparameters
- holds the model weights that will be penalizedlambda1
- the weight to apply to the L1 penalty value, default 1lambda2
- the weight to apply to the L2 penalty value, default 1
-
-