public class SameDiff extends Object
You define a graph symbolically.
That graph accumulates operations.
In order to execute the graph, you run
exec()
to get all the operations
exec(List)
for an already created set of ops
execAndEndResult()
for the end result only
execAndEndResult(List)
for a cached set of ops
Modifier and Type | Class and Description |
---|---|
static class |
SameDiff.DefaultSameDiffConditional |
static interface |
SameDiff.SameDiffConditional
An interface for representing a conditional statement
|
static interface |
SameDiff.SameDiffFunctionDefinition
A function definition for
samediff
|
public static com.rits.cloning.Cloner newCloner()
public void updateVariableName(String varName, String withName)
varName
- the vertex id to updatewithName
- thew new opNamepublic SameDiff disableDebugging()
public SameDiff enableDebugMode()
public DifferentialFunctionFactory f()
DifferentialFunctionFactory
public SDVariable invokeGraphOn(SameDiff sameDiff)
sameDiff
- public boolean functionExists(String id)
id
- the function id to test forpublic DifferentialFunction getFunctionById(String id)
DifferentialFunction#getOwnName()
id
- the id of the functionpublic void putFunctionForId(String id, DifferentialFunction function)
id
- the idfunction
- the functionpublic String[] getInputsForFunction(DifferentialFunction function)
function
- the function to get the
inputs forpublic String[] getOutputsForFunction(DifferentialFunction function)
function
- the function to get the
inputs forpublic SDVariable[] getOutputVariablesForFunction(DifferentialFunction function)
getOutputsForFunction(DifferentialFunction)
function
- the function reference to get the id forpublic SDVariable[] getInputVariablesForFunction(DifferentialFunction function)
getInputVariablesForFunction(DifferentialFunction)
function
- the function reference to get the id forpublic void updateArrayForVarName(String varName, INDArray arr)
varName
- arr
- {@link
- ND4JIllegalStateException} when the array does not exist.public void putArrayForVarName(String varName, INDArray arr)
updateArrayForVarName(String, INDArray)
if the array already exists.varName
- the vertex id to addarr
- the array to add{@link
- ND4JIllegalStateException} when the array already exists.public int[] getShapeForVarName(String varName)
A shape *and* an array should not be defined at the same time. This wastes memory. The internal map used for tracking shapes for particular vertex ids should also delete redundant shapes stored to avoid redundant sources of information.
varName
- the vertex id to get the shape forpublic void updateShapeForVarName(String varName, int[] shape)
putShapeForVarName(String, int[])
if you want to add a new shape.
Update is meant to be an in place replacement
of the shape for the vertex id *only*.varName
- the vertex id to associateshape
- the shape to associate withpublic void putShapeForVarName(String varName, int[] shape)
varName
- the vertex id to associateshape
- the shape to associate withpublic boolean shapeAlreadyExistsForVarName(String varName)
varName
- the vertex idpublic boolean arrayAlreadyExistsForVarName(String varName)
INDArray
already exist.varName
- the vertex idpublic INDArray getArrForVarName(String varName)
INDArray
for a given vertex idvarName
- public void associateArrayWithVariable(INDArray arr, SDVariable variable)
arr
- the array to get the variable forvariable
- the variable to associatepublic void putSubFunction(String name, SameDiff nameSpace)
SameDiff
namespace as a sub function.name
- the opName of the functionnameSpace
- the namespacepublic Map<String,SDVariable> variableMap()
public SDVariable invoke(Op op, SDVariable x, SDVariable y)
op
- the opx
- the first inputy
- the second inputpublic SDVariable getVariableForArray(INDArray arr)
SDVariable
for an array reference.
Internally samediff associates array references
with variables. This will typically be a shortcut
for the array associated with SDVariable.getArr()
arr
- the array referencepublic Collection<String> definedFunctionNames()
public long memoryForGraph()
public SDVariable invoke(Op op, SDVariable x)
op
- the opx
- the first inputpublic void addPropertyToResolve(DifferentialFunction forFunction, String arrayName)
This is very common for model import.
forFunction
- the function to add the property to resolve forarrayName
- the array namepublic List<String> propertiesToResolveForFunction(DifferentialFunction function)
DifferentialFunction.resolvePropertiesFromSameDiffBeforeExecution()
function
- the function get the properties to resolve forpublic boolean hasPropertiesToResolve(DifferentialFunction function)
function
- the function to checkpublic <T> T getPropertyForFunction(DifferentialFunction functionInstance, String propertyName)
T
- the inferred return typefunctionInstance
- the function to get the
property forpropertyName
- the name of the property to getpublic void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, INDArray property)
functionFor
- the function add a property forpropertyName
- the property nameproperty
- the property valuepublic void addPropertyForFunction(DifferentialFunction functionFor, String propertyName, long property)
functionFor
- the function to add the property forpropertyName
- the name of the property to add the value forproperty
- the property value to addpublic void addVariableMappingForField(DifferentialFunction function, String fieldName, String varName)
GraphMapper.importGraph(File)
.
This data structure is typically accessed during DifferentialFunction.resolvePropertiesFromSameDiffBeforeExecution()
When a function attempts to resolve variables right before execution, there needs to be a way of knowing which variable in a samediff graph should map to a function's particular field name
function
- the function to mapfieldName
- the field name for the function to mapvarName
- the variable name of the array to get from samediffpublic String getVarNameForFieldAndFunction(DifferentialFunction function, String fieldName)
DifferentialFunction.resolvePropertiesFromSameDiffBeforeExecution()
function
- the function to get the variable name forfieldName
- the field name to resolve forpublic boolean isImportVariable(String variableName)
variableName
- the imported variable namepublic void addVarNameForImport(String varName)
varName
- the var name to add.public void setBaseNameForFunctionInstanceId(String baseName, DifferentialFunction function)
generateOutputVariableForOp(DifferentialFunction, String)
for ensuring original names for model import map to current samediff names
when names are generated.baseName
- the base name to addfunction
- the function to declare a base name for.public String getBaseNameForFunction(DifferentialFunction function)
function
- the function to get the base name forpublic <X extends SDVariable> X setupFunction(X function)
DifferentialFunction
reference in to this SameDiff
instance.
If the given array field with the given
index already exists, it will do a reference
check to ensure that the 2 array fields are the same.
If not, an exception is thrown. If the instances are the same (by semantics, not reference) then it will just return the original instance. This is to ensure that instances that are created are unique and reference checked.
function
- the array field to attempt to createpublic void addOutgoingFor(SDVariable[] variables, DifferentialFunction function)
variables
- function
- public void addOutgoingFor(String[] varNames, DifferentialFunction function)
varNames
- function
- public void addArgsFor(String[] variables, DifferentialFunction function)
variables
- function
- public void addArgsFor(SDVariable[] variables, DifferentialFunction function)
variables
- function
- public boolean hasArgs(int[] function)
function
- the function to checkpublic boolean hasArgs(DifferentialFunction function)
function
- the function to checkpublic DifferentialFunction[] functions()
public static SameDiff create(SameDiff originalSameDiff)
originalSameDiff
- public static SameDiff create()
public INDArray[] eval(Map<String,INDArray> inputs)
inputs
- the inputs to evaluatepublic SameDiff dup()
public long numElements()
public List<SDVariable> variables()
public SDVariable one(String name, int[] shape)
name
- the opName of the variableshape
- the shape of the array to be createdpublic SDVariable onesLike(SDVariable input)
input
- public SDVariable onesLike(String name, SDVariable input)
input
- public SDVariable zero(String name, int[] shape)
name
- the opName of the variableshape
- the shape of the array to be createdpublic SDVariable zerosLike(SDVariable input)
input
- public SDVariable zerosLike(String name, SDVariable input)
input
- public SDVariable var(String name, int[] shape, WeightInitScheme weightInitScheme)
WeightInitScheme
name
- the opName of the variableshape
- the shape of the array to be createdweightInitScheme
- the weight init schemepublic SDVariable var(String name, int[] shape)
SDVariable
with the given shape
and a depth of 0.name
- the opName of the variableshape
- the shape of the variablepublic SDVariable var(SDVariable arr)
SDVariable
reference tying this variable to this
samediff instance.
NDArraySupplierInitScheme
is used
to ensure that if the array is allocated anywhere
and SameDiff
instance to exist as a copy of the variable.
arr
- public void removeArgFromFunction(String varName, DifferentialFunction function)
varName
- the variable name to removefunction
- the function to remove the argument frompublic SDVariable var(String name, INDArray arr)
name
- arr
- public SDVariable getVariable(String name)
name
- the opName of the variablepublic SDVariable getGradForVariable(String varName)
varName
- the vertex idpublic void setGradientForVariableName(String variableName, SDVariable variable)
variableName
- the vertex id
to assignvariable
- the variablepublic SDVariable getForwardVariableForVertexId(int vertexId)
vertexId
- the vertex idpublic void setForwardVariableForVarName(String varName, SDVariable forwardVariable)
varName
- forwardVariable
- public SDVariable grad(String varName)
execBackwards()
must be executed first.
All gradient functions are obtained within that time.varName
- the variable opName to get the gradient for.public SDVariable avgPooling2d(SDVariable[] inputs, Pooling2DConfig pooling2DConfig)
inputs
- the inputs to average pooling 2dpooling2DConfig
- the configurationpublic SDVariable avgPooling2d(String name, SDVariable[] inputs, Pooling2DConfig pooling2DConfig)
name
- name of the operation in SameDiffinputs
- the inputs to average pooling 2dpooling2DConfig
- the configurationpublic SDVariable maxPooling2d(SDVariable[] inputs, Pooling2DConfig pooling2DConfig)
inputs
- the inputs to max pooling 2dpooling2DConfig
- the configurationpublic SDVariable maxPooling2d(String name, SDVariable[] inputs, Pooling2DConfig pooling2DConfig)
name
- name of the operation in SameDiffinputs
- the inputs to max pooling 2dpooling2DConfig
- the configurationpublic SDVariable avgPooling3d(SDVariable[] inputs, Pooling3DConfig pooling3DConfig)
inputs
- the inputs to average pooling 3dpooling3DConfig
- the configurationpublic SDVariable avgPooling3d(String name, SDVariable[] inputs, Pooling3DConfig pooling3DConfig)
name
- name of the operation in SameDiffinputs
- the inputs to average pooling 3dpooling3DConfig
- the configurationpublic SDVariable maxPooling3d(SDVariable[] inputs, Pooling3DConfig pooling3DConfig)
inputs
- the inputs to max pooling 3dpooling3DConfig
- the configurationpublic SDVariable maxPooling3d(String name, SDVariable[] inputs, Pooling3DConfig pooling3DConfig)
name
- name of the operation in SameDiffinputs
- the inputs to max pooling 3dpooling3DConfig
- the configurationpublic SDVariable conv1d(SDVariable[] inputs, Conv1DConfig conv1DConfig)
inputs
- the inputs to conv1dconv1DConfig
- the configurationpublic SDVariable conv1d(String name, SDVariable[] inputs, Conv1DConfig conv1DConfig)
name
- name of the operation in SameDiffinputs
- the inputs to conv1dconv1DConfig
- the configurationpublic SDVariable localResponseNormalization(SDVariable inputs, LocalResponseNormalizationConfig lrnConfig)
inputs
- the inputs to lrnlrnConfig
- the configurationpublic SDVariable localResponseNormalization(String name, SDVariable inputs, LocalResponseNormalizationConfig lrnConfig)
name
- name of the operation in SameDiffinputs
- the inputs to lrnlrnConfig
- the configurationpublic SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig)
inputs
- the inputs to conv2dconv2DConfig
- the configurationpublic SDVariable conv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig)
name
- name of the operation in SameDiffinputs
- the inputs to conv2dconv2DConfig
- the configurationpublic SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig)
inputs
- the inputs to conv2ddepthConv2DConfig
- the configurationpublic SDVariable depthWiseConv2d(String name, SDVariable[] inputs, Conv2DConfig depthConv2DConfig)
name
- name of the operation in SameDiffinputs
- the inputs to sconv2ddepthConv2DConfig
- the configurationpublic SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig)
inputs
- the inputs to conv2dconv2DConfig
- the configurationpublic SDVariable sconv2d(String name, SDVariable[] inputs, Conv2DConfig conv2DConfig)
name
- name of the operation in SameDiffinputs
- the inputs to sconv2dconv2DConfig
- the configurationpublic SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig)
inputs
- the inputs to sconv2ddeconv2DConfig
- the configurationpublic SDVariable deconv2d(String name, SDVariable[] inputs, DeConv2DConfig deconv2DConfig)
name
- name of the operation in SameDiffinputs
- the inputs to sconv2ddeconv2DConfig
- the configurationpublic SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig)
inputs
- the inputs to conv3dconv3DConfig
- the configurationpublic SDVariable conv3d(String name, SDVariable[] inputs, Conv3DConfig conv3DConfig)
name
- name of the operation in SameDiffinputs
- the inputs to conv3dconv3DConfig
- the configurationpublic SDVariable batchNorm(SDVariable input, SDVariable mean, SDVariable variance, SDVariable gamma, SDVariable beta, boolean applyGamma, boolean applyBeta, double epsilon)
public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, SDVariable variance, SDVariable gamma, SDVariable beta, boolean applyGamma, boolean applyBeta, double epsilon)
public SDVariable scalar(String name, double value)
name
- value
- public SDVariable gte(SDVariable iX, double iy)
iX
- public SDVariable lte(SDVariable iX, double iy)
iX
- public SDVariable gt(SDVariable iX, double iy)
iX
- public SDVariable lt(SDVariable iX, double iy)
iX
- public SDVariable neq(SDVariable iX, double iy)
iX
- public SDVariable eq(SDVariable iX, double iy)
iX
- public SDVariable gte(SDVariable iX, SDVariable iy)
iX
- public SDVariable lte(SDVariable iX, SDVariable iy)
iX
- public SDVariable gt(SDVariable iX, SDVariable iy)
iX
- public SDVariable lt(SDVariable iX, SDVariable iy)
iX
- public SDVariable neq(SDVariable iX, SDVariable iy)
iX
- public SDVariable eq(SDVariable iX, SDVariable iy)
iX
- public SDVariable or(SDVariable iX, SDVariable iy)
iX
- public SDVariable and(SDVariable iX, SDVariable iY)
public SDVariable and(String name, SDVariable ix, SDVariable iy)
public SDVariable xor(SDVariable ix, SDVariable iy)
public SDVariable xor(String name, SDVariable ix, SDVariable iy)
public SDVariable abs(SDVariable ix)
public SDVariable abs(String name, SDVariable ix)
public SDVariable neg(SDVariable iX)
iX
- public SDVariable cos(SDVariable iX)
iX
- public SDVariable sin(SDVariable iX)
iX
- public SDVariable tan(SDVariable iX)
iX
- public SDVariable invertPermutation(SDVariable input)
public SDVariable invertPermutation(String name, SDVariable input)
public SDVariable acos(SDVariable iX)
iX
- public SDVariable asin(SDVariable iX)
iX
- public SDVariable atan(SDVariable iX)
iX
- public SDVariable atan2(SDVariable y, SDVariable x)
public SDVariable atan2(String name, SDVariable y, SDVariable x)
public SDVariable cosh(SDVariable iX)
iX
- public SDVariable sinh(SDVariable iX)
iX
- public SDVariable tanh(SDVariable iX)
iX
- public SDVariable acosh(SDVariable iX)
iX
- public SDVariable asinh(SDVariable iX)
iX
- public SDVariable atanh(SDVariable iX)
iX
- public SDVariable exp(SDVariable iX)
iX
- public SDVariable rsqrt(SDVariable iX)
iX
- public SDVariable expm1(SDVariable iX)
iX
- public SDVariable log1p(SDVariable iX)
iX
- public SDVariable isInfinite(SDVariable iX)
iX
- public SDVariable isNaN(SDVariable iX)
iX
- public SDVariable round(SDVariable iX)
iX
- public SDVariable isFinite(SDVariable iX)
iX
- public SDVariable log(SDVariable iX)
iX
- public SDVariable cube(SDVariable iX)
iX
- public SDVariable pow(SDVariable iX, double value)
iX
- value
- public SDVariable sqrt(SDVariable iX)
iX
- public SDVariable square(SDVariable iX)
iX
- public SDVariable floor(SDVariable iX)
iX
- public SDVariable ceil(SDVariable x)
public SDVariable ceil(String name, SDVariable x)
public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax)
public SDVariable clipByValue(String name, SDVariable x, double clipValueMin, double clipValueMax)
public SDVariable clipByNorm(SDVariable x, double clipValue)
public SDVariable clipByNorm(String name, SDVariable x, double clipValue)
public SDVariable relu(SDVariable iX, double cutoff)
iX
- public SDVariable relu6(SDVariable iX, double cutoff)
iX
- public SDVariable softmax(SDVariable iX)
iX
- public SDVariable logSoftmax(SDVariable iX)
public SDVariable logSoftmax(String name, SDVariable iX)
public SDVariable selu(SDVariable iX)
public SDVariable selu(String name, SDVariable iX)
public SDVariable mergeAdd(SDVariable... iX)
public SDVariable mergeAdd(String name, SDVariable[] iX)
public SDVariable batchToSpace(SDVariable iX, int[] blocks, int[][] crops)
public SDVariable batchToSpace(String name, SDVariable iX, int[] blocks, int[][] crops)
public SDVariable depthToSpace(SDVariable iX, int blockSize, String dataFormat)
public SDVariable depthToSpace(String name, SDVariable iX, int blockSize, String dataFormat)
public SDVariable spaceToBatch(SDVariable iX, int[] blocks, int[][] padding)
public SDVariable spaceToBatch(String name, SDVariable iX, int[] blocks, int[][] padding)
public SDVariable spaceToDepth(SDVariable iX, int blockSize, String dataFormat)
public SDVariable spaceToDepth(String name, SDVariable iX, int blockSize, String dataFormat)
public SDVariable[] dynamicPartition(SDVariable iX, SDVariable partitions, int numPartitions)
public SDVariable[] dynamicPartition(String[] name, SDVariable iX, SDVariable partitions, int numPartitions)
public SDVariable dynamicStitch(SDVariable[] indices, SDVariable[] iX)
public SDVariable dynamicStitch(String name, SDVariable[] indices, SDVariable[] iX)
public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, int[] rates, boolean isSameMode)
public SDVariable dilation2D(String name, SDVariable df, SDVariable weights, int[] strides, int[] rates, boolean isSameMode)
public SDVariable shape(SDVariable df)
public SDVariable shape(String name, SDVariable df)
public SDVariable cross(SDVariable a, SDVariable b)
public SDVariable cross(String name, SDVariable a, SDVariable b)
public SDVariable gather(SDVariable df, int axis, int[] broadcast)
public SDVariable gather(String name, SDVariable df, int axis, int[] broadcast)
public SDVariable gatherNd(SDVariable df, SDVariable indices)
public SDVariable gatherNd(String name, SDVariable df, SDVariable indices)
public SDVariable repeat(SDVariable df, int axis)
public SDVariable repeat(String name, SDVariable df, int axis)
public SDVariable stack(SDVariable[] values, int axis)
public SDVariable stack(String name, SDVariable[] values, int axis)
public SDVariable parallel_stack(SDVariable[] values)
public SDVariable parallel_stack(String name, SDVariable[] values)
public SDVariable[] unstack(SDVariable value, int axis)
public SDVariable[] unstack(String[] names, SDVariable value, int axis)
public SDVariable erf(SDVariable iX)
public SDVariable erf(String name, SDVariable iX)
public SDVariable erfc(SDVariable iX)
public SDVariable erfc(String name, SDVariable iX)
public SDVariable diag(SDVariable iX)
public SDVariable diag(String name, SDVariable iX)
public SDVariable diagPart(SDVariable iX)
public SDVariable diagPart(String name, SDVariable iX)
public SDVariable oneHot(SDVariable indices, int depth)
public SDVariable oneHot(SDVariable indices, int depth, int axis, double on, double off)
public SDVariable oneHot(String name, SDVariable indices, int depth)
public SDVariable oneHot(String name, SDVariable indices, int depth, int axis, double on, double off)
public SDVariable reciprocal(SDVariable a)
public SDVariable reciprocal(String name, SDVariable a)
public SDVariable gradientBackwardsMarker(SDVariable iX)
iX
- public SDVariable hardTanh(SDVariable iX)
iX
- public SDVariable hardTanhDerivative(SDVariable iX)
iX
- public SDVariable sigmoid(SDVariable iX)
iX
- public SDVariable sigmoidDerivative(SDVariable iX, SDVariable wrt)
iX
- public SDVariable logSigmoid(SDVariable iX)
public SDVariable logSigmoid(String name, SDVariable iX)
public SDVariable sign(SDVariable iX)
iX
- public SDVariable softsign(SDVariable iX)
iX
- public SDVariable softsignDerivative(SDVariable iX)
iX
- public SDVariable softplus(SDVariable iX)
iX
- public SDVariable swish(SDVariable iX)
public SDVariable swish(String name, SDVariable iX)
public SDVariable elu(SDVariable iX)
iX
- public SDVariable eluDerivative(SDVariable iX)
iX
- public SDVariable leakyRelu(SDVariable iX, double cutoff)
iX
- cutoff
- public SDVariable mean(SDVariable iX)
iX
- public SDVariable mean(SDVariable iX, int... dimension)
iX
- dimension
- public SDVariable standardDeviation(SDVariable iX, boolean biasCorrected, int... dimensions)
iX
- biasCorrected
- dimensions
- public SDVariable variance(SDVariable iX, boolean biasCorrected, int... dimensions)
iX
- biasCorrected
- dimensions
- public SDVariable sum(SDVariable iX, int... dimensions)
iX
- dimensions
- public SDVariable prod(SDVariable iX, int... dimensions)
iX
- dimensions
- public SDVariable max(SDVariable iX, int... dimensions)
iX
- dimensions
- public SDVariable max(SDVariable first, SDVariable second)
public SDVariable max(String name, SDVariable first, SDVariable second)
public SDVariable countZero(SDVariable input)
public SDVariable countZero(String name, SDVariable input)
public SDVariable zeroFraction(SDVariable input)
public SDVariable zeroFraction(String name, SDVariable input)
public SDVariable countNonZero(SDVariable input)
public SDVariable countNonZero(String name, SDVariable input)
public SDVariable min(SDVariable iX, int... dimensions)
iX
- dimensions
- public SDVariable min(SDVariable first, SDVariable second)
public SDVariable min(String name, SDVariable first, SDVariable second)
public SDVariable argmax(SDVariable in, int... dimensions)
public SDVariable argmax(String name, SDVariable in, int... dimensions)
public SDVariable argmin(SDVariable in, int... dimensions)
public SDVariable argmin(String name, SDVariable in, int... dimensions)
public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int... dimensions)
public SDVariable cumsum(String name, SDVariable in, boolean exclusive, boolean reverse, int... dimensions)
public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int... dimensions)
public SDVariable cumprod(String name, SDVariable in, boolean exclusive, boolean reverse, int... dimensions)
public SDVariable biasAdd(SDVariable input, SDVariable bias)
public SDVariable biasAdd(String name, SDVariable input, SDVariable bias)
public SDVariable reshape(SDVariable iX, int... shape)
iX
- shape
- public SDVariable reverse(SDVariable x, int... dimensions)
x
- dimensions
- public SDVariable reverse(String name, SDVariable x, int... dimensions)
x
- dimensions
- public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim)
public SDVariable reverseSequence(String name, SDVariable x, SDVariable seq_lengths)
public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seqDim, int batchDim)
public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths)
public SDVariable sequenceMask(String name, SDVariable lengths, SDVariable maxLen)
public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen)
public SDVariable sequenceMask(String name, SDVariable lengths, int maxLen)
public SDVariable sequenceMask(SDVariable lengths, int maxLen)
public SDVariable sequenceMask(String name, SDVariable lengths)
public SDVariable sequenceMask(SDVariable lengths)
public SDVariable assign(SDVariable x, SDVariable y)
public SDVariable assign(String name, SDVariable x, SDVariable y)
public SDVariable transpose(SDVariable iX)
iX
- public SDVariable permute(SDVariable iX, int... dimensions)
iX
- dimensions
- public SDVariable rollAxis(SDVariable x, int axis)
x
- axis
- public SDVariable concat(int dimension, SDVariable... inputs)
dimension
- inputs
- public SDVariable[] moments(SDVariable input, int... axes)
public SDVariable[] moments(String[] name, SDVariable input, int... axes)
public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift)
public SDVariable[] normalizeMoments(String[] name, SDVariable counts, SDVariable means, SDVariable variances, double shift)
public SDVariable tile(SDVariable iX, int[] repeat)
iX
- repeat
- public SDVariable fill(SDVariable shape, double value)
public SDVariable dropout(SDVariable input, double p)
public SDVariable dropout(String name, SDVariable input, double p)
public SDVariable xwPlusB(SDVariable input, SDVariable weights, SDVariable bias)
public SDVariable xwPlusB(String name, SDVariable input, SDVariable weights, SDVariable bias)
public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias)
public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias)
public SDVariable mmul(SDVariable x, SDVariable y, MMulTranspose transpose)
x
- y
- transpose
- public SDVariable mmul(SDVariable x, SDVariable y)
x
- y
- public SDVariable tensorMmul(SDVariable x, SDVariable y, int[][] dimensions)
x
- y
- dimensions
- public SDVariable cosineSimilarity(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable euclideanDistance(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable manhattanDistance(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable cosineDistance(SDVariable ix, SDVariable iy, int... dimensions)
public SDVariable cosineDistance(String name, SDVariable ix, SDVariable iy, int... dimensions)
public SDVariable hammingDistance(SDVariable ix, SDVariable iy, int... dimensions)
public SDVariable hammingDistance(String name, SDVariable ix, SDVariable iy, int... dimensions)
public SDVariable jaccardDistance(SDVariable ix, SDVariable iy, int... dimensions)
public SDVariable jaccardDistance(String name, SDVariable ix, SDVariable iy, int... dimensions)
public SDVariable lossBinaryXENT(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossCosineSimilarity(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossHinge(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossKLD(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossL1(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossL2(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossMAE(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossMSE(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossMCXENT(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossMSLE(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossNegativeLogLikelihood(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossPoisson(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossSquaredHinge(SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable gradientBackwardsMarker(String name, SDVariable iX)
name
- iX
- public SDVariable neq(String name, SDVariable iX, double iy)
iX
- public SDVariable eq(String name, SDVariable iX, double iy)
iX
- public SDVariable gte(String name, SDVariable iX, double iy)
iX
- public SDVariable lte(String name, SDVariable iX, double iy)
iX
- public SDVariable gt(String name, SDVariable iX, double iy)
iX
- public SDVariable lt(String name, SDVariable iX, double iy)
iX
- public SDVariable neq(String name, SDVariable iX, SDVariable iy)
iX
- public SDVariable eq(String name, SDVariable iX, SDVariable iy)
iX
- public SDVariable gte(String name, SDVariable iX, SDVariable iy)
iX
- public SDVariable lte(String name, SDVariable iX, SDVariable iy)
iX
- public SDVariable gt(String name, SDVariable iX, SDVariable iy)
iX
- public SDVariable lt(String name, SDVariable iX, SDVariable iy)
iX
- public SDVariable or(String name, SDVariable iX, SDVariable iy)
iX
- public SDVariable neg(String name, SDVariable iX)
iX
- public SDVariable isNonDecreasing(SDVariable iX)
iX
- public SDVariable isNonDecreasing(String name, SDVariable iX)
iX
- public SDVariable isStrictlyIncreasing(SDVariable iX)
iX
- public SDVariable isStrictlyIncreasing(String name, SDVariable iX)
iX
- public SDVariable isNumericTensor(SDVariable iX)
public SDVariable isNumericTensor(String name, SDVariable iX)
public SDVariable cos(String name, SDVariable iX)
iX
- public SDVariable sin(String name, SDVariable iX)
iX
- public SDVariable tan(String name, SDVariable iX)
iX
- public SDVariable acos(String name, SDVariable iX)
iX
- public SDVariable asin(String name, SDVariable iX)
iX
- public SDVariable atan(String name, SDVariable iX)
iX
- public SDVariable cosh(String name, SDVariable iX)
iX
- public SDVariable sinh(String name, SDVariable iX)
iX
- public SDVariable tanh(String name, SDVariable iX)
iX
- public SDVariable acosh(String name, SDVariable iX)
iX
- public SDVariable asinh(String name, SDVariable iX)
iX
- public SDVariable atanh(String name, SDVariable iX)
iX
- public SDVariable exp(String name, SDVariable iX)
iX
- public SDVariable expm1(String name, SDVariable iX)
iX
- public SDVariable rsqrt(String name, SDVariable iX)
iX
- public SDVariable log(String name, SDVariable iX)
iX
- public SDVariable log1p(String name, SDVariable iX)
iX
- public SDVariable isFinite(String name, SDVariable iX)
iX
- public SDVariable isInfinite(String name, SDVariable iX)
iX
- public SDVariable isNaN(String name, SDVariable iX)
iX
- public SDVariable round(String name, SDVariable iX)
iX
- public SDVariable pow(String name, SDVariable iX, double value)
iX
- value
- public SDVariable cube(String name, SDVariable iX)
iX
- public SDVariable sqrt(String name, SDVariable iX)
iX
- public SDVariable square(String name, SDVariable iX)
iX
- public SDVariable floor(String name, SDVariable iX)
iX
- public SDVariable relu(String name, SDVariable iX, double cutoff)
iX
- public SDVariable relu6(String name, SDVariable iX, double cutoff)
iX
- public SDVariable softmax(String name, SDVariable iX)
iX
- public SDVariable softmaxDerivative(String name, SDVariable iX, SDVariable wrt)
iX
- public SDVariable hardTanh(String name, SDVariable iX)
iX
- public SDVariable hardTanhDerivative(String name, SDVariable iX)
iX
- public SDVariable sigmoid(String name, SDVariable iX)
iX
- public SDVariable sigmoidDerivative(String name, SDVariable iX, SDVariable wrt)
iX
- public SDVariable sign(String name, SDVariable iX)
iX
- public SDVariable softsign(String name, SDVariable iX)
iX
- public SDVariable softsignDerivative(String name, SDVariable iX)
iX
- public SDVariable softplus(String name, SDVariable iX)
iX
- public SDVariable elu(String name, SDVariable iX)
iX
- public SDVariable eluDerivative(String name, SDVariable iX)
iX
- public SDVariable leakyRelu(String name, SDVariable iX, double alpha)
iX
- alpha
- public SDVariable leakyReluDerivative(String name, SDVariable iX, double alpha)
iX
- alpha
- public SDVariable mean(String name, SDVariable iX)
iX
- public SDVariable mean(String name, SDVariable iX, int... dimension)
public SDVariable standardDeviation(String name, SDVariable iX, boolean biasCorrected, int... dimensions)
iX
- biasCorrected
- dimensions
- public SDVariable variance(String name, SDVariable iX, boolean biasCorrected, int... dimensions)
iX
- biasCorrected
- dimensions
- public SDVariable sum(String name, SDVariable iX, int... dimensions)
iX
- dimensions
- public SDVariable prod(String name, SDVariable iX, int... dimensions)
iX
- dimensions
- public SDVariable max(String name, SDVariable iX, int... dimensions)
iX
- dimensions
- public SDVariable min(String name, SDVariable iX, int... dimensions)
iX
- dimensions
- public SDVariable norm1(String name, SDVariable ix, int... dimensions)
public SDVariable norm2(String name, SDVariable ix, int... dimensions)
public SDVariable normmax(String name, SDVariable ix, int... dimensions)
public SDVariable reshape(String name, SDVariable iX, int... shape)
iX
- shape
- public SDVariable transpose(String name, SDVariable iX)
iX
- public SDVariable permute(String name, SDVariable iX, int... dimensions)
iX
- dimensions
- public SDVariable rollAxis(String name, SDVariable x, int axis)
x
- axis
- public SDVariable fill(String name, SDVariable shape, double value)
shape
- value
- public SDVariable concat(String name, int dimension, SDVariable... inputs)
dimension
- inputs
- public SDVariable tile(String name, SDVariable iX, int[] repeat)
iX
- repeat
- public SDVariable mmul(String name, SDVariable x, SDVariable y, MMulTranspose transpose)
x
- y
- transpose
- public SDVariable mmul(String name, SDVariable x, SDVariable y)
x
- y
- public SDVariable tensorMmul(String name, SDVariable x, SDVariable y, int[][] dimensions)
x
- y
- dimensions
- public SDVariable cosineSimilarity(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable euclideanDistance(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable manhattanDistance(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable sigmoidCrossEntropyWithLogits(SDVariable logits, SDVariable weights, SDVariable labels, int reductionMode, double labelSmoothing)
public SDVariable sigmoidCrossEntropyWithLogits(String name, SDVariable logits, SDVariable weights, SDVariable labels, int reductionMode, double labelSmoothing)
public SDVariable softmaxCrossEntropyWithLogits(SDVariable logits, SDVariable weights, SDVariable labels, int reductionMode, double labelSmoothing)
public SDVariable softmaxCrossEntropyWithLogits(String name, SDVariable logits, SDVariable weights, SDVariable labels, int reductionMode, double labelSmoothing)
public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs, SDVariable weights)
public SDVariable weightedCrossEntropyWithLogits(String name, SDVariable targets, SDVariable inputs, SDVariable weights)
public SDVariable lossBinaryXENT(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossCosineSimilarity(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossHinge(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossKLD(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossL1(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossL2(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossMAE(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossMSE(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossMCXENT(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossMSLE(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossNegativeLogLikelihood(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossPoisson(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable lossSquaredHinge(String name, SDVariable iX, SDVariable i_y, int... dimensions)
iX
- i_y
- dimensions
- public SDVariable expandDims(SDVariable ix, int axis)
public SDVariable expandDims(String name, SDVariable ix, int axis)
public SDVariable squeeze(SDVariable ix, int axis)
public SDVariable squeeze(String name, SDVariable ix, int axis)
public SDVariable confusionMatrix(SDVariable labels, SDVariable predictions)
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred)
public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses)
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses)
public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights)
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, SDVariable weights)
public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights)
public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights)
public void addVariable(SDVariable variable)
variable
- public String generateNewVarName(String baseName, int argIndex)
baseName
- the base name to use (use function.opName() where function is a DifferentialFunction
argIndex
- the arg indexpublic SDVariable lstm(String baseName, LSTMCellConfiguration configuration)
baseName
- the base name for outputsconfiguration
- the configuration to usepublic SDVariable sruCell(SRUCellConfiguration configuration)
configuration
- the configuration for the sru cellpublic SDVariable sru(SRUConfiguration configuration)
configuration
- the configuration for the srupublic SDVariable gru(GRUCellConfiguration configuration)
configuration
- teh configuration to usepublic SDVariable sruCell(String baseName, SRUCellConfiguration configuration)
baseName
- the base name to use for the output variablesconfiguration
- the configuration for the sru cellpublic SDVariable sru(String baseName, SRUConfiguration configuration)
baseName
- the base name to use for output variablesconfiguration
- the configuration for the srupublic SDVariable gru(String baseName, GRUCellConfiguration configuration)
baseName
- the base name for the gru cellconfiguration
- teh configuration to usepublic SDVariable slice(SDVariable input, int[] begin, int[] size)
public SDVariable slice(String name, SDVariable input, int[] begin, int[] size)
public SDVariable stridedSlice(SDVariable input, int[] begin, int[] end, int[] strides)
public SDVariable stridedSlice(String name, SDVariable input, int[] begin, int[] end, int[] strides)
public SDVariable stridedSlice(SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask)
public SDVariable stridedSlice(String name, SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask)
public SDVariable scatterAdd(String name, SDVariable ref, SDVariable indices, SDVariable updates)
public SDVariable scatterMul(String name, SDVariable ref, SDVariable indices, SDVariable updates)
public SDVariable scatterSub(String name, SDVariable ref, SDVariable indices, SDVariable updates)
public SDVariable scatterDiv(String name, SDVariable ref, SDVariable indices, SDVariable updates)
public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates)
public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates)
public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates)
public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates)
public SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName)
function
- the function to generate the output
variable names forpublic SDVariable[] generateOutputVariableForOp(DifferentialFunction function)
function
- the function to generate the output
variable names forpublic SameDiff getFunction(String functionName)
functionName
- the opName of the functionpublic INDArray execAndEndResult(List<DifferentialFunction> ops)
public INDArray execAndEndResult()
public List<DifferentialFunction> exec(List<DifferentialFunction> ops)
ops
- the list of already created opspublic While whileStatement(SameDiff.SameDiffConditional sameDiffConditional, SameDiff.SameDiffFunctionDefinition conditionBody, SameDiff.SameDiffFunctionDefinition loopBody, SDVariable[] inputVars)
sameDiffConditional
- loopBody
- public If ifStatement(SameDiff.SameDiffConditional conditional, SameDiff.SameDiffFunctionDefinition conditionBody, SameDiff.SameDiffFunctionDefinition trueBody, SameDiff.SameDiffFunctionDefinition falseBody, SDVariable[] inputVars)
conditional
- trueBody
- falseBody
- public SDVariable invokeFunctionOn(String functionName, SameDiff with)
functionName
- with
- public SameDiff defineFunction(String function, SameDiff.SameDiffFunctionDefinition functionDefinition, SDVariable[] variables)
function
- public void defineFunction(String function, SameDiff.SameDiffFunctionDefinition functionDefinition)
function
- public void defineFunction(String function, SameDiff.SameDiffFunctionDefinition functionDefinition, Map<String,INDArray> inputs)
function
- functionDefinition
- inputs
- public INDArray execAndEndResult(String functionName)
functionName
- the opName of the function
to invokepublic org.nd4j.linalg.primitives.Pair<Map<SDVariable,DifferentialFunction>,List<DifferentialFunction>> exec(String functionName)
functionName
- the opName of the function
to invokepublic List<DifferentialFunction> exec(String functionName, List<DifferentialFunction> cachedOps)
functionName
- the opName of the function to
execcachedOps
- the cached operationspublic org.nd4j.linalg.primitives.Pair<Map<SDVariable,DifferentialFunction>,List<DifferentialFunction>> execBackwards()
public INDArray execBackwardAndEndResult()
public INDArray execWithPlaceHolderAndEndResult(Map<String,INDArray> inputs)
public void setOriginalPlaceHolderShape(String variableName, int[] shape)
Note that if isPlaceHolder(String)
returns false for the passed in vertex id,
a ND4JIllegalStateException
is thrown.
A vertex id must be added first. You can
do this with addAsPlaceHolder(String)
variableName
- the vertex id for the original shapeshape
- the shape of the place holderpublic int[] getOriginalShapeForPlaceHolder(String varName)
resolveVariablesWith(Map)
usually when executing using execWithPlaceHolder(Map)
varName
- the vertex id to get the original shape for.public boolean isPlaceHolder(String varName)
varName
- the vertex id to testpublic void addAsPlaceHolder(String varName)
varName
- the vertex id to addpublic void resolveVariablesWith(Map<String,INDArray> arrays)
IllegalStateException
will be thrown
if not all arrays are
specified for resolution.arrays
- the arrays to resolve.public boolean allPlaceHolderVariablesResolved()
getVariable(String)
getArr() does not return null and
the shape is properly resolved.public void putPlaceHolderForVariable(String varName, String... placeHolderVariables)
Note that if a vertex id in placeHolderVariables
isn't present in this samediff instance anyways,
an ND4JIllegalStateException
is thrown
varName
- the vertex id to add place holders forplaceHolderVariables
- the place holder variablespublic boolean hasPlaceHolderVariables(String vertexId)
vertexId
- the vertex id to check forpublic List<String[]> getPlaceHoldersFor(String varName)
Consider using hasPlaceHolderVariables(String)
varName
- the vertex id to get the place holders forpublic org.nd4j.linalg.primitives.Pair<Map<SDVariable,DifferentialFunction>,List<DifferentialFunction>> execWithPlaceHolder(Map<String,INDArray> inputs)
resolveVariablesWith(Map)
is calledpublic List<SDVariable> getVariablesAssociatedWithFunctions(List<DifferentialFunction> functions)
SDVariable
associated with each function
based on the DifferentialFunction.outputVariables()
()}functions
- the functions to get the variables forDifferentialFunction
public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName)
Note that if null for the new variable is passed in, it will just return the original input variable.
varToUpdate
- the variable to updatenewVarName
- the new variable namepublic SDVariable[] updateVariableNamesAndReferences(SDVariable[] variablesToUpdate, String[] newVariableNames)
variablesToUpdate
- the variable to updatenewVariableNames
- the new variable namepublic org.nd4j.linalg.primitives.Pair<Map<SDVariable,DifferentialFunction>,List<DifferentialFunction>> exec()
public void printFunction(DifferentialFunction differentialFunction)
differentialFunction
- the function to printpublic static int[] permuteDataFormatForSameDiff(String dataFormat, boolean weights)
dataFormat
- the data format to permutepublic void updateVariable(String variableName, INDArray arr)
INDArray
ndarray for the given variable namevariableName
- the variable to updatearr
- the array to update withprotected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull com.google.flatbuffers.FlatBufferBuilder bufferBuilder)
public static org.nd4j.linalg.primitives.Pair<String,Integer> parseVariable(@NonNull String varName)
varName
- protected int asFlatNode(@NonNull DifferentialFunction node, @NonNull com.google.flatbuffers.FlatBufferBuilder bufferBuilder, List<SDVariable> variables, Map<String,Integer> reverseMap, Map<String,Integer> forwardMap, Map<String,Integer> framesMap, AtomicInteger idCounter)
public ByteBuffer asFlatBuffers(@NonNull ExecutorConfiguration configuration)
configuration
- - ExecutorConfiguration to be embedded into serialized graphpublic ByteBuffer asFlatBuffers()
public static ByteOrder getOrderFromByte(byte val)
val
- public static byte getOrderAsByte()
public void asFlatFile(@NonNull File file) throws IOException
file
- IOException
public void asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration configuration) throws IOException
file
- IOException
public String asFlatPrint()
public static DataBuffer.Type getDataTypeFromByte(byte val)
val
- public static byte getDataTypeAsByte(DataBuffer.Type type)
type
- public static long getOpNum(String name, Op.Type type)
name
- type
- public static Op.Type getTypeFromByte(byte type)
type
- public static byte getFlatOpType(Op.Type type)
type
- Copyright © 2018. All rights reserved.