Class Loss

  • Direct Known Subclasses:
    AbstractCompositeLoss, BertMaskedLanguageModelLoss, BertNextSentenceLoss, ElasticNetWeightDecay, HingeLoss, IndexLoss, L1Loss, L1WeightDecay, L2Loss, L2WeightDecay, MaskedSoftmaxCrossEntropyLoss, SigmoidBinaryCrossEntropyLoss, SoftmaxCrossEntropyLoss

    public abstract class Loss
    extends Evaluator
    Loss functions (or Cost functions) are used to evaluate the model predictions against true labels for optimization.

    Although all evaluators can be used to measure the performance of a model, not all of them are suited to being used by an optimizer. Loss functions are usually non-negative where a larger loss represents worse performance. They are also real-valued to accurately compare models.

    When creating a loss function, you should avoid having the loss depend on the batch size. For example, if you have a loss per item in a batch and sum those losses, your loss would be numItemsInBatch*avgLoss. Instead, you should take the mean of those losses to reduce out the batchSize factor. Otherwise, it can make it difficult to tune the learning rate since any change in the batch size would throw it off. If you have a variable batch size, it would be even more difficult.

    For more details about the class internals, see Evaluator.

    • Constructor Detail

      • Loss

        public Loss​(java.lang.String name)
        Base class for metric with abstract update methods.
        Parameters:
        name - The display name of the Loss
    • Method Detail

      • l1Loss

        public static L1Loss l1Loss()
        Returns a new instance of L1Loss with default weight and batch axis.
        Returns:
        a new instance of L1Loss
      • l1Loss

        public static L1Loss l1Loss​(java.lang.String name)
        Returns a new instance of L1Loss with default weight and batch axis.
        Parameters:
        name - the name of the loss
        Returns:
        a new instance of L1Loss
      • l1Loss

        public static L1Loss l1Loss​(java.lang.String name,
                                    float weight)
        Returns a new instance of L1Loss with given weight and batch axis.
        Parameters:
        name - the name of the loss
        weight - the weight to apply on loss value, default 1
        Returns:
        a new instance of L1Loss
      • l2Loss

        public static L2Loss l2Loss()
        Returns a new instance of L2Loss with default weight and batch axis.
        Returns:
        a new instance of L2Loss
      • l2Loss

        public static L2Loss l2Loss​(java.lang.String name)
        Returns a new instance of L2Loss with default weight and batch axis.
        Parameters:
        name - the name of the loss
        Returns:
        a new instance of L2Loss
      • l2Loss

        public static L2Loss l2Loss​(java.lang.String name,
                                    float weight)
        Returns a new instance of L2Loss with given weight and batch axis.
        Parameters:
        name - the name of the loss
        weight - the weight to apply on loss value, default 1
        Returns:
        a new instance of L2Loss
      • sigmoidBinaryCrossEntropyLoss

        public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss​(java.lang.String name,
                                                                                  float weight,
                                                                                  boolean fromSigmoid)
        Returns a new instance of SigmoidBinaryCrossEntropyLoss with the given arguments.
        Parameters:
        name - the name of the loss
        weight - the weight to apply on the loss value, default 1
        fromSigmoid - whether the input is from the output of sigmoid, default false
        Returns:
        a new instance of SigmoidBinaryCrossEntropyLoss
      • softmaxCrossEntropyLoss

        public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss​(java.lang.String name,
                                                                      float weight,
                                                                      int classAxis,
                                                                      boolean sparseLabel,
                                                                      boolean fromLogit)
        Returns a new instance of SoftmaxCrossEntropyLoss with the given arguments.
        Parameters:
        name - the name of the loss
        weight - the weight to apply on the loss value, default 1
        classAxis - the axis that represents the class probabilities, default -1
        sparseLabel - whether labels are integer array or probabilities, default true
        fromLogit - whether labels are log probabilities or un-normalized numbers
        Returns:
        a new instance of SoftmaxCrossEntropyLoss
      • maskedSoftmaxCrossEntropyLoss

        public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss​(java.lang.String name,
                                                                                  float weight,
                                                                                  int classAxis,
                                                                                  boolean sparseLabel,
                                                                                  boolean fromLogit)
        Returns a new instance of MaskedSoftmaxCrossEntropyLoss with the given arguments.
        Parameters:
        name - the name of the loss
        weight - the weight to apply on the loss value, default 1
        classAxis - the axis that represents the class probabilities, default -1
        sparseLabel - whether labels are integer array or probabilities, default true
        fromLogit - whether labels are log probabilities or un-normalized numbers
        Returns:
        a new instance of MaskedSoftmaxCrossEntropyLoss
      • hingeLoss

        public static HingeLoss hingeLoss()
        Returns a new instance of HingeLoss with default arguments.
        Returns:
        a new instance of HingeLoss
      • hingeLoss

        public static HingeLoss hingeLoss​(java.lang.String name)
        Returns a new instance of HingeLoss with default arguments.
        Parameters:
        name - the name of the loss
        Returns:
        a new instance of HingeLoss
      • hingeLoss

        public static HingeLoss hingeLoss​(java.lang.String name,
                                          int margin,
                                          float weight)
        Returns a new instance of HingeLoss with the given arguments.
        Parameters:
        name - the name of the loss
        margin - the margin in hinge loss. Defaults to 1.0
        weight - the weight to apply on loss value, default 1
        Returns:
        a new instance of HingeLoss
      • l1WeightedDecay

        public static L1WeightDecay l1WeightedDecay​(NDList parameters)
        Returns a new instance of L1WeightDecay with default weight and name.
        Parameters:
        parameters - holds the model weights that will be penalized
        Returns:
        a new instance of L1WeightDecay
      • l1WeightedDecay

        public static L1WeightDecay l1WeightedDecay​(java.lang.String name,
                                                    NDList parameters)
        Returns a new instance of L1WeightDecay with default weight.
        Parameters:
        name - the name of the weight decay
        parameters - holds the model weights that will be penalized
        Returns:
        a new instance of L1WeightDecay
      • l1WeightedDecay

        public static L1WeightDecay l1WeightedDecay​(java.lang.String name,
                                                    float weight,
                                                    NDList parameters)
        Returns a new instance of L1WeightDecay.
        Parameters:
        name - the name of the weight decay
        weight - the weight to apply on weight decay value, default 1
        parameters - holds the model weights that will be penalized
        Returns:
        a new instance of L1WeightDecay
      • l2WeightedDecay

        public static L2WeightDecay l2WeightedDecay​(NDList parameters)
        Returns a new instance of L2WeightDecay with default weight and name.
        Parameters:
        parameters - holds the model weights that will be penalized
        Returns:
        a new instance of L2WeightDecay
      • l2WeightedDecay

        public static L2WeightDecay l2WeightedDecay​(java.lang.String name,
                                                    NDList parameters)
        Returns a new instance of L2WeightDecay with default weight.
        Parameters:
        name - the name of the weight decay
        parameters - holds the model weights that will be penalized
        Returns:
        a new instance of L2WeightDecay
      • l2WeightedDecay

        public static L2WeightDecay l2WeightedDecay​(java.lang.String name,
                                                    float weight,
                                                    NDList parameters)
        Returns a new instance of L2WeightDecay.
        Parameters:
        name - the name of the weight decay
        weight - the weight to apply on weight decay value, default 1
        parameters - holds the model weights that will be penalized
        Returns:
        a new instance of L2WeightDecay
      • elasticNetWeightedDecay

        public static ElasticNetWeightDecay elasticNetWeightedDecay​(java.lang.String name,
                                                                    NDList parameters)
        Returns a new instance of ElasticNetWeightDecay with default weight.
        Parameters:
        name - the name of the weight decay
        parameters - holds the model weights that will be penalized
        Returns:
        a new instance of ElasticNetWeightDecay
      • elasticNetWeightedDecay

        public static ElasticNetWeightDecay elasticNetWeightedDecay​(java.lang.String name,
                                                                    float weight,
                                                                    NDList parameters)
        Returns a new instance of ElasticNetWeightDecay.
        Parameters:
        name - the name of the weight decay
        weight - the weight to apply on weight decay values, default 1
        parameters - holds the model weights that will be penalized
        Returns:
        a new instance of ElasticNetWeightDecay
      • elasticNetWeightedDecay

        public static ElasticNetWeightDecay elasticNetWeightedDecay​(java.lang.String name,
                                                                    float weight1,
                                                                    float weight2,
                                                                    NDList parameters)
        Returns a new instance of ElasticNetWeightDecay.
        Parameters:
        name - the name of the weight decay
        weight1 - the weight to apply on weight decay L1 value, default 1
        weight2 - the weight to apply on weight decay L2 value, default 1
        parameters - holds the model weights that will be penalized
        Returns:
        a new instance of ElasticNetWeightDecay
      • addAccumulator

        public void addAccumulator​(java.lang.String key)
        Adds an accumulator for the results of the evaluation with the given key.
        Specified by:
        addAccumulator in class Evaluator
        Parameters:
        key - the key for the new accumulator
      • updateAccumulator

        public void updateAccumulator​(java.lang.String key,
                                      NDList labels,
                                      NDList predictions)
        Updates the evaluator with the given key based on a NDList of labels and predictions.

        This is a synchronized operation. You should only call it at the end of a batch or epoch.

        Specified by:
        updateAccumulator in class Evaluator
        Parameters:
        key - the key of the accumulator to update
        labels - a NDList of labels
        predictions - a NDList of predictions
      • resetAccumulator

        public void resetAccumulator​(java.lang.String key)
        Resets the evaluator value with the given key.
        Specified by:
        resetAccumulator in class Evaluator
        Parameters:
        key - the key of the accumulator to reset
      • getAccumulator

        public float getAccumulator​(java.lang.String key)
        Returns the accumulated evaluator value.
        Specified by:
        getAccumulator in class Evaluator
        Parameters:
        key - the key of the accumulator to get
        Returns:
        the accumulated value