public abstract class AbstractCompositeLoss extends Loss
AbstractCompositeLoss
is a Loss
class that can combine other Loss
es
together to make a larger loss.
The AbstractCompositeLoss is designed to be extended for more complicated composite losses.
For simpler use cases, consider using the SimpleCompositeLoss
.
Modifier and Type | Field and Description |
---|---|
protected java.util.List<Loss> |
components |
totalInstances
Constructor and Description |
---|
AbstractCompositeLoss(java.lang.String name)
Constructs a composite loss with the given name.
|
Modifier and Type | Method and Description |
---|---|
void |
addAccumulator(java.lang.String key)
Adds an accumulator for the results of the evaluation with the given key.
|
NDArray |
evaluate(NDList labels,
NDList predictions)
Calculates the evaluation between the labels and the predictions.
|
float |
getAccumulator(java.lang.String key)
Returns the accumulated evaluator value.
|
java.util.List<Loss> |
getComponents()
Returns the component losses that make up the composite loss.
|
protected abstract ai.djl.util.Pair<NDList,NDList> |
inputForComponent(int componentIndex,
NDList labels,
NDList predictions)
Returns the inputs to computing the loss for a component loss.
|
void |
resetAccumulator(java.lang.String key)
Resets the evaluator value with the given key.
|
void |
updateAccumulator(java.lang.String key,
NDList labels,
NDList predictions)
Updates the evaluator with the given key based on a
NDList of labels and predictions. |
elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss
checkLabelShapes, checkLabelShapes, getName
protected java.util.List<Loss> components
public AbstractCompositeLoss(java.lang.String name)
name
- the display name of the lossprotected abstract ai.djl.util.Pair<NDList,NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions)
componentIndex
- the index of the component losslabels
- the label input to the composite losspredictions
- the predictions input to the composite losspublic java.util.List<Loss> getComponents()
public NDArray evaluate(NDList labels, NDList predictions)
public void addAccumulator(java.lang.String key)
addAccumulator
in class Loss
key
- the key for the new accumulatorpublic void updateAccumulator(java.lang.String key, NDList labels, NDList predictions)
NDList
of labels and predictions.
This is a synchronized operation. You should only call it at the end of a batch or epoch.
updateAccumulator
in class Loss
key
- the key of the accumulator to updatelabels
- a NDList
of labelspredictions
- a NDList
of predictionspublic void resetAccumulator(java.lang.String key)
resetAccumulator
in class Loss
key
- the key of the accumulator to resetpublic float getAccumulator(java.lang.String key)
getAccumulator
in class Loss
key
- the key of the accumulator to get