Package ai.djl.training.loss
Class AbstractCompositeLoss
java.lang.Object
ai.djl.training.evaluator.Evaluator
ai.djl.training.loss.Loss
ai.djl.training.loss.AbstractCompositeLoss
- Direct Known Subclasses:
BertPretrainingLoss
,SimpleCompositeLoss
,SingleShotDetectionLoss
AbstractCompositeLoss
is a Loss
class that can combine other Loss
es
together to make a larger loss.
The AbstractCompositeLoss is designed to be extended for more complicated composite losses.
For simpler use cases, consider using the SimpleCompositeLoss
.
-
Field Summary
FieldsFields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
Constructor Summary
ConstructorsConstructorDescriptionAbstractCompositeLoss
(String name) Constructs a composite loss with the given name. -
Method Summary
Modifier and TypeMethodDescriptionvoid
addAccumulator
(String key) Adds an accumulator for the results of the evaluation with the given key.Calculates the evaluation between the labels and the predictions.float
getAccumulator
(String key) Returns the accumulated evaluator value.Returns the component losses that make up the composite loss.inputForComponent
(int componentIndex, NDList labels, NDList predictions) Returns the inputs to computing the loss for a component loss.void
resetAccumulator
(String key) Resets the evaluator value with the given key.void
updateAccumulators
(String[] keys, NDList labels, NDList predictions) Updates the evaluator with the given keys based on aNDList
of labels and predictions.Methods inherited from class ai.djl.training.loss.Loss
elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
Field Details
-
components
-
-
Constructor Details
-
AbstractCompositeLoss
Constructs a composite loss with the given name.- Parameters:
name
- the display name of the loss
-
-
Method Details
-
inputForComponent
protected abstract ai.djl.util.Pair<NDList,NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions) Returns the inputs to computing the loss for a component loss.- Parameters:
componentIndex
- the index of the component losslabels
- the label input to the composite losspredictions
- the predictions input to the composite loss- Returns:
- a pair of the (labels, predictions) inputs to the component loss
-
getComponents
Returns the component losses that make up the composite loss.- Returns:
- the component losses that make up the composite loss
-
evaluate
Calculates the evaluation between the labels and the predictions. -
addAccumulator
Adds an accumulator for the results of the evaluation with the given key.- Overrides:
addAccumulator
in classLoss
- Parameters:
key
- the key for the new accumulator
-
updateAccumulators
Updates the evaluator with the given keys based on aNDList
of labels and predictions.This is a synchronized operation. You should only call it at the end of a batch or epoch.
This is an alternative to @{link
Evaluator.updateAccumulator(String, NDList, NDList)
} that may be more efficient when updating multiple accumulators at once.- Overrides:
updateAccumulators
in classLoss
- Parameters:
keys
- the keys of all the accumulators to updatelabels
- aNDList
of labelspredictions
- aNDList
of predictions
-
resetAccumulator
Resets the evaluator value with the given key.- Overrides:
resetAccumulator
in classLoss
- Parameters:
key
- the key of the accumulator to reset
-
getAccumulator
Returns the accumulated evaluator value.- Overrides:
getAccumulator
in classLoss
- Parameters:
key
- the key of the accumulator to get- Returns:
- the accumulated value
-