T
- Node output type - for example, INDArray, shape, etc depending on what we're calculatingO
- Op typepublic abstract class AbstractSession<T,O> extends Object
Modifier and Type | Class and Description |
---|---|
static class |
AbstractSession.FrameIter |
static class |
AbstractSession.VarId |
Modifier and Type | Field and Description |
---|---|
protected Queue<AbstractSession.VarId> |
availableForExec |
protected Set<AbstractSession.VarId> |
availableForExecSet |
protected Map<String,Set<String>> |
execConstInputs
Contains the set set of constant and placeholders inputs
Essentially the same as the execInputs map, but the constants and placeholders are used for calculating all instances
of a variable - i.e., the input (constant/placeholder) applies to all frames and iterations.
|
protected Map<AbstractSession.VarId,Set<AbstractSession.VarId>> |
execInputs
Stores what variables are required to calculate the specific variable.
|
protected Map<AbstractSession.VarId,Set<AbstractSession.VarId>> |
execInputsAllIter
As per execInputs map - with the different that the iteration number should be ignored (i.e., always 0)
Reason: Enter nodes - these are executed once
Example: EnterOp(x) -> LoopCondition(less(x,y)): less op requires "X" on all iterations which is the output of the
enter op, which is only executed for iteration 0 in a frame.
|
protected Map<String,AbstractSession.FrameIter> |
frameParents
Map for exit ops.
|
protected Map<AbstractSession.VarId,T> |
nodeOutputs |
static String |
OUTER_FRAME |
protected SameDiff |
sameDiff |
protected Set<String> |
subgraph
Contains variables we *might* need to execute in process of getting outputs we want.
|
protected Map<AbstractSession.VarId,List<T>> |
tensorArrays |
Constructor and Description |
---|
AbstractSession(SameDiff sameDiff) |
Modifier and Type | Method and Description |
---|---|
protected void |
addToExecInputs(boolean isConstOrPh,
AbstractSession.VarId inputVar,
AbstractSession.VarId forVariable)
This method is used to record that the specified input is required for calculating the specified output.
|
protected boolean |
allInputsAvailable(int execStep,
String[] inputsThisOp,
AbstractSession.VarId executedVar) |
boolean |
contains(String variable,
String frame,
int iteration,
AbstractSession.FrameIter parentFrameIter) |
T |
get(String variable,
String frame,
int iteration,
AbstractSession.FrameIter parentFrameIter)
Get a previously calculated output; throws an exception if the output does not exist
|
T |
get(String variable,
String frame,
int iteration,
AbstractSession.FrameIter parentFrameIter,
boolean enforceExistence)
Get a previously calculated output
|
abstract O |
getAndParameterizeOp(String opName,
AbstractSession.FrameIter frameIter,
Set<AbstractSession.VarId> inputs,
Set<AbstractSession.VarId> allIterInputs,
Set<String> constAndPhInputs,
Map<String,T> placeholderValues)
Get the parameterized op to execute - for example, the op/DifferentialFunction with all inputs set
|
abstract T |
getConstantOrVariable(String variableName)
Get the constant or variable output - for example, constant array or constant shape.
|
abstract T[] |
getOutputs(O op,
AbstractSession.FrameIter outputFrameIter,
Set<AbstractSession.VarId> inputs,
Set<AbstractSession.VarId> allIterInputs,
Set<String> constAndPhInputs,
List<Listener> listeners,
At at,
MultiDataSet batch)
Execute the op - calculate INDArrays, or shape info, etc
|
protected void |
initSubgraph(List<String> variables) |
protected static AbstractSession.VarId |
lookup(String name,
Collection<AbstractSession.VarId> varIds,
boolean exceptionOnNotFound) |
AbstractSession.VarId |
newVarId(String variable,
AbstractSession.FrameIter frameIter) |
AbstractSession.VarId |
newVarId(String variable,
String frame,
int iteration,
AbstractSession.FrameIter parentFrameIter) |
Map<String,T> |
output(List<String> variables,
Map<String,T> placeholderValues,
MultiDataSet batch,
Collection<String> requiredActivations,
boolean training,
At at)
Deprecated.
|
Map<String,T> |
output(List<String> variables,
Map<String,T> placeholderValues,
MultiDataSet batch,
Collection<String> requiredActivations,
List<Listener> listeners,
At at)
Get the output of the session - i.e., perform inference/forward pass
|
protected Map<String,T> |
preprocessPlaceholders(Map<String,T> placeholders)
Preprocess the placeholder values, if required.
|
protected void |
updateDescendentsForExec(int execStep,
AbstractSession.VarId executedVar)
This method should be called for a variable once it's array is ready for use.
|
public static final String OUTER_FRAME
protected final SameDiff sameDiff
protected final Map<AbstractSession.VarId,T> nodeOutputs
protected final Map<AbstractSession.VarId,List<T>> tensorArrays
protected final Queue<AbstractSession.VarId> availableForExec
protected final Set<AbstractSession.VarId> availableForExecSet
protected final Set<String> subgraph
protected final Map<AbstractSession.VarId,Set<AbstractSession.VarId>> execInputs
protected final Map<AbstractSession.VarId,Set<AbstractSession.VarId>> execInputsAllIter
protected final Map<String,Set<String>> execConstInputs
protected final Map<String,AbstractSession.FrameIter> frameParents
public AbstractSession(@NonNull SameDiff sameDiff)
public boolean contains(String variable, String frame, int iteration, AbstractSession.FrameIter parentFrameIter)
public T get(String variable, String frame, int iteration, AbstractSession.FrameIter parentFrameIter)
public T get(String variable, String frame, int iteration, AbstractSession.FrameIter parentFrameIter, boolean enforceExistence)
enforceExistence
- If true: throw an exception if the array does not existpublic AbstractSession.VarId newVarId(String variable, String frame, int iteration, AbstractSession.FrameIter parentFrameIter)
public AbstractSession.VarId newVarId(String variable, AbstractSession.FrameIter frameIter)
@Deprecated public Map<String,T> output(@NonNull List<String> variables, Map<String,T> placeholderValues, MultiDataSet batch, Collection<String> requiredActivations, boolean training, At at)
output(List, Map, MultiDataSet, Collection, List, At)
.training
- Uses Operation.TRAINING if true, otherwise Operation.INFERENCEpublic Map<String,T> output(@NonNull List<String> variables, Map<String,T> placeholderValues, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> listeners, At at)
variables
- Name of the variables we want the arrays/activations forplaceholderValues
- The placeholder values (if any).batch
- The batch data, used to call Listener.opExecutionrequiredActivations
- Additional activations that are required. Won't be outputed, but opExecution will be called. May be null.protected void updateDescendentsForExec(int execStep, AbstractSession.VarId executedVar)
execStep
- Current execution step (mainly for debugging)executedVar
- Variable that was just executedprotected boolean allInputsAvailable(int execStep, String[] inputsThisOp, AbstractSession.VarId executedVar)
protected Map<String,T> preprocessPlaceholders(Map<String,T> placeholders)
placeholders
- Placeholders to preprocess.public abstract T getConstantOrVariable(String variableName)
variableName
- The name of the variable to get the constant forpublic abstract O getAndParameterizeOp(String opName, AbstractSession.FrameIter frameIter, Set<AbstractSession.VarId> inputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, Map<String,T> placeholderValues)
opName
- Name of the opframeIter
- The frame and iteration of the op outputsinputs
- The inputs to the op (excluding constants/placeholders) - for the specific frame + iterationallIterInputs
- The inputs - those that are not iteration-specific (mainly Enter op vars, which might be used in all iterations but are only executed once on iter 0)constAndPhInputs
- The constant and placeholder inputs - used for all frames/iterationspublic abstract T[] getOutputs(O op, AbstractSession.FrameIter outputFrameIter, Set<AbstractSession.VarId> inputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch)
op
- Operation to exit. This should be parameterized (i.e., all inputs set)outputFrameIter
- The frame and iteration of the outputsinputs
- The specific input arrays for the opprotected void addToExecInputs(boolean isConstOrPh, AbstractSession.VarId inputVar, AbstractSession.VarId forVariable)
This method is basically used to store information we need to parameterize ops for execution later
isConstOrPh
- If true: inputVar is either a constant or a placeholderinputVar
- Input variable (i.e., the X in (X, ...) -> op -> (forVariable,...))forVariable
- Output variable (i.e., the Y in (inputVar, ...) -> op -> (Y,...))protected static AbstractSession.VarId lookup(String name, Collection<AbstractSession.VarId> varIds, boolean exceptionOnNotFound)
Copyright © 2019. All rights reserved.