Package ai.djl.training.loss
Class SingleShotDetectionLoss
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- ai.djl.training.loss.AbstractCompositeLoss
-
- ai.djl.training.loss.SingleShotDetectionLoss
-
public class SingleShotDetectionLoss extends AbstractCompositeLoss
SingleShotDetectionLoss
is an implementation ofLoss
. It is used to compute the loss while training a Single Shot Detection (SSD) model for object detection. It involves computing the targets given the generated anchors, labels and predictions, and then computing the sum of class predictions and bounding box predictions.
-
-
Field Summary
-
Fields inherited from class ai.djl.training.loss.AbstractCompositeLoss
components
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description SingleShotDetectionLoss()
Base class for metric with abstract update methods.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description protected ai.djl.util.Pair<NDList,NDList>
inputForComponent(int componentIndex, NDList labels, NDList predictions)
Calculate loss between label and prediction.-
Methods inherited from class ai.djl.training.loss.AbstractCompositeLoss
addAccumulator, evaluate, getAccumulator, getComponents, resetAccumulator, updateAccumulator
-
Methods inherited from class ai.djl.training.loss.Loss
elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss
-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-
-
-
Method Detail
-
inputForComponent
protected ai.djl.util.Pair<NDList,NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions)
Calculate loss between label and prediction.- Specified by:
inputForComponent
in classAbstractCompositeLoss
- Parameters:
labels
- target labels. Must contain (offsetLabels, masks, classlabels). This is returned by MultiBoxTarget functionpredictions
- predicted labels (class prediction, offset prediction)componentIndex
- the index of the component loss- Returns:
- loss value
-
-