Class MultiBoxTarget


  • public class MultiBoxTarget
    extends java.lang.Object
    MultiBoxTarget is the class that computes the training targets for training a Single Shot Detection (SSD) models.

    The output from a Single Shot Detection (SSD) network would be class probabilities, box offset predictions, and the generated anchor boxes. The labels contain a class label and the bounding box for each object in the image. The generated anchor boxes are each a prior, and need loss computed for each of them. This requires that we assign a ground truth box to every one of them.

    MultiBoxTarget takes an NDList containing (anchor boxes, labels, class predictions) in that order. It computes the Intersection-over-Union (IoU) of each anchor box against every ground-truth box. For every anchor box, it assigns a ground-truth box with maximum IoU with respect to the anchor box if the IoU is greater than a given threshold. Once a ground-truth box is assigned for each anchor box, it computes the offset of each anchor box with respect to it's assigned ground-truth box.

    MultiBoxTarget handles these tasks and returns an NDList containing (Bounding box offsets, bounding box masks, class labels). Bounding box offsets and class labels are computed as above. Bounding box masks is a mask array that contains either a 0 or 1, with the 0s corresponding to the anchor boxes whose IoUs with the ground-truth boxes were less than the given threshold.

    • Method Detail

      • target

        public NDList target​(NDList inputs)
        Computes multi-box training targets.
        Parameters:
        inputs - a NDList of (anchors, labels, and class prediction) in that order
        Returns:
        an NDList containing (Bounding box offsets, bounding box masks, class labels)
      • builder

        public static MultiBoxTarget.Builder builder()
        Creates a builder to build a MultiBoxTarget.
        Returns:
        a new builder