public class Evaluation extends BaseEvaluation<Evaluation>
Evaluation(List, int)
)Evaluation(double)
(default if not set is
argmax / 0.5)Evaluation(INDArray)
or Evaluation(List, INDArray)
for multi-class f1()
,
precision()
, recall()
etc will report the binary metric for class 1 onlyEvaluation(int, Integer)
or
Evaluation(double, Integer)
or #setBinaryPositiveClass(Integer)
. Then, f1()
,
precision()
, recall()
etc will report the binary metric for class 0 only.f1()
, precision()
, recall()
will report macro-average (of the one-vs-all) binary metrics. Note that you can specify micro vs. macro averaging
using f1(EvaluationAveraging)
and similar methodsModifier and Type | Class and Description |
---|---|
static class |
Evaluation.Metric |
Modifier and Type | Field and Description |
---|---|
protected int |
axis |
protected Double |
binaryDecisionThreshold |
protected Integer |
binaryPositiveClass |
protected ConfusionMatrix<Integer> |
confusion |
protected static int |
CONFUSION_PRINT_MAX_CLASSES |
protected Map<Pair<Integer,Integer>,List<Object>> |
confusionMatrixMetaData |
protected INDArray |
costArray |
protected static double |
DEFAULT_EDGE_VALUE |
protected Counter<Integer> |
falseNegatives |
protected Counter<Integer> |
falsePositives |
protected List<String> |
labelsList |
protected int |
maxWarningClassesToPrint
For stats(): When classes are excluded from precision/recall, what is the maximum number we should print?
If this is set to a high value, the output (potentially thousands of classes) can become unreadable.
|
protected int |
numRowCounter |
protected int |
topN |
protected int |
topNCorrectCount |
protected int |
topNTotalCount |
protected Counter<Integer> |
trueNegatives |
protected Counter<Integer> |
truePositives |
Modifier | Constructor and Description |
---|---|
|
Evaluation() |
|
Evaluation(double binaryDecisionThreshold)
Create an evaluation instance with a custom binary decision threshold.
|
|
Evaluation(double binaryDecisionThreshold,
@NonNull Integer binaryPositiveClass)
Create an evaluation instance with a custom binary decision threshold.
|
|
Evaluation(INDArray costArray)
Created evaluation instance with the specified cost array.
|
|
Evaluation(int numClasses)
The number of classes to account for in the evaluation
|
|
Evaluation(int numClasses,
Integer binaryPositiveClass)
Constructor for specifying the number of classes, and optionally the positive class for binary classification.
|
protected |
Evaluation(int axis,
Integer binaryPositiveClass,
int topN,
List<String> labelsList,
Double binaryDecisionThreshold,
INDArray costArray,
int maxWarningClassesToPrint) |
|
Evaluation(List<String> labels)
The labels to include with the evaluation.
|
|
Evaluation(List<String> labels,
INDArray costArray)
Created evaluation instance with the specified cost array.
|
|
Evaluation(List<String> labels,
int topN)
Constructor to use for top N accuracy
|
|
Evaluation(Map<Integer,String> labels)
Use a map to generate labels
Pass in a label index with the actual label
you want to use for output
|
Modifier and Type | Method and Description |
---|---|
double |
accuracy()
Accuracy:
(TP + TN) / (P + N)
|
void |
addToConfusion(Integer real,
Integer guess)
Adds to the confusion matrix
|
int |
averageF1NumClassesExcluded()
When calculating the (macro) average F1, how many classes are excluded from the average due to
no predictions - i.e., F1 would be calculated from a precision or recall of 0/0
|
int |
averageFBetaNumClassesExcluded()
When calculating the (macro) average FBeta, how many classes are excluded from the average due to
no predictions - i.e., FBeta would be calculated from a precision or recall of 0/0
|
int |
averagePrecisionNumClassesExcluded()
When calculating the (macro) average precision, how many classes are excluded from the average due to
no predictions - i.e., precision would be the edge case of 0/0
|
int |
averageRecallNumClassesExcluded()
When calculating the (macro) average Recall, how many classes are excluded from the average due to
no predictions - i.e., recall would be the edge case of 0/0
|
int |
classCount(Integer clazz)
Returns the number of times the given label
has actually occurred
|
String |
confusionMatrix()
Get the confusion matrix as a String
|
String |
confusionToString()
Get a String representation of the confusion matrix
|
void |
eval(INDArray realOutcomes,
INDArray guesses)
Collects statistics on the real outcomes vs the
guesses.
|
void |
eval(INDArray labels,
INDArray predictions,
INDArray mask,
List<? extends Serializable> recordMetaData)
Evaluate the network, with optional metadata
|
void |
eval(int predictedIdx,
int actualIdx)
Evaluate a single prediction (one prediction at a time)
|
double |
f1()
Calculate the F1 score
F1 score is defined as: TP: true positive FP: False Positive FN: False Negative F1 score: 2 * TP / (2TP + FP + FN) Note: value returned will differ depending on number of classes and settings. 1. |
double |
f1(EvaluationAveraging averaging)
Calculate the average F1 score across all classes, using macro or micro averaging
|
double |
f1(int classLabel)
Calculate f1 score for a given class
|
double |
falseAlarmRate()
False Alarm Rate (FAR) reflects rate of misclassified to classified records
http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
Note: value returned will differ depending on number of classes and settings. 1. |
double |
falseNegativeRate()
False negative rate based on guesses so far
Note: value returned will differ depending on number of classes and settings.
1. |
double |
falseNegativeRate(EvaluationAveraging averaging)
Calculate the average false negative rate for all classes - can specify whether macro or micro averaging should be used
|
double |
falseNegativeRate(Integer classLabel)
Returns the false negative rate for a given label
|
double |
falseNegativeRate(Integer classLabel,
double edgeCase)
Returns the false negative rate for a given label
|
Map<Integer,Integer> |
falseNegatives()
False negatives: correctly rejected
|
double |
falsePositiveRate()
False positive rate based on guesses so far
Note: value returned will differ depending on number of classes and settings. 1. |
double |
falsePositiveRate(EvaluationAveraging averaging)
Calculate the average false positive rate across all classes.
|
double |
falsePositiveRate(int classLabel)
Returns the false positive rate for a given label
|
double |
falsePositiveRate(int classLabel,
double edgeCase)
Returns the false positive rate for a given label
|
Map<Integer,Integer> |
falsePositives()
False positive: wrong guess
|
double |
fBeta(double beta,
EvaluationAveraging averaging)
Calculate the average F_beta score across all classes, using macro or micro averaging
|
double |
fBeta(double beta,
int classLabel)
Calculate the f_beta for a given class, where f_beta is defined as:
(1+beta^2) * (precision * recall) / (beta^2 * precision + recall). F1 is a special case of f_beta, with beta=1.0 |
double |
fBeta(double beta,
int classLabel,
double defaultValue)
Calculate the f_beta for a given class, where f_beta is defined as:
(1+beta^2) * (precision * recall) / (beta^2 * precision + recall). F1 is a special case of f_beta, with beta=1.0 |
static Evaluation |
fromJson(String json) |
static Evaluation |
fromYaml(String yaml) |
int |
getAxis()
Get the axis - see
setAxis(int) for details |
String |
getClassLabel(Integer clazz) |
ConfusionMatrix<Integer> |
getConfusionMatrix()
Returns the confusion matrix variable
|
int |
getNumRowCounter() |
List<Prediction> |
getPredictionByPredictedClass(int predictedClass)
Get a list of predictions, for all data with the specified predicted class, regardless of the actual data
class.
|
List<Prediction> |
getPredictionErrors()
Get a list of prediction errors, on a per-record basis
|
List<Prediction> |
getPredictions(int actualClass,
int predictedClass)
Get a list of predictions in the specified confusion matrix entry (i.e., for the given actua/predicted class pair)
|
List<Prediction> |
getPredictionsByActualClass(int actualClass)
Get a list of predictions, for all data with the specified actual class, regardless of the predicted
class.
|
int |
getTopNCorrectCount()
Return the number of correct predictions according to top N value.
|
int |
getTopNTotalCount()
Return the total number of top N evaluations.
|
double |
getValue(IMetric metric)
Get the value of a given metric for this evaluation.
|
double |
gMeasure(EvaluationAveraging averaging)
Calculates the average G measure for all outputs using micro or macro averaging
|
double |
gMeasure(int output)
Calculate the G-measure for the given output
|
void |
incrementFalseNegatives(Integer classLabel) |
void |
incrementFalsePositives(Integer classLabel) |
void |
incrementTrueNegatives(Integer classLabel) |
void |
incrementTruePositives(Integer classLabel) |
double |
matthewsCorrelation(EvaluationAveraging averaging)
Calculate the average binary Mathews correlation coefficient, using macro or micro averaging.
MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN)) Note: This is NOT the same as the multi-class Matthews correlation coefficient |
double |
matthewsCorrelation(int classIdx)
Calculate the binary Mathews correlation coefficient, for the specified class.
MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN)) |
void |
merge(Evaluation other)
Merge the other evaluation object into this one.
|
Map<Integer,Integer> |
negative()
Total negatives true negatives + false negatives
|
Evaluation |
newInstance()
Get a new instance of this evaluation, with the same configuration but no data.
|
protected int |
numClasses() |
Map<Integer,Integer> |
positive()
Returns all of the positive guesses:
true positive + false negative
|
double |
precision()
Precision based on guesses so far.
Note: value returned will differ depending on number of classes and settings. 1. |
double |
precision(EvaluationAveraging averaging)
Calculate the average precision for all classes.
|
double |
precision(Integer classLabel)
Returns the precision for a given class label
|
double |
precision(Integer classLabel,
double edgeCase)
Returns the precision for a given label
|
double |
recall()
Recall based on guesses so far
Note: value returned will differ depending on number of classes and settings. 1. |
double |
recall(EvaluationAveraging averaging)
Calculate the average recall for all classes - can specify whether macro or micro averaging should be used
NOTE: if any classes have tp=0 and fn=0, (recall=0/0) these are excluded from the average
|
double |
recall(int classLabel)
Returns the recall for a given label
|
double |
recall(int classLabel,
double edgeCase)
Returns the recall for a given label
|
void |
reset() |
double |
scoreForMetric(Evaluation.Metric metric) |
void |
setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1). Axis should be set as follows: For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1 For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1 For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2 For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1 For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3 |
String |
stats()
Report the classification statistics as a String
|
String |
stats(boolean suppressWarnings)
Method to obtain the classification report as a String
|
String |
stats(boolean suppressWarnings,
boolean includeConfusion)
Method to obtain the classification report as a String
|
double |
topNAccuracy()
Top N accuracy of the predictions so far.
|
Map<Integer,Integer> |
trueNegatives()
True negatives: correctly rejected
|
Map<Integer,Integer> |
truePositives()
True positives: correctly rejected
|
attempFromLegacyFromJson, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
protected static final double DEFAULT_EDGE_VALUE
protected static final int CONFUSION_PRINT_MAX_CLASSES
protected int axis
protected Integer binaryPositiveClass
protected final int topN
protected int topNCorrectCount
protected int topNTotalCount
protected ConfusionMatrix<Integer> confusion
protected int numRowCounter
protected Double binaryDecisionThreshold
protected INDArray costArray
protected int maxWarningClassesToPrint
protected Evaluation(int axis, Integer binaryPositiveClass, int topN, List<String> labelsList, Double binaryDecisionThreshold, INDArray costArray, int maxWarningClassesToPrint)
public Evaluation()
public Evaluation(int numClasses)
numClasses
- the number of classes to account for in the evaluationpublic Evaluation(int numClasses, Integer binaryPositiveClass)
numClasses
- The number of classes for the evaluation. Must be 2, if binaryPositiveClass is non-nullbinaryPositiveClass
- If non-null, the positive class (0 or 1).public Evaluation(List<String> labels)
labels
- the labels to use
for the outputpublic Evaluation(Map<Integer,String> labels)
labels
- a map of label index to label valuepublic Evaluation(List<String> labels, int topN)
labels
- Labels for the classes (may be null)topN
- Value to use for top N accuracy calculation (<=1: standard accuracy). Note that with top N
accuracy, an example is considered 'correct' if the probability for the true class is one of the
highest N valuespublic Evaluation(double binaryDecisionThreshold)
Evaluation(double, Integer)
to
change this.binaryDecisionThreshold
- Decision threshold to use for binary predictionspublic Evaluation(double binaryDecisionThreshold, @NonNull @NonNull Integer binaryPositiveClass)
binaryDecisionThreshold
- Decision threshold to use for binary predictionspublic Evaluation(INDArray costArray)
costArray
- Row vector cost array. May be nullpublic Evaluation(List<String> labels, INDArray costArray)
labels
- Labels for the output classes. May be nullcostArray
- Row vector cost array. May be nullprotected int numClasses()
public void reset()
public void setAxis(int axis)
axis
- Axis to use for evaluationpublic int getAxis()
setAxis(int)
for detailspublic void eval(INDArray realOutcomes, INDArray guesses)
Note that an IllegalArgumentException is thrown if the two passed in matrices aren't the same length.
eval
in interface IEvaluation<Evaluation>
eval
in class BaseEvaluation<Evaluation>
realOutcomes
- the real outcomes (labels - usually binary)guesses
- the guesses/prediction (usually a probability vector)public void eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
labels
- Data labelspredictions
- Network predictionsrecordMetaData
- Optional; may be null. If not null, should have size equal to the number of outcomes/guessespublic void eval(int predictedIdx, int actualIdx)
predictedIdx
- Index of class predicted by the networkactualIdx
- Index of actual classpublic String stats()
public String stats(boolean suppressWarnings)
suppressWarnings
- whether or not to output warnings related to the evaluation resultspublic String stats(boolean suppressWarnings, boolean includeConfusion)
suppressWarnings
- whether or not to output warnings related to the evaluation resultsincludeConfusion
- whether the confusion matrix should be included it the returned stats or notpublic String confusionMatrix()
public double precision(Integer classLabel)
classLabel
- the labelpublic double precision(Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double precision()
#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class
only.#getBinaryPositiveClass()
is null, the returned value is macro-averaged
across all classes. i.e., is macro-averaged precision, equivalent to precision(EvaluationAveraging.Macro)
public double precision(EvaluationAveraging averaging)
averaging
- Averaging method - macro or micropublic int averagePrecisionNumClassesExcluded()
public int averageRecallNumClassesExcluded()
public int averageF1NumClassesExcluded()
public int averageFBetaNumClassesExcluded()
public double recall(int classLabel)
classLabel
- the labelpublic double recall(int classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double recall()
#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class
only.#getBinaryPositiveClass()
is null, the returned value is macro-averaged
across all classes. i.e., is macro-averaged recall, equivalent to recall(EvaluationAveraging.Macro)
public double recall(EvaluationAveraging averaging)
averaging
- Averaging method - macro or micropublic double falsePositiveRate(int classLabel)
classLabel
- the labelpublic double falsePositiveRate(int classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double falsePositiveRate()
#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class
only.#getBinaryPositiveClass()
is null, the returned value is macro-averaged
across all classes. i.e., is macro-averaged false positive rate, equivalent to
falsePositiveRate(EvaluationAveraging.Macro)
public double falsePositiveRate(EvaluationAveraging averaging)
averaging
- Averaging method - macro or micropublic double falseNegativeRate(Integer classLabel)
classLabel
- the labelpublic double falseNegativeRate(Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double falseNegativeRate()
#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class
only.#getBinaryPositiveClass()
is null, the returned value is macro-averaged
across all classes. i.e., is macro-averaged false negative rate, equivalent to
falseNegativeRate(EvaluationAveraging.Macro)
public double falseNegativeRate(EvaluationAveraging averaging)
averaging
- Averaging method - macro or micropublic double falseAlarmRate()
#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class
only.#getBinaryPositiveClass()
is null, the returned value is macro-averaged
across all classes. i.e., is macro-averaged false alarm rate)public double f1(int classLabel)
classLabel
- the label to calculate f1 forpublic double fBeta(double beta, int classLabel)
beta
- Beta value to useclassLabel
- Class labelpublic double fBeta(double beta, int classLabel, double defaultValue)
beta
- Beta value to useclassLabel
- Class labeldefaultValue
- Default value to use when precision or recall is undefined (0/0 for prec. or recall)public double f1()
#setBinaryPositiveClass(Integer)
), the returned value will be for the specified positive class
only.#getBinaryPositiveClass()
is null, the returned value is macro-averaged
across all classes. i.e., is macro-averaged f1, equivalent to f1(EvaluationAveraging.Macro)
public double f1(EvaluationAveraging averaging)
averaging
- Averaging method to usepublic double fBeta(double beta, EvaluationAveraging averaging)
beta
- Beta value to useaveraging
- Averaging method to usepublic double gMeasure(int output)
output
- The specified outputpublic double gMeasure(EvaluationAveraging averaging)
averaging
- Averaging method to usepublic double accuracy()
public double topNAccuracy()
accuracy()
public double matthewsCorrelation(int classIdx)
classIdx
- Class index to calculate Matthews correlation coefficient forpublic double matthewsCorrelation(EvaluationAveraging averaging)
averaging
- Averaging approachpublic Map<Integer,Integer> truePositives()
public Map<Integer,Integer> trueNegatives()
public Map<Integer,Integer> falsePositives()
public Map<Integer,Integer> falseNegatives()
public Map<Integer,Integer> negative()
public Map<Integer,Integer> positive()
public void incrementTruePositives(Integer classLabel)
public void incrementTrueNegatives(Integer classLabel)
public void incrementFalseNegatives(Integer classLabel)
public void incrementFalsePositives(Integer classLabel)
public void addToConfusion(Integer real, Integer guess)
real
- the actual guessguess
- the system guesspublic int classCount(Integer clazz)
clazz
- the labelpublic int getNumRowCounter()
public int getTopNCorrectCount()
public int getTopNTotalCount()
getNumRowCounter()
,
but may differ in the case of using eval(int, int)
as top N accuracy cannot be calculated in that case
(i.e., requires the full probability distribution, not just predicted/actual indices)public ConfusionMatrix<Integer> getConfusionMatrix()
public void merge(Evaluation other)
other
- Evaluation object to merge into this one.public String confusionToString()
public List<Prediction> getPredictionErrors()
Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: BaseEvaluation.eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
via getConfusionMatrix()
public List<Prediction> getPredictionsByActualClass(int actualClass)
Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: BaseEvaluation.eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
via getConfusionMatrix()
actualClass
- Actual class to get predictions forpublic List<Prediction> getPredictionByPredictedClass(int predictedClass)
Note: Prediction errors are ONLY available if the "evaluate with metadata" method is used: BaseEvaluation.eval(INDArray, INDArray, List)
Otherwise (if the metadata hasn't been recorded via that previously mentioned eval method), there is no value in
splitting each prediction out into a separate Prediction object - instead, use the confusion matrix to get the counts,
via getConfusionMatrix()
predictedClass
- Actual class to get predictions forpublic List<Prediction> getPredictions(int actualClass, int predictedClass)
actualClass
- Actual classpredictedClass
- Predicted classpublic double scoreForMetric(Evaluation.Metric metric)
public static Evaluation fromJson(String json)
public static Evaluation fromYaml(String yaml)
public double getValue(IMetric metric)
IEvaluation
public Evaluation newInstance()
IEvaluation
Copyright © 2020. All rights reserved.