Represents a cross-tower context (as opposed to an in-tower context).
Represents a distribution context (e.g., in-tower or cross-tower).
Represents a distribution context (e.g., in-tower or cross-tower).
Distribution contexts are used to constrain what actions are allowed at every point in the code, and enforce those constraints at compile time, using implicits. More specifically, the current distribution context is always available implicitly and can be checked (e.g., some method require an implicit cross-tower context and will not be able to compile if there current context is an in-tower context).
For example, for the following execution steps:
forEachTower(fn, ...)
(i.e., the code in fn
will have an
in-tower implicit context available).
4. If fn
calls currentTowerContext->mergeCall(mergeFn, ...)
, then inside mergeFn
, a cross-tower context
will again be implicitly available.Note that you can also go directly from step 1 to 4 to switch to a cross-tower context for the default distribution
strategy. You may also switch from the cross-tower context of 4 to an in-tower context by calling
forEachTower()
, jumping back to step 3.
Most distribution API methods may only be executed in cross-tower contexts.
Represents a list of devices with a state and a compute distribution policy.
Represents a list of devices with a state and a compute distribution policy.
The intent is that you can write an algorithm in a stylized way and it will be usable with a variety of different
DistributionStrategy
implementations. Each descendant will implement a different strategy for distributing the
algorithm across multiple devices/machines. Furthermore, these changes can be hidden inside the specific layers and
other library classes that need special treatment to run in a distributed setting, so that most users' model
definition code can run unchanged.
First let's introduce a few high-level concepts:
To distribute an algorithm, we might use some of these ingredients:
Allreduce
is an algorithm for performing a reduction on values from multiple devices and making
the result available on all of those devices.We have a few approaches we want to support:
DistributionStrategy
. This code should work as before, even if
some of the layers, etc., used by that code are written to be distribution-aware. This is done by having a
default DistributionStrategy
that gives ordinary behavior, and by default being in a single tower context.DistributionStrategy
. This can be as simple as:d.scope { val iterator = d.distributeDataset(dataset) val towerTrainOps = d.forEachTower(towerFn, iterator.next()) val trainOp = tf.group(d.unwrap(towerTrainOps)) }
This takes an ordinary dataset
and towerFn
and runs it distributed using a particular DistributionStrategy
in d
. Any variables created in towerFn
are created using d
's policy, and library functions called by
towerFn
can use the currentTowerContext
API to get enhanced behavior in this case. Note that in the future
we will add support for initializable dataset iterators, at which point this example code will change.
DistributionStrategy
APIs inside a
d.scope
block of code.Lower-level concepts:
PerDevice
or Mirrored
object that contains a map from
device to values. PerDevice
is used when the value may be different across devices, and Mirrored
is used
when the value is the same across devices.fn
on multiple devices, like forEachTower(fn, w)
with an argument w
that is a wrapped value. This means that w
will have a map taking tower device d0
to
w0
, tower device d1
to w1
, etc. forEachTower()
unwraps w
before calling fn
, and so it calls fn(w0)
on d0
, fn(w1)
on d1
, etc. It then merges the return values from fn()
, which can possibly result in
wrapped values. For example, let's say fn()
returns a tuple with three components: (x, a, v0)
from tower 0,
(x, b, v1)
from tower 1, etc. If the first component is the same object x
for every tower, then the first
component of the merged result will also be x
. If the second component is different (a
, b
, ...) for each
tower, then the merged value will have a wrapped map from tower device to the different values. If the third
component is the members of a mirrored variable (v
maps d0
to v0
, d1
to v1
, etc.), then the merged
result will be that mirrored variable (i.e., v
).DistributionStrategy
methods which operate across towers (like reduce()
). By default you start in a tower
context (the default "single tower context") and then some methods can switch you back and forth,
as described below.colocateVariablesWith()
to get the remaining non-slot variables on the same device. Otherwise, you can use
nonSlotDevices()
to pick a consistent set of devices to pass to both colocateVariablesWith()
and
updateNonSlot()
.When using a DistributionStrategy
, we have a new type dimension called locality that says what values are
compatible with which APIs:
T
: Different value for each tower (e.g., a PerDevice
-wrapped value).M
: Value is "mirrored" across towers. That is, there are copies with the same value on each tower (e.g., a
Mirrored
-wrapped value).V(v)
: Value is "mirrored" across all the devices which have a copy of variable v
(also a Mirrored
-wrapped
value, but over parameter devices instead of worker devices).N
: Value is "mirrored" across all the "non-slot" devices.Rules for methods with respect to locality and single-tower vs. cross-tower context:
d.scope()
: Default single-tower context -> cross-tower context for d
.d.colocateVariablesWith(v)
: In tower/cross-tower context, variables will be created with locality V(v)
. That
is, if we write d.colocateVariablesWith(v1) { val v2 = tf.variable(...) }
, then v2
will have locality
V(v1)
(i.e., locality V(v2)
will equal V(v1)
).d.colocateVariablesWith(d.nonSlotDevices(...))
: In tower/cross-tower context, variables will be created with
locality N
.v = tf.variable(...)
: In tower/cross-tower context, creates a variable (which by definition will have locality
V(v)
, though will match another locality if inside a colocateVariablesWith()
scope).d.distributeDataset(dataset)
: In cross-tower context, produces an iterator with locality T
.d.broadcast(t)
: In cross-tower context, produces a value with locality M
.d.broadcast(t, v)
: In cross-tower context, produces a value with locality V(v)
.d.forEachTower(fn, ...)
: In cross-tower context, runs fn()
in a tower context (and so may call
currentTowerContext
and use its API, including mergeCall()
to get back to cross-tower context), once for
each tower. May use values with locality T
or M
, and any variable.d.reduce(m, t)
: In cross-tower context, accepts t
with locality T
and produces a value with locality M
.d.reduce(m, t, v)
: In cross-tower context, accepts t
with locality T
and produces a value with locality
V(v)
.d.batchReduce(m, Seq((t, v)))
: See d.reduce()
.d.update(v, fn, ...)
: In cross-tower context, runs fn()
once for each device v
is copied to. All inputs
should have locality V(v)
, and the output will have locality V(v)
as well.d.updateNonSlot(d.nonSlotDevices(), fn)
: In cross-tower context, like d.update()
except with locality N
.d.fetch(t)
: Copy t
with any locality to the client's CPU device.The standard pattern for updating variables is to:
d.distributeDataset()
.
2. Define each tower d.forEachTower()
up to the point of getting a list of gradient, variable pairs.
3. Call d.reduce("sum", t, v)
or d.batchReduce()
to sum the gradients (with locality T
) into values with
locality V(v)
.
4. Call d.update(v)
for each variable to update its value.Steps 3 and 4 are done automatically by the Optimizer
class if you call its applyGradients
method from within a
tower context. Otherwise, you can manually call its distributedApply
method in a cross-tower context.
Another thing you might want to do in the middle of your tower function is an all-reduce of some intermediate value,
using d.reduce()
or d.batchReduce()
without supplying a variable as the destination.
Layers should expect to be called in a tower context, and can use the currentTowerContext
function to get a
TowerContext
object. The TowerContext
object has a mergeCall()
method for entering cross-tower context where
you can use reduce()
(or batchReduce()
) and then optionally update()
to update state.
You may use this API whether or not a DistributionStrategy
is being used, since there is a default implementation
of TowerContext
and DistributionStrategy
. Or you can use the currentTowerContext.isSingleTower
property to run
different code in the distributed vs. single tower cases.
Represents an in-tower context (as opposed to a cross-tower context).
Represents an in-tower context (as opposed to a cross-tower context).
This context is only present during a forEachTower()
call (except during a mergeRun()
call), and in such a scope
it will be implicitly available.
Distribution strategy.
ID of the tower that is being defined, which is a number in [0, numTowers - 1]
.
Represents a cross-tower context (as opposed to an in-tower context).
This context typically becomes available during a
distributionStrategy.scope
call. That call also sets up a new variable scope that changes variable creation calls (e.g., to make mirrored variables). This is intended as an outer scope that users enter once, around their model creation and graph definition.Distribution strategy.