Class AbstractCompositeLoss

Direct Known Subclasses:
BertPretrainingLoss, SimpleCompositeLoss, SingleShotDetectionLoss

public abstract class AbstractCompositeLoss extends Loss
AbstractCompositeLoss is a Loss class that can combine other Losses together to make a larger loss.

The AbstractCompositeLoss is designed to be extended for more complicated composite losses. For simpler use cases, consider using the SimpleCompositeLoss.

  • Field Details

    • components

      protected List<Loss> components
  • Constructor Details

    • AbstractCompositeLoss

      public AbstractCompositeLoss(String name)
      Constructs a composite loss with the given name.
      Parameters:
      name - the display name of the loss
  • Method Details

    • inputForComponent

      protected abstract ai.djl.util.Pair<NDList,NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions)
      Returns the inputs to computing the loss for a component loss.
      Parameters:
      componentIndex - the index of the component loss
      labels - the label input to the composite loss
      predictions - the predictions input to the composite loss
      Returns:
      a pair of the (labels, predictions) inputs to the component loss
    • getComponents

      public List<Loss> getComponents()
      Returns the component losses that make up the composite loss.
      Returns:
      the component losses that make up the composite loss
    • evaluate

      public NDArray evaluate(NDList labels, NDList predictions)
      Calculates the evaluation between the labels and the predictions.
      Specified by:
      evaluate in class Evaluator
      Parameters:
      labels - the correct values
      predictions - the predicted values
      Returns:
      the evaluation result
    • addAccumulator

      public void addAccumulator(String key)
      Adds an accumulator for the results of the evaluation with the given key.
      Overrides:
      addAccumulator in class Loss
      Parameters:
      key - the key for the new accumulator
    • updateAccumulators

      public void updateAccumulators(String[] keys, NDList labels, NDList predictions)
      Updates the evaluator with the given keys based on a NDList of labels and predictions.

      This is a synchronized operation. You should only call it at the end of a batch or epoch.

      This is an alternative to @{link Evaluator.updateAccumulator(String, NDList, NDList)} that may be more efficient when updating multiple accumulators at once.

      Overrides:
      updateAccumulators in class Loss
      Parameters:
      keys - the keys of all the accumulators to update
      labels - a NDList of labels
      predictions - a NDList of predictions
    • resetAccumulator

      public void resetAccumulator(String key)
      Resets the evaluator value with the given key.
      Overrides:
      resetAccumulator in class Loss
      Parameters:
      key - the key of the accumulator to reset
    • getAccumulator

      public float getAccumulator(String key)
      Returns the accumulated evaluator value.
      Overrides:
      getAccumulator in class Loss
      Parameters:
      key - the key of the accumulator to get
      Returns:
      the accumulated value