Package ai.djl.nn.transformer
Class BertNextSentenceLoss
java.lang.Object
ai.djl.training.evaluator.Evaluator
ai.djl.training.loss.Loss
ai.djl.nn.transformer.BertNextSentenceLoss
Calculates the loss for the next sentence prediction task.
-
Field Summary
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
Constructor Summary
ConstructorsConstructorDescriptionBertNextSentenceLoss
(int labelIdx, int nextSentencePredictionIdx) Creates a new bert next sentence loss. -
Method Summary
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator, updateAccumulators
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
Constructor Details
-
BertNextSentenceLoss
public BertNextSentenceLoss(int labelIdx, int nextSentencePredictionIdx) Creates a new bert next sentence loss.- Parameters:
labelIdx
- index of the next sentence labelsnextSentencePredictionIdx
- index of the next sentence prediction in the bert output
-
-
Method Details
-
evaluate
Calculates the evaluation between the labels and the predictions. -
accuracy
Calculates the fraction of correct predictions.- Parameters:
labels
- the labels with the correct predictionspredictions
- the bert pretraining model output- Returns:
- the fraction of correct predictions.
-