Package ai.djl.nn.transformer
Class BertMaskedLanguageModelLoss
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- ai.djl.nn.transformer.BertMaskedLanguageModelLoss
-
public class BertMaskedLanguageModelLoss extends Loss
The loss for the bert masked language model task.
-
-
Field Summary
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description BertMaskedLanguageModelLoss(int labelIdx, int maskIdx, int logProbsIdx)
Creates an MLM loss.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArray
accuracy(NDList labels, NDList predictions)
Calculates the percentage of correctly predicted masked tokens.NDArray
evaluate(NDList labels, NDList predictions)
Calculates the evaluation between the labels and the predictions.-
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator
-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-