public class ROCMultiClass extends BaseEvaluation<ROCMultiClass>
ROC
, ROCBinary supports both exact (thersholdSteps == 0) and thresholded; see ROC
for details.
The ROC curves are produced by treating the predictions as a set of one-vs-all classifiers, and then calculating ROC curves for each. In practice, this means for N classes, we get N ROC curves.
Modifier and Type | Class and Description |
---|---|
static class |
ROCMultiClass.Metric
AUROC: Area under ROC curve
AUPRC: Area under Precision-Recall Curve |
Modifier and Type | Field and Description |
---|---|
protected int |
axis |
static int |
DEFAULT_STATS_PRECISION |
Modifier | Constructor and Description |
---|---|
|
ROCMultiClass() |
|
ROCMultiClass(int thresholdSteps) |
|
ROCMultiClass(int thresholdSteps,
boolean rocRemoveRedundantPts) |
protected |
ROCMultiClass(int axis,
int thresholdSteps,
boolean rocRemoveRedundantPts,
List<String> labels) |
Modifier and Type | Method and Description |
---|---|
double |
calculateAUC(int classIdx)
Calculate the AUC - Area Under ROC Curve
Utilizes trapezoidal integration internally |
double |
calculateAUCPR(int classIdx)
Calculate the AUPRC - Area Under Curve Precision Recall
Utilizes trapezoidal integration internally |
double |
calculateAverageAUC()
Calculate the macro-average (one-vs-all) AUC for all classes
|
double |
calculateAverageAUCPR()
Calculate the macro-average (one-vs-all) AUCPR (area under precision recall curve) for all classes
|
void |
eval(INDArray labels,
INDArray predictions,
INDArray mask,
List<? extends Serializable> recordMetaData)
Evaluate the network, with optional metadata
|
static ROCMultiClass |
fromJson(String json) |
int |
getAxis()
Get the axis - see
setAxis(int) for details |
long |
getCountActualNegative(int outputNum)
Get the actual negative count (accounting for any masking) for the specified output/column
|
long |
getCountActualPositive(int outputNum)
Get the actual positive count (accounting for any masking) for the specified class
|
int |
getNumClasses() |
PrecisionRecallCurve |
getPrecisionRecallCurve(int classIdx)
Get the (one vs.
|
RocCurve |
getRocCurve(int classIdx)
Get the (one vs.
|
double |
getValue(IMetric metric)
Get the value of a given metric for this evaluation.
|
void |
merge(ROCMultiClass other)
Merge this ROCMultiClass instance with another.
|
ROCMultiClass |
newInstance()
Get a new instance of this evaluation, with the same configuration but no data.
|
void |
reset() |
double |
scoreForMetric(ROCMultiClass.Metric metric,
int idx) |
void |
setAxis(int axis)
Set the axis for evaluation - this is the dimension along which the probability (and label classes) are present.
For DL4J, this can be left as the default setting (axis = 1). Axis should be set as follows: For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1 For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1 For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2 For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1 For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3 |
String |
stats() |
String |
stats(int printPrecision) |
attempFromLegacyFromJson, eval, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toString, toYaml
public static final int DEFAULT_STATS_PRECISION
protected int axis
protected ROCMultiClass(int axis, int thresholdSteps, boolean rocRemoveRedundantPts, List<String> labels)
public ROCMultiClass()
public ROCMultiClass(int thresholdSteps)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. Set to 0 for exact ROC calculationpublic ROCMultiClass(int thresholdSteps, boolean rocRemoveRedundantPts)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curvespublic void setAxis(int axis)
axis
- Axis to use for evaluationpublic int getAxis()
setAxis(int)
for detailspublic void reset()
public String stats()
public String stats(int printPrecision)
public void eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
labels
- Data labelspredictions
- Network predictionsrecordMetaData
- Optional; may be null. If not null, should have size equal to the number of outcomes/guessespublic RocCurve getRocCurve(int classIdx)
classIdx
- Class index to get the ROC curve forpublic PrecisionRecallCurve getPrecisionRecallCurve(int classIdx)
classIdx
- Class to get the P-R curve forpublic double calculateAUC(int classIdx)
public double calculateAUCPR(int classIdx)
public double calculateAverageAUC()
public double calculateAverageAUCPR()
public long getCountActualPositive(int outputNum)
outputNum
- Index of the classpublic long getCountActualNegative(int outputNum)
outputNum
- Index of the classpublic void merge(ROCMultiClass other)
other
- ROCMultiClass instance to combine with this onepublic int getNumClasses()
public static ROCMultiClass fromJson(String json)
public double scoreForMetric(ROCMultiClass.Metric metric, int idx)
public double getValue(IMetric metric)
IEvaluation
public ROCMultiClass newInstance()
IEvaluation
Copyright © 2019. All rights reserved.