public class ROC extends BaseEvaluation<ROC>
Modifier and Type | Class and Description |
---|---|
static class |
ROC.CountsForThreshold |
static class |
ROC.Metric
AUROC: Area under ROC curve
AUPRC: Area under Precision-Recall Curve |
Modifier and Type | Field and Description |
---|---|
protected int |
axis |
Constructor and Description |
---|
ROC() |
ROC(int thresholdSteps) |
ROC(int thresholdSteps,
boolean rocRemoveRedundantPts) |
ROC(int thresholdSteps,
boolean rocRemoveRedundantPts,
int exactAllocBlockSize) |
ROC(int thresholdSteps,
boolean rocRemoveRedundantPts,
int exactAllocBlockSize,
int axis) |
Modifier and Type | Method and Description |
---|---|
double |
calculateAUC()
Calculate the AUROC - Area Under ROC Curve
Utilizes trapezoidal integration internally |
double |
calculateAUCPR()
Calculate the area under the precision/recall curve - aka AUCPR
|
void |
eval(INDArray labels,
INDArray predictions,
INDArray mask,
List<? extends Serializable> recordMetaData)
Evaluate (collect statistics for) the given minibatch of data.
|
static ROC |
fromJson(String json) |
int |
getAxis()
Get the axis - see
setAxis(int) for details |
PrecisionRecallCurve |
getPrecisionRecallCurve()
Get the precision recall curve as array.
|
protected INDArray |
getProbAndLabelUsed() |
RocCurve |
getRocCurve()
Get the ROC curve, as a set of (threshold, falsePositive, truePositive) points
|
double |
getValue(IMetric metric)
Get the value of a given metric for this evaluation.
|
void |
merge(ROC other)
Merge this ROC instance with another.
|
ROC |
newInstance()
Get a new instance of this evaluation, with the same configuration but no data.
|
void |
reset() |
double |
scoreForMetric(ROC.Metric metric) |
void |
setAxis(int axis)
Set the axis for evaluation - this should be a size 1 dimension
For DL4J, this can be left as the default setting (axis = 1).
Axis should be set as follows: For 2D (OutputLayer), shape [minibatch, numClasses] - axis = 1 For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NCW format, shape [minibatch, numClasses, sequenceLength] - axis = 1 For 3D, RNNs/CNN1D (DL4J RnnOutputLayer), NWC format, shape [minibatch, sequenceLength, numClasses] - axis = 2 For 4D, CNN2D (DL4J CnnLossLayer), NCHW format, shape [minibatch, channels, height, width] - axis = 1 For 4D, CNN2D, NHWC format, shape [minibatch, height, width, channels] - axis = 3 |
String |
stats() |
String |
toString() |
attempFromLegacyFromJson, eval, eval, eval, evalTimeSeries, evalTimeSeries, fromJson, fromYaml, reshapeAndExtractNotMasked, toJson, toYaml
public ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize, int axis)
public ROC()
public ROC(int thresholdSteps)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationpublic ROC(int thresholdSteps, boolean rocRemoveRedundantPts)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curvespublic ROC(int thresholdSteps, boolean rocRemoveRedundantPts, int exactAllocBlockSize)
thresholdSteps
- Number of threshold steps to use for the ROC calculation. If set to 0: use exact calculationrocRemoveRedundantPts
- Usually set to true. If true, remove any redundant points from ROC and P-R curvesexactAllocBlockSize
- if using exact mode, the block size relocation. Users can likely use the default
setting in almost all casespublic void setAxis(int axis)
axis
- Axis to use for evaluationpublic int getAxis()
setAxis(int)
for detailspublic double calculateAUC()
public RocCurve getRocCurve()
protected INDArray getProbAndLabelUsed()
public double calculateAUCPR()
public PrecisionRecallCurve getPrecisionRecallCurve()
public void eval(INDArray labels, INDArray predictions, INDArray mask, List<? extends Serializable> recordMetaData)
BaseEvaluation.evalTimeSeries(INDArray, INDArray)
or BaseEvaluation.evalTimeSeries(INDArray, INDArray, INDArray)
labels
- Labels / true outcomespredictions
- Predictionspublic void merge(ROC other)
other
- ROC instance to combine with this onepublic void reset()
public String stats()
public String toString()
toString
in class BaseEvaluation<ROC>
public double scoreForMetric(ROC.Metric metric)
public double getValue(IMetric metric)
IEvaluation
public ROC newInstance()
IEvaluation
Copyright © 2021. All rights reserved.