Package ai.djl.training.loss
Class TabNetClassificationLoss
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.loss.Loss
-
- ai.djl.training.loss.TabNetClassificationLoss
-
public final class TabNetClassificationLoss extends Loss
Calculates the loss for tabNet in Classification tasks.Actually, tabNet is not only used for Supervised Learning, it's also widely used in unsupervised learning. For unsupervised learning, it should come from the decoder(aka attentionTransformer of tabNet)
-
-
Field Summary
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description TabNetClassificationLoss()
Calculates the loss of a TabNet instance for regression tasks.TabNetClassificationLoss(java.lang.String name)
Calculates the loss of a TabNet instance for regression tasks.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArray
evaluate(NDList labels, NDList predictions)
Calculates the evaluation between the labels and the predictions.-
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator
-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-
-
-
Constructor Detail
-
TabNetClassificationLoss
public TabNetClassificationLoss()
Calculates the loss of a TabNet instance for regression tasks.
-
TabNetClassificationLoss
public TabNetClassificationLoss(java.lang.String name)
Calculates the loss of a TabNet instance for regression tasks.- Parameters:
name
- the name of the loss function
-
-