public abstract class Loss extends Evaluator
Although all evaluators can be used to measure the performance of a model, not all of them are suited to being used by an optimizer. Loss functions are usually non-negative where a larger loss represents worse performance. They are also real-valued to accurately compare models.
When creating a loss function, you should avoid having the loss depend on the batch size. For
example, if you have a loss per item in a batch and sum those losses, your loss would be numItemsInBatch*avgLoss
. Instead, you should take the mean of those losses to reduce out the
batchSize factor. Otherwise, it can make it difficult to tune the learning rate since any change
in the batch size would throw it off. If you have a variable batch size, it would be even more
difficult.
For more details about the class internals, see Evaluator
.
totalInstances
Constructor and Description |
---|
Loss(java.lang.String name)
Base class for metric with abstract update methods.
|
Modifier and Type | Method and Description |
---|---|
void |
addAccumulator(java.lang.String key)
Adds an accumulator for the results of the evaluation with the given key.
|
static ElasticNetWeightDecay |
elasticNetWeightedDecay(NDList parameters)
Returns a new instance of
ElasticNetWeightDecay with default weight and name. |
static ElasticNetWeightDecay |
elasticNetWeightedDecay(java.lang.String name,
float weight1,
float weight2,
NDList parameters)
Returns a new instance of
ElasticNetWeightDecay . |
static ElasticNetWeightDecay |
elasticNetWeightedDecay(java.lang.String name,
float weight,
NDList parameters)
Returns a new instance of
ElasticNetWeightDecay . |
static ElasticNetWeightDecay |
elasticNetWeightedDecay(java.lang.String name,
NDList parameters)
Returns a new instance of
ElasticNetWeightDecay with default weight. |
float |
getAccumulator(java.lang.String key)
Returns the accumulated evaluator value.
|
static HingeLoss |
hingeLoss()
Returns a new instance of
HingeLoss with default arguments. |
static HingeLoss |
hingeLoss(java.lang.String name)
Returns a new instance of
HingeLoss with default arguments. |
static HingeLoss |
hingeLoss(java.lang.String name,
int margin,
float weight)
Returns a new instance of
HingeLoss with the given arguments. |
static L1Loss |
l1Loss()
Returns a new instance of
L1Loss with default weight and batch axis. |
static L1Loss |
l1Loss(java.lang.String name)
Returns a new instance of
L1Loss with default weight and batch axis. |
static L1Loss |
l1Loss(java.lang.String name,
float weight)
Returns a new instance of
L1Loss with given weight and batch axis. |
static L1WeightDecay |
l1WeightedDecay(NDList parameters)
Returns a new instance of
L1WeightDecay with default weight and name. |
static L1WeightDecay |
l1WeightedDecay(java.lang.String name,
float weight,
NDList parameters)
Returns a new instance of
L1WeightDecay . |
static L1WeightDecay |
l1WeightedDecay(java.lang.String name,
NDList parameters)
Returns a new instance of
L1WeightDecay with default weight. |
static L2Loss |
l2Loss()
Returns a new instance of
L2Loss with default weight and batch axis. |
static L2Loss |
l2Loss(java.lang.String name)
Returns a new instance of
L2Loss with default weight and batch axis. |
static L2Loss |
l2Loss(java.lang.String name,
float weight)
Returns a new instance of
L2Loss with given weight and batch axis. |
static L2WeightDecay |
l2WeightedDecay(NDList parameters)
Returns a new instance of
L2WeightDecay with default weight and name. |
static L2WeightDecay |
l2WeightedDecay(java.lang.String name,
float weight,
NDList parameters)
Returns a new instance of
L2WeightDecay . |
static L2WeightDecay |
l2WeightedDecay(java.lang.String name,
NDList parameters)
Returns a new instance of
L2WeightDecay with default weight. |
static MaskedSoftmaxCrossEntropyLoss |
maskedSoftmaxCrossEntropyLoss()
Returns a new instance of
MaskedSoftmaxCrossEntropyLoss with default arguments. |
static MaskedSoftmaxCrossEntropyLoss |
maskedSoftmaxCrossEntropyLoss(java.lang.String name)
Returns a new instance of
MaskedSoftmaxCrossEntropyLoss with default arguments. |
static MaskedSoftmaxCrossEntropyLoss |
maskedSoftmaxCrossEntropyLoss(java.lang.String name,
float weight,
int classAxis,
boolean sparseLabel,
boolean fromLogit)
Returns a new instance of
MaskedSoftmaxCrossEntropyLoss with the given arguments. |
void |
resetAccumulator(java.lang.String key)
Resets the evaluator value with the given key.
|
static SigmoidBinaryCrossEntropyLoss |
sigmoidBinaryCrossEntropyLoss()
Returns a new instance of
SigmoidBinaryCrossEntropyLoss with default arguments. |
static SigmoidBinaryCrossEntropyLoss |
sigmoidBinaryCrossEntropyLoss(java.lang.String name)
Returns a new instance of
SigmoidBinaryCrossEntropyLoss with default arguments. |
static SigmoidBinaryCrossEntropyLoss |
sigmoidBinaryCrossEntropyLoss(java.lang.String name,
float weight,
boolean fromSigmoid)
Returns a new instance of
SigmoidBinaryCrossEntropyLoss with the given arguments. |
static SoftmaxCrossEntropyLoss |
softmaxCrossEntropyLoss()
Returns a new instance of
SoftmaxCrossEntropyLoss with default arguments. |
static SoftmaxCrossEntropyLoss |
softmaxCrossEntropyLoss(java.lang.String name)
Returns a new instance of
SoftmaxCrossEntropyLoss with default arguments. |
static SoftmaxCrossEntropyLoss |
softmaxCrossEntropyLoss(java.lang.String name,
float weight,
int classAxis,
boolean sparseLabel,
boolean fromLogit)
Returns a new instance of
SoftmaxCrossEntropyLoss with the given arguments. |
void |
updateAccumulator(java.lang.String key,
NDList labels,
NDList predictions)
Updates the evaluator with the given key based on a
NDList of labels and predictions. |
checkLabelShapes, checkLabelShapes, evaluate, getName
public Loss(java.lang.String name)
name
- The display name of the Losspublic static L1Loss l1Loss()
L1Loss
with default weight and batch axis.L1Loss
public static L1Loss l1Loss(java.lang.String name)
L1Loss
with default weight and batch axis.name
- the name of the lossL1Loss
public static L1Loss l1Loss(java.lang.String name, float weight)
L1Loss
with given weight and batch axis.name
- the name of the lossweight
- the weight to apply on loss value, default 1L1Loss
public static L2Loss l2Loss()
L2Loss
with default weight and batch axis.L2Loss
public static L2Loss l2Loss(java.lang.String name)
L2Loss
with default weight and batch axis.name
- the name of the lossL2Loss
public static L2Loss l2Loss(java.lang.String name, float weight)
L2Loss
with given weight and batch axis.name
- the name of the lossweight
- the weight to apply on loss value, default 1L2Loss
public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss()
SigmoidBinaryCrossEntropyLoss
with default arguments.SigmoidBinaryCrossEntropyLoss
public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(java.lang.String name)
SigmoidBinaryCrossEntropyLoss
with default arguments.name
- the name of the lossSigmoidBinaryCrossEntropyLoss
public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(java.lang.String name, float weight, boolean fromSigmoid)
SigmoidBinaryCrossEntropyLoss
with the given arguments.name
- the name of the lossweight
- the weight to apply on the loss value, default 1fromSigmoid
- whether the input is from the output of sigmoid, default falseSigmoidBinaryCrossEntropyLoss
public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss()
SoftmaxCrossEntropyLoss
with default arguments.SoftmaxCrossEntropyLoss
public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(java.lang.String name)
SoftmaxCrossEntropyLoss
with default arguments.name
- the name of the lossSoftmaxCrossEntropyLoss
public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(java.lang.String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit)
SoftmaxCrossEntropyLoss
with the given arguments.name
- the name of the lossweight
- the weight to apply on the loss value, default 1classAxis
- the axis that represents the class probabilities, default -1sparseLabel
- whether labels are integer array or probabilities, default truefromLogit
- whether labels are log probabilities or un-normalized numbersSoftmaxCrossEntropyLoss
public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss()
MaskedSoftmaxCrossEntropyLoss
with default arguments.MaskedSoftmaxCrossEntropyLoss
public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss(java.lang.String name)
MaskedSoftmaxCrossEntropyLoss
with default arguments.name
- the name of the lossMaskedSoftmaxCrossEntropyLoss
public static MaskedSoftmaxCrossEntropyLoss maskedSoftmaxCrossEntropyLoss(java.lang.String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit)
MaskedSoftmaxCrossEntropyLoss
with the given arguments.name
- the name of the lossweight
- the weight to apply on the loss value, default 1classAxis
- the axis that represents the class probabilities, default -1sparseLabel
- whether labels are integer array or probabilities, default truefromLogit
- whether labels are log probabilities or un-normalized numbersMaskedSoftmaxCrossEntropyLoss
public static HingeLoss hingeLoss()
HingeLoss
with default arguments.HingeLoss
public static HingeLoss hingeLoss(java.lang.String name)
HingeLoss
with default arguments.name
- the name of the lossHingeLoss
public static HingeLoss hingeLoss(java.lang.String name, int margin, float weight)
HingeLoss
with the given arguments.name
- the name of the lossmargin
- the margin in hinge loss. Defaults to 1.0weight
- the weight to apply on loss value, default 1HingeLoss
public static L1WeightDecay l1WeightedDecay(NDList parameters)
L1WeightDecay
with default weight and name.parameters
- holds the model weights that will be penalizedL1WeightDecay
public static L1WeightDecay l1WeightedDecay(java.lang.String name, NDList parameters)
L1WeightDecay
with default weight.name
- the name of the weight decayparameters
- holds the model weights that will be penalizedL1WeightDecay
public static L1WeightDecay l1WeightedDecay(java.lang.String name, float weight, NDList parameters)
L1WeightDecay
.name
- the name of the weight decayweight
- the weight to apply on weight decay value, default 1parameters
- holds the model weights that will be penalizedL1WeightDecay
public static L2WeightDecay l2WeightedDecay(NDList parameters)
L2WeightDecay
with default weight and name.parameters
- holds the model weights that will be penalizedL2WeightDecay
public static L2WeightDecay l2WeightedDecay(java.lang.String name, NDList parameters)
L2WeightDecay
with default weight.name
- the name of the weight decayparameters
- holds the model weights that will be penalizedL2WeightDecay
public static L2WeightDecay l2WeightedDecay(java.lang.String name, float weight, NDList parameters)
L2WeightDecay
.name
- the name of the weight decayweight
- the weight to apply on weight decay value, default 1parameters
- holds the model weights that will be penalizedL2WeightDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(NDList parameters)
ElasticNetWeightDecay
with default weight and name.parameters
- holds the model weights that will be penalizedElasticNetWeightDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(java.lang.String name, NDList parameters)
ElasticNetWeightDecay
with default weight.name
- the name of the weight decayparameters
- holds the model weights that will be penalizedElasticNetWeightDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(java.lang.String name, float weight, NDList parameters)
ElasticNetWeightDecay
.name
- the name of the weight decayweight
- the weight to apply on weight decay values, default 1parameters
- holds the model weights that will be penalizedElasticNetWeightDecay
public static ElasticNetWeightDecay elasticNetWeightedDecay(java.lang.String name, float weight1, float weight2, NDList parameters)
ElasticNetWeightDecay
.name
- the name of the weight decayweight1
- the weight to apply on weight decay L1 value, default 1weight2
- the weight to apply on weight decay L2 value, default 1parameters
- holds the model weights that will be penalizedElasticNetWeightDecay
public void addAccumulator(java.lang.String key)
addAccumulator
in class Evaluator
key
- the key for the new accumulatorpublic void updateAccumulator(java.lang.String key, NDList labels, NDList predictions)
NDList
of labels and predictions.
This is a synchronized operation. You should only call it at the end of a batch or epoch.
updateAccumulator
in class Evaluator
key
- the key of the accumulator to updatelabels
- a NDList
of labelspredictions
- a NDList
of predictionspublic void resetAccumulator(java.lang.String key)
resetAccumulator
in class Evaluator
key
- the key of the accumulator to resetpublic float getAccumulator(java.lang.String key)
getAccumulator
in class Evaluator
key
- the key of the accumulator to get