Package ai.djl.training.evaluator
Class AbstractAccuracy
- java.lang.Object
-
- ai.djl.training.evaluator.Evaluator
-
- ai.djl.training.evaluator.AbstractAccuracy
-
- Direct Known Subclasses:
Accuracy
,BinaryAccuracy
,SingleShotDetectionAccuracy
,TopKAccuracy
public abstract class AbstractAccuracy extends Evaluator
Accuracy
is anEvaluator
that computes the accuracy score.The accuracy score is defined as \(accuracy(y, \hat{y}) = \frac{1}{n}\sum_{i=0}^{n-1}1(\hat{y_i} == y_i)\)
-
-
Field Summary
Fields Modifier and Type Field Description protected int
axis
protected java.util.Map<java.lang.String,java.lang.Long>
correctInstances
-
Fields inherited from class ai.djl.training.evaluator.Evaluator
totalInstances
-
-
Constructor Summary
Constructors Constructor Description AbstractAccuracy(java.lang.String name)
Creates an accuracy evaluator that computes accuracy across axis 1.AbstractAccuracy(java.lang.String name, int axis)
Creates an accuracy evaluator.
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description protected abstract ai.djl.util.Pair<java.lang.Long,NDArray>
accuracyHelper(NDList labels, NDList predictions)
A helper for classes extendingAbstractAccuracy
.void
addAccumulator(java.lang.String key)
Adds an accumulator for the results of the evaluation with the given key.NDArray
evaluate(NDList labels, NDList predictions)
Calculates the evaluation between the labels and the predictions.float
getAccumulator(java.lang.String key)
Returns the accumulated evaluator value.void
resetAccumulator(java.lang.String key)
Resets the evaluator value with the given key.void
updateAccumulator(java.lang.String key, NDList labels, NDList predictions)
Updates the evaluator with the given key based on aNDList
of labels and predictions.-
Methods inherited from class ai.djl.training.evaluator.Evaluator
checkLabelShapes, checkLabelShapes, getName
-
-
-
-
Constructor Detail
-
AbstractAccuracy
public AbstractAccuracy(java.lang.String name)
Creates an accuracy evaluator that computes accuracy across axis 1.- Parameters:
name
- the name of the evaluator, default is "Accuracy"
-
AbstractAccuracy
public AbstractAccuracy(java.lang.String name, int axis)
Creates an accuracy evaluator.- Parameters:
name
- the name of the evaluator, default is "Accuracy"axis
- the axis that represent classes in prediction, default 1
-
-
Method Detail
-
accuracyHelper
protected abstract ai.djl.util.Pair<java.lang.Long,NDArray> accuracyHelper(NDList labels, NDList predictions)
A helper for classes extendingAbstractAccuracy
.- Parameters:
labels
- the labels to get accuracy forpredictions
- the predictions to get accuracy for- Returns:
- a pair(number of total values, ndarray int of correct values)
-
evaluate
public NDArray evaluate(NDList labels, NDList predictions)
Calculates the evaluation between the labels and the predictions.
-
addAccumulator
public void addAccumulator(java.lang.String key)
Adds an accumulator for the results of the evaluation with the given key.- Specified by:
addAccumulator
in classEvaluator
- Parameters:
key
- the key for the new accumulator
-
updateAccumulator
public void updateAccumulator(java.lang.String key, NDList labels, NDList predictions)
Updates the evaluator with the given key based on aNDList
of labels and predictions.This is a synchronized operation. You should only call it at the end of a batch or epoch.
- Specified by:
updateAccumulator
in classEvaluator
- Parameters:
key
- the key of the accumulator to updatelabels
- aNDList
of labelspredictions
- aNDList
of predictions
-
resetAccumulator
public void resetAccumulator(java.lang.String key)
Resets the evaluator value with the given key.- Specified by:
resetAccumulator
in classEvaluator
- Parameters:
key
- the key of the accumulator to reset
-
getAccumulator
public float getAccumulator(java.lang.String key)
Returns the accumulated evaluator value.- Specified by:
getAccumulator
in classEvaluator
- Parameters:
key
- the key of the accumulator to get- Returns:
- the accumulated value
-
-