public class SameDiff extends SDBaseOps
You define a graph symbolically.
That graph accumulates operations.
In order to execute the graph, you run one of the execution methods, such as output(Map, String...)
Modifier and Type | Field and Description |
---|---|
SDBitwise |
bitwise
Op creator object for bitwise operations
|
SDCNN |
cnn
Op creator object for convolutional neural network operations
|
protected static String |
GRAD_FN_KEY |
SDImage |
image
Op creator object for image operations
|
SDLoss |
loss
Op creator object for loss function operations
|
SDMath |
math
Op creator object for math operations
|
SDNN |
nn
Op creator object for general neural network operations
|
SDRandom |
random
Op creator object for random number generation operations
|
SDRNN |
rnn
Op creator object for recurrent neural network operations
|
Modifier and Type | Method and Description |
---|---|
void |
addArgsFor(SDVariable[] variables,
DifferentialFunction function)
Adds incoming arguments for the specified differential function to the graph
|
void |
addArgsFor(String[] variables,
DifferentialFunction function)
Adds incoming arguments for the specified differential function to the graph
|
void |
addArgumentInterceptor(ArgumentInterceptor interceptor)
Add a new argument interceptor to the interceptor stack
|
void |
addListeners(Collection<? extends Listener> listeners)
|
void |
addListeners(Listener... listeners)
Add SameDiff-wide
Listener instances. |
void |
addLossVariable(SDVariable variable)
|
void |
addLossVariable(String variableName)
Mark the specified variable as a loss function variable.
|
void |
addOutgoingFor(SDVariable[] variables,
DifferentialFunction function)
Adds outgoing arguments to the graph for the specified DifferentialFunction
Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.
|
void |
addOutgoingFor(String[] varNames,
DifferentialFunction function)
Adds outgoing arguments to the graph for the specified DifferentialFunction
Also checks for input arguments and updates the graph adding an appropriate edge when the full graph is declared.
|
void |
addPropertyToResolve(DifferentialFunction forFunction,
String arrayName)
Adds a property that needs to be resolve for later.
|
SDVariable |
addVariable(SDVariable variable)
Add the specified variable to this SameDiff instance
|
void |
addVariableMappingForField(DifferentialFunction function,
String fieldName,
String varName)
Adds a field name -> variable name mapping for a given function.
This is used for model import where there is an unresolved variable at the time of calling any GraphMapper.importGraph(File)
. |
boolean |
arrayAlreadyExistsForVarName(String varName)
Returns true if the given vertex id and
INDArray already exist. |
ByteBuffer |
asFlatBuffers(boolean includeUpdaterState)
This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and
all arrays as a ByteBuffer containing the FlatBuffers format data
Uses the default
ExecutorConfiguration with output mode as
OutputMode.VARIABLE_SPACE , execution mode as ExecutionMode.SEQUENTIAL ,
with profiling disabled and gather timings enabled. |
ByteBuffer |
asFlatBuffers(ExecutorConfiguration configuration,
boolean includeUpdaterState)
This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and
all arrays as a ByteBuffer containing the FlatBuffers format data
|
ByteBuffer |
asFlatBuffers(long graphId,
ExecutorConfiguration configuration,
boolean includeUpdaterState)
This method exports the current SameDiff instance into FlatBuffers format, returning the array ops and
all arrays as a ByteBuffer containing the FlatBuffers format data
|
void |
asFlatFile(File file)
This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later
This includes the updater state, if applicable. |
void |
asFlatFile(File file,
boolean withUpdaterState)
|
void |
asFlatFile(File file,
ExecutorConfiguration configuration,
boolean includeUpdaterState)
This method converts SameDiff instance to FlatBuffers and saves it to file which can be restored later
|
FlatGraph |
asFlatGraph(boolean includeUpdaterState)
|
FlatGraph |
asFlatGraph(long graphId,
ExecutorConfiguration configuration,
boolean includeUpdaterState)
This method returns FlatGraph structure
|
protected int |
asFlatNode(String name,
SameDiff scope,
com.google.flatbuffers.FlatBufferBuilder bufferBuilder) |
String |
asFlatPrint()
This method returns a text representation of the "flattened" graph.
|
void |
assignArray(INDArray arr,
SDVariable variable)
Update the constant or variable type SDVariable with the values from the specified
array.
|
void |
associateArrayWithVariable(INDArray arr,
SDVariable variable)
Associate the array with the given variable.
|
void |
associateArrayWithVariable(INDArray arr,
String variable)
Associate the array with the given variable.
|
protected void |
associateSameDiffWithOpsAndVariables()
Associate the current SameDiff instance with all ops and variables.
|
BatchOutputConfig |
batchOutput()
Set up for a single batch inference operation using OutputConfig.
|
protected Map<String,INDArray> |
batchOutputHelper(Map<String,INDArray> placeholders,
List<Listener> listeners,
String... outputs) |
SDBitwise |
bitwise()
Op creator object for bitwise operations
|
double |
calcRegularizationScore()
Calculate the regularization (L1, L2 and/or WeightDecay) component of the loss function for the current parameters..
|
Map<String,DataType> |
calculateOutputDataTypes()
Calculate data types for the variables in the graph
|
Map<String,DataType> |
calculateOutputDataTypes(boolean dynamicUpdate)
Calculate data types for the variables in the graph
|
void |
clearOpInputs()
Clear the input arrays to each op.
|
void |
clearPlaceholders(boolean allThreads)
Clear the placeholder arrays from the SameDiff instance
|
SDCNN |
cnn()
Op creator object for convolutional neural network operations
|
SDVariable |
constant(double value)
Create a new double scalar constant (rank 0) with the specified value.
Constants are not modified by training/backprop. |
SDVariable |
constant(float value)
Create a new float scalar constant (rank 0) with the specified value
Constants are not modified by training/backprop. |
SDVariable |
constant(INDArray constant)
Create an SDVariable with a fixed/constant value, with a generated name
Constants are not modified by training/backprop. |
SDVariable |
constant(int value)
Create a new integer scalar constant (rank 0) with the specified value
|
SDVariable |
constant(long value)
Create a new long scalar constant (rank 0) with the specified value
|
SDVariable |
constant(SDVariable value,
long... shape)
Deprecated.
|
SDVariable |
constant(String name,
DataType dataType,
Number value)
Create a new scalar constant (rank 0) with the specified value and datatype
|
SDVariable |
constant(String name,
double value)
Create a new double scalar constant (rank 0) with the specified value
|
SDVariable |
constant(String name,
float value)
Create a new float scalar constant (rank 0) with the specified value
|
SDVariable |
constant(String name,
INDArray constant)
Create an SDVariable with a fixed/constant value
Constants are not modified by training/backprop. |
SDVariable |
constant(String name,
int value)
Create a new integer scalar constant (rank 0) with the specified value
|
SDVariable |
constant(String name,
long value)
Create a new long scalar constant (rank 0) with the specified value
|
SDVariable |
constant(String name,
SDVariable value,
long... shape)
Deprecated.
|
void |
convertDataTypes(Map<String,DataType> dataTypeMap)
Convert the datatypes of the specified constants, placeholders and variables.
After conversion, the downstream datatypes are changed. |
SDVariable |
convertToConstant(SDVariable variable)
Convert the specified variable to a constant.
|
void |
convertToConstants(List<SDVariable> variables)
Convert all of the specified variables to constants.
|
SDVariable |
convertToVariable(SDVariable constant)
Convert the specified variable to a VARIABLE type SDVariable.
This can only be done for constants and placeholders, not ARRAY type variables (which are usually network activations). |
void |
convertToVariables(List<SDVariable> constants)
Convert the specified variables to VARIABLE type SDVariables.
This can only be done for constants and placeholders, not ARRAY type variables (which are usually network activations). |
static SameDiff |
create()
Create a new (empty) SameDiff instance without any functions or variables
|
static SameDiff |
create(SameDiff originalSameDiff)
Create a new SameDiff instance from an existing instance.
|
void |
createGradFunction()
Create the gradient function (for calculating gradients via
execBackwards(Map, Operation, String[]) ) if it is not already defined. |
void |
createGradFunction(String... variablesRequiringGradients)
As per
createGradFunction() , but this method allows a set of variables requiring gradients to be specified. |
String |
currentNameScope() |
Collection<String> |
definedFunctionNames()
The set of defined SameDiff function names.
|
void |
defineFunction(String function,
SameDiffFunctionDefinition functionDefinition) |
void |
defineFunction(String function,
SameDiffFunctionDefinition functionDefinition,
Map<String,INDArray> inputs) |
SameDiff |
defineFunction(String function,
SameDiffFunctionDefinition functionDefinition,
SDVariable[] variables) |
protected Map<String,INDArray> |
directExecHelper(Map<String,INDArray> placeholders,
At at,
MultiDataSet batch,
Collection<String> requiredActivations,
List<Listener> activeListeners,
String... outputs)
Do inference for the given variables for a single batch, with training information
|
SameDiff |
disableDebugging()
Clears debugging state and disables debug mode.
|
SameDiff |
dup()
Clone/duplicate the SameDiff instance, including arrays etc.
|
SameDiff |
enableDebugMode()
Enables tracing of graphs automatically.
|
boolean |
equals(Object o) |
EvaluationConfig |
evaluate()
Set up for a evaluation operation using EvaluationConfig.
|
void |
evaluate(DataSetIterator iterator,
Map<String,IEvaluation> variableEvals,
Listener... listeners)
Evaluation for multiple-output networks.
See evaluate(MultiDataSetIterator, Map, Map, Listener[]) . |
void |
evaluate(DataSetIterator iterator,
String outputVariable,
IEvaluation... evaluations)
|
void |
evaluate(DataSetIterator iterator,
String outputVariable,
List<Listener> listeners,
IEvaluation... evaluations)
Evaluate the performance of a single variable's prediction.
For example, if the variable to evaluatate was called "softmax" you would use: |
void |
evaluate(MultiDataSetIterator iterator,
Map<String,List<IEvaluation>> variableEvals,
Map<String,Integer> predictionLabelMapping,
Listener... listeners)
Perform evaluation using classes such as
Evaluation for classifier outputs
and RegressionEvaluation for regression outputs.Example: classifier evaluation Predictions variable name: "softmaxOutput" Evaluations to perform: Evaluation Data: single input, single output MultiDataSets Code: |
void |
evaluate(MultiDataSetIterator iterator,
String outputVariable,
int labelIndex,
IEvaluation... evaluations)
|
void |
evaluate(MultiDataSetIterator iterator,
String outputVariable,
int labelIndex,
List<Listener> listeners,
IEvaluation... evaluations)
Evaluate the performance of a single variable's prediction.
For example, if the variable to evaluatate was called "softmax" you would use: |
void |
evaluateMultiple(DataSetIterator iterator,
Map<String,List<IEvaluation>> variableEvals,
Listener... listeners)
Evaluation for multiple output networks - one or more.
|
Map<String,INDArray> |
exec(Map<String,INDArray> placeholders,
List<String> outputs)
Deprecated.
See
output(Map, List) and batchOutput() |
Map<String,INDArray> |
exec(Map<String,INDArray> placeholders,
String... outputs)
Deprecated.
See
output(Map, String...) and batchOutput() |
Map<String,INDArray> |
execAll(Map<String,INDArray> placeholders)
Deprecated.
See
outputAll(Map) and batchOutput() |
INDArray |
execAndEndResult()
Deprecated.
|
void |
execBackwards(Map<String,INDArray> placeholders)
|
Map<String,INDArray> |
execBackwards(Map<String,INDArray> placeholders,
List<String> variableGradNamesList)
|
Map<String,INDArray> |
execBackwards(Map<String,INDArray> placeholders,
List<String> variableGradNamesList,
Operation operation)
As per
execBackwards(Map, Operation, MultiDataSet, Collection, List) , but the set of gradients to calculate can be specified manually.For example, to calculate the gradient for placeholder variable "myPlaceholder", use execBackwards(placeholders, Arrays.asList(myPlaceholder.gradient().getVarName()) . |
protected Map<String,INDArray> |
execBackwards(Map<String,INDArray> placeholders,
List<String> variableGradNamesList,
Operation operation,
MultiDataSet batch,
Collection<String> requiredActivations,
List<Listener> activeListeners) |
void |
execBackwards(Map<String,INDArray> placeholders,
Operation op)
Create (if required) and then calculate the variable gradients (backward pass) for this graph.
After execution, the gradient arrays can be accessed using myVariable.getGradient().getArr() Note: This method by default calculates VARIABLE type SDVariable gradients only (as well as any other gradients needed to calculate the variable gradients). |
protected void |
execBackwards(Map<String,INDArray> placeholders,
Operation op,
MultiDataSet batch,
Collection<String> requiredActivations,
List<Listener> activeListeners) |
Map<String,INDArray> |
execBackwards(Map<String,INDArray> placeholders,
Operation op,
String... variableGradNamesList)
|
Map<String,INDArray> |
execBackwards(Map<String,INDArray> placeholders,
String... variableGradNamesList)
|
INDArray |
execSingle(Map<String,INDArray> placeholders,
String output)
Deprecated.
|
DifferentialFunctionFactory |
f()
Returns this samediff instance's
DifferentialFunctionFactory |
FitConfig |
fit()
Set up for a fit operation using
FitConfig . |
History |
fit(DataSetIterator iter,
int numEpochs,
DataSetIterator validationIter,
int validationFrequency,
Listener... listeners)
Fit the SameDiff instance based on DataSetIterator for the specified number of epochs.
This method can only be used for singe input, single output SameDiff instances as DataSet only supports a single input and a single output. Note that a TrainingConfig must be set via setTrainingConfig(TrainingConfig) before training can
be performed. |
History |
fit(DataSetIterator iter,
int numEpochs,
Listener... listeners)
See
fit(DataSetIterator, int, DataSetIterator, int, Listener...) , does not preform validation. |
History |
fit(DataSet dataSet,
Listener... listeners)
Fit the SameDiff instance based on a single DataSet (i.e., a single minibatch for one iteration).
This method can only be used for singe input, single output SameDiff instances as DataSet only supports a single input and a single output. Note that a TrainingConfig must be set via setTrainingConfig(TrainingConfig) before training can
be performed. |
protected History |
fit(MultiDataSetIterator iter,
int numEpochs,
boolean incrementEpochCount,
MultiDataSetIterator validationData,
int validationFrequency,
Listener... listeners) |
History |
fit(MultiDataSetIterator iter,
int numEpochs,
Listener... listeners)
See
fit(MultiDataSetIterator, int, MultiDataSetIterator, int, Listener...) , does not preform validation. |
History |
fit(MultiDataSetIterator iter,
int numEpochs,
MultiDataSetIterator validationIter,
int validationFrequency,
Listener... listeners)
Fit the SameDiff instance based on MultiDataSetIterator for the specified number of epochs.
This method can both singe input, single output and multi-input, multi-output SameDiff instances Note that a TrainingConfig must be set via setTrainingConfig(TrainingConfig) before training can
be performed. |
History |
fit(MultiDataSet dataSet,
Listener... listeners)
Fit the SameDiff instance based on a single MultiDataSet (i.e., a single minibatch for one iteration).
Note that a TrainingConfig must be set via setTrainingConfig(TrainingConfig) before training can
be performed. |
protected History |
fitHelper(MultiDataSetIterator iter,
int numEpochs,
boolean incrementEpochCount,
MultiDataSetIterator validationData,
int validationFrequency,
List<Listener> listeners) |
static SameDiff |
fromFlatBuffers(ByteBuffer bbIn)
Create a
SameDiff
instance from a byte buffers
instance. |
static SameDiff |
fromFlatBuffers(ByteBuffer bbIn,
boolean loadUpdaterState)
Create a
SameDiff
instance from a byte buffers
instance. |
static SameDiff |
fromFlatFile(File file)
Create a
SameDiff instance from a file, including the updater state
The method to save the file is save(File, boolean) |
static SameDiff |
fromFlatFile(File file,
boolean loadUpdaterState)
Create a
SameDiff instance from a file, optionally also loading the updater state
The method to save the file is save(File, boolean) |
String |
generateDistinctCustomVariableName(String base)
Returns an unused variable name of the format <base>_#.
|
String |
generateNewVarName(String base,
int argIndex)
See
generateNewVarName(String, int, boolean)
existingOp is true. |
String |
generateNewVarName(String base,
int argIndex,
boolean existingOp)
Generate a new, distinct variable name of the form <base>_#[:#].
|
SDVariable[] |
generateOutputVariableForOp(DifferentialFunction function)
Generate the variables based on the given input op
and return the output variable names.
|
SDVariable[] |
generateOutputVariableForOp(DifferentialFunction function,
String baseName,
boolean isImport)
Generate the variables based on the given input op and return the output variable names.
|
INDArray |
getArrForVarName(String varName)
Get an
INDArray for a given vertex id, or null if none exists |
String |
getBaseNameForFunction(DifferentialFunction function)
Returns the base name for the given function
if any (may return null)
|
SameDiff |
getFunction(String functionName)
Get a SameDiff function instance given the name of the function
|
SDVariable |
getGradForVariable(String varName)
Get the gradient for the variable with the specified name.
The gradient variable is the variable that represents the derivative of the loss function with respect to the output of this variable. |
String[] |
getInputsForOp(DifferentialFunction function)
Returns the name(s) of the inputs for the given function
|
SDVariable[] |
getInputVariablesForOp(DifferentialFunction function)
Get the input variable(s) for the specified differential function
|
List<Listener> |
getListeners()
Gets the current SameDiff-wide listeners.
|
List<String> |
getLossVariables()
Get the names of variables (if any) that have been marked as loss variables to be minimized.
Variables can be marked as loss variables in a few different ways: (a) Losses are automatically added when creating loss functions via sd() (b) Via setLossVariables(String...) , @link #addLossVariable(String)} or SDVariable.markAsLoss() (c) Via TrainingConfig#setLossVariables(List) |
DifferentialFunction |
getOpById(String id)
Get the function by the
DifferentialFunction#getOwnName() |
String |
getOpName(String base)
See
getOpName(String, boolean)
force is false |
String |
getOpName(String base,
boolean force)
Generate a new, distinct op name of the form <base>_#.
|
List<SameDiffOp> |
getOpsInScope(NameScope scope)
Gets all operations in a given name scope.
|
List<SameDiffOp> |
getOpsInScope(String scope)
|
long[] |
getOriginalShapeForPlaceHolder(String varName)
Deprecated.
|
String[] |
getOutputsForOp(DifferentialFunction function)
Returns the name(s) of the outputs for the given function
|
SDVariable[] |
getOutputVariablesForOp(DifferentialFunction function)
Get the output variable(s) for the specified differential function
|
LongShapeDescriptor |
getShapeDescriptorForVarName(String varName)
See
getShapeForVarName(String) , but returns the shape descriptor. |
long[] |
getShapeForVarName(String varName)
Get the shape for the given vertex id.
|
SDVariable |
getVariable(String name)
Get the variable based on the opName
|
DifferentialFunction |
getVariableOutputOp(String variableName)
Get the differential function (if any) that this variable is the output for
|
List<SDVariable> |
getVariablesInScope(NameScope scope)
Gets all variables in a given name scope.
|
List<SDVariable> |
getVariablesInScope(String scope)
|
String |
getVarNameForFieldAndFunction(DifferentialFunction function,
String fieldName)
Get the variable name to use
for resolving a given field
for a given function during import time.
|
SDVariable |
grad(String varName)
Get the gradient for the variable with the specified variable name.
|
boolean |
hasArgs(DifferentialFunction function)
Returns true if this function already has defined arguments
|
boolean |
hasGradientFunction()
Returns true if the gradient function has been created - i.e.,
createGradFunction() or createGradFunction(String...)
has been called at all |
int |
hashCode() |
boolean |
hasVariable(String name) |
If |
ifStatement(SameDiffConditional conditional,
SameDiffFunctionDefinition conditionBody,
SameDiffFunctionDefinition trueBody,
SameDiffFunctionDefinition falseBody,
SDVariable[] inputVars)
|
SDImage |
image()
Op creator object for image operations
|
static SameDiff |
importFrozenTF(File graphFile)
Import a frozen Tensorflow graph to a new SameDiff graph.
|
static SameDiff |
importFrozenTF(GraphDef graphDef)
|
static SameDiff |
importFrozenTF(InputStream graph)
|
protected void |
initializeTraining()
Perform setup for training.
|
List<String> |
inputs()
Returns the inputs (placeholders) for the SameDiff graph
|
SDVariable |
invoke(Op op,
SDVariable x)
Invoke an op by opName
|
SDVariable |
invoke(Op op,
SDVariable x,
SDVariable y)
Deprecated.
|
SDVariable |
invokeFunctionOn(String functionName,
SameDiff with) |
SDVariable |
invokeGraphOn(SameDiff sameDiff) |
boolean |
isPlaceHolder(String varName)
Returns true if this vertex id is a place holder variable or not
A place holder variable is one where the array shape(s) are currently known and can't yet be calculated |
static SameDiff |
load(File file,
boolean loadUpdaterState)
Load the SameDiff instance previously saved with
save(File, boolean) |
static SameDiff |
load(InputStream is,
boolean loadUpdaterState)
As per
load(File, boolean) but the SameDiff instance |
SDLoss |
loss()
Op creator object for loss function operations
|
SDMath |
math()
Op creator object for math operations
|
protected String |
nameWithScope(String name) |
String |
newBlockName(String baseName)
For internal use only.
|
SDNN |
nn()
Op creator object for general neural network operations
|
long |
numElements()
Count the number of elements in all arrays, according to
SDVariable.getShape() |
SDVariable |
one(String name,
DataType dataType,
int... shape)
Create a new variable with the specified shape, with all values initialized to 1.0
|
SDVariable |
one(String name,
DataType dataType,
long... shape)
Create a new variable with the specified shape, with all values initialized to 1.0
|
SDVariable |
one(String name,
int... shape)
|
SDVariable |
one(String name,
long... shape)
|
boolean |
opExists(String id)
Returns true if the given function id exists
|
DifferentialFunction[] |
ops()
Get an array of differential functions that have been defined for this SameDiff instance
|
OutputConfig |
output()
Set up for an inference operation using OutputConfig.
|
Map<String,INDArray> |
output(DataSetIterator iterator,
List<Listener> listeners,
String... outputs)
Do inference on a network with a single input.
For example, if the variable to infer was called "softmax" you would use: |
Map<String,INDArray> |
output(DataSetIterator dataSet,
String... outputs)
|
Map<String,INDArray> |
output(DataSet dataSet,
String... outputs)
Do a single batch inference on a network with a single input.
For example, if the variable to infer was called "softmax" you would use: |
Map<String,INDArray> |
output(Map<String,INDArray> placeholders,
List<Listener> listeners,
String... outputs)
Do inference for the given variables for a single batch.
|
Map<String,INDArray> |
output(Map<String,INDArray> placeholders,
List<String> outputs)
Do inference for the given variables for a single batch.
|
Map<String,INDArray> |
output(Map<String,INDArray> placeholders,
String... outputs)
Do inference for the given variables for a single batch.
|
Map<String,INDArray> |
output(MultiDataSetIterator iterator,
List<Listener> listeners,
String... outputs)
Perform inference.
Example: classifier inference Predictions variable name: "softmaxOutput" Evaluations to perform: Evaluation Data: single output MultiDataSets Code: |
Map<String,INDArray> |
output(MultiDataSetIterator dataSet,
String... outputs)
|
Map<String,INDArray> |
output(MultiDataSet dataSet,
String... outputs)
Do a single batch inference on a network.
For example, if the variable to infer was called "softmax" you would use: |
Map<String,INDArray> |
outputAll(Map<String,INDArray> placeholders)
Do inference for all variables for a single batch.
|
List<Map<String,INDArray>> |
outputBatches(DataSetIterator iterator,
List<Listener> listeners,
String... outputs)
See
output(DataSetIterator, List, String...) , but without the concatenation of batches. |
List<Map<String,INDArray>> |
outputBatches(DataSetIterator iterator,
String... outputs)
See
output(DataSetIterator, String...) , but without the concatenation of batches. |
List<Map<String,INDArray>> |
outputBatches(MultiDataSetIterator iterator,
List<Listener> listeners,
String... outputs)
Perform inference.
Example: classifier inference Predictions variable name: "softmaxOutput" Evaluations to perform: Evaluation Data: single output MultiDataSets Code: |
List<Map<String,INDArray>> |
outputBatches(MultiDataSetIterator iterator,
String... outputs)
|
List<String> |
outputs()
Outputs are those variables (not placeholders, constants, etc) that are the output of a function that aren't the
input to any other ops.
|
INDArray |
outputSingle(Map<String,INDArray> placeholders,
String output)
Do inference for a single variable for a single batch.
|
static org.nd4j.linalg.primitives.Pair<String,Integer> |
parseVariable(String varName)
Note: INTENDED FOR DEVELOPER USE
This method extract base variable name and output index (if exists) from raw variable name. |
void |
pauseArgumentInterceptor()
Pause the top (most recently added) argument interceptor
|
void |
pauseArgumentInterceptor(ArgumentInterceptor interceptor)
Pause the given argument interceptor
|
SDVariable |
placeHolder(String name,
DataType dataType,
long... shape)
Create a a placeholder variable.
|
List<String> |
propertiesToResolveForFunction(DifferentialFunction function)
Return the properties to resolve for the given function.
|
void |
putOpForId(String id,
DifferentialFunction function)
Put the function for the given id
|
void |
putOrUpdateShapeForVarName(String varName,
long[] shape,
boolean clearArrayOnShapeMismatch)
Deprecated.
|
void |
putShapeForVarName(String varName,
long[] shape)
Deprecated.
|
void |
putShapeForVarName(String varName,
LongShapeDescriptor shape)
Sets the shape descriptor for a variable.
|
void |
putSubFunction(String name,
SameDiff nameSpace)
Associate a
SameDiff namespace as a sub function. |
SDRandom |
random()
Op creator object for random number generation operations
|
void |
removeArgFromOp(String varName,
DifferentialFunction function)
Remove an argument for a function.
|
void |
removeArgumentInterceptor()
Remote the top (most recently added) argument interceptor
|
void |
removePropertyToResolve(DifferentialFunction forFunction,
String arrayName)
Remove a property to resolve added with
addPropertyToResolve(DifferentialFunction, String) |
void |
renameVariable(String from,
String to)
Rename the specified variable to the new name.
|
void |
replaceArgFor(int i,
SDVariable newArg,
DifferentialFunction function)
Replaces the argument at i with newArg for function
Does not use (or remove) ArgumentInterceptor stuff
|
void |
resolveVariablesWith(Map<String,INDArray> arrays)
Resolve all ndarrays by updating the variables for each array specified in the given map.
|
SDRNN |
rnn()
Op creator object for recurrent neural network operations
|
void |
save(File file,
boolean saveUpdaterState)
Save the SameDiff instance to a file.
|
void |
save(OutputStream outputStream,
boolean saveUpdater)
As per
save(File, boolean) but the serialized SameDiff instance is written to the output stream instead. |
SDVariable |
scalar(String name,
DataType dataType,
Number value)
Create a new scalar (rank 0) SDVariable with the specified value and datatype
|
SDVariable |
scalar(String name,
double value)
Create a new double scalar (rank 0) SDVariable with the specified value
|
SDVariable |
scalar(String name,
float value)
Create a new float scalar (rank 0) SDVariable with the specified value
|
SDVariable |
scalar(String name,
int value)
Create a new integer scalar (rank 0) SDVariable with the specified value
|
SDVariable |
scalar(String name,
long value)
Create a new long scalar (rank 0) SDVariable with the specified value
|
protected SameDiff |
sd() |
void |
setArrayForVariable(String varName,
INDArray arr)
Set the stored
INDArray for a variable. |
void |
setBaseNameForFunctionInstanceId(String baseName,
DifferentialFunction function)
Sets a base name for the function id.
|
void |
setForwardVariableForVarName(String varName,
SDVariable forwardVariable) |
void |
setGradientForVariableName(String variableName,
SDVariable variable)
Assign a SDVariable to represent the gradient of the SDVariable with the specified name
|
void |
setListeners(Collection<? extends Listener> listeners)
|
void |
setListeners(Listener... listeners)
Set the current SameDiff-wide
Listener instances. |
void |
setLossVariables(SDVariable... lossVariables)
|
void |
setLossVariables(String... lossVariableNames)
Clear/remove any existing loss variables, and set the loss variables to the specified variable names.
See addLossVariable(String) for more details |
void |
setOriginalPlaceHolderShape(String variableName,
long[] shape)
Set the original shape for a given place holder.
This is used to track original shapes of place holder variables. The reason we track original shapes is to validate possible candidate arrays coming in (especially with -1 as the expected shapes). |
void |
setTrainingConfig(TrainingConfig trainingConfig)
Set the training configuration (
TrainingConfig ) for the SameDiff instance. |
<X extends SDVariable> |
setupFunction(X function)
Attempts to insert the
DifferentialFunction reference in to this SameDiff instance. |
boolean |
shapeAlreadyExistsForVarName(String varName)
Returns true if the given vertex id and shape already exist.
|
String |
summary()
Generate and return a String representation of the current SameDiff instance
Reports variables, ops, SameDiff function instances, and (where possible) array shapes. For ops, the input and output variables are reported. For variables, the ops that they are inputs to - or outputs of - are also reported |
TensorArray |
tensorArray(DataType dataType)
Create a new TensorArray.
|
void |
unpauseArgumentInterceptor()
Unpause the top (most recently added) argument interceptor
|
void |
unpauseArgumentInterceptor(ArgumentInterceptor interceptor)
Unpause the top given argument interceptor
|
void |
updateVariableName(String varName,
String withName)
Update the opName for the variable with the given vertex id
|
SDVariable |
updateVariableNameAndReference(SDVariable varToUpdate,
String newVarName)
Updates the variable name property on the passed in variable, the reference in samediff, and returns the variable.
|
SDVariable[] |
updateVariableNamesAndReferences(SDVariable[] variablesToUpdate,
String[] newVariableNames)
Updates the variable name property on the passed in variables, its reference in samediff, and returns the variable.
|
SDVariable |
var(DataType dataType,
int... shape)
Creates a
SDVariable with the specified shape and a generated nameAny array will be generated with all zeros for the values This method creates a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. |
SDVariable |
var(DataType dataType,
long... shape)
Creates a
SDVariable with the specified shape and a generated nameAny array will be generated with all zeros for the values This method creates a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. |
SDVariable |
var(INDArray arr)
Create an
SDVariable with a generated name, and assocate the specified array with it.This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. |
SDVariable |
var(SDVariable v)
Initialize a
SDVariable reference tying this variable to this samediff instance. |
SDVariable |
var(String name,
DataType dataType,
int... shape)
Creates a
SDVariable with the given shape and nameAny array will be generated with all zeros for the values |
SDVariable |
var(String name,
DataType dataType,
long... shape)
Creates a
SDVariable with the given shape and nameAny array will be generated with all zeros for the values This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. |
SDVariable |
var(String name,
INDArray arr)
Create an
SDVariable with the specified name, and associate the specified array with itThis is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. |
SDVariable |
var(String name,
int... shape)
Creates a
SDVariable with the given shape and nameAny array will be generated with all zeros for the values. |
SDVariable |
var(String name,
long... shape)
Creates a
SDVariable with the given shape and nameAny array will be generated with all zeros for the values. |
SDVariable |
var(String name,
LongShapeDescriptor shapeDesc)
Creates a
SDVariable with the given shape and nameAny array will be generated with all zeros for the values This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. |
SDVariable |
var(String name,
LongShapeDescriptor shape,
WeightInitScheme weightInitScheme)
Creates a
SDVariable with the given shape and nameThe underlying array will be initialized using the specified weight initilization scheme This is a VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. |
SDVariable |
var(String name,
VariableType variableType,
WeightInitScheme weightInitScheme,
DataType dataType,
long... shape)
Variable initialization with a specified
WeightInitScheme
This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. |
SDVariable |
var(String name,
WeightInitScheme weightInitScheme,
DataType dataType,
long... shape)
Variable initialization with a specified
WeightInitScheme
This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. |
SDVariable |
var(WeightInitScheme weightInitScheme,
DataType dataType,
long... shape)
Creates a
SDVariable with the specified shape and a generated name. |
boolean |
variableHasGradient(String varName)
Determine if the specified variable has a gradient with respect to the current loss.
|
Map<String,SDVariable> |
variableMap()
Return a copy of the internal variable map
|
List<SDVariable> |
variables()
The list of all variables in the graph
|
While |
whileStatement(SameDiffConditional sameDiffConditional,
SameDiffFunctionDefinition conditionBody,
SameDiffFunctionDefinition loopBody,
SDVariable[] inputVars)
|
NameScope |
withNameScope(String nameScope)
Create a name scope.
|
SDVariable |
zero(String name,
DataType dataType,
int... shape)
Create a new variable with the specified shape, with all values initialized to 0
|
SDVariable |
zero(String name,
DataType dataType,
long... shape)
Create a new variable with the specified shape, with all values initialized to 0
|
SDVariable |
zero(String name,
int... shape)
|
SDVariable |
zero(String name,
long... shape)
|
all, all, any, any, argmax, argmax, argmax, argmax, argmin, argmin, argmin, argmin, assign, assign, assign, assign, batchMmul, batchMmul, batchMmul, castTo, castTo, concat, concat, cumprod, cumprod, cumsum, cumsum, dot, dot, dynamicPartition, dynamicPartition, dynamicStitch, dynamicStitch, eq, eq, eq, eq, expandDims, expandDims, fill, fill, gather, gather, gather, gather, gatherNd, gatherNd, gradientBackwardsMarker, gradientBackwardsMarker, gt, gt, gt, gt, gte, gte, gte, gte, identity, identity, ifCond, ifCond, ifCond, invertPermutation, invertPermutation, isNumericTensor, isNumericTensor, linspace, linspace, linspace, lt, lt, lt, lt, lte, lte, lte, lte, matchCondition, matchCondition, matchConditionCount, matchConditionCount, matchConditionCount, max, max, max, max, max, mean, mean, mean, mean, min, min, min, min, min, mmul, mmul, mmul, mmul, neq, neq, neq, neq, norm1, norm1, norm2, norm2, normmax, normmax, oneHot, oneHot, oneHot, oneHot, oneHot, oneHot, onesLike, onesLike, onesLike, parallel_stack, parallel_stack, permute, permute, permute, prod, prod, prod, range, range, range, rank, rank, repeat, repeat, replaceWhere, replaceWhere, replaceWhere, replaceWhere, reshape, reshape, reshape, reshape, reshape, reshape, reverse, reverse, reverseSequence, reverseSequence, reverseSequence, reverseSequence, scalarFloorMod, scalarFloorMod, scalarMax, scalarMax, scalarMin, scalarMin, scalarSet, scalarSet, scatterAdd, scatterAdd, scatterDiv, scatterDiv, scatterMax, scatterMax, scatterMin, scatterMin, scatterMul, scatterMul, scatterSub, scatterSub, scatterUpdate, scatterUpdate, segmentMax, segmentMax, segmentMean, segmentMean, segmentMin, segmentMin, segmentProd, segmentProd, segmentSum, segmentSum, sequenceMask, sequenceMask, sequenceMask, sequenceMask, sequenceMask, sequenceMask, shape, shape, size, size, sizeAt, sizeAt, slice, slice, slice, slice, squaredNorm, squaredNorm, squaredNorm, squaredNorm, squeeze, squeeze, stack, stack, standardDeviation, standardDeviation, standardDeviation, stridedSlice, stridedSlice, stridedSlice, stridedSlice, stridedSlice, stridedSlice, stridedSlice, stridedSlice, sum, sum, sum, sum, tensorMmul, tensorMmul, tile, tile, tile, tile, transpose, transpose, unsortedSegmentMax, unsortedSegmentMax, unsortedSegmentMean, unsortedSegmentMean, unsortedSegmentMin, unsortedSegmentMin, unsortedSegmentProd, unsortedSegmentProd, unsortedSegmentSqrtN, unsortedSegmentSqrtN, unsortedSegmentSum, unsortedSegmentSum, unstack, unstack, unstack, unstack, variance, variance, variance, whileLoop, whileLoop, whileLoop, zerosLike, zerosLike
protected static final String GRAD_FN_KEY
public final SDMath math
public final SDRandom random
public final SDNN nn
public final SDCNN cnn
public final SDRNN rnn
public final SDLoss loss
public final SDImage image
public final SDBitwise bitwise
public SDMath math()
public SDRandom random()
public SDNN nn()
public SDCNN cnn()
public SDRNN rnn()
public SDLoss loss()
public SDImage image()
public SDBitwise bitwise()
public void updateVariableName(String varName, String withName)
varName
- the vertex id to updatewithName
- thew new opNamepublic SameDiff disableDebugging()
public SameDiff enableDebugMode()
public DifferentialFunctionFactory f()
DifferentialFunctionFactory
public void setListeners(Listener... listeners)
Listener
instances.
Note that this will overwrite the current listener list.
If you want to use additional listeners for a single operation,
use the listener arguments in those methods (e.g. fit()
and FitConfig.listeners(Listener...)
).listeners
- Listenerspublic void setListeners(Collection<? extends Listener> listeners)
public void addListeners(Listener... listeners)
Listener
instances.
If you want to use additional listeners for a single operation,
use the listener arguments in those methods (e.g. fit()
and FitConfig.listeners(Listener...)
).listeners
- Listenerspublic void addListeners(Collection<? extends Listener> listeners)
public String currentNameScope()
withNameScope(String)
for more details.protected String nameWithScope(String name)
withNameScope(String)
public NameScope withNameScope(String nameScope)
SameDiff sd = SameDiff.create();
SDVariable x = sd.var("x", DataType.FLOAT, 5);
SDVariable y;
try(NameScope ns = sd.withNameScope("myScope"){
y = sd.var("y", DataType.FLOAT, 5);
}
SDVariable z = sd.var("z", DataType.FLOAT, 5);
String xName = x.getVarName(); //RESULT: "x"
String yName = y.getVarName(); //RESULT: "myScope/y"
String zName = z.getVarName(); //RESULT: "z"
Note that name scopes can also be nested:
SameDiff sd = SameDiff.create();
SDVariable x;
try(NameScope ns = sd.withNameScope("first"){
try(NameScope ns2 = sd.withNameScope("second"){
x = sd.var("x", DataType.FLOAT, 5);
}
}
String xName = x.getVarName(); //RESULT: "first/second/x"
nameScope
- Name of the name scope to open/createpublic List<SameDiffOp> getOpsInScope(NameScope scope)
public List<SameDiffOp> getOpsInScope(String scope)
public List<SDVariable> getVariablesInScope(NameScope scope)
public List<SDVariable> getVariablesInScope(String scope)
public SDVariable invokeGraphOn(SameDiff sameDiff)
sameDiff
- public boolean opExists(String id)
id
- the function id to test forpublic DifferentialFunction getVariableOutputOp(String variableName)
variableName
- Name of the variablepublic DifferentialFunction getOpById(@NonNull String id)
DifferentialFunction#getOwnName()
id
- the id of the functionpublic void putOpForId(String id, DifferentialFunction function)
id
- the id of the functionfunction
- the functionpublic String[] getInputsForOp(DifferentialFunction function)
function
- the function to get the inputs forpublic String[] getOutputsForOp(DifferentialFunction function)
function
- the function to get the outputs forpublic SDVariable[] getOutputVariablesForOp(DifferentialFunction function)
function
- the function reference to get the output variable(s) forpublic SDVariable[] getInputVariablesForOp(DifferentialFunction function)
function
- the function reference to get the input variable(s) forpublic void setArrayForVariable(@NonNull String varName, @NonNull INDArray arr)
INDArray
for a variable. Only works if the variable is of type
VariableType.CONSTANT
, VariableType.PLACEHOLDER
, or VariableType.VARIABLE
.public long[] getShapeForVarName(String varName)
A shape *and* an array should not be defined at the same time. This wastes memory. The internal map used for tracking shapes for particular vertex ids should also delete redundant shapes stored to avoid redundant sources of information.
varName
- the vertex id to get the shape forpublic LongShapeDescriptor getShapeDescriptorForVarName(String varName)
getShapeForVarName(String)
, but returns the shape descriptor.@Deprecated public void putShapeForVarName(String varName, long[] shape)
varName
- the vertex id to associateshape
- the shape to associate withputShapeForVarName(String, long[])
,
putOrUpdateShapeForVarName(String, long[], boolean)
public void putShapeForVarName(String varName, LongShapeDescriptor shape)
@Deprecated public void putOrUpdateShapeForVarName(String varName, long[] shape, boolean clearArrayOnShapeMismatch)
varName
- Variable nameshape
- Shape to putclearArrayOnShapeMismatch
- If false: no change to arrays. If true: if an INDArray is defined for the specified
variable name, it will be removed from the graph (to be later re-generated) if
its shape does not match the specified shapepublic boolean shapeAlreadyExistsForVarName(String varName)
varName
- the vertex idpublic boolean arrayAlreadyExistsForVarName(String varName)
INDArray
already exist.varName
- the vertex idpublic INDArray getArrForVarName(@NonNull String varName)
INDArray
for a given vertex id, or null if none existsvarName
- Variable name to get the array forpublic void associateArrayWithVariable(INDArray arr, @NonNull String variable)
arr
- the array to get the variable forvariable
- the name of the variable to associate the array withpublic void associateArrayWithVariable(INDArray arr, SDVariable variable)
arr
- the array to get the variable forvariable
- the variable to associate the array withpublic void assignArray(@NonNull INDArray arr, @NonNull SDVariable variable)
associateArrayWithVariable(INDArray, String)
this method will take the
values from the argument array and assign it to the current array.
The actual array (INDArray object) will not be stored or otherwise used within the SameDiff instance.arr
- Array values to setvariable
- Variable to update the array of. Must be CONSTANT or VARIBLE type SDVariablepublic void putSubFunction(String name, SameDiff nameSpace)
SameDiff
namespace as a sub function.name
- the opName of the functionnameSpace
- the namespacepublic Map<String,SDVariable> variableMap()
@Deprecated public SDVariable invoke(Op op, SDVariable x, SDVariable y)
op
- the opx
- the first inputy
- the second inputpublic Collection<String> definedFunctionNames()
public SDVariable invoke(Op op, SDVariable x)
op
- the opx
- the first inputpublic void addPropertyToResolve(DifferentialFunction forFunction, String arrayName)
This is very common for model import.
forFunction
- the function to add the property to resolve forarrayName
- the array namepublic void removePropertyToResolve(DifferentialFunction forFunction, String arrayName)
addPropertyToResolve(DifferentialFunction, String)
forFunction
- the function to add the property to resolve forarrayName
- the array namepublic List<String> propertiesToResolveForFunction(DifferentialFunction function)
DifferentialFunction.resolvePropertiesFromSameDiffBeforeExecution()
function
- the function get the properties to resolve forpublic void addVariableMappingForField(DifferentialFunction function, String fieldName, String varName)
GraphMapper.importGraph(File)
.
This data structure is typically accessed during DifferentialFunction.resolvePropertiesFromSameDiffBeforeExecution()
When a function attempts to resolve variables right before execution, there needs to be a way of knowing which variable in a samediff graph should map to a function's particular field name
function
- the function to mapfieldName
- the field name for the function to mapvarName
- the variable name of the array to get from samediffpublic String getVarNameForFieldAndFunction(DifferentialFunction function, String fieldName)
DifferentialFunction.resolvePropertiesFromSameDiffBeforeExecution()
function
- the function to get the variable name forfieldName
- the field name to resolve forpublic void setBaseNameForFunctionInstanceId(String baseName, DifferentialFunction function)
#generateOutputVariableForOp(DifferentialFunction, String)
for ensuring original names for model import map to current samediff names
when names are generated.baseName
- the base name to addfunction
- the function to declare a base name for.public String getBaseNameForFunction(DifferentialFunction function)
function
- the function to get the base name forpublic <X extends SDVariable> X setupFunction(X function)
DifferentialFunction
reference in to this SameDiff
instance.
If the given array field with the given index already exists, it will do a reference check to ensure that the 2
array fields are the same. If not, an exception is thrown.function
- the array field to attempt to createpublic void addOutgoingFor(SDVariable[] variables, DifferentialFunction function)
variables
- Variables - arguments for the specified differential functionfunction
- Differential functionpublic void addOutgoingFor(String[] varNames, DifferentialFunction function)
varNames
- Name of the variables that are outputs of the specified differential functionfunction
- Differential functionpublic void addArgumentInterceptor(@NonNull ArgumentInterceptor interceptor)
For internal use only.
When a op is added with arguments, most recent argument interceptor is called on it. If ops are added in that interceptor, the next most recent will be called on their args, and so on.
interceptor
- the argument interceptor to addpublic void removeArgumentInterceptor()
For internal use only.
public void pauseArgumentInterceptor()
For internal use only.
public void pauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor)
For internal use only.
interceptor
- the argument interceptor to pausepublic void unpauseArgumentInterceptor()
For internal use only.
public void unpauseArgumentInterceptor(@NonNull ArgumentInterceptor interceptor)
For internal use only.
interceptor
- the argument interceptor to unpausepublic void addArgsFor(String[] variables, DifferentialFunction function)
variables
- Name of the variables that are arguments (inputs) to the specified functionfunction
- Functionpublic void addArgsFor(SDVariable[] variables, DifferentialFunction function)
variables
- variables that are arguments (inputs) to the specified functionfunction
- Functionpublic void replaceArgFor(int i, @NonNull SDVariable newArg, @NonNull DifferentialFunction function)
public boolean hasArgs(DifferentialFunction function)
function
- the function to checkpublic void clearPlaceholders(boolean allThreads)
allThreads
- If true: clear the placeholders for all threads. False: clear only for current threadpublic void clearOpInputs()
public DifferentialFunction[] ops()
public static SameDiff create(SameDiff originalSameDiff)
originalSameDiff
- Original SameDiff instancepublic static SameDiff create()
public SameDiff dup()
public long numElements()
SDVariable.getShape()
public List<String> inputs()
public List<String> outputs()
public List<SDVariable> variables()
public List<String> getLossVariables()
sd()
setLossVariables(String...)
, @link #addLossVariable(String)} or SDVariable.markAsLoss()
TrainingConfig#setLossVariables(List)
public void setLossVariables(@NonNull String... lossVariableNames)
addLossVariable(String)
for more detailslossVariableNames
- Names of variables to be loss function variablespublic void setLossVariables(@NonNull SDVariable... lossVariables)
public void addLossVariable(@NonNull String variableName)
public void addLossVariable(@NonNull SDVariable variable)
public void setTrainingConfig(TrainingConfig trainingConfig)
TrainingConfig
) for the SameDiff instance.
A TrainingConfig must be set before the SameDiff instance can be trained via the fit methodstrainingConfig
- Training configurationpublic History fit(@NonNull DataSet dataSet, @NonNull Listener... listeners)
TrainingConfig
must be set via setTrainingConfig(TrainingConfig)
before training can
be performed.dataSet
- The DataSet (single minibatch) to peform training onlisteners
- Additional listeners to use during this operationHistory
object containing the history information for this training operation
(evaluations specified in the TrainingConfig
, loss values, and timing information).public History fit(@NonNull MultiDataSet dataSet, @NonNull Listener... listeners)
TrainingConfig
must be set via setTrainingConfig(TrainingConfig)
before training can
be performed.dataSet
- The MultiDataSet (single minibatch) to peform training onlisteners
- Additional listeners to use during this operationHistory
object containing the history information for this training operation
(evaluations specified in the TrainingConfig
, loss values, and timing information).public History fit(@NonNull DataSetIterator iter, int numEpochs, DataSetIterator validationIter, int validationFrequency, @NonNull Listener... listeners)
TrainingConfig
must be set via setTrainingConfig(TrainingConfig)
before training can
be performed.
A special case of fit()
.
iter
- The iterator to train the SameDiff instance withnumEpochs
- The number of epochs for training. Must be > 0validationIter
- The DataSetIterator to use for validation (null to skip validation)validationFrequency
- The frequency with which to run validation. 1 is every epoch, 2 is every other, etc.listeners
- Additional listeners to use during this operationHistory
object containing the history information for this training operation
(evaluations specified in the TrainingConfig
, loss values, and timing information).public History fit(@NonNull DataSetIterator iter, int numEpochs, @NonNull Listener... listeners)
fit(DataSetIterator, int, DataSetIterator, int, Listener...)
, does not preform validation.
A special case of fit()
.
iter
- The iterator to train the SameDiff instance withnumEpochs
- The number of epochs for training. Must be > 0listeners
- Additional listeners to use during this operationHistory
object containing the history information for this training operation
(evaluations specified in the TrainingConfig
, loss values, and timing information).public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, MultiDataSetIterator validationIter, int validationFrequency, @NonNull Listener... listeners)
TrainingConfig
must be set via setTrainingConfig(TrainingConfig)
before training can
be performed.
A special case of fit()
.
iter
- The iterator to train the SameDiff instance withnumEpochs
- The number of epochs for training. Must be > 0validationIter
- The MultiDataSetIterator to use for validation (null to skip validation)validationFrequency
- The frequency with which to run validation. 1 is every epoch, 2 is every other, etc.listeners
- Additional listeners to use during this operationHistory
object containing the history information for this training operation
(evaluations specified in the TrainingConfig
, loss values, and timing information).public History fit(@NonNull MultiDataSetIterator iter, int numEpochs, @NonNull Listener... listeners)
fit(MultiDataSetIterator, int, MultiDataSetIterator, int, Listener...)
, does not preform validation.
A special case of fit()
.
iter
- The iterator to train the SameDiff instance withnumEpochs
- The number of epochs for training. Must be > 0listeners
- Additional listeners to use during this operationHistory
object containing the history information for this training operation
(evaluations specified in the TrainingConfig
, loss values, and timing information).public FitConfig fit()
FitConfig
.
Supports the setting of training data (MultiDataSetIterator
or DataSetIterator
), number of epochs,
validation data (MultiDataSetIterator
or DataSetIterator
), validation frequency, and additional listeners.
Example: train on data for 5 epochs, validating on valData every 2nd epoch
SameDiff sd = ...;
MultiDataSet data = ...;
MultiDataSet valData = ...;
History hist = sd.fit()
.train(data, 5)
.validate(valData, 2)
.exec();
protected History fit(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, @NonNull Listener... listeners)
protected History fitHelper(@NonNull MultiDataSetIterator iter, int numEpochs, boolean incrementEpochCount, MultiDataSetIterator validationData, int validationFrequency, @NonNull List<Listener> listeners)
public double calcRegularizationScore()
setTrainingConfig(TrainingConfig)
) before this
method can be calledprotected void initializeTraining()
public void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull List<Listener> listeners, @NonNull IEvaluation... evaluations)
Evaluation e = new Evaluation();
sameDiff.evaluate(iterator, "softmax", e);
A special case of evaluate()
.
iterator
- Iterator as source of data to evaluateoutputVariable
- The variable to evaluatelisteners
- Additional listeners to use during this operation.evaluations
- The evaluations to performpublic void evaluate(@NonNull DataSetIterator iterator, @NonNull String outputVariable, @NonNull IEvaluation... evaluations)
evaluate(DataSetIterator, String, List, IEvaluation[])
.
A special case of evaluate()
.
public void evaluate(@NonNull DataSetIterator iterator, @NonNull Map<String,IEvaluation> variableEvals, @NonNull Listener... listeners)
evaluate(MultiDataSetIterator, Map, Map, Listener[])
.
A special case of evaluate()
.
public void evaluateMultiple(DataSetIterator iterator, Map<String,List<IEvaluation>> variableEvals, @NonNull Listener... listeners)
evaluate(MultiDataSetIterator, Map, Map, Listener[])
.
A special case of evaluate()
.
public void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, @NonNull List<Listener> listeners, @NonNull IEvaluation... evaluations)
Evaluation e = new Evaluation();
sameDiff.evaluate(iterator, "softmax", e);
A special case of evaluate()
.
iterator
- Iterator as source of data to evaluateoutputVariable
- The variable to evaluatelabelIndex
- The index of the target variable's labels in the iteratorlisteners
- Additional listeners to use during this operation.evaluations
- The evaluations to performpublic void evaluate(@NonNull MultiDataSetIterator iterator, @NonNull String outputVariable, int labelIndex, @NonNull IEvaluation... evaluations)
evaluate(MultiDataSetIterator, String, int, List, IEvaluation[])
.
A special case of evaluate()
.
public void evaluate(MultiDataSetIterator iterator, Map<String,List<IEvaluation>> variableEvals, Map<String,Integer> predictionLabelMapping, Listener... listeners)
Evaluation
for classifier outputs
and RegressionEvaluation
for regression outputs.Evaluation
MultiDataSetIterator data = ...
Map<String,List<IEvaluation>> evals = Collections.singletonMap("softmaxOutput",Collections.singletonList(new Evaluation()));
Map<String,Integer> labelMapping = Collections.singletonMap("softmaxOutput",0); //Compare: "softmaxOutput" vs. MultiDataSet.getLabels(0)
A special case of evaluate()
.
iterator
- The iterator - the source of the data for evaluationvariableEvals
- The evaluations to perform. Key: the name of the variable. Value: the evaluations to performpredictionLabelMapping
- The output/label mapping. Key: the name of the variable.listeners
- Additional listeners to use during this operation.public EvaluationConfig evaluate()
Supports the setting of the data (MultiDataSetIterator
or DataSetIterator
),
adding evaluations for variables (with optional label index setting), setting label indices,
and setting additional listeners.
Does not require setting label indices when using a DataSetIterator
.
Also supports using SDVariable
instances instead of variable names.
Example: evaluate "pred" with Evaluation
and ROC
, using label 0.
SameDiff sd = ...;
MultiDataSetIterator data = ...;
EvaluationRecord results = sd.evaluate()
.data(data)
.evaluate("pred", 0, new Evaluation(), new ROC()),
.exec();
Example: evaluate "pred" with Evaluation
, using the only label from a DataSetIterator.
SameDiff sd = ...;
DataSetIterator singleData = ...;
EvaluationRecord results = sd.evaluate()
.data(singleData)
.evaluate("pred", new Evaluation()),
.exec();
public Map<String,INDArray> output(@NonNull DataSet dataSet, @NonNull String... outputs)
sameDiff.output(iterator, "softmax");
dataSet
- The data to evaluateoutputs
- The variables to evaluatepublic Map<String,INDArray> output(@NonNull MultiDataSet dataSet, @NonNull String... outputs)
sameDiff.output(iterator, "softmax");
dataSet
- The data to evaluateoutputs
- The variables to evaluatepublic Map<String,INDArray> output(@NonNull DataSetIterator iterator, @NonNull List<Listener> listeners, @NonNull String... outputs)
sameDiff.output(iterator, "softmax");
Uses concatenation on the outputs of outputBatches(DataSetIterator, String...)
which may cause issues with some inputs.
RNNs with variable time series length and CNNs with variable image sizes will most likely have issues.
Special case of output()
.
iterator
- Iterator as source of data to evaluatelisteners
- Additional listeners to use during this operation.outputs
- The variables to evaluatepublic Map<String,INDArray> output(@NonNull DataSetIterator dataSet, @NonNull String... outputs)
output(DataSetIterator, List, String...)
. No additional listeners.
Special case of output()
.
public List<Map<String,INDArray>> outputBatches(DataSetIterator iterator, List<Listener> listeners, String... outputs)
output(DataSetIterator, List, String...)
, but without the concatenation of batches.
Special case of output()
.
public List<Map<String,INDArray>> outputBatches(DataSetIterator iterator, String... outputs)
output(DataSetIterator, String...)
, but without the concatenation of batches.
Special case of output()
.
public Map<String,INDArray> output(@NonNull MultiDataSetIterator iterator, @NonNull List<Listener> listeners, @NonNull String... outputs)
Evaluation
MultiDataSetIterator data = ...
sameDiff.output(iterator, "softmaxOutput);
Special case of output()
.
iterator
- The iterator - the source of the data for inferencelisteners
- Additional listeners to use during this operation.outputs
- The set of outputs to report. If null, defaults to all outputs of this SameDiff.public Map<String,INDArray> output(@NonNull MultiDataSetIterator dataSet, @NonNull String... outputs)
output(MultiDataSetIterator, List, String...)
. No additional listeners.
Special case of output()
.
public List<Map<String,INDArray>> outputBatches(MultiDataSetIterator iterator, List<Listener> listeners, String... outputs)
Evaluation
MultiDataSetIterator data = ...
sameDiff.output(iterator, "softmaxOutput);
Uses concatenation on the outputs of outputBatches(MultiDataSetIterator, List, String...)
which may cause issues with some inputs.
RNNs with variable time series length and CNNs with variable image sizes will most likely have issues.
Special case of output()
.
iterator
- The iterator - the source of the data for inferencelisteners
- Additional listeners to use during this operation.outputs
- The set of outputs to report. If null, defaults to all outputs of this SameDiff.public List<Map<String,INDArray>> outputBatches(MultiDataSetIterator iterator, String... outputs)
outputBatches(MultiDataSetIterator, List, String...)
. No additional listeners.
Special case of output()
.
public OutputConfig output()
MultiDataSetIterator
or DataSetIterator
),
and additional listeners.
Has exec methods to get results in batches or concatenated, or to get results when there is only
a single output (again in batches or concatenated).
Also supports using SDVariable
instances instead of variable names.
Example: get the output of pred, with batches concatenated together
SameDiff sd = ...;
MultiDataSet data = ...;
INDArray out = sd.output()
.data(data)
.output("pred")
.execSingle();
public BatchOutputConfig batchOutput()
Also supports using SDVariable
instances instead of variable names.
Example: get the value of "out" with placeholders x and y
SameDiff sd = ...;
INDArray xValue = ...;
INDArray yValue = ...;
SDVariable y = ...;
INDArray outValue = sd.batchOutput()
.output("out")
.input("x", xValue)
.input(y, yValue)
.execSingle();
@Deprecated public Map<String,INDArray> execAll(Map<String,INDArray> placeholders)
outputAll(Map)
and batchOutput()
public Map<String,INDArray> outputAll(Map<String,INDArray> placeholders)
See output(Map, List, String...)
.
Special case of batchOutput()
.
@Deprecated public INDArray execSingle(Map<String,INDArray> placeholders, String output)
outputSingle(Map, String)
and batchOutput()
public INDArray outputSingle(Map<String,INDArray> placeholders, String output)
See output(Map, List, String...)
.
Special case of batchOutput()
.
@Deprecated public Map<String,INDArray> exec(Map<String,INDArray> placeholders, List<String> outputs)
output(Map, List)
and batchOutput()
public Map<String,INDArray> output(Map<String,INDArray> placeholders, List<String> outputs)
See output(Map, List, String...)
.
Special case of batchOutput()
.
@Deprecated public Map<String,INDArray> exec(Map<String,INDArray> placeholders, String... outputs)
output(Map, String...)
and batchOutput()
public Map<String,INDArray> output(Map<String,INDArray> placeholders, String... outputs)
See output(Map, List, String...)
.
Special case of batchOutput()
.
public Map<String,INDArray> output(Map<String,INDArray> placeholders, @NonNull List<Listener> listeners, String... outputs)
Special case of batchOutput()
.
placeholders
- The values to use for placeholders.listeners
- Additional listeners to use during this operation.outputs
- The variables to output and return.protected Map<String,INDArray> batchOutputHelper(Map<String,INDArray> placeholders, @NonNull List<Listener> listeners, String... outputs)
protected Map<String,INDArray> directExecHelper(Map<String,INDArray> placeholders, At at, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners, String... outputs)
public SDVariable one(String name, int... shape)
one(String, DataType, int...)
.
Uses the DataType of the Nd4j default floating point type (Nd4j.defaultFloatingPointType()
).public SDVariable one(String name, long... shape)
one(String, DataType, long...)
.
Uses the DataType of the Nd4j default floating point type (Nd4j.defaultFloatingPointType()
).public SDVariable one(String name, DataType dataType, int... shape)
name
- the name of the variable to createshape
- the shape of the array to be createdpublic SDVariable one(String name, DataType dataType, long... shape)
name
- the name of the variable to createshape
- the shape of the array to be createdpublic SDVariable zero(String name, long... shape)
zero(String, DataType, long...)
.
Uses the DataType of the Nd4j default floating point type (Nd4j.defaultFloatingPointType()
).public SDVariable zero(String name, int... shape)
zero(String, DataType, int...)
.
Uses the DataType of the Nd4j default floating point type (Nd4j.defaultFloatingPointType()
).public SDVariable zero(String name, DataType dataType, long... shape)
name
- the name of the variable to createshape
- the shape of the array to be createdpublic SDVariable zero(String name, DataType dataType, int... shape)
name
- the name of the variable to createshape
- the shape of the array to be createdpublic SDVariable constant(@NonNull INDArray constant)
VariableType
for more details.constant
- Value for the constant SDVariablepublic SDVariable constant(String name, @NonNull INDArray constant)
VariableType
for more details.name
- Name of the constant SDVariableconstant
- Value for the constant SDVariable@Deprecated public SDVariable constant(SDVariable value, long... shape)
value
- constant to set for each valueshape
- shape of the variable as long array@Deprecated public SDVariable constant(String name, SDVariable value, long... shape)
name
- Name of the new SDVariablevalue
- constant to set for each valueshape
- shape of the variable as long arraypublic SDVariable placeHolder(@NonNull String name, DataType dataType, long... shape)
VariableType
name
- the name of the variabledataType
- Data type of the new placeholdershape
- the shape of the variable if anypublic SDVariable var(@NonNull String name, @NonNull WeightInitScheme weightInitScheme, @NonNull DataType dataType, @NonNull long... shape)
WeightInitScheme
This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See VariableType
for more details.name
- the name of the variableshape
- the shape of the array to be createdweightInitScheme
- the weight initialization schemepublic SDVariable var(@NonNull String name, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, DataType dataType, long... shape)
WeightInitScheme
This method creates VARIABLE type SDVariable - i.e., must be floating point, and is a trainable parameter. See VariableType
for more details.name
- the name of the variablevariableType
- the SameDiff variable type of the variable (e.g. CONSTANT, PLACEHOLDER, etc.)weightInitScheme
- the weight initialization schemedataType
- the data type of the variable (float, int, etc)shape
- the shape of the array to be createdpublic SDVariable var(@NonNull String name, @NonNull LongShapeDescriptor shape, WeightInitScheme weightInitScheme)
SDVariable
with the given shape and nameVariableType
for more details.name
- the name of the variableshape
- the shape of the variableweightInitScheme
- Weight initialization scheme to use to initialize the underlying arraypublic SDVariable var(String name, DataType dataType, long... shape)
SDVariable
with the given shape and nameVariableType
for more details.name
- the name of the variableshape
- the shape of the variablepublic SDVariable var(String name, LongShapeDescriptor shapeDesc)
SDVariable
with the given shape and nameVariableType
for more details.name
- the name of the variableshapeDesc
- the shape of the variablepublic SDVariable var(String name, int... shape)
SDVariable
with the given shape and nameNd4j.defaultFloatingPointType()
VariableType
for more details.name
- the name of the variableshape
- the shape of the variablepublic SDVariable var(String name, long... shape)
SDVariable
with the given shape and nameNd4j.defaultFloatingPointType()
VariableType
for more details.name
- the name of the variableshape
- the shape of the variablepublic SDVariable var(String name, DataType dataType, int... shape)
SDVariable
with the given shape and namename
- the name of the variableshape
- the shape of the variablepublic SDVariable var(@NonNull SDVariable v)
SDVariable
reference tying this variable to this samediff instance.
NDArraySupplierInitScheme
is used to ensure that if the array is allocated anywhere
and SameDiff
instance to exist as a copy of the variable.
v
- Variablepublic SDVariable var(DataType dataType, int... shape)
SDVariable
with the specified shape and a generated nameVariableType
for more details.shape
- the shape of the variablepublic SDVariable var(DataType dataType, long... shape)
SDVariable
with the specified shape and a generated nameVariableType
for more details.shape
- the shape of the variablepublic SDVariable var(WeightInitScheme weightInitScheme, DataType dataType, long... shape)
SDVariable
with the specified shape and a generated name. The associated array will
then be generated using the specified weight initialization schemeweightInitScheme
- The weight initialization scheme to use when generating an INDArrayshape
- the shape of the variablepublic SDVariable var(INDArray arr)
SDVariable
with a generated name, and assocate the specified array with it.VariableType
for more details.arr
- Array to associate with the new variablevar(String, INDArray)
public SDVariable var(String name, @NonNull INDArray arr)
SDVariable
with the specified name, and associate the specified array with itVariableType
for more details.arr
- Array to associate with the new variablepublic SDVariable convertToConstant(@NonNull SDVariable variable)
VariableType
variable
- Variable to convert to a constantpublic void convertToConstants(List<SDVariable> variables)
VariableType
variables
- Variables to convert to constantspublic SDVariable convertToVariable(@NonNull SDVariable constant)
VariableType
public void convertToVariables(@NonNull List<SDVariable> constants)
VariableType
public void convertDataTypes(@NonNull Map<String,DataType> dataTypeMap)
z(float) = x(float)+y(float)
, changing both x and y to double results in z(double) = x(double)+y(double)
without doing anything to change z's datatype directly (z datatype is inferred from x + y + add op).op(x(float),y(float)) -> op(x(double),y(float))
may not be
supported by all ops.dataTypeMap
- Map of SDVariables to change the datatype for. Key = SDVariable name, Value = new datatypepublic void renameVariable(String from, String to)
from
- The variable to rename - this variable must existto
- The new name for the variable - no variable with this name must already existpublic void removeArgFromOp(String varName, DifferentialFunction function)
varName
- the variable name to removefunction
- the function to remove the argument frompublic SDVariable getVariable(String name)
name
- the opName of the variablepublic boolean hasVariable(String name)
public SDVariable getGradForVariable(String varName)
setLossVariables(String...)
and then create the
gradient functions using createGradFunction()
. Alternatively, the gradient function will be
created automatically when training is performed.varName
- the vertex idpublic boolean variableHasGradient(String varName)
createGradFunction()
and setLossVariables(String...)
varName
- Name of the variable to check the existence of a gradient variable forpublic void setGradientForVariableName(String variableName, SDVariable variable)
variableName
- the variable name to assign the gradient variable forvariable
- the gradient variablepublic void setForwardVariableForVarName(String varName, SDVariable forwardVariable)
varName
- forwardVariable
- public SDVariable grad(String varName)
execBackwards(Map, Operation, MultiDataSet, Collection, List)
must be executed first.
All gradient functions are obtained from the results of the execBackwards call.varName
- the variable name to get the gradient variable for.public SDVariable scalar(String name, double value)
name
- Name of the SDVariablevalue
- Value to initialize the variable withpublic SDVariable scalar(String name, float value)
name
- Name of the SDVariablevalue
- Value to initialize the variable withpublic SDVariable scalar(String name, int value)
name
- Name of the SDVariablevalue
- Value to initialize the variable withpublic SDVariable scalar(String name, long value)
name
- Name of the SDVariablevalue
- Value to initialize the variable withpublic SDVariable scalar(String name, DataType dataType, Number value)
name
- Name of the SDVariabledataType
- Data type of the scalarvalue
- Value to initialize the variable withpublic SDVariable constant(double value)
VariableType
for more details.value
- Value to initialize the constant withpublic SDVariable constant(String name, double value)
name
- Name of the SDVariablevalue
- Value to initialize the constant withpublic SDVariable constant(float value)
VariableType
for more details.value
- Value to initialize the constant withpublic SDVariable constant(String name, float value)
name
- Name of the SDVariablevalue
- Value to initialize the constant withpublic SDVariable constant(int value)
value
- Value to initialize the constant withpublic SDVariable constant(String name, int value)
name
- Name of the SDVariablevalue
- Value to initialize the constant withpublic SDVariable constant(long value)
value
- Value to initialize the constant withpublic SDVariable constant(String name, long value)
name
- Name of the SDVariablevalue
- Value to initialize the constant withpublic SDVariable constant(String name, DataType dataType, Number value)
name
- Name of the SDVariabledataType
- Data type of the scalar constantvalue
- Value to initialize the constant withpublic SDVariable addVariable(SDVariable variable)
variable
- Variable to addpublic SDVariable[] generateOutputVariableForOp(DifferentialFunction function, String baseName, boolean isImport)
function
- the function to generate the output
variable names forpublic SDVariable[] generateOutputVariableForOp(DifferentialFunction function)
function
- the function to generate the output
variable names forpublic SameDiff getFunction(String functionName)
functionName
- the name of the function@Deprecated public While whileStatement(SameDiffConditional sameDiffConditional, SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition loopBody, SDVariable[] inputVars)
@Deprecated public If ifStatement(SameDiffConditional conditional, SameDiffFunctionDefinition conditionBody, SameDiffFunctionDefinition trueBody, SameDiffFunctionDefinition falseBody, SDVariable[] inputVars)
public TensorArray tensorArray(DataType dataType)
public SDVariable invokeFunctionOn(String functionName, SameDiff with)
functionName
- with
- public SameDiff defineFunction(String function, SameDiffFunctionDefinition functionDefinition, SDVariable[] variables)
function
- public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition)
function
- public void defineFunction(String function, SameDiffFunctionDefinition functionDefinition, Map<String,INDArray> inputs)
function
- functionDefinition
- inputs
- @Deprecated public INDArray execAndEndResult()
public void execBackwards(Map<String,INDArray> placeholders, Operation op)
myVariable.getGradient().getArr()
execBackwards(Map, List, Operation, MultiDataSet, Collection, List)
instead,
which allows specifying the set of SDVariables to calculate the gradients for. For example,
execBackwards(placeholders, Arrays.asList(myPlaceholder.gradient().getVarName())
. In some cases,
createGradFunction()
may need to be called firstplaceholders
- Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty mappublic void execBackwards(Map<String,INDArray> placeholders)
execBackwards(Map, Operation)
.
Uses Operation.INFERENCE
.
protected void execBackwards(Map<String,INDArray> placeholders, Operation op, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners)
public Map<String,INDArray> execBackwards(Map<String,INDArray> placeholders, Operation op, String... variableGradNamesList)
public Map<String,INDArray> execBackwards(Map<String,INDArray> placeholders, String... variableGradNamesList)
public Map<String,INDArray> execBackwards(Map<String,INDArray> placeholders, List<String> variableGradNamesList, Operation operation)
execBackwards(Map, Operation, MultiDataSet, Collection, List)
, but the set of gradients to calculate can be specified manually.execBackwards(placeholders, Arrays.asList(myPlaceholder.gradient().getVarName())
.placeholders
- Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty mapvariableGradNamesList
- Names of the gradient variables to calculatepublic Map<String,INDArray> execBackwards(Map<String,INDArray> placeholders, List<String> variableGradNamesList)
protected Map<String,INDArray> execBackwards(Map<String,INDArray> placeholders, List<String> variableGradNamesList, Operation operation, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners)
public boolean hasGradientFunction()
createGradFunction()
or createGradFunction(String...)
has been called at allpublic void createGradFunction()
execBackwards(Map, Operation, String[])
) if it is not already defined.
Users do not usually need to call this function manually, as it is called as required in the aforementioned method.
getFunction(String)
with name "grad" as the argument.SDVariable.gradient().getArr()
public void createGradFunction(String... variablesRequiringGradients)
createGradFunction()
, but this method allows a set of variables requiring gradients to be specified.
By default, only parameter gradients will be calculated; placeholder gradients may not be defined (unless they happen
to be calculated in the same op as calculating a parameter gradient.
This method allows you to override this behaviour by passing the name of the placeholder you want the gradients for.
The specified gradient variables still need to be floating point variables.variablesRequiringGradients
- May be null. If non-null: the gradients for the variables with these names will
be calculated and available after backprop has been donepublic void setOriginalPlaceHolderShape(String variableName, long[] shape)
Note that if isPlaceHolder(String)
returns false for the passed in vertex id,
a ND4JIllegalStateException
is thrown.
variableName
- the vertex id for the original shapeshape
- the shape of the place holder@Deprecated public long[] getOriginalShapeForPlaceHolder(String varName)
resolveVariablesWith(Map)
usually when executing using execAll(Map)
varName
- the vertex id to get the original shape for.public boolean isPlaceHolder(String varName)
varName
- the vertex id to testpublic void resolveVariablesWith(Map<String,INDArray> arrays)
IllegalStateException
will be thrown if not all arrays are specified for resolution.arrays
- the arrays to resolve.public SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName)
Note that if null for the new variable is passed in, it will just return the original input variable.
updateVariableNameAndReference
in class SDBaseOps
varToUpdate
- the variable to updatenewVarName
- the new variable namepublic SDVariable[] updateVariableNamesAndReferences(SDVariable[] variablesToUpdate, String[] newVariableNames)
updateVariableNamesAndReferences
in class SDBaseOps
variablesToUpdate
- the variable to updatenewVariableNames
- the new variable nameprotected void associateSameDiffWithOpsAndVariables()
protected int asFlatNode(String name, @NonNull SameDiff scope, @NonNull com.google.flatbuffers.FlatBufferBuilder bufferBuilder)
public static org.nd4j.linalg.primitives.Pair<String,Integer> parseVariable(@NonNull String varName)
varName
- public ByteBuffer asFlatBuffers(@NonNull ExecutorConfiguration configuration, boolean includeUpdaterState)
configuration
- - ExecutorConfiguration to be embedded into serialized graphincludeUpdaterState
- If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)public ByteBuffer asFlatBuffers(long graphId, @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState)
configuration
- - ExecutorConfiguration to be embedded into serialized graphincludeUpdaterState
- If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)public FlatGraph asFlatGraph(boolean includeUpdaterState)
asFlatGraph(long, ExecutorConfiguration, boolean)
.
Uses the default ExecutorConfiguration
with output mode as
OutputMode.VARIABLE_SPACE
, execution mode as ExecutionMode.SEQUENTIAL
,
with profiling disabled and gather timings enabled.public FlatGraph asFlatGraph(long graphId, ExecutorConfiguration configuration, boolean includeUpdaterState)
configuration
- includeUpdaterState
- If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)public ByteBuffer asFlatBuffers(boolean includeUpdaterState)
ExecutorConfiguration
with output mode as
OutputMode.VARIABLE_SPACE
, execution mode as ExecutionMode.SEQUENTIAL
,
with profiling disabled and gather timings enabled.includeUpdaterState
- If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)public void save(@NonNull File file, boolean saveUpdaterState)
load(File, boolean)
file
- File to save tosaveUpdaterState
- If true: save the updater state (arrays etc for Adam, Nesterov, RmsProp etc). If false: don't save
the updater state. If you want to continue training after loading your model, this should be true,
however may increase the file size significantly.
If the network is to be used for inference only, set this to false to save spacepublic void save(@NonNull OutputStream outputStream, boolean saveUpdater)
save(File, boolean)
but the serialized SameDiff instance is written to the output stream instead.
Note that this temporarily saves to disk (using ND4JFileUtils.createTempFile(String, String)
then copies all
file bytes to the streamoutputStream
- Stream to write the serialized SameDiff instance tosaveUpdater
- If true: save the updater state (arrays etc for Adam, Nesterov, RmsProp etc). If false: don't save
the updater state. If you want to continue training after loading your model, this should be true,
however may increase the file size significantly.
If the network is to be used for inference only, set this to false to save space.public static SameDiff load(@NonNull File file, boolean loadUpdaterState)
save(File, boolean)
file
- The file to load the network fromloadUpdaterState
- If true - load the updater state (history etc for updaters such as Adam, Nesterov momentum, RMSProp etc).
For inference only, this should be false, as the updater state will take more memory, but
is not required for training.
If the network is to be trained further, this should be true.
The updater state can only be loaded if it was saved with the network.public static SameDiff load(@NonNull InputStream is, boolean loadUpdaterState)
load(File, boolean)
but the SameDiff instanceis
- Input stream to load the saved network fromloadUpdaterState
- If true - load the updater state (history etc for updaters such as Adam, Nesterov momentum, RMSProp etc).
For inference only, this should be false, as the updater state will take more memory, but
is not required for training.
If the network is to be trained further, this should be true.
The updater state can only be loaded if it was saved with the network.public void asFlatFile(@NonNull File file) throws IOException
ExecutorConfiguration
with output mode as
OutputMode.VARIABLE_SPACE
, execution mode as ExecutionMode.SEQUENTIAL
,
with profiling disabled and gather timings enabled.file
- File to save the FlatBuffers serialized graph (including arrays) toIOException
public void asFlatFile(@NonNull File file, boolean withUpdaterState) throws IOException
asFlatFile(File, ExecutorConfiguration, boolean)
.
Uses the default ExecutorConfiguration
with output mode as
OutputMode.VARIABLE_SPACE
, execution mode as ExecutionMode.SEQUENTIAL
,
with profiling disabled and gather timings enabled.IOException
public void asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration configuration, boolean includeUpdaterState) throws IOException
file
- File to save the FlatBuffers serialized graph (including arrays) toincludeUpdaterState
- If true: include the updater state (state for updaters such as Adam, Nesterov, AdaGrad etc)IOException
public static SameDiff fromFlatFile(@NonNull File file) throws IOException
SameDiff
instance from a file, including the updater state
The method to save the file is save(File, boolean)
file
- the file to load fromIOException
public static SameDiff fromFlatFile(@NonNull File file, boolean loadUpdaterState) throws IOException
SameDiff
instance from a file, optionally also loading the updater state
The method to save the file is save(File, boolean)
file
- the file to load fromloadUpdaterState
- If true, load the updater state (Adam etc state). For training, use true. For inference, use falseIOException
public static SameDiff fromFlatBuffers(ByteBuffer bbIn) throws IOException
SameDiff
instance from a byte buffers
instance.
See fromFlatBuffers(ByteBuffer, boolean)
. Loads updater state (loadUpdaterState is true).bbIn
- the input byte bufferIOException
public static SameDiff fromFlatBuffers(ByteBuffer bbIn, boolean loadUpdaterState) throws IOException
SameDiff
instance from a byte buffers
instance.bbIn
- the input byte bufferloadUpdaterState
- If true, load the updater state (Adam etc state). For training, use true. For inference, use falseIOException
public String asFlatPrint()
summary()
public String summary()
public Map<String,DataType> calculateOutputDataTypes()
public Map<String,DataType> calculateOutputDataTypes(boolean dynamicUpdate)
public String newBlockName(String baseName)
public static SameDiff importFrozenTF(File graphFile)
graphFile
- The text or binary file containing the graphpublic static SameDiff importFrozenTF(InputStream graph)
importFrozenTF(File)
Again, the input can be text or binary.
public String getOpName(String base, boolean force)
Applies name scope if active.
base
- The base name to useforce
- Whether to force the result name to be the same as base.public String getOpName(String base)
getOpName(String, boolean)
force is falsepublic String generateNewVarName(String base, int argIndex, boolean existingOp)
Applies name scopes if active.
base
- The base of the name.argIndex
- The argument index, used in the ":#". A value of 0 (or negative) does not include the ":#" part.existingOp
- Whether to generate an distinct operation name from base (if false), or just use base (if true).public String generateNewVarName(String base, int argIndex)
generateNewVarName(String, int, boolean)
existingOp is true.generateNewVarName
in class SDBaseOps
public String generateDistinctCustomVariableName(String base)
generateNewVarName(String, int)
.Copyright © 2019. All rights reserved.