Built session scaffold.
Session factory for CHIEF
s.
Session factory for CHIEF
s.
TensorFlow master to use.
Session scaffold used for gathering and/or building supportive ops. If not specified, a default one is created. The session scaffold is used to finalize the graph.
Session configuration to be used for the new sessions.
Path to either a checkpoint file to restore the model from, or a directory containing multiple checkpoint files, in which case the latest checkpoint in the directory will be used.
Represents a gradient-clipping method that can be used while training.
Clips the gradients using the clipByAverageNorm
op.
Clips the gradients using the clipByAverageNorm
op.
$OpDocClipClipByAverageNorm
Maximum average norm clipping value (must be > 0).
Clips the gradients using the clipByGlobalNorm
op.
Clips the gradients using the clipByGlobalNorm
op.
$OpDocClipClipByGlobalNorm
Maximum norm clipping value (must be > 0).
Clips the gradients using the clipByNorm
op.
Clips the gradients using the clipByNorm
op.
$OpDocClipClipByNorm
Maximum norm clipping value (must be > 0).
Clips the gradients using the clipByValue
op.
Clips the gradients using the clipByValue
op.
$OpDocClipClipByValue
Minimum value to clip by.
Maximum value to clip by.
Configuration for models in the learn API, to be used by estimators.
Configuration for models in the learn API, to be used by estimators.
If clusterConfig
is not provided, then all distributed training related properties are set based on the
TF_CONFIG
environment variable, if the pertinent information is present. The TF_CONFIG
environment variable is a
JSON object with attributes: cluster
and task
.
cluster
is a JSON serialized version of ClusterConfig, mapping task types (usually one of the instances of
TaskType) to a list of task addresses.
task
has two attributes: type
and index
, where type
can be any of the task types in cluster
. When
TF_CONFIG
contains said information, the following properties are set on this class:
clusterConfig
is parsed from TF_CONFIG['cluster']
. Defaults to None
. If present, it must have one and only
one node for the chief
job (i.e., CHIEF
task type).taskType
is set to TF_CONFIG['task']['type']
. Must be set if clusterConfig
is present; must be worker
(the default value), if it is not.taskIndex
is set to TF_CONFIG['task']['index']
. Must be set if clusterConfig
is present; must be 0 (the
default value), if it is not.master
is determined by looking up taskType
and taskIndex
in the clusterConfig
. Defaults to ""
.numParameterServers
is set by counting the number of nodes listed in the ps
job (i.e., PARAMETER_SERVER
task type) of clusterConfig
. Defaults to 0.numWorkers
is set by counting the number of nodes listed in the worker
and chief
jobs (i.e., WORKER
and
CHIEF
task types) of clusterConfig
. Defaults to 1.isChief
is determined based on taskType
and TF_CONFIG['cluster']
.There is a special node with taskType
set as EVALUATOR
, which is not part of the (training) clusterConfig
. It
handles the distributed evaluation job.
Example for a non-chief node:
// The TF_CONFIG environment variable contains: // { // "cluster": { // "chief": ["host0:2222"], // "ps": ["host1:2222", "host2:2222"], // "worker": ["host3:2222", "host4:2222", "host5:2222"]} // "task": { // "type": "worker", // "index": 1}} // } val config = Configuration() assert(config.clusterConfig == Some(ClusterConfig(Map( "chief" -> JobConfig.fromAddresses("host0:2222"), "ps" -> JobConfig.fromAddresses("host1:2222", "host2:2222"), "worker" -> JobConfig.fromAddresses("host3:2222", "host4:2222", "host5:2222"))))) assert(config.taskType == "worker") assert(config.taskIndex == 1) assert(config.master == "host4:2222") assert(config.numParameterServers == 2) assert(config.numWorkers == 4) assert(!config.isChief)
Example for a chief node:
// The TF_CONFIG environment variable contains: // { // "cluster": { // "chief": ["host0:2222"], // "ps": ["host1:2222", "host2:2222"], // "worker": ["host3:2222", "host4:2222", "host5:2222"]} // "task": { // "type": "chief", // "index": 0}} // } val config = Configuration() assert(config.clusterConfig == Some(ClusterConfig(Map( "chief" -> JobConfig.fromAddresses("host0:2222"), "ps" -> JobConfig.fromAddresses("host1:2222", "host2:2222"), "worker" -> JobConfig.fromAddresses("host3:2222", "host4:2222", "host5:2222"))))) assert(config.taskType == "chief") assert(config.taskIndex == 0) assert(config.master == "host0:2222") assert(config.numParameterServers == 2) assert(config.numWorkers == 4) assert(config.isChief)
Example for an evaluator node (an evaluator is not part of the training cluster):
// The TF_CONFIG environment variable contains: // { // "cluster": { // "chief": ["host0:2222"], // "ps": ["host1:2222", "host2:2222"], // "worker": ["host3:2222", "host4:2222", "host5:2222"]} // "task": { // "type": "evaluator", // "index": 0}} // } val config = Configuration() assert(config.clusterConfig == None) assert(config.taskType == "evaluator") assert(config.taskIndex == 0) assert(config.master == "") assert(config.numParameterServers == 0) assert(config.numWorkers == 0) assert(!config.isChief)
NOTE: If a checkpointConfig
is set, maxCheckpointsToKeep
might need to be adjusted accordingly, especially
in distributed training. For example, using TimeBasedCheckpoints(60)
without adjusting maxCheckpointsToKeep
(which defaults to 5) leads to a situation that checkpoints would be garbage collected after 5 minutes. In
distributed training, the evaluation job starts asynchronously and might fail to load or find the checkpoints due to
a race condition.
Directory used to save model parameters, graph, etc. It can also be used to load
checkpoints for a previously saved model. If null
, a temporary directory will be used.
Configuration to use for the created sessions.
Configuration specifying when to save checkpoints.
Random seed value to be used by the TensorFlow initializers. Setting this value allows consistency between re-runs.
Represents the mode that a model is on, while being used by a learner (e.g., training mode, evaluation mode, or prediction mode).
Represents an instance of a constructed model.
Represents an instance of a constructed model. Such instances are constructed by estimators and passed on to model-dependent hooks.
Session wrapper that handles initialization, recovery, and hooks.
Session wrapper that handles initialization, recovery, and hooks.
Example usage:
val stopAtStepHook = StopAtStepHook(5, true) val session = MonitoredSession(ChiefSessionCreator(...), Seq(stopAtStepHook)) while (!session.shouldStop) { session.run(...) }
Initialization: At creation time the monitored session does following things, in the presented order:
Hook.begin()
for each hook.SessionScaffold.build()
.Hook.afterSessionCreation()
for each hook.Run: When MonitoredSession.run()
is called, the monitored session does the following things, in the
presented order:
Hook.beforeSessionRun()
for each hook.Session.run()
with the combined feeds, fetches, and targets (i.e., user-provided and hook-provided).Hook.afterSessionRun()
for each hook.Session.run()
call, again.Exit: When MonitoredSession.close()
, the monitored session does following things, in the presented order:
Hook.end()
for each hook, if no exception has been thrown (other than AbortedException
, or
UnavailableException
).MonitoredSession
s In most cases you can set the constructor arguments as follows:
MonitoredSession(ChiefSessionCreator(master = ..., sessionConfig = ...))
In a distributed setting for a non-chief worker, you can use the following:
MonitoredSession(WorkerSessionCreator(master = ..., sessionConfig = ...))
See MonitoredTrainingSession
for an example usage based on chief or worker.
Session wrapper that recreates a session upon certain kinds of errors.
Session wrapper that recreates a session upon certain kinds of errors.
The constructor is passed a SessionCreator object, not a Session.
Calls to run()
are delegated to the wrapped session. If a call throws an AbortedException
or an
UnavailableException
, the wrapped session is closed, and a new one is created by invoking the session creator.
Factory for Sessions.
Structure used to create or gather pieces commonly needed to train a model.
Structure used to create or gather pieces commonly needed to train a model.
When you build a model for training you usually need ops to initialize variables, a Saver to checkpoint them, an op to collect summaries for the visualizer, and so on.
Various libraries built on top of the core TensorFlow library take care of creating some or all of these pieces and storing them in well known collections in the graph. The SessionScaffold class helps pick these pieces from graph collections, create them, and/or also add them to graph collections if needed.
If you call the scaffold constructor without any arguments, it will pick pieces from the graph collections, creating
default ones if needed, when SessionScaffold.build()
is called. You can pass arguments to the constructor to
provide your own pieces. Pieces that you pass to the constructor are not added to the graph collections.
Output used to verify that the variables are initialized. Picked from and stored
into the READY_OP
graph collection by default.
Output used to verify that global state has been initialized and it is fine to
execute localInitOp
. Picked from and stored into the READY_FOR_LOCAL_INIT_OP
graph
collection by default.
Op used to initialize the variables. Picked from and stored into the INIT_OP
graph
collection by default.
Feed map that will be used when executing initOp
.
Function to run after the init op to perform additional initializations.
Op used to initialize the local variables. Picked from and stored into the
LOCAL_INIT_OP
graph collection by default.
Function to run after the local init op to perform additional initializations.
Output used to merge the summaries in the graph. Picked from and stored into the
SUMMARY_OP
graph collection by default.
Saver object taking care of saving the variables. Picked from and stored into the
SAVERS
graph collection by default.
Wrapper around a Session that invokes Hook callbacks before and after calls to Session.run()
.
Wrapper around a Session that invokes Hook callbacks before and after calls to Session.run()
.
This wrapper is used as a base class for various session wrappers that provide additional functionality such as monitoring and recovery.
In addition to the methods provided by Session the wrapper provides a method to check for requested stops and
never throws any exceptions thrown by calls to Session.close
.
The list of hooks to call is passed in the constructor. Before each call to Session.run()
the session calls the
Hook.beforeSessionRun()
method of each hook, which can return additional ops or tensors to run. These are added to
the arguments of the call to Session.run()
.
When the Session.run()
call finishes, the session invokes the Hook.afterSessionRun()
method of each hook,
passing the values returned by the Session.run()
call corresponding to the ops and tensors that each hook
requested.
If any call to the hooks requests a stop via the runContext
, the session will be marked as needing to stop and its
shouldStop()
method will then return true
.
Criteria used to stop the training process iteration.
Session factory for WORKER
s.
Session factory for WORKER
s.
TensorFlow master to use.
Session scaffold used for gathering and/or building supportive ops. If not specified, a default one is created. The session scaffold is used to finalize the graph.
Session configuration to be used for the new sessions.
Contains helper methods for dealing with Configurations.
Contains helper methods for creating and obtaining counter variables (e.g., epoch or global iteration).
Contains helper methods for creating monitored sessions.
Represents no clipping of the gradients (i.e., identity operation).
Contains helper methods used internally by recoverable sessions.