Package ai.djl.training.loss
Class TabNetClassificationLoss
java.lang.Object
ai.djl.training.evaluator.Evaluator
ai.djl.training.loss.Loss
ai.djl.training.loss.TabNetClassificationLoss
Calculates the loss for tabNet in Classification tasks.
Actually, tabNet is not only used for Supervised Learning, it's also widely used in unsupervised learning. For unsupervised learning, it should come from the decoder(aka attentionTransformer of tabNet)
-
Field Summary
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
Constructor Summary
ConstructorsConstructorDescriptionCalculates the loss of a TabNet instance for regression tasks.Calculates the loss of a TabNet instance for regression tasks. -
Method Summary
Methods inherited from class ai.djl.training.loss.Loss
addAccumulator, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, elasticNetWeightedDecay, getAccumulator, hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l1WeightedDecay, l1WeightedDecay, l1WeightedDecay, l2Loss, l2Loss, l2Loss, l2WeightedDecay, l2WeightedDecay, l2WeightedDecay, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, quantileL1Loss, quantileL1Loss, resetAccumulator, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, updateAccumulator, updateAccumulators
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
Constructor Details
-
TabNetClassificationLoss
public TabNetClassificationLoss()Calculates the loss of a TabNet instance for regression tasks. -
TabNetClassificationLoss
Calculates the loss of a TabNet instance for regression tasks.- Parameters:
name
- the name of the loss function
-
-
Method Details