Package ai.djl.nn.transformer
Class BertMaskedLanguageModelLoss
java.lang.Object
ai.djl.training.evaluator.Evaluator
ai.djl.training.loss.Loss
ai.djl.nn.transformer.BertMaskedLanguageModelLoss
The loss for the bert masked language model task.
-
Field Summary
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
Constructor Summary
ConstructorsConstructorDescriptionBertMaskedLanguageModelLoss
(int labelIdx, int maskIdx, int logProbsIdx) Creates an MLM loss. -
Method Summary
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator, updateAccumulators
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
Constructor Details
-
BertMaskedLanguageModelLoss
public BertMaskedLanguageModelLoss(int labelIdx, int maskIdx, int logProbsIdx) Creates an MLM loss.- Parameters:
labelIdx
- index of labelsmaskIdx
- index of masklogProbsIdx
- index of log probs
-
-
Method Details
-
evaluate
Calculates the evaluation between the labels and the predictions. -
accuracy
Calculates the percentage of correctly predicted masked tokens.- Parameters:
labels
- expected tokens and maskpredictions
- prediction of a bert model- Returns:
- the percentage of correctly predicted masked tokens
-