public class ROCBinary extends BaseEvaluation<ROCBinary>
ROC
, ROCBinary supports both exact (thersholdSteps == 0) and thresholded; see ROC
for details.
Unlike ROC
(which supports a single binary label (as a single column probability, or 2 column 'softmax' probability
distribution), ROCBinary assumes that all outputs are independent binary variables. This also differs from
ROCMultiClass
, which should be used for multi-class (single non-binary) cases.
ROCBinary supports per-example and per-output masking: for per-output masking, any particular output may be absent (mask value 0) and hence won't be included in the calculated ROC.
Modifier and Type | Class and Description |
---|---|
static class |
ROCBinary.Metric
AUROC: Area under ROC curve
AUPRC: Area under Precision-Recall Curve |
Modifier and Type | Field and Description |
---|---|
protected int |
axis |
static int |
DEFAULT_STATS_PRECISION |
Modifier | Constructor and Description |
---|---|
|
ROCBinary() |
|
ROCBinary(int thresholdSteps) |
|
ROCBinary(int thresholdSteps,
boolean rocRemoveRedundantPts) |
protected |
ROCBinary(int axis,
int thresholdSteps,
boolean rocRemoveRedundantPts,
List<String> labels) |
Modifier and Type | Method and Description |
---|---|
double |
calculateAUC(int outputNum)
Calculate the AUC - Area Under (ROC) Curve
Utilizes trapezoidal integration internally |
double |
calculateAUCPR(int outputNum)
Calculate the AUCPR - Area Under Curve - Precision Recall
Utilizes trapezoidal integration internally |
double |
calculateAverageAuc()
Macro-average AUC for all outcomes
|
double |
calculateAverageAUCPR() |
void |
eval(INDArray labels,
INDArray predictions,
INDArray mask,
List<? extends Serializable> recordMetaData) |
static ROCBinary |
fromJson(String json) |
int |
getAxis()
Get the axis - see
setAxis(int) for details |
long |
getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/column
|
long |
getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified output/column
|
PrecisionRecallCurve |
getPrecisionRecallCurve(int outputNum)
Get the Precision-Recall curve for the specified output
|
ROC |
getROC(int outputNum)
Get the ROC object for the specific column
|
RocCurve |
getRocCurve(int outputNum)
Get the ROC curve for the specified output
|
double |
getValue(IMetric metric)
Get the value of a given metric for this evaluation.
|
void |
merge(ROCBinary other) |
ROCBinary |
newInstance()
Get a new instance of this evaluation, with the same configuration but no data.
|
int |
numLabels()
Returns the number of labels - (i.e., size of the prediction/labels arrays) - if known.
|
void |
reset() |
double |
scoreForMetric(ROCBinary.Metric metric,
int idx) |
void |
setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label independent binary classes) are present.
For DL4J, this can be left as the default setting (axis = 1). Axis should be set as follows: For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1 For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1 For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2 For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1 For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3 |
void |
setLabelNames(List<String> labels)
Set the label names, for printing via
stats() |
String |
stats() |
String |
stats(int printPrecision) |
attempFromLegacyFromJson, eval, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
public static final int DEFAULT_STATS_PRECISION
protected int axis
protected ROCBinary(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels)
public ROCBinary()
public ROCBinary(int thresholdSteps)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. Set to 0 for exact ROC calculationpublic ROCBinary(int thresholdSteps, boolean rocRemoveRedundantPts)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curvespublic void setAxis(int axis)
axis
- Axis to use for evaluationpublic int getAxis()
setAxis(int)
for detailspublic void reset()
public void eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
public void merge(ROCBinary other)
public int numLabels()
public long getCountActualPositive(int outputNum)
outputNum
- Index of the output (0 to numLabels()
-1)public long getCountActualNegative(int outputNum)
outputNum
- Index of the output (0 to numLabels()
-1)public ROC getROC(int outputNum)
outputNum
- Column (output number)public RocCurve getRocCurve(int outputNum)
outputNum
- Number of the output to get the ROC curve forpublic PrecisionRecallCurve getPrecisionRecallCurve(int outputNum)
outputNum
- Number of the output to get the P-R curve forpublic double calculateAverageAuc()
public double calculateAverageAUCPR()
public double calculateAUC(int outputNum)
outputNum
- Output number to calculate AUC forpublic double calculateAUCPR(int outputNum)
outputNum
- Output number to calculate AUCPR forpublic void setLabelNames(List<String> labels)
stats()
public String stats()
public String stats(int printPrecision)
public double scoreForMetric(ROCBinary.Metric metric, int idx)
public double getValue(IMetric metric)
IEvaluation
public ROCBinary newInstance()
IEvaluation
Copyright © 2020. All rights reserved.