public abstract class AbstractAccuracy extends Evaluator
Accuracy
is an Evaluator
that computes the accuracy score.
The accuracy score is defined as \(accuracy(y, \hat{y}) = \frac{1}{n}\sum_{i=0}^{n-1}1(\hat{y_i} == y_i)\)
Modifier and Type | Field and Description |
---|---|
protected int |
axis |
protected java.util.Map<java.lang.String,java.lang.Long> |
correctInstances |
protected int |
index |
totalInstances
Constructor and Description |
---|
AbstractAccuracy(java.lang.String name,
int index)
Creates an accuracy evaluator that computes accuracy across axis 1 along given index.
|
AbstractAccuracy(java.lang.String name,
int index,
int axis)
Creates an accuracy evaluator.
|
Modifier and Type | Method and Description |
---|---|
protected abstract ai.djl.util.Pair<java.lang.Long,NDArray> |
accuracyHelper(NDList labels,
NDList predictions)
A helper for classes extending
AbstractAccuracy . |
void |
addAccumulator(java.lang.String key)
Adds an accumulator for the results of the evaluation with the given key.
|
NDArray |
evaluate(NDList labels,
NDList predictions)
Calculates the evaluation between the labels and the predictions.
|
float |
getAccumulator(java.lang.String key)
Returns the accumulated evaluator value.
|
void |
resetAccumulator(java.lang.String key)
Resets the evaluator value with the given key.
|
void |
updateAccumulator(java.lang.String key,
NDList labels,
NDList predictions)
Updates the evaluator with the given key based on a
NDList of labels and predictions. |
checkLabelShapes, checkLabelShapes, getName
protected java.util.Map<java.lang.String,java.lang.Long> correctInstances
protected int axis
protected int index
public AbstractAccuracy(java.lang.String name, int index)
name
- the name of the evaluator, default is "Accuracy"index
- the index of the NDArray in labels to compute accuracy forpublic AbstractAccuracy(java.lang.String name, int index, int axis)
name
- the name of the evaluator, default is "Accuracy"index
- the index of the NDArray in labels to compute accuracy foraxis
- the axis that represent classes in prediction, default 1protected abstract ai.djl.util.Pair<java.lang.Long,NDArray> accuracyHelper(NDList labels, NDList predictions)
AbstractAccuracy
.labels
- the labels to get accuracy forpredictions
- the predictions to get accuracy forpublic NDArray evaluate(NDList labels, NDList predictions)
public void addAccumulator(java.lang.String key)
addAccumulator
in class Evaluator
key
- the key for the new accumulatorpublic void updateAccumulator(java.lang.String key, NDList labels, NDList predictions)
NDList
of labels and predictions.
This is a synchronized operation. You should only call it at the end of a batch or epoch.
updateAccumulator
in class Evaluator
key
- the key of the accumulator to updatelabels
- a NDList
of labelspredictions
- a NDList
of predictionspublic void resetAccumulator(java.lang.String key)
resetAccumulator
in class Evaluator
key
- the key of the accumulator to resetpublic float getAccumulator(java.lang.String key)
getAccumulator
in class Evaluator
key
- the key of the accumulator to get