Package ai.djl.training.evaluator
Class TopKAccuracy
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.evaluator.AbstractAccuracy
-
- ai.djl.training.evaluator.TopKAccuracy
-
public class TopKAccuracy extends AbstractAccuracy
TopKAccuracy
is anEvaluator
that computes the accuracy of the top k predictions.TopKAccuracy
differs fromAbstractAccuracy
in that it considers the prediction to be `True` as long as the ground truth label is in the top K predicated labels. If `top_k = 1`, thenTopKAccuracy
is identical toAccuracy
.
-
-
Field Summary
-
Fields inherited from class ai.djl.training.evaluator.AbstractAccuracy
axis, correctInstances
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description TopKAccuracy(int topK)
Creates an instance ofTopKAccuracy
evaluator that computes topK accuracy across axis 1 along the 0th index.TopKAccuracy(java.lang.String name, int topK)
Creates aTopKAccuracy
instance.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description protected ai.djl.util.Pair<java.lang.Long,NDArray>
accuracyHelper(NDList labels, NDList predictions)
A helper for classes extendingAbstractAccuracy
.-
Methods inherited from class ai.djl.training.evaluator.AbstractAccuracy
addAccumulator, evaluate, getAccumulator, resetAccumulator, updateAccumulator
-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-
-
-
Constructor Detail
-
TopKAccuracy
public TopKAccuracy(java.lang.String name, int topK)
Creates aTopKAccuracy
instance.- Parameters:
name
- the accuracy name, default "Top_K_Accuracy"topK
- the value of K
-
TopKAccuracy
public TopKAccuracy(int topK)
Creates an instance ofTopKAccuracy
evaluator that computes topK accuracy across axis 1 along the 0th index.- Parameters:
topK
- the value of K
-
-
Method Detail
-
accuracyHelper
protected ai.djl.util.Pair<java.lang.Long,NDArray> accuracyHelper(NDList labels, NDList predictions)
A helper for classes extendingAbstractAccuracy
.- Specified by:
accuracyHelper
in classAbstractAccuracy
- Parameters:
labels
- the labels to get accuracy forpredictions
- the predictions to get accuracy for- Returns:
- a pair(number of total values, ndarray int of correct values)
-
-