public class TestCase extends Object
Modifier and Type | Class and Description |
---|---|
static class |
TestCase.TestSerialization |
Modifier and Type | Field and Description |
---|---|
static boolean |
GC_DEFAULT_DEBUG_MODE |
static double |
GC_DEFAULT_EPS |
static boolean |
GC_DEFAULT_EXIT_FIRST_FAILURE |
static double |
GC_DEFAULT_MAX_REL_ERROR |
static double |
GC_DEFAULT_MIN_ABS_ERROR |
static boolean |
GC_DEFAULT_PRINT |
Modifier and Type | Method and Description |
---|---|
void |
assertConfigValid() |
TestCase |
expected(SDVariable var,
Function<INDArray,String> validationFn) |
TestCase |
expected(@NonNull SDVariable var,
@NonNull INDArray output)
Validate the output (forward pass) for a single variable using INDArray.equals(INDArray)
|
TestCase |
expected(String name,
Function<INDArray,String> validationFn) |
TestCase |
expected(@NonNull String name,
@NonNull INDArray output)
Validate the output (forward pass) for a single variable using INDArray.equals(INDArray)
|
TestCase |
expectedOutput(@NonNull String name,
@NonNull INDArray expected)
Validate the output (forward pass) for a single variable using INDArray.equals(INDArray)
|
TestCase |
expectedOutputRelError(@NonNull String name,
@NonNull INDArray expected,
double maxRelError,
double minAbsError)
Validate the output (forward pass) for a single variable using element-wise relative error:
relError = abs(x-y)/(abs(x)+abs(y)), with x=y=0 case defined to be 0.0.
|
Map<String,INDArray> |
gradCheckMask() |
Set<String> |
gradCheckSkipVariables() |
TestCase |
gradCheckSkipVariables(String... toSkip)
Specify the input variables that should NOT be gradient checked.
|
TestCase |
placeholderValue(String variable,
INDArray value) |
TestCase |
placeholderValues(Map<String,INDArray> placeholderValues) |
String |
testNameErrMsg() |
public static final boolean GC_DEFAULT_PRINT
public static final boolean GC_DEFAULT_EXIT_FIRST_FAILURE
public static final boolean GC_DEFAULT_DEBUG_MODE
public static final double GC_DEFAULT_EPS
public static final double GC_DEFAULT_MAX_REL_ERROR
public static final double GC_DEFAULT_MIN_ABS_ERROR
public TestCase(SameDiff sameDiff)
sameDiff
- SameDiff instance to test. Note: All of the required inputs should already be setpublic TestCase expectedOutput(@NonNull @NonNull String name, @NonNull @NonNull INDArray expected)
name
- Name of the variable to checkexpected
- Expected INDArraypublic TestCase expectedOutputRelError(@NonNull @NonNull String name, @NonNull @NonNull INDArray expected, double maxRelError, double minAbsError)
name
- Name of the variable to checkexpected
- Expected INDArraymaxRelError
- Maximum allowable relative errorminAbsError
- Minimum absolute error for a failure to be considered legitimatepublic TestCase expected(@NonNull @NonNull SDVariable var, @NonNull @NonNull INDArray output)
var
- Variable to checkoutput
- Expected INDArraypublic TestCase expected(@NonNull @NonNull String name, @NonNull @NonNull INDArray output)
name
- Name of the variable to checkoutput
- Expected INDArraypublic TestCase expected(SDVariable var, Function<INDArray,String> validationFn)
public TestCase expected(String name, Function<INDArray,String> validationFn)
name
- The name of the variable to checkvalidationFn
- Function to use to validate the correctness of the specific Op. Should return null
if validation passes, or an error message if the op validation failspublic TestCase gradCheckSkipVariables(String... toSkip)
toSkip
- Name of the input variables to skip gradient check forpublic void assertConfigValid()
public String testNameErrMsg()
Copyright © 2021. All rights reserved.