org.platanios.tensorflow.api.ops.training.distribute.strategies
Combines multiple reduce
calls into one for faster execution.
Combines multiple reduce
calls into one for faster execution.
Reduction method to use.
Sequence of values to reduce pairs with destinations to copy the reduced values to.
Reduced values.
Mirrors value
to all worker devices.
Mirrors value
to all worker devices.
Value to broadcast.
Destination devices.
Mirrored value.
Executes block
within a scope that controls which devices variables will be created on.
Executes block
within a scope that controls which devices variables will be created on.
No operations should be added to the graph inside this scope; it should only be used when creating variables (some
implementations work by changing variable creation and others work by using a colocateWith
scope). This may only
be used inside DistributionStrategy.scope
.
For example:
distributionStrategy.scope { val variable1 = tf.variable(...) distributionStrategy.colocateVariablesWith(Set(variable1.op)) { // `variable2` and `variable3` will be created on the same device(s) as `variable1`. val variable2 = tf.variable(...) val variable3 = tf.variable(...) } def fn(v1: Variable, v2: Variable, v3: Variable): Unit = { // Operates on `v1` from `variable1`, `v2` from `variable2`, and `v3` from `variable3`. } // `fn` runs on every device `v1` is on, and `v2` and `v3` will be there too. distributionStrategy.update(variable1, fn, variable2, variable3) }
Variables created in block
will be on the same set of devices as these ops.
Code block to execute in this scope.
Value returned by block
.
Finds and sets the best configuration for the provided TensorFlow session configuration.
Finds and sets the best configuration for the provided TensorFlow session configuration.
Returns a copy of fn(variable.value)
on destination
.
Returns a copy of fn(variable.value)
on destination
. This is useful for getting a mirrored variable value onto
a device. The method will attempt to avoid a copy by checking if the value is already on the destination device.
Variable (which may be mirrored) to copy and fetch.
Device to copy the variable value to.
Optional function to apply to the value on the source device, before copying.
Fetched value in device
.
Runs fn
once per tower.
Runs fn
once per tower.
fn
may call tf.currentTowerContext
to access fields and methods such as towerID
and mergeCall()
.
mergeCall()
is used to communicate between the towers and re-enter the cross-tower context. All towers pause
their execution having encountered a mergeCall()
call. After that the mergeFn
-function is executed. Its
results are then unwrapped and given back to each tower call. After that execution resumes until fn
is complete
or another mergeCall()
is encountered.
For example:
// Called once in "cross-tower" context. def mergeFn(distributionStrategy: DistributionStrategy, threePlusTowerID: Int): tf.Output = { // Sum the values across towers. tf.addN(distribution.unwrap(threePlusTowerID)) } // Called once per tower in `distributionStrategy`, in a "tower" context. def fn(three: Int): Output = { val towerContext = tf.currentTowerContext val v = three + towerContext.towerID // Computes the sum of the `v` values across all towers. val s = towerContext.mergeCall(mergeFn(_, v)) s + v } distributionStrategy.scope { // In "cross-tower" context ... val mergedResults = distributionStrategy.forEachTower(() => fn(3)) // `mergedResults` has the values from every tower execution of `fn`. val resultsList = distributionStrategy.unwrap(mergedResults) }
Function that will be run once per tower.
Wrapped values that will be unwrapped when invoking fn
on each tower.
Merged return value of fn
across all towers.
Acts as a shortcut for tf.group(distributionStrategy.unwrap(value))
.
Acts as a shortcut for tf.group(distributionStrategy.unwrap(value))
.
A value returned by forEachTower()
, or a variable created in scope
.
Name for the created op.
Grouped unwrapped value
.
Returns true
if there is only a single tower, and false
, otherwise.
Returns true
if there is only a single tower, and false
, otherwise.
If true
, forEachTower(fn)
will only call fn
once.
If false
, forEachTower(fn)
may call fn
multiple times.
Merges arguments across towers and runs mergeFn
in a cross-tower context.
Merges arguments across towers and runs mergeFn
in a cross-tower context.
This allows communication and coordination when there are multiple calls to a model function triggered by a call
to forEachTower(modelFn, ...)
. See MirroredDistribution.forEachTower()
for an explanation.
Otherwise, this is equivalent to:
val strategy = tf.distribute.currentStrategy
strategy.scope {
mergeFn(strategy)
}
Merge function to invoke from within a cross-tower context.
Result of the mergeFn
call, except for per-device values which are unpacked.
Returns the devices used for non-slot variables.
Returns the devices used for non-slot variables.
Create variables on these devices in a colocateVariablesWith(nonSlotDevices(...)):
block. Then, update them
using updateNonSlot()
.
Variables being optimized.
Colocation ops for non-slot variables.
Returns number of towers, for purposes of averaging across towers.
Returns number of towers, for purposes of averaging across towers.
Returns the devices used for variable and updates placement.
Returns the devices used for variable and updates placement.
Combines values across towers into one value.
Combines values across towers into one value.
Reduction method to use.
Value to reduce.
Optional destination on which to copy the reduced value.
Reduced value.
Executes block
within a scope where new variables will not be mirrored.
Executes block
within a scope where new variables will not be mirrored.
There will still be one component variable per tower, but there is no requirement that they stay in sync. Instead,
when saving them or calling fetch()
, we use the value that results when calling reduce()
on all the towers'
variables. Note that tower-local implies not trainable. Instead, it is expected that each tower will directly
update (e.g., using assignAdd()
) its local variable instance but only the aggregated value (accessible using
fetch()
) will be exported from the model. When it is acceptable to only aggregate on export, we greatly reduce
communication overhead by using tower-local variables.
Note that all component variables will be initialized to the same value, using the initialization expression from the first tower. The values will match even if the initialization expression uses random numbers.
Reduction method used to get the value to save when creating checkpoints.
Code block to execute in this scope.
Value returned by block
.
Returns the list of all per-device values contained in value
.
Returns the list of all per-device values contained in value
.
A value returned by forEachTower()
, or a variable created in scope
.
Sequence of values contained in value
.
Runs fn
to update variable
using inputs mirrored to the same devices.
Runs fn
to update variable
using inputs mirrored to the same devices.
If variable
is mirrored across multiple devices, then this method implements logic like:
val results = variable.index.map { case (deviceSpec, variable) => tf.createWith(device = deviceSpec.toString) { fn(variable) } } merged(results)
Otherwise this returns fn(variable)
colocated with variable
.
Variable to update.
Update function to use.
Mirrored arguments that should be passed to fn
.
Merged return value of fn
across all towers.
Runs fn
on the devices specified by colocateWith
, with the provided arguments.
Runs fn
on the devices specified by colocateWith
, with the provided arguments.
Destination on which to execute fn
.
Function to use for the update.
Mirrored arguments that should be passed to fn
.
Merged return value of fn
across all towers.
InvalidArgumentException
If the provided colocateWith
argument is invalid (e.g., too many devices).
Returns a map from worker devices to indices.
Returns a map from worker devices to indices.
TODO: [DISTRIBUTE] Settle on the interface of forEachTower()
first.
This map might be passed as an argument to forEachTower()
, as in:
distributionStrategy.scope { def fn(deviceIndex: Int): Unit = { // `fn` is being executed on device `distributionStrategy.workerDevices(deviceIndex)`. } distributionStrategy.forEachTower(fn, distributionStrategy.workerDeviceIndex) }
Returns the devices used to run forEachTower()
calls.
Returns the devices used to run forEachTower()
calls.