Class TabNetClassificationLoss


  • public final class TabNetClassificationLoss
    extends Loss
    Calculates the loss for tabNet in Classification tasks.

    Actually, tabNet is not only used for Supervised Learning, it's also widely used in unsupervised learning. For unsupervised learning, it should come from the decoder(aka attentionTransformer of tabNet)

    • Constructor Detail

      • TabNetClassificationLoss

        public TabNetClassificationLoss()
        Calculates the loss of a TabNet instance for regression tasks.
      • TabNetClassificationLoss

        public TabNetClassificationLoss​(java.lang.String name)
        Calculates the loss of a TabNet instance for regression tasks.
        Parameters:
        name - the name of the loss function
    • Method Detail

      • evaluate

        public NDArray evaluate​(NDList labels,
                                NDList predictions)
        Calculates the evaluation between the labels and the predictions.
        Specified by:
        evaluate in class Evaluator
        Parameters:
        labels - the correct values
        predictions - the predicted values
        Returns:
        the evaluation result