public class If extends DifferentialFunction implements CustomOp
SameDiff.SameDiffFunctionDefinition
depending on a predicate org.nd4j.autodiff.samediff.SameDiff.SameDiffConditional
Modifier and Type | Field and Description |
---|---|
protected String |
blockName |
protected SDVariable |
dummyResult |
protected SameDiffFunctionDefinition |
falseBody |
protected SameDiff |
falseBodyExecution |
protected String |
falseBodyName |
protected SDVariable[] |
inputVars |
protected SameDiff |
loopBodyExecution |
protected SDVariable[] |
outputVars |
protected SameDiffConditional |
predicate |
protected SameDiff |
predicateExecution |
protected SDVariable |
targetBoolean |
protected SameDiffFunctionDefinition |
trueBody |
protected Boolean |
trueBodyExecuted |
protected String |
trueBodyName |
dimensions, extraArgs, inPlace, sameDiff, scalarValue
Constructor and Description |
---|
If(If ifStatement) |
If(String blockName,
SameDiff parent,
SDVariable[] inputVars,
SameDiffFunctionDefinition conditionBody,
SameDiffConditional predicate,
SameDiffFunctionDefinition trueBody,
SameDiffFunctionDefinition falseBody) |
Modifier and Type | Method and Description |
---|---|
void |
addBArgument(boolean... arg) |
void |
addIArgument(int... arg) |
void |
addIArgument(long... arg) |
void |
addInputArgument(INDArray... arg) |
void |
addOutputArgument(INDArray... arg) |
void |
addTArgument(double... arg) |
void |
assertValidForExecution()
Asserts a valid state for execution,
otherwise throws an
ND4JIllegalStateException |
boolean[] |
bArgs() |
List<LongShapeDescriptor> |
calculateOutputShape()
Calculate the output shape for this op
|
List<SDVariable> |
doDiff(List<SDVariable> f1)
The actual implementation for automatic differentiation.
|
void |
exectedTrueOrFalse(boolean trueBodyExecuted)
Toggle whether the true body was executed
or the false body
|
Boolean |
getBArgument(int index) |
CustomOpDescriptor |
getDescriptor()
Get the custom op descriptor if one is available.
|
Long |
getIArgument(int index) |
INDArray |
getInputArgument(int index) |
INDArray |
getOutputArgument(int index) |
Double |
getTArgument(int index) |
long[] |
iArgs() |
void |
initFromOnnx(OnnxProto3.NodeProto node,
SameDiff initWith,
Map<String,OnnxProto3.AttributeProto> attributesForNode,
OnnxProto3.GraphProto graph)
Iniitialize the function from the given
OnnxProto3.NodeProto |
void |
initFromTensorFlow(NodeDef nodeDef,
SameDiff initWith,
Map<String,AttrValue> attributesForNode,
GraphDef graph)
Initialize the function from the given
NodeDef |
INDArray[] |
inputArguments() |
boolean |
isInplaceCall()
This method returns true if op is supposed to be executed inplace
|
int |
numBArguments() |
int |
numIArguments() |
int |
numInputArguments() |
int |
numOutputArguments() |
int |
numTArguments() |
String |
onnxName()
The opName of this function in onnx
|
long |
opHash()
This method returns LongHash of the opName()
|
String |
opName()
The name of the op
|
Op.Type |
opType()
The type of the op
|
INDArray[] |
outputArguments() |
SDVariable[] |
outputVariables(String baseName)
Return the output functions for this differential function.
|
void |
removeIArgument(Integer arg) |
void |
removeInputArgument(INDArray arg) |
void |
removeOutputArgument(INDArray arg) |
void |
removeTArgument(Double arg) |
double[] |
tArgs() |
String |
tensorflowName()
The opName of this function tensorflow
|
String |
toString() |
arg, arg, argNames, args, attributeAdaptersForFunction, calculateOutputDataTypes, configFieldName, diff, dup, equals, f, getNumOutputs, getValue, hashCode, isConfigProperties, larg, mappingsForFunction, onnxNames, opNum, outputVariable, outputVariables, outputVariablesNames, propertiesForFunction, rarg, resolvePropertiesFromSameDiffBeforeExecution, setInstanceId, setPropertiesForFunction, setValueFor, tensorflowNames
protected SameDiff loopBodyExecution
protected SameDiff predicateExecution
protected SameDiff falseBodyExecution
protected SameDiffConditional predicate
protected SameDiffFunctionDefinition trueBody
protected SameDiffFunctionDefinition falseBody
protected String blockName
protected String trueBodyName
protected String falseBodyName
protected SDVariable[] inputVars
protected Boolean trueBodyExecuted
protected SDVariable targetBoolean
protected SDVariable dummyResult
protected SDVariable[] outputVars
public If(If ifStatement)
public If(String blockName, SameDiff parent, SDVariable[] inputVars, SameDiffFunctionDefinition conditionBody, SameDiffConditional predicate, SameDiffFunctionDefinition trueBody, SameDiffFunctionDefinition falseBody)
public void exectedTrueOrFalse(boolean trueBodyExecuted)
trueBodyExecuted
- public SDVariable[] outputVariables(String baseName)
DifferentialFunction
outputVariables
in class DifferentialFunction
public List<SDVariable> doDiff(List<SDVariable> f1)
DifferentialFunction
doDiff
in class DifferentialFunction
public String opName()
DifferentialFunction
opName
in interface CustomOp
opName
in class DifferentialFunction
public long opHash()
CustomOp
public boolean isInplaceCall()
CustomOp
isInplaceCall
in interface CustomOp
public INDArray[] outputArguments()
outputArguments
in interface CustomOp
public INDArray[] inputArguments()
inputArguments
in interface CustomOp
public void addIArgument(int... arg)
addIArgument
in interface CustomOp
public void addIArgument(long... arg)
addIArgument
in interface CustomOp
public void addBArgument(boolean... arg)
addBArgument
in interface CustomOp
public void removeIArgument(Integer arg)
removeIArgument
in interface CustomOp
public Boolean getBArgument(int index)
getBArgument
in interface CustomOp
public Long getIArgument(int index)
getIArgument
in interface CustomOp
public int numIArguments()
numIArguments
in interface CustomOp
public void addTArgument(double... arg)
addTArgument
in interface CustomOp
public void removeTArgument(Double arg)
removeTArgument
in interface CustomOp
public Double getTArgument(int index)
getTArgument
in interface CustomOp
public int numTArguments()
numTArguments
in interface CustomOp
public int numBArguments()
numBArguments
in interface CustomOp
public void addInputArgument(INDArray... arg)
addInputArgument
in interface CustomOp
public void removeInputArgument(INDArray arg)
removeInputArgument
in interface CustomOp
public INDArray getInputArgument(int index)
getInputArgument
in interface CustomOp
public int numInputArguments()
numInputArguments
in interface CustomOp
public void addOutputArgument(INDArray... arg)
addOutputArgument
in interface CustomOp
public void removeOutputArgument(INDArray arg)
removeOutputArgument
in interface CustomOp
public INDArray getOutputArgument(int index)
getOutputArgument
in interface CustomOp
public int numOutputArguments()
numOutputArguments
in interface CustomOp
public Op.Type opType()
DifferentialFunction
opType
in class DifferentialFunction
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String,AttrValue> attributesForNode, GraphDef graph)
DifferentialFunction
NodeDef
initFromTensorFlow
in class DifferentialFunction
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String,OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph)
DifferentialFunction
OnnxProto3.NodeProto
initFromOnnx
in class DifferentialFunction
public List<LongShapeDescriptor> calculateOutputShape()
DifferentialFunction
calculateOutputShape
in interface CustomOp
calculateOutputShape
in class DifferentialFunction
public CustomOpDescriptor getDescriptor()
CustomOp
getDescriptor
in interface CustomOp
public void assertValidForExecution()
CustomOp
ND4JIllegalStateException
assertValidForExecution
in interface CustomOp
public String onnxName()
DifferentialFunction
onnxName
in class DifferentialFunction
public String tensorflowName()
DifferentialFunction
tensorflowName
in class DifferentialFunction
Copyright © 2019. All rights reserved.