public class TrainingSession extends InferenceSession
InferenceSession.ConstantDep, InferenceSession.Dep, InferenceSession.ExecDoneDep, InferenceSession.OpDep, InferenceSession.PlaceholderDep, InferenceSession.ReqOutputDep, InferenceSession.VariableDep
AbstractSession.ExecStep, AbstractSession.ExecStepPredicate, AbstractSession.ExecType, AbstractSession.VarId
Modifier and Type | Field and Description |
---|---|
protected TrainingConfig |
config |
protected double[] |
currIterLoss |
protected Map<Class<?>,AtomicDouble> |
currIterRegLoss |
protected Map<String,String> |
gradVarToVarMap |
protected List<Listener> |
listeners |
protected Map<String,Integer> |
lossVarsToLossIdx |
protected Map<String,GradientUpdater> |
updaters |
KERAS_TRAIN_TEST
dt, nodeOutputs, OUTER_FRAME, sameDiff, subgraph, subgraphOps, tensorArrays, zeroInputOpsInSubgraph
Constructor and Description |
---|
TrainingSession(SameDiff sameDiff) |
Modifier and Type | Method and Description |
---|---|
INDArray[] |
getOutputs(Pair<SameDiffOp,OpContext> opPair,
FrameIter outputFrameIter,
Set<AbstractSession.VarId> opInputs,
Set<AbstractSession.VarId> allIterInputs,
Set<String> constAndPhInputs,
List<Listener> listeners,
At at,
MultiDataSet batch,
Set<String> allReqVariables)
Execute the op - calculate INDArrays, or shape info, etc
|
Loss |
trainingIteration(TrainingConfig config,
Map<String,INDArray> placeholders,
Set<String> paramsToTrain,
Map<String,GradientUpdater> updaters,
MultiDataSet batch,
List<String> lossVariables,
List<Listener> listeners,
At at)
Perform one iteration of training - i.e., do forward and backward passes, and update the parameters
|
doExec, getAndParameterizeOp, getArray, getConstantOrVariable, getOutputsHelperTensorArrayOps, postProcessOutput, preprocessPlaceholders
addDependenciesForOp, addVarControlDeps, contains, execFailed, get, get, getExecStepForVar, initSubgraph, lookup, lookup, output, updateDescendantDeps
protected TrainingConfig config
protected Map<String,GradientUpdater> updaters
protected double[] currIterLoss
protected Map<Class<?>,AtomicDouble> currIterRegLoss
public TrainingSession(SameDiff sameDiff)
public Loss trainingIteration(TrainingConfig config, Map<String,INDArray> placeholders, Set<String> paramsToTrain, Map<String,GradientUpdater> updaters, MultiDataSet batch, List<String> lossVariables, List<Listener> listeners, At at)
config
- Training configurationplaceholders
- Current placeholdersparamsToTrain
- Set of parameters that will be trainedupdaters
- Current updater statebatch
- Current data/batch (mainly for listeners, should have already been converted to placeholders map)lossVariables
- Loss variables (names)listeners
- Listeners (if any)at
- Current epoch, iteration, etcpublic INDArray[] getOutputs(Pair<SameDiffOp,OpContext> opPair, FrameIter outputFrameIter, Set<AbstractSession.VarId> opInputs, Set<AbstractSession.VarId> allIterInputs, Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables)
AbstractSession
getOutputs
in class InferenceSession
opPair
- Operation to exit. This should be parameterized (i.e., all inputs set)outputFrameIter
- The frame and iteration of the outputsopInputs
- The specific input arrays for the opallReqVariables
- All required variables requested for the current session execution (not just the current op outputs)Copyright © 2020. All rights reserved.