Class Loss
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- Direct Known Subclasses:
AbstractCompositeLoss
,BertMaskedLanguageModelLoss
,BertNextSentenceLoss
,ElasticNetWeightDecay
,HingeLoss
,IndexLoss
,L1Loss
,L1WeightDecay
,L2Loss
,L2WeightDecay
,MaskedSoftmaxCrossEntropyLoss
,QuantileL1Loss
,SigmoidBinaryCrossEntropyLoss
,SoftmaxCrossEntropyLoss
public abstract class Loss extends Evaluator
Loss functions (or Cost functions) are used to evaluate the model predictions against true labels for optimization.Although all evaluators can be used to measure the performance of a model, not all of them are suited to being used by an optimizer. Loss functions are usually non-negative where a larger loss represents worse performance. They are also real-valued to accurately compare models.
When creating a loss function, you should avoid having the loss depend on the batch size. For example, if you have a loss per item in a batch and sum those losses, your loss would be
numItemsInBatch*avgLoss
. Instead, you should take the mean of those losses to reduce out the batchSize factor. Otherwise, it can make it difficult to tune the learning rate since any change in the batch size would throw it off. If you have a variable batch size, it would be even more difficult.For more details about the class internals, see
Evaluator
.
-
-
Field Summary
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description Loss(java.lang.String name)
Base class for metric with abstract update methods.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description void
addAccumulator(java.lang.String key)
Adds an accumulator for the results of the evaluation with the given key.static ElasticNetWeightDecay
elasticNetWeightedDecay(NDList parameters)
Returns a new instance ofElasticNetWeightDecay
with default weight and name.static ElasticNetWeightDecay
elasticNetWeightedDecay(java.lang.String name, float weight1, float weight2, NDList parameters)
Returns a new instance ofElasticNetWeightDecay
.static ElasticNetWeightDecay
elasticNetWeightedDecay(java.lang.String name, float weight, NDList parameters)
Returns a new instance ofElasticNetWeightDecay
.static ElasticNetWeightDecay
elasticNetWeightedDecay(java.lang.String name, NDList parameters)
Returns a new instance ofElasticNetWeightDecay
with default weight.float
getAccumulator(java.lang.String key)
Returns the accumulated evaluator value.static HingeLoss
hingeLoss()
Returns a new instance ofHingeLoss
with default arguments.static HingeLoss
hingeLoss(java.lang.String name)
Returns a new instance ofHingeLoss
with default arguments.static HingeLoss
hingeLoss(java.lang.String name, int margin, float weight)
Returns a new instance ofHingeLoss
with the given arguments.static L1Loss
l1Loss()
Returns a new instance ofL1Loss
with default weight and batch axis.static L1Loss
l1Loss(java.lang.String name)
Returns a new instance ofL1Loss
with default weight and batch axis.static L1Loss
l1Loss(java.lang.String name, float weight)
Returns a new instance ofL1Loss
with given weight.static L1WeightDecay
l1WeightedDecay(NDList parameters)
Returns a new instance ofL1WeightDecay
with default weight and name.static L1WeightDecay
l1WeightedDecay(java.lang.String name, float weight, NDList parameters)
Returns a new instance ofL1WeightDecay
.static L1WeightDecay
l1WeightedDecay(java.lang.String name, NDList parameters)
Returns a new instance ofL1WeightDecay
with default weight.static L2Loss
l2Loss()
Returns a new instance ofL2Loss
with default weight and batch axis.static L2Loss
l2Loss(java.lang.String name)
Returns a new instance ofL2Loss
with default weight and batch axis.static L2Loss
l2Loss(java.lang.String name, float weight)
Returns a new instance ofL2Loss
with given weight and batch axis.static L2WeightDecay
l2WeightedDecay(NDList parameters)
Returns a new instance ofL2WeightDecay
with default weight and name.static L2WeightDecay
l2WeightedDecay(java.lang.String name, float weight, NDList parameters)
Returns a new instance ofL2WeightDecay
.static L2WeightDecay
l2WeightedDecay(java.lang.String name, NDList parameters)
Returns a new instance ofL2WeightDecay
with default weight.static MaskedSoftmaxCrossEntropyLoss
maskedSoftmaxCrossEntropyLoss()
Returns a new instance ofMaskedSoftmaxCrossEntropyLoss
with default arguments.static MaskedSoftmaxCrossEntropyLoss
maskedSoftmaxCrossEntropyLoss(java.lang.String name)
Returns a new instance ofMaskedSoftmaxCrossEntropyLoss
with default arguments.static MaskedSoftmaxCrossEntropyLoss
maskedSoftmaxCrossEntropyLoss(java.lang.String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit)
Returns a new instance ofMaskedSoftmaxCrossEntropyLoss
with the given arguments.static QuantileL1Loss
quantileL1Loss(float quantile)
Returns a new instance ofQuantileL1Loss
with given quantile.static QuantileL1Loss
quantileL1Loss(java.lang.String name, float quantile)
Returns a new instance ofQuantileL1Loss
with given quantile.void
resetAccumulator(java.lang.String key)
Resets the evaluator value with the given key.static SigmoidBinaryCrossEntropyLoss
sigmoidBinaryCrossEntropyLoss()
Returns a new instance ofSigmoidBinaryCrossEntropyLoss
with default arguments.static SigmoidBinaryCrossEntropyLoss
sigmoidBinaryCrossEntropyLoss(java.lang.String name)
Returns a new instance ofSigmoidBinaryCrossEntropyLoss
with default arguments.static SigmoidBinaryCrossEntropyLoss
sigmoidBinaryCrossEntropyLoss(java.lang.String name, float weight, boolean fromSigmoid)
Returns a new instance ofSigmoidBinaryCrossEntropyLoss
with the given arguments.static SoftmaxCrossEntropyLoss
softmaxCrossEntropyLoss()
Returns a new instance ofSoftmaxCrossEntropyLoss
with default arguments.static SoftmaxCrossEntropyLoss
softmaxCrossEntropyLoss(java.lang.String name)
Returns a new instance ofSoftmaxCrossEntropyLoss
with default arguments.static SoftmaxCrossEntropyLoss
softmaxCrossEntropyLoss(java.lang.String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit)
Returns a new instance ofSoftmaxCrossEntropyLoss
with the given arguments.void
updateAccumulator(java.lang.String key, NDList labels, NDList predictions)
Updates the evaluator with the given key based on aNDList
of labels and predictions.-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, evaluate, getName
-
-
-
-
Method Detail
-
l1Loss
public static L1Loss l1Loss()
Returns a new instance ofL1Loss
with default weight and batch axis.- Returns:
- a new instance of
L1Loss
-
l1Loss
public static L1Loss l1Loss(java.lang.String name)
Returns a new instance ofL1Loss
with default weight and batch axis.- Parameters:
name
- the name of the loss- Returns:
- a new instance of
L1Loss
-
l1Loss
public static L1Loss l1Loss(java.lang.String name, float weight)
Returns a new instance ofL1Loss
with given weight.- Parameters:
name
- the name of the lossweight
- the weight to apply on loss value, default 1- Returns:
- a new instance of
L1Loss
-
quantileL1Loss
public static QuantileL1Loss quantileL1Loss(float quantile)
Returns a new instance ofQuantileL1Loss
with given quantile.- Parameters:
quantile
- the quantile position of the data to focus on- Returns:
- a new instance of
QuantileL1Loss
-
quantileL1Loss
public static QuantileL1Loss quantileL1Loss(java.lang.String name, float quantile)
Returns a new instance ofQuantileL1Loss
with given quantile.- Parameters:
name
- the name of the lossquantile
- the quantile position of the data to focus on- Returns:
- a new instance of
QuantileL1Loss
-
l2Loss
public static L2Loss l2Loss()
Returns a new instance ofL2Loss
with default weight and batch axis.- Returns:
- a new instance of
L2Loss
-
l2Loss
public static L2Loss l2Loss(java.lang.String name)
Returns a new instance ofL2Loss
with default weight and batch axis.- Parameters:
name
- the name of the loss- Returns:
- a new instance of
L2Loss
-
l2Loss
public static L2Loss l2Loss(java.lang.String name, float weight)
Returns a new instance ofL2Loss
with given weight and batch axis.- Parameters:
name
- the name of the lossweight
- the weight to apply on loss value, default 1- Returns:
- a new instance of
L2Loss
-
sigmoidBinaryCrossEntropyLoss
public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss()
Returns a new instance ofSigmoidBinaryCrossEntropyLoss
with default arguments.- Returns:
- a new instance of
SigmoidBinaryCrossEntropyLoss
-
sigmoidBinaryCrossEntropyLoss
public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(java.lang.String name)
Returns a new instance ofSigmoidBinaryCrossEntropyLoss
with default arguments.- Parameters:
name
- the name of the loss- Returns:
- a new instance of
SigmoidBinaryCrossEntropyLoss
-
sigmoidBinaryCrossEntropyLoss
public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(java.lang.String name, float weight, boolean fromSigmoid)
Returns a new instance ofSigmoidBinaryCrossEntropyLoss
with the given arguments.- Parameters:
name
- the name of the lossweight
- the weight to apply on the loss value, default 1fromSigmoid
- whether the input is from the output of sigmoid, default false- Returns:
- a new instance of
SigmoidBinaryCrossEntropyLoss
-
softmaxCrossEntropyLoss
public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss()
Returns a new instance ofSoftmaxCrossEntropyLoss
with default arguments.- Returns:
- a new instance of
SoftmaxCrossEntropyLoss
-
softmaxCrossEntropyLoss
public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(java.lang.String name)
Returns a new instance ofSoftmaxCrossEntropyLoss
with default arguments.- Parameters:
name
- the name of the loss- Returns:
- a new instance of
SoftmaxCrossEntropyLoss
-
softmaxCrossEntropyLoss
public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(java.lang.String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit)
Returns a new instance ofSoftmaxCrossEntropyLoss
with the given arguments.- Parameters:
name
- the name of the lossweight
- the weight to apply on the loss value, default 1classAxis
- the axis that represents the class probabilities, default -1sparseLabel
- whether labels are integer array or probabilities, default truefromLogit
- whether labels are log probabilities or un-normalized numbers- Returns:
- a new instance of
SoftmaxCrossEntropyLoss
-
maskedSoftmaxCrossEntropyLoss
public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss()
Returns a new instance ofMaskedSoftmaxCrossEntropyLoss
with default arguments.- Returns:
- a new instance of
MaskedSoftmaxCrossEntropyLoss
-
maskedSoftmaxCrossEntropyLoss
public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss(java.lang.String name)
Returns a new instance ofMaskedSoftmaxCrossEntropyLoss
with default arguments.- Parameters:
name
- the name of the loss- Returns:
- a new instance of
MaskedSoftmaxCrossEntropyLoss
-
maskedSoftmaxCrossEntropyLoss
public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss(java.lang.String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit)
Returns a new instance ofMaskedSoftmaxCrossEntropyLoss
with the given arguments.- Parameters:
name
- the name of the lossweight
- the weight to apply on the loss value, default 1classAxis
- the axis that represents the class probabilities, default -1sparseLabel
- whether labels are integer array or probabilities, default truefromLogit
- whether labels are log probabilities or un-normalized numbers- Returns:
- a new instance of
MaskedSoftmaxCrossEntropyLoss
-
hingeLoss
public static HingeLoss hingeLoss()
Returns a new instance ofHingeLoss
with default arguments.- Returns:
- a new instance of
HingeLoss
-
hingeLoss
public static HingeLoss hingeLoss(java.lang.String name)
Returns a new instance ofHingeLoss
with default arguments.- Parameters:
name
- the name of the loss- Returns:
- a new instance of
HingeLoss
-
hingeLoss
public static HingeLoss hingeLoss(java.lang.String name, int margin, float weight)
Returns a new instance ofHingeLoss
with the given arguments.- Parameters:
name
- the name of the lossmargin
- the margin in hinge loss. Defaults to 1.0weight
- the weight to apply on loss value, default 1- Returns:
- a new instance of
HingeLoss
-
l1WeightedDecay
public static L1WeightDecay l1WeightedDecay(NDList parameters)
Returns a new instance ofL1WeightDecay
with default weight and name.- Parameters:
parameters
- holds the model weights that will be penalized- Returns:
- a new instance of
L1WeightDecay
-
l1WeightedDecay
public static L1WeightDecay l1WeightedDecay(java.lang.String name, NDList parameters)
Returns a new instance ofL1WeightDecay
with default weight.- Parameters:
name
- the name of the weight decayparameters
- holds the model weights that will be penalized- Returns:
- a new instance of
L1WeightDecay
-
l1WeightedDecay
public static L1WeightDecay l1WeightedDecay(java.lang.String name, float weight, NDList parameters)
Returns a new instance ofL1WeightDecay
.- Parameters:
name
- the name of the weight decayweight
- the weight to apply on weight decay value, default 1parameters
- holds the model weights that will be penalized- Returns:
- a new instance of
L1WeightDecay
-
l2WeightedDecay
public static L2WeightDecay l2WeightedDecay(NDList parameters)
Returns a new instance ofL2WeightDecay
with default weight and name.- Parameters:
parameters
- holds the model weights that will be penalized- Returns:
- a new instance of
L2WeightDecay
-
l2WeightedDecay
public static L2WeightDecay l2WeightedDecay(java.lang.String name, NDList parameters)
Returns a new instance ofL2WeightDecay
with default weight.- Parameters:
name
- the name of the weight decayparameters
- holds the model weights that will be penalized- Returns:
- a new instance of
L2WeightDecay
-
l2WeightedDecay
public static L2WeightDecay l2WeightedDecay(java.lang.String name, float weight, NDList parameters)
Returns a new instance ofL2WeightDecay
.- Parameters:
name
- the name of the weight decayweight
- the weight to apply on weight decay value, default 1parameters
- holds the model weights that will be penalized- Returns:
- a new instance of
L2WeightDecay
-
elasticNetWeightedDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(NDList parameters)
Returns a new instance ofElasticNetWeightDecay
with default weight and name.- Parameters:
parameters
- holds the model weights that will be penalized- Returns:
- a new instance of
ElasticNetWeightDecay
-
elasticNetWeightedDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(java.lang.String name, NDList parameters)
Returns a new instance ofElasticNetWeightDecay
with default weight.- Parameters:
name
- the name of the weight decayparameters
- holds the model weights that will be penalized- Returns:
- a new instance of
ElasticNetWeightDecay
-
elasticNetWeightedDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(java.lang.String name, float weight, NDList parameters)
Returns a new instance ofElasticNetWeightDecay
.- Parameters:
name
- the name of the weight decayweight
- the weight to apply on weight decay values, default 1parameters
- holds the model weights that will be penalized- Returns:
- a new instance of
ElasticNetWeightDecay
-
elasticNetWeightedDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(java.lang.String name, float weight1, float weight2, NDList parameters)
Returns a new instance ofElasticNetWeightDecay
.- Parameters:
name
- the name of the weight decayweight1
- the weight to apply on weight decay L1 value, default 1weight2
- the weight to apply on weight decay L2 value, default 1parameters
- holds the model weights that will be penalized- Returns:
- a new instance of
ElasticNetWeightDecay
-
addAccumulator
public void addAccumulator(java.lang.String key)
Adds an accumulator for the results of the evaluation with the given key.- Specified by:
addAccumulator
in classEvaluator
- Parameters:
key
- the key for the new accumulator
-
updateAccumulator
public void updateAccumulator(java.lang.String key, NDList labels, NDList predictions)
Updates the evaluator with the given key based on aNDList
of labels and predictions.This is a synchronized operation. You should only call it at the end of a batch or epoch.
- Specified by:
updateAccumulator
in classEvaluator
- Parameters:
key
- the key of the accumulator to updatelabels
- aNDList
of labelspredictions
- aNDList
of predictions
-
resetAccumulator
public void resetAccumulator(java.lang.String key)
Resets the evaluator value with the given key.- Specified by:
resetAccumulator
in classEvaluator
- Parameters:
key
- the key of the accumulator to reset
-
getAccumulator
public float getAccumulator(java.lang.String key)
Returns the accumulated evaluator value.- Specified by:
getAccumulator
in classEvaluator
- Parameters:
key
- the key of the accumulator to get- Returns:
- the accumulated value
-
-