Class BertMaskedLanguageModelLoss


  • public class BertMaskedLanguageModelLoss
    extends Loss
    The loss for the bert masked language model task.
    • Constructor Detail

      • BertMaskedLanguageModelLoss

        public BertMaskedLanguageModelLoss​(int labelIdx,
                                           int maskIdx,
                                           int logProbsIdx)
        Creates an MLM loss.
        Parameters:
        labelIdx - index of labels
        maskIdx - index of mask
        logProbsIdx - index of log probs
    • Method Detail

      • evaluate

        public NDArray evaluate​(NDList labels,
                                NDList predictions)
        Calculates the evaluation between the labels and the predictions.
        Specified by:
        evaluate in class Evaluator
        Parameters:
        labels - the correct values
        predictions - the predicted values
        Returns:
        the evaluation result
      • accuracy

        public NDArray accuracy​(NDList labels,
                                NDList predictions)
        Calculates the percentage of correctly predicted masked tokens.
        Parameters:
        labels - expected tokens and mask
        predictions - prediction of a bert model
        Returns:
        the percentage of correctly predicted masked tokens