Class TabNetClassificationLoss


public final class TabNetClassificationLoss extends Loss
Calculates the loss for tabNet in Classification tasks.

Actually, tabNet is not only used for Supervised Learning, it's also widely used in unsupervised learning. For unsupervised learning, it should come from the decoder(aka attentionTransformer of tabNet)

  • Constructor Details

    • TabNetClassificationLoss

      public TabNetClassificationLoss()
      Calculates the loss of a TabNet instance for regression tasks.
    • TabNetClassificationLoss

      public TabNetClassificationLoss(String name)
      Calculates the loss of a TabNet instance for regression tasks.
      Parameters:
      name - the name of the loss function
  • Method Details

    • evaluate

      public NDArray evaluate(NDList labels, NDList predictions)
      Calculates the evaluation between the labels and the predictions.
      Specified by:
      evaluate in class Evaluator
      Parameters:
      labels - the correct values
      predictions - the predicted values
      Returns:
      the evaluation result