Class BertNextSentenceLoss


  • public class BertNextSentenceLoss
    extends Loss
    Calculates the loss for the next sentence prediction task.
    • Constructor Detail

      • BertNextSentenceLoss

        public BertNextSentenceLoss​(int labelIdx,
                                    int nextSentencePredictionIdx)
        Creates a new bert next sentence loss.
        Parameters:
        labelIdx - index of the next sentence labels
        nextSentencePredictionIdx - index of the next sentence prediction in the bert output
    • Method Detail

      • evaluate

        public NDArray evaluate​(NDList labels,
                                NDList predictions)
        Calculates the evaluation between the labels and the predictions.
        Specified by:
        evaluate in class Evaluator
        Parameters:
        labels - the correct values
        predictions - the predicted values
        Returns:
        the evaluation result
      • accuracy

        public NDArray accuracy​(NDList labels,
                                NDList predictions)
        Calculates the fraction of correct predictions.
        Parameters:
        labels - the labels with the correct predictions
        predictions - the bert pretraining model output
        Returns:
        the fraction of correct predictions.