public class EvaluationBinary extends BaseEvaluation<EvaluationBinary>
ROCBinary
is also used internally to calculate AUC for each output, but only when using an
appropriate constructor, EvaluationBinary(int, Integer)
Note that EvaluationBinary supports both per-example and per-output masking.
EvaluationBinary by default uses a decision threshold of 0.5, however decision thresholds can be set on a per-output
basis using EvaluationBinary(INDArray)
.
The most common use case: multi-task networks, where each output is a binary value. This differs from Evaluation
in that Evaluation
is for a single class (binary or non-binary) evaluation.
Modifier and Type | Class and Description |
---|---|
static class |
EvaluationBinary.Metric |
Modifier and Type | Field and Description |
---|---|
protected int |
axis |
static double |
DEFAULT_EDGE_VALUE |
static int |
DEFAULT_PRECISION |
Modifier | Constructor and Description |
---|---|
|
EvaluationBinary(INDArray decisionThreshold)
Create an EvaulationBinary instance with an optional decision threshold array.
|
|
EvaluationBinary(int size,
Integer rocBinarySteps)
This constructor allows for ROC to be calculated in addition to the standard evaluation metrics, when the
rocBinarySteps arg is non-null.
|
protected |
EvaluationBinary(int axis,
ROCBinary rocBinary,
List<String> labels,
INDArray decisionThreshold) |
Modifier and Type | Method and Description |
---|---|
double |
accuracy(int outputNum)
Get the accuracy for the specified output
|
double |
averageAccuracy() |
double |
averageF1() |
double |
averageFalseAlarmRate()
Average False Alarm Rate (FAR) (see
falseAlarmRate(int) ) for all labels. |
double |
averageGMeasure()
Average G-measure (see
gMeasure(int) ) for all labels. |
double |
averageMatthewsCorrelation()
Macro average of the Matthews correlation coefficient (MCC) (see
matthewsCorrelation(int) ) for all labels. |
double |
averagePrecision() |
double |
averageRecall() |
void |
eval(INDArray labels,
INDArray networkPredictions) |
void |
eval(INDArray labelsArr,
INDArray predictionsArr,
INDArray maskArr) |
void |
eval(INDArray labels,
INDArray networkPredictions,
INDArray maskArray,
List<? extends Serializable> recordMetaData) |
double |
f1(int outputNum)
Get the F1 score for the specified output
|
double |
falseAlarmRate(int outputNum)
False Alarm Rate (FAR) reflects rate of misclassified to classified records
http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
|
double |
falseNegativeRate(Integer classLabel)
Returns the false negative rate for a given label
|
double |
falseNegativeRate(Integer classLabel,
double edgeCase)
Returns the false negative rate for a given label
|
int |
falseNegatives(int outputNum)
Get the false negatives count for the specified output
|
double |
falsePositiveRate(int classLabel)
Returns the false positive rate for a given label
|
double |
falsePositiveRate(int classLabel,
double edgeCase)
Returns the false positive rate for a given label
|
int |
falsePositives(int outputNum)
Get the false positives count for the specified output
|
double |
fBeta(double beta,
int outputNum)
Calculate the F-beta value for the given output
|
static EvaluationBinary |
fromJson(String json) |
static EvaluationBinary |
fromYaml(String yaml) |
int |
getAxis()
Get the axis - see
setAxis(int) for details |
ROCBinary |
getROCBinary()
Returns the
ROCBinary instance, if present |
double |
getValue(IMetric metric)
Get the value of a given metric for this evaluation.
|
double |
gMeasure(int output)
Calculate the macro average G-measure for the given output
|
double |
matthewsCorrelation(int outputNum)
Calculate the Matthews correlation coefficient for the specified output
|
void |
merge(EvaluationBinary other)
Merge the other evaluation object into this one.
|
EvaluationBinary |
newInstance()
Get a new instance of this evaluation, with the same configuration but no data.
|
int |
numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known.
|
double |
precision(int outputNum)
Get the precision (tp / (tp + fp)) for the specified output
|
double |
recall(int outputNum)
Get the recall (tp / (tp + fn)) for the specified output
|
void |
reset() |
double |
scoreForMetric(EvaluationBinary.Metric metric,
int outputNum)
Calculate specific metric (see
EvaluationBinary.Metric ) for a given label. |
void |
setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1). Axis should be set as follows: For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1 For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1 For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2 For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1 For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3 |
void |
setLabelNames(List<String> labels)
Set the label names, for printing via
stats() |
String |
stats()
Get a String representation of the EvaluationBinary class, using the default precision
|
String |
stats(int printPrecision)
Get a String representation of the EvaluationBinary class, using the specified precision
|
int |
totalCount(int outputNum)
Get the total number of values for the specified column, accounting for any masking
|
int |
trueNegatives(int outputNum)
Get the true negatives count for the specified output
|
int |
truePositives(int outputNum)
Get the true positives count for the specified output
|
attempFromLegacyFromJson, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
public static final int DEFAULT_PRECISION
public static final double DEFAULT_EDGE_VALUE
protected int axis
protected EvaluationBinary(int axis, ROCBinary rocBinary, List<String> labels, INDArray decisionThreshold)
public EvaluationBinary(INDArray decisionThreshold)
decisionThreshold
- Decision threshold for each output; may be null. Should be a row vector with length
equal to the number of outputs, with values in range 0 to 1. An array of 0.5 values is
equivalent to the default (no manually specified decision threshold).public EvaluationBinary(int size, Integer rocBinarySteps)
ROCBinary
for more detailssize
- Number of outputsrocBinarySteps
- Constructor arg for ROCBinary.ROCBinary(int)
public void setAxis(int axis)
axis
- Axis to use for evaluationpublic int getAxis()
setAxis(int)
for detailspublic void eval(INDArray labels, INDArray networkPredictions)
eval
in interface IEvaluation<EvaluationBinary>
eval
in class BaseEvaluation<EvaluationBinary>
public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData)
public void eval(INDArray labelsArr, INDArray predictionsArr, INDArray maskArr)
eval
in interface IEvaluation<EvaluationBinary>
eval
in class BaseEvaluation<EvaluationBinary>
public void merge(EvaluationBinary other)
EvaluationBinary(int, java.lang.Integer)
instance contains the counts
etc from bothother
- EvaluationBinary object to merge into this one.public void reset()
public int numLabels()
public void setLabelNames(List<String> labels)
stats()
public int totalCount(int outputNum)
public int truePositives(int outputNum)
public int trueNegatives(int outputNum)
public int falsePositives(int outputNum)
public int falseNegatives(int outputNum)
public double averageAccuracy()
public double accuracy(int outputNum)
public double averagePrecision()
public double precision(int outputNum)
public double averageRecall()
public double recall(int outputNum)
public double averageF1()
public double fBeta(double beta, int outputNum)
beta
- Beta value to useoutputNum
- Output numberpublic double f1(int outputNum)
public double matthewsCorrelation(int outputNum)
outputNum
- Output numberpublic double averageMatthewsCorrelation()
matthewsCorrelation(int)
) for all labels.public double gMeasure(int output)
output
- The specified outputpublic double averageGMeasure()
gMeasure(int)
) for all labels.public double falsePositiveRate(int classLabel)
classLabel
- the labelpublic double falsePositiveRate(int classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double falseNegativeRate(Integer classLabel)
classLabel
- the labelpublic double falseNegativeRate(Integer classLabel, double edgeCase)
classLabel
- the labeledgeCase
- What to output in case of 0/0public double averageFalseAlarmRate()
falseAlarmRate(int)
) for all labels.public double falseAlarmRate(int outputNum)
outputNum
- Class index to calculate False Alarm Rate (FAR)public String stats()
public String stats(int printPrecision)
printPrecision
- The precision (number of decimal places) for the accuracy, f1, etc.public double scoreForMetric(EvaluationBinary.Metric metric, int outputNum)
EvaluationBinary.Metric
) for a given label.metric
- The Metric to calculate.outputNum
- Class index to calculate.public static EvaluationBinary fromJson(String json)
public static EvaluationBinary fromYaml(String yaml)
public double getValue(IMetric metric)
IEvaluation
public EvaluationBinary newInstance()
IEvaluation
Copyright © 2020. All rights reserved.