org.platanios.tensorflow.api.ops.training.distribute.strategies
Mirrors value
to all worker devices.
Mirrors value
to all worker devices.
Value to broadcast.
Destination devices.
Mirrored value.
Returns a copy of fn(variable.value)
on destination
.
Returns a copy of fn(variable.value)
on destination
. This is useful for getting a mirrored variable value onto
a device. The method will attempt to avoid a copy by checking if the value is already on the destination device.
Variable (which may be mirrored) to copy and fetch.
Device to copy the variable value to.
Optional function to apply to the value on the source device, before copying.
Fetched value in device
.
InvalidArgumentException
If there is an issue with the provided variable.
Runs fn
once per tower.
Runs fn
once per tower.
fn
may call tf.currentTowerContext
to access fields and methods such as towerID
and mergeCall()
.
mergeCall()
is used to communicate between the towers and re-enter the cross-tower context. All towers pause
their execution having encountered a mergeCall()
call. After that the mergeFn
-function is executed. Its
results are then unwrapped and given back to each tower call. After that execution resumes until fn
is complete
or another mergeCall()
is encountered.
For example:
// Called once in "cross-tower" context. def mergeFn(distributionStrategy: DistributionStrategy, threePlusTowerID: Int): tf.Output = { // Sum the values across towers. tf.addN(distribution.unwrap(threePlusTowerID)) } // Called once per tower in `distributionStrategy`, in a "tower" context. def fn(three: Int): Output = { val towerContext = tf.currentTowerContext val v = three + towerContext.towerID // Computes the sum of the `v` values across all towers. val s = towerContext.mergeCall(mergeFn(_, v)) s + v } distributionStrategy.scope { // In "cross-tower" context ... val mergedResults = distributionStrategy.forEachTower(() => fn(3)) // `mergedResults` has the values from every tower execution of `fn`. val resultsList = distributionStrategy.unwrap(mergedResults) }
Function that will be run once per tower.
Wrapped values that will be unwrapped when invoking fn
on each tower.
Merged return value of fn
across all towers.
Returns true
if there is only a single tower, and false
, otherwise.
Returns true
if there is only a single tower, and false
, otherwise.
If true
, forEachTower(fn)
will only call fn
once.
If false
, forEachTower(fn)
may call fn
multiple times.
Returns the devices used for non-slot variables.
Returns the devices used for non-slot variables.
Create variables on these devices in a colocateVariablesWith(nonSlotDevices(...)):
block. Then, update them
using updateNonSlot()
.
Variables being optimized.
Colocation ops for non-slot variables.
Returns number of towers, for purposes of averaging across towers.
Returns the devices used for variable and updates placement.
Combines values across towers into one value.
Combines values across towers into one value.
Reduction method to use.
Value to reduce.
Optional destination on which to copy the reduced value.
Reduced value.
Returns the list of all per-device values contained in value
.
Returns the list of all per-device values contained in value
.
A value returned by forEachTower()
, or a variable created in scope
.
Sequence of values contained in value
.
Runs fn
to update variable
using inputs mirrored to the same devices.
Runs fn
to update variable
using inputs mirrored to the same devices.
If variable
is mirrored across multiple devices, then this method implements logic like:
val results = variable.index.map { case (deviceSpec, variable) => tf.createWith(device = deviceSpec.toString) { fn(variable) } } merged(results)
Otherwise this returns fn(variable)
colocated with variable
.
Variable to update.
Update function to use.
Mirrored arguments that should be passed to fn
.
Merged return value of fn
across all towers.
Runs fn
on the devices specified by colocateWith
, with the provided arguments.
Runs fn
on the devices specified by colocateWith
, with the provided arguments.
Destination on which to execute fn
.
Function to use for the update.
Mirrored arguments that should be passed to fn
.
Merged return value of fn
across all towers.
InvalidArgumentException
If the provided colocateWith
argument is invalid (e.g., too many devices).
Returns a map from worker devices to indices.
Returns a map from worker devices to indices.
TODO: [DISTRIBUTE] Settle on the interface of forEachTower()
first.
This map might be passed as an argument to forEachTower()
, as in:
distributionStrategy.scope { def fn(deviceIndex: Int): Unit = { // `fn` is being executed on device `distributionStrategy.workerDevices(deviceIndex)`. } distributionStrategy.forEachTower(fn, distributionStrategy.workerDeviceIndex) }
Returns the devices used to run forEachTower()
calls.
Combines multiple reduce
calls into one for faster execution.
Combines multiple reduce
calls into one for faster execution.
Reduction method to use.
Sequence of values to reduce pairs with destinations to copy the reduced values to.
Reduced values.
Executes block
within a scope that controls which devices variables will be created on.
Executes block
within a scope that controls which devices variables will be created on.
No operations should be added to the graph inside this scope; it should only be used when creating variables (some
implementations work by changing variable creation and others work by using a colocateWith
scope). This may only
be used inside DistributionStrategy.scope
.
For example:
distributionStrategy.scope { val variable1 = tf.variable(...) distributionStrategy.colocateVariablesWith(Set(variable1.op)) { // `variable2` and `variable3` will be created on the same device(s) as `variable1`. val variable2 = tf.variable(...) val variable3 = tf.variable(...) } def fn(v1: Variable, v2: Variable, v3: Variable): Unit = { // Operates on `v1` from `variable1`, `v2` from `variable2`, and `v3` from `variable3`. } // `fn` runs on every device `v1` is on, and `v2` and `v3` will be there too. distributionStrategy.update(variable1, fn, variable2, variable3) }
Variables created in block
will be on the same set of devices as these ops.
Code block to execute in this scope.
Value returned by block
.
Finds and sets the best configuration for the provided TensorFlow session configuration.
Acts as a shortcut for tf.group(distributionStrategy.unwrap(value))
.
Acts as a shortcut for tf.group(distributionStrategy.unwrap(value))
.
A value returned by forEachTower()
, or a variable created in scope
.
Name for the created op.
Grouped unwrapped value
.
Merges arguments across towers and runs mergeFn
in a cross-tower context.
Merges arguments across towers and runs mergeFn
in a cross-tower context.
This allows communication and coordination when there are multiple calls to a model function triggered by a call
to forEachTower(modelFn, ...)
. See MirroredDistribution.forEachTower()
for an explanation.
Otherwise, this is equivalent to:
val strategy = tf.distribute.currentStrategy
strategy.scope {
mergeFn(strategy)
}
Merge function to invoke from within a cross-tower context.
Result of the mergeFn
call, except for per-device values which are unpacked.
Executes block
within a scope where new variables will not be mirrored.
Executes block
within a scope where new variables will not be mirrored.
There will still be one component variable per tower, but there is no requirement that they stay in sync. Instead,
when saving them or calling fetch()
, we use the value that results when calling reduce()
on all the towers'
variables. Note that tower-local implies not trainable. Instead, it is expected that each tower will directly
update (e.g., using assignAdd()
) its local variable instance but only the aggregated value (accessible using
fetch()
) will be exported from the model. When it is acceptable to only aggregate on export, we greatly reduce
communication overhead by using tower-local variables.
Note that all component variables will be initialized to the same value, using the initialization expression from the first tower. The values will match even if the initialization expression uses random numbers.
Reduction method used to get the value to save when creating checkpoints.
Code block to execute in this scope.
Value returned by block
.
Represents a list of devices with a state and a compute distribution policy.
The intent is that you can write an algorithm in a stylized way and it will be usable with a variety of different
DistributionStrategy
implementations. Each descendant will implement a different strategy for distributing the algorithm across multiple devices/machines. Furthermore, these changes can be hidden inside the specific layers and other library classes that need special treatment to run in a distributed setting, so that most users' model definition code can run unchanged.First let's introduce a few high-level concepts:
To distribute an algorithm, we might use some of these ingredients:
Allreduce
is an algorithm for performing a reduction on values from multiple devices and making the result available on all of those devices.We have a few approaches we want to support:
DistributionStrategy
. This code should work as before, even if some of the layers, etc., used by that code are written to be distribution-aware. This is done by having a defaultDistributionStrategy
that gives ordinary behavior, and by default being in a single tower context.DistributionStrategy
. This can be as simple as:This takes an ordinary
dataset
andtowerFn
and runs it distributed using a particularDistributionStrategy
ind
. Any variables created intowerFn
are created usingd
's policy, and library functions called bytowerFn
can use thecurrentTowerContext
API to get enhanced behavior in this case. Note that in the future we will add support for initializable dataset iterators, at which point this example code will change.DistributionStrategy
APIs inside ad.scope
block of code.Lower-level concepts:
PerDevice
orMirrored
object that contains a map from device to values.PerDevice
is used when the value may be different across devices, andMirrored
is used when the value is the same across devices.fn
on multiple devices, likeforEachTower(fn, w)
with an argumentw
that is a wrapped value. This means thatw
will have a map taking tower deviced0
tow0
, tower deviced1
tow1
, etc.forEachTower()
unwrapsw
before callingfn
, and so it callsfn(w0)
ond0
,fn(w1)
ond1
, etc. It then merges the return values fromfn()
, which can possibly result in wrapped values. For example, let's sayfn()
returns a tuple with three components:(x, a, v0)
from tower 0,(x, b, v1)
from tower 1, etc. If the first component is the same objectx
for every tower, then the first component of the merged result will also bex
. If the second component is different (a
,b
, ...) for each tower, then the merged value will have a wrapped map from tower device to the different values. If the third component is the members of a mirrored variable (v
mapsd0
tov0
,d1
tov1
, etc.), then the merged result will be that mirrored variable (i.e.,v
).DistributionStrategy
methods which operate across towers (likereduce()
). By default you start in a tower context (the default "single tower context") and then some methods can switch you back and forth, as described below.colocateVariablesWith()
to get the remaining non-slot variables on the same device. Otherwise, you can usenonSlotDevices()
to pick a consistent set of devices to pass to bothcolocateVariablesWith()
andupdateNonSlot()
.When using a
DistributionStrategy
, we have a new type dimension called locality that says what values are compatible with which APIs:T
: Different value for each tower (e.g., aPerDevice
-wrapped value).M
: Value is "mirrored" across towers. That is, there are copies with the same value on each tower (e.g., aMirrored
-wrapped value).V(v)
: Value is "mirrored" across all the devices which have a copy of variablev
(also aMirrored
-wrapped value, but over parameter devices instead of worker devices).N
: Value is "mirrored" across all the "non-slot" devices.Rules for methods with respect to locality and single-tower vs. cross-tower context:
d.scope()
: Default single-tower context -> cross-tower context ford
.d.colocateVariablesWith(v)
: In tower/cross-tower context, variables will be created with localityV(v)
. That is, if we writed.colocateVariablesWith(v1) { val v2 = tf.variable(...) }
, thenv2
will have localityV(v1)
(i.e., localityV(v2)
will equalV(v1)
).d.colocateVariablesWith(d.nonSlotDevices(...))
: In tower/cross-tower context, variables will be created with localityN
.v = tf.variable(...)
: In tower/cross-tower context, creates a variable (which by definition will have localityV(v)
, though will match another locality if inside acolocateVariablesWith()
scope).d.distributeDataset(dataset)
: In cross-tower context, produces an iterator with localityT
.d.broadcast(t)
: In cross-tower context, produces a value with localityM
.d.broadcast(t, v)
: In cross-tower context, produces a value with localityV(v)
.d.forEachTower(fn, ...)
: In cross-tower context, runsfn()
in a tower context (and so may callcurrentTowerContext
and use its API, includingmergeCall()
to get back to cross-tower context), once for each tower. May use values with localityT
orM
, and any variable.d.reduce(m, t)
: In cross-tower context, acceptst
with localityT
and produces a value with localityM
.d.reduce(m, t, v)
: In cross-tower context, acceptst
with localityT
and produces a value with localityV(v)
.d.batchReduce(m, Seq((t, v)))
: Seed.reduce()
.d.update(v, fn, ...)
: In cross-tower context, runsfn()
once for each devicev
is copied to. All inputs should have localityV(v)
, and the output will have localityV(v)
as well.d.updateNonSlot(d.nonSlotDevices(), fn)
: In cross-tower context, liked.update()
except with localityN
.d.fetch(t)
: Copyt
with any locality to the client's CPU device.The standard pattern for updating variables is to:
d.distributeDataset()
. 2. Define each towerd.forEachTower()
up to the point of getting a list of gradient, variable pairs. 3. Calld.reduce("sum", t, v)
ord.batchReduce()
to sum the gradients (with localityT
) into values with localityV(v)
. 4. Calld.update(v)
for each variable to update its value.Steps 3 and 4 are done automatically by the
Optimizer
class if you call itsapplyGradients
method from within a tower context. Otherwise, you can manually call itsdistributedApply
method in a cross-tower context.Another thing you might want to do in the middle of your tower function is an all-reduce of some intermediate value, using
d.reduce()
ord.batchReduce()
without supplying a variable as the destination.Layers should expect to be called in a tower context, and can use the
currentTowerContext
function to get aTowerContext
object. TheTowerContext
object has amergeCall()
method for entering cross-tower context where you can usereduce()
(orbatchReduce()
) and then optionallyupdate()
to update state.You may use this API whether or not a
DistributionStrategy
is being used, since there is a default implementation ofTowerContext
andDistributionStrategy
. Or you can use thecurrentTowerContext.isSingleTower
property to run different code in the distributed vs. single tower cases.