public abstract class BaseReductionBp extends DynamicCustomOp
DynamicCustomOp.DynamicCustomOpsBuilder
Modifier and Type | Field and Description |
---|---|
protected int[] |
dimensions |
protected boolean |
keepDims |
axis, bArguments, dArguments, iArguments, inplaceCall, inputArguments, outputArguments, outputVariables, tArguments
extraArgs, inPlace, ownName, ownNameSetWithDefault, sameDiff, scalarValue
Constructor and Description |
---|
BaseReductionBp(INDArray origInput,
INDArray gradAtOutput,
INDArray output,
boolean keepDims,
int... dimensions) |
BaseReductionBp(INDArray origInput1,
INDArray origInput2,
INDArray gradAtOutput,
INDArray output,
boolean keepDims,
int... dimensions) |
BaseReductionBp(INDArray origInput1,
INDArray origInput2,
INDArray gradAtOutput,
INDArray output1,
INDArray output2,
boolean keepDims,
int... dimensions) |
BaseReductionBp(SameDiff sameDiff,
SDVariable origInput,
SDVariable gradAtOutput,
boolean keepDims,
int... dimensions) |
BaseReductionBp(SameDiff sameDiff,
SDVariable origInput1,
SDVariable origInput2,
SDVariable gradAtOutput,
boolean keepDims,
int... dimensions) |
Modifier and Type | Method and Description |
---|---|
protected void |
addArgs() |
List<DataType> |
calculateOutputDataTypes(List<DataType> dataTypes)
Calculate the data types for the output arrays.
|
abstract String |
opName()
This method returns op opName as string
|
addBArgument, addDArgument, addIArgument, addIArgument, addInputArgument, addOutputArgument, addTArgument, assertValidForExecution, bArgs, builder, calculateOutputShape, calculateOutputShape, clearArrays, dArgs, doDiff, getBArgument, getDescriptor, getIArgument, getInputArgument, getOutputArgument, getTArgument, iArgs, initFromOnnx, initFromTensorFlow, inputArguments, numBArguments, numDArguments, numIArguments, numInputArguments, numOutputArguments, numTArguments, onnxName, opHash, opNum, opType, outputArguments, outputVariables, outputVariables, removeIArgument, removeInputArgument, removeOutputArgument, removeTArgument, setInputArgument, setInputArguments, setOutputArgument, tArgs, tensorflowName, toString, wrapFilterNull, wrapOrNull, wrapOrNull
arg, arg, argNames, args, attributeAdaptersForFunction, configFieldName, diff, dup, equals, getNumOutputs, getValue, hashCode, isConfigProperties, larg, mappingsForFunction, onnxNames, outputs, outputVariable, outputVariablesNames, propertiesForFunction, rarg, replaceArg, setInstanceId, setPropertiesForFunction, setValueFor, tensorflowNames
clone, finalize, getClass, notify, notifyAll, wait, wait, wait
isInplaceCall
public BaseReductionBp(SameDiff sameDiff, SDVariable origInput, SDVariable gradAtOutput, boolean keepDims, int... dimensions)
origInput
- Pre-reduced inputgradAtOutput
- Gradient at the outputkeepDims
- If true: reduction dimensions were keptdimensions
- Dimensions to reduce. May be nullpublic BaseReductionBp(SameDiff sameDiff, SDVariable origInput1, SDVariable origInput2, SDVariable gradAtOutput, boolean keepDims, int... dimensions)
origInput1
- Pre-reduced input 1origInput2
- Pre-reduced input 2gradAtOutput
- Gradient at the outputkeepDims
- If true: reduction dimensions were keptdimensions
- Dimensions to reduce. May be nullpublic BaseReductionBp(INDArray origInput, INDArray gradAtOutput, INDArray output, boolean keepDims, int... dimensions)
origInput
- Pre-reduced inputgradAtOutput
- Gradient at the outputoutput
- Output array - i.e., gradient at the input to the reduction functionkeepDims
- If true: reduction dimensions were keptdimensions
- Dimensions to reduce. May be nullpublic BaseReductionBp(INDArray origInput1, INDArray origInput2, INDArray gradAtOutput, INDArray output, boolean keepDims, int... dimensions)
origInput1
- Pre-reduced input1origInput2
- Pre-reduced input2gradAtOutput
- Gradient at the outputoutput
- Output array - i.e., gradient at the input to the reduction functionkeepDims
- If true: reduction dimensions were keptdimensions
- Dimensions to reduce. May be nullprotected void addArgs()
public abstract String opName()
DynamicCustomOp
opName
in interface CustomOp
opName
in class DynamicCustomOp
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes)
DifferentialFunction
DifferentialFunction.calculateOutputShape()
, this method differs in that it does not
require the input arrays to be populated.
This is important as it allows us to do greedy datatype inference for the entire net - even if arrays are not
available.calculateOutputDataTypes
in class DifferentialFunction
dataTypes
- The data types of the inputsCopyright © 2021. All rights reserved.