Class BertNextSentenceLoss


public class BertNextSentenceLoss extends Loss
Calculates the loss for the next sentence prediction task.
  • Constructor Details

    • BertNextSentenceLoss

      public BertNextSentenceLoss(int labelIdx, int nextSentencePredictionIdx)
      Creates a new bert next sentence loss.
      Parameters:
      labelIdx - index of the next sentence labels
      nextSentencePredictionIdx - index of the next sentence prediction in the bert output
  • Method Details

    • evaluate

      public NDArray evaluate(NDList labels, NDList predictions)
      Calculates the evaluation between the labels and the predictions.
      Specified by:
      evaluate in class Evaluator
      Parameters:
      labels - the correct values
      predictions - the predicted values
      Returns:
      the evaluation result
    • accuracy

      public NDArray accuracy(NDList labels, NDList predictions)
      Calculates the fraction of correct predictions.
      Parameters:
      labels - the labels with the correct predictions
      predictions - the bert pretraining model output
      Returns:
      the fraction of correct predictions.