T
- Node output type - for example, INDArray, shape, etc depending on what we're calculatingO
- Op typepublic abstract class AbstractSession<T,O> extends Object
Modifier and Type | Class and Description |
---|---|
protected static class |
AbstractSession.ExecStep
ExecStep represents a single execution step, for a single op (or variable/constant etc) at a specific frame/iteration
|
protected class |
AbstractSession.ExecStepPredicate
Used in getting the next ExecStep that matches the specified (current) frame/iteration
|
protected static class |
AbstractSession.ExecType
ExecType: Execution type, as used in ExecStep
OP: Operation execution VARIABLE: Variable "execution", mainly used to trigger ops that depend on the variable CONSTANT: As per variable PLACEHOLDER: As per variable SWITCH_L and SWITCH_R: This is a bit of a hack to account for the fact that only one of the switch branches (left or right) will ever be available; without this, once the switch op is executed, we'll (incorrectly) conclude that *both* branches can be executed EXEC_START: Start of execution CONTROL_DEP: Control dependency for op. |
static class |
AbstractSession.FrameIter
FrameIter: Identifies a frame + iteration (but not a specific op or variable).
Note that frames can be nested - which generally represents nested loop situations. |
static class |
AbstractSession.VarId
VarId: identifies the value of a variable in a specific frame and frame iteration
Note that frames can be nested - which generally represents nested loop situations. Used for 2 places: (a) to identify variables that are available for execution (b) to store results |
Modifier and Type | Field and Description |
---|---|
protected DependencyTracker<AbstractSession.ExecStep,AbstractSession.ExecStep> |
dt |
protected Map<AbstractSession.VarId,T> |
nodeOutputs |
static String |
OUTER_FRAME
All execution in Samediff happens in a frame...
|
protected SameDiff |
sameDiff |
protected Set<String> |
subgraph
Contains variables we *might* need to execute in process of getting outputs we want.
|
protected Set<String> |
subgraphOps
As per subgraph set, but for ops instead
|
protected Map<AbstractSession.VarId,List<T>> |
tensorArrays |
protected Set<String> |
zeroInputOpsInSubgraph
Constains the names of ops that don't have any inputs.
|
Constructor and Description |
---|
AbstractSession(SameDiff sameDiff) |
Modifier and Type | Method and Description |
---|---|
protected void |
addDependenciesForOp(String opName,
AbstractSession.FrameIter depFrameIter)
Suppose operation X has just been executed.
|
protected void |
addVarControlDeps(AbstractSession.ExecStep es,
Variable v)
Add the control dependency from Op -> variable
|
boolean |
contains(String variable,
String frame,
int iteration,
AbstractSession.FrameIter parentFrameIter) |
protected void |
execFailed(Set<String> userRequestedUnique,
Map<String,T> out,
int step)
Execution failed - can't calculate all requested outputs, and there's nothing left to calculate.
|
T |
get(String variable,
String frame,
int iteration,
AbstractSession.FrameIter parentFrameIter)
Get a previously calculated output; throws an exception if the output does not exist
|
T |
get(String variable,
String frame,
int iteration,
AbstractSession.FrameIter parentFrameIter,
boolean enforceExistence)
Get a previously calculated output
|
abstract O |
getAndParameterizeOp(String opName,
AbstractSession.FrameIter frameIter,
Set<AbstractSession.VarId> inputs,
Set<AbstractSession.VarId> allIterInputs,
Set<String> constAndPhInputs,
Map<String,T> placeholderValues,
Set<String> allReqVariables)
Get the parameterized op to execute - for example, the op/DifferentialFunction with all inputs set
|
abstract T |
getConstantOrVariable(String variableName)
Get the constant or variable output - for example, constant array or constant shape.
|
protected AbstractSession.ExecStep |
getExecStepForVar(String varName,
AbstractSession.FrameIter frameIter)
Get the ExecStep for the given variable, given execution is happening at the specified frame/iteration
|
abstract T[] |
getOutputs(O op,
AbstractSession.FrameIter outputFrameIter,
Set<AbstractSession.VarId> inputs,
Set<AbstractSession.VarId> allIterInputs,
Set<String> constAndPhInputs,
List<Listener> listeners,
At at,
MultiDataSet batch,
Set<String> allReqVariables)
Execute the op - calculate INDArrays, or shape info, etc
|
protected void |
initSubgraph(Set<String> variables)
Initialize the subgraph - the subgraph and subgraphOps sets
This works our what ops and variables we might need to execute to get the requested outputs.
|
protected static AbstractSession.VarId |
lookup(String name,
Collection<AbstractSession.VarId> varIds,
boolean exceptionOnNotFound)
Get the VarId from the specified name.
|
protected static AbstractSession.VarId |
lookup(String name,
Collection<AbstractSession.VarId> varIds,
Collection<AbstractSession.VarId> varIds2,
boolean exceptionOnNotFound)
Get the VarId from the specified name.
|
Map<String,T> |
output(List<String> variables,
Map<String,T> placeholderValues,
MultiDataSet batch,
Collection<String> requiredActivations,
List<Listener> listeners,
At at)
Get the output of the session - i.e., perform inference/forward pass and return the autputs for the specified variables
|
protected Map<String,T> |
postProcessOutput(Map<String,T> output)
Post process the session output values, if required.
|
protected Map<String,T> |
preprocessPlaceholders(Map<String,T> placeholders,
At at)
Preprocess the placeholder values, if required.
|
protected void |
updateDescendantDeps(AbstractSession.ExecStep justExecuted,
AbstractSession.FrameIter outFrameIter)
Update the descendant dependencies
So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker
This is for a specific frame and iteration, for both sides of the dependency (in and out)
|
public static final String OUTER_FRAME
protected final SameDiff sameDiff
protected final Map<AbstractSession.VarId,T> nodeOutputs
protected final Map<AbstractSession.VarId,List<T>> tensorArrays
protected final DependencyTracker<AbstractSession.ExecStep,AbstractSession.ExecStep> dt
protected final Set<String> subgraph
public AbstractSession(@NonNull SameDiff sameDiff)
public boolean contains(String variable, String frame, int iteration, AbstractSession.FrameIter parentFrameIter)
public T get(String variable, String frame, int iteration, AbstractSession.FrameIter parentFrameIter)
public T get(String variable, String frame, int iteration, AbstractSession.FrameIter parentFrameIter, boolean enforceExistence)
enforceExistence
- If true: throw an exception if the array does not existpublic Map<String,T> output(@NonNull List<String> variables, Map<String,T> placeholderValues, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> listeners, At at)
variables
- Name of the variables we want the arrays/activations forplaceholderValues
- The placeholder values (if any). May be null.batch
- The batch data, used to call Listener.opExecutionrequiredActivations
- Additional activations that are required. Won't be outputed, but opExecution will be called. May be null.protected void addVarControlDeps(AbstractSession.ExecStep es, Variable v)
es
- Execution step for the variablev
- Variableprotected void execFailed(Set<String> userRequestedUnique, Map<String,T> out, int step)
userRequestedUnique
- All outputs that the user requsetedout
- Current outputsstep
- Execution stepprotected void updateDescendantDeps(AbstractSession.ExecStep justExecuted, AbstractSession.FrameIter outFrameIter)
justExecuted
- The execution step that has just completedoutFrameIter
- The frame/iteration of the outputprotected void addDependenciesForOp(String opName, AbstractSession.FrameIter depFrameIter)
opName
- Name of the opdepFrameIter
- Frame/iteration of the op instance to be executedprotected AbstractSession.ExecStep getExecStepForVar(String varName, AbstractSession.FrameIter frameIter)
protected void initSubgraph(Set<String> variables)
variables
- Set of output variables we needprotected Map<String,T> preprocessPlaceholders(Map<String,T> placeholders, At at)
placeholders
- Placeholders to preprocess.protected Map<String,T> postProcessOutput(Map<String,T> output)
output
- Output to be returned to the userpublic abstract T getConstantOrVariable(String variableName)
variableName
- The name of the variable to get the constant forpublic abstract O getAndParameterizeOp(String opName, AbstractSession.FrameIter frameIter, Set<AbstractSession.VarId> inputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String,T> placeholderValues, Set<String> allReqVariables)
opName
- Name of the opframeIter
- The frame and iteration of the op outputsinputs
- The inputs to the op (excluding constants/placeholders) - for the specific frame + iterationallIterInputs
- The inputs - those that are not iteration-specific (mainly Enter op vars, which might be used in all iterations but are only executed once on iter 0)constAndPhInputs
- The constant and placeholder inputs - used for all frames/iterationsallReqVariables
- All required variables requested for the current session execution (not just the current op outputs)public abstract T[] getOutputs(O op, AbstractSession.FrameIter outputFrameIter, Set<AbstractSession.VarId> inputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables)
op
- Operation to exit. This should be parameterized (i.e., all inputs set)outputFrameIter
- The frame and iteration of the outputsinputs
- The specific input arrays for the opallReqVariables
- All required variables requested for the current session execution (not just the current op outputs)protected static AbstractSession.VarId lookup(String name, Collection<AbstractSession.VarId> varIds, Collection<AbstractSession.VarId> varIds2, boolean exceptionOnNotFound)
protected static AbstractSession.VarId lookup(String name, Collection<AbstractSession.VarId> varIds, boolean exceptionOnNotFound)
Copyright © 2019. All rights reserved.