Class BertMaskedLanguageModelLoss

java.lang.Object
ai.djl.training.evaluator.Evaluator
ai.djl.training.loss.Loss
ai.djl.nn.transformer.BertMaskedLanguageModelLoss

public class BertMaskedLanguageModelLoss extends Loss
The loss for the bert masked language model task.
  • Constructor Details

    • BertMaskedLanguageModelLoss

      public BertMaskedLanguageModelLoss(int labelIdx, int maskIdx, int logProbsIdx)
      Creates an MLM loss.
      Parameters:
      labelIdx - index of labels
      maskIdx - index of mask
      logProbsIdx - index of log probs
  • Method Details

    • evaluate

      public NDArray evaluate(NDList labels, NDList predictions)
      Calculates the evaluation between the labels and the predictions.
      Specified by:
      evaluate in class Evaluator
      Parameters:
      labels - the correct values
      predictions - the predicted values
      Returns:
      the evaluation result
    • accuracy

      public NDArray accuracy(NDList labels, NDList predictions)
      Calculates the percentage of correctly predicted masked tokens.
      Parameters:
      labels - expected tokens and mask
      predictions - prediction of a bert model
      Returns:
      the percentage of correctly predicted masked tokens