public class SimpleCompositeLoss extends AbstractCompositeLoss
SimpleCompositeLoss
is an implementation of the Loss
abstract class that can
combine different Loss
functions by adding the individual losses together.
This class can be used when the losses either accept a single index of the labels and
predictions or the entire lists. For more complicated composite losses, extend the AbstractCompositeLoss
.
For an example of using this loss, see the captcha training example.
components
totalInstances
Constructor and Description |
---|
SimpleCompositeLoss()
Creates a new empty instance of
CompositeLoss that can combine the given Loss
components. |
SimpleCompositeLoss(java.lang.String name)
Creates a new empty instance of
CompositeLoss that can combine the given Loss
components. |
Modifier and Type | Method and Description |
---|---|
SimpleCompositeLoss |
addLoss(Loss loss)
Adds a Loss that applies to all labels and predictions to this composite loss.
|
SimpleCompositeLoss |
addLoss(Loss loss,
int index)
Adds a Loss that applies to a single index of the label and predictions to this composite
loss.
|
protected ai.djl.util.Pair<NDList,NDList> |
inputForComponent(int componentIndex,
NDList labels,
NDList predictions)
Returns the inputs to computing the loss for a component loss.
|
addAccumulator, evaluate, getAccumulator, getComponents, resetAccumulator, updateAccumulator
hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l2Loss, l2Loss, l2Loss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss
checkLabelShapes, checkLabelShapes, getName
public SimpleCompositeLoss()
CompositeLoss
that can combine the given Loss
components.public SimpleCompositeLoss(java.lang.String name)
CompositeLoss
that can combine the given Loss
components.name
- the display name of the losspublic SimpleCompositeLoss addLoss(Loss loss)
loss
- the loss to addpublic SimpleCompositeLoss addLoss(Loss loss, int index)
loss
- the loss to addindex
- the index in the label and predictions NDLists this loss applies toprotected ai.djl.util.Pair<NDList,NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions)
inputForComponent
in class AbstractCompositeLoss
componentIndex
- the index of the component losslabels
- the label input to the composite losspredictions
- the predictions input to the composite loss