Class

org.platanios.tensorflow.api.ops.training.distribute.strategies

DistributionStrategy

Related Doc: package strategies

Permalink

abstract class DistributionStrategy extends AnyRef

Represents a list of devices with a state and a compute distribution policy.

The intent is that you can write an algorithm in a stylized way and it will be usable with a variety of different DistributionStrategy implementations. Each descendant will implement a different strategy for distributing the algorithm across multiple devices/machines. Furthermore, these changes can be hidden inside the specific layers and other library classes that need special treatment to run in a distributed setting, so that most users' model definition code can run unchanged.

First let's introduce a few high-level concepts:

To distribute an algorithm, we might use some of these ingredients:

We have a few approaches we want to support:

d.scope {
  val iterator = d.distributeDataset(dataset)
  val towerTrainOps = d.forEachTower(towerFn, iterator.next())
  val trainOp = tf.group(d.unwrap(towerTrainOps))
}

This takes an ordinary dataset and towerFn and runs it distributed using a particular DistributionStrategy in d. Any variables created in towerFn are created using d's policy, and library functions called by towerFn can use the currentTowerContext API to get enhanced behavior in this case. Note that in the future we will add support for initializable dataset iterators, at which point this example code will change.

Lower-level concepts:

When using a DistributionStrategy, we have a new type dimension called locality that says what values are compatible with which APIs:

Rules for methods with respect to locality and single-tower vs. cross-tower context:

The standard pattern for updating variables is to:

  1. Wrap your input dataset in d.distributeDataset(). 2. Define each tower d.forEachTower() up to the point of getting a list of gradient, variable pairs. 3. Call d.reduce("sum", t, v) or d.batchReduce() to sum the gradients (with locality T) into values with locality V(v). 4. Call d.update(v) for each variable to update its value.

Steps 3 and 4 are done automatically by the Optimizer class if you call its applyGradients method from within a tower context. Otherwise, you can manually call its distributedApply method in a cross-tower context.

Another thing you might want to do in the middle of your tower function is an all-reduce of some intermediate value, using d.reduce() or d.batchReduce() without supplying a variable as the destination.

Layers should expect to be called in a tower context, and can use the currentTowerContext function to get a TowerContext object. The TowerContext object has a mergeCall() method for entering cross-tower context where you can use reduce() (or batchReduce()) and then optionally update() to update state.

You may use this API whether or not a DistributionStrategy is being used, since there is a default implementation of TowerContext and DistributionStrategy. Or you can use the currentTowerContext.isSingleTower property to run different code in the distributed vs. single tower cases.

Linear Supertypes
Known Subclasses
Ordering
  1. Alphabetic
  2. By Inheritance
Inherited
  1. DistributionStrategy
  2. AnyRef
  3. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Instance Constructors

  1. new DistributionStrategy()

    Permalink

Abstract Value Members

  1. abstract def broadcast[O <: OutputLike](value: O, devices: Seq[DeviceSpecification] = Seq.empty)(implicit context: CrossTowerContext): MirroredValue[O]

    Permalink

    Mirrors value to all worker devices.

    Mirrors value to all worker devices.

    value

    Value to broadcast.

    devices

    Destination devices.

    returns

    Mirrored value.

  2. abstract def createVariable: ColocatedVariableGetter

    Permalink
    Attributes
    protected
  3. abstract def fetch(variable: DistributedVariable, destination: String = "/device:CPU:0", fn: (Output) ⇒ Output = (o: Output) => o)(implicit context: CrossTowerContext): Output

    Permalink

    Returns a copy of fn(variable.value) on destination.

    Returns a copy of fn(variable.value) on destination. This is useful for getting a mirrored variable value onto a device. The method will attempt to avoid a copy by checking if the value is already on the destination device.

    variable

    Variable (which may be mirrored) to copy and fetch.

    destination

    Device to copy the variable value to.

    fn

    Optional function to apply to the value on the source device, before copying.

    returns

    Fetched value in device.

    Annotations
    @throws( ... )
    Exceptions thrown

    InvalidArgumentException If there is an issue with the provided variable.

  4. abstract def forEachTower[T, R](fn: (Seq[T]) ⇒ R, values: Seq[DistributedValue[T]])(implicit arg0: Distributable[T], context: CrossTowerContext): R

    Permalink

    Runs fn once per tower.

    Runs fn once per tower.

    fn may call tf.currentTowerContext to access fields and methods such as towerID and mergeCall(). mergeCall() is used to communicate between the towers and re-enter the cross-tower context. All towers pause their execution having encountered a mergeCall() call. After that the mergeFn-function is executed. Its results are then unwrapped and given back to each tower call. After that execution resumes until fn is complete or another mergeCall() is encountered.

    For example:

    // Called once in "cross-tower" context.
    def mergeFn(distributionStrategy: DistributionStrategy, threePlusTowerID: Int): tf.Output = {
      // Sum the values across towers.
      tf.addN(distribution.unwrap(threePlusTowerID))
    }
    
    // Called once per tower in `distributionStrategy`, in a "tower" context.
    def fn(three: Int): Output = {
      val towerContext = tf.currentTowerContext
      val v = three + towerContext.towerID
      // Computes the sum of the `v` values across all towers.
      val s = towerContext.mergeCall(mergeFn(_, v))
      s + v
    }
    
    distributionStrategy.scope {
      // In "cross-tower" context
      ...
      val mergedResults = distributionStrategy.forEachTower(() => fn(3))
      // `mergedResults` has the values from every tower execution of `fn`.
      val resultsList = distributionStrategy.unwrap(mergedResults)
    }
    fn

    Function that will be run once per tower.

    values

    Wrapped values that will be unwrapped when invoking fn on each tower.

    returns

    Merged return value of fn across all towers.

  5. abstract def isSingleTower: Boolean

    Permalink

    Returns true if there is only a single tower, and false, otherwise.

    Returns true if there is only a single tower, and false, otherwise.

    If true, forEachTower(fn) will only call fn once. If false, forEachTower(fn) may call fn multiple times.

  6. abstract def nonSlotDevices(variables: Seq[variables.Variable]): Set[DeviceSpecification]

    Permalink

    Returns the devices used for non-slot variables.

    Returns the devices used for non-slot variables.

    Create variables on these devices in a colocateVariablesWith(nonSlotDevices(...)): block. Then, update them using updateNonSlot().

    variables

    Variables being optimized.

    returns

    Colocation ops for non-slot variables.

  7. abstract def numTowers: Int

    Permalink

    Returns number of towers, for purposes of averaging across towers.

  8. abstract def parameterDevices: Set[String]

    Permalink

    Returns the devices used for variable and updates placement.

  9. abstract def reduce[D](reduction: Reduction, value: PerDeviceValue[OutputLike], destination: Option[D] = None)(implicit arg0: Destination[D], context: CrossTowerContext): MirroredValue[OutputLike]

    Permalink

    Combines values across towers into one value.

    Combines values across towers into one value.

    reduction

    Reduction method to use.

    value

    Value to reduce.

    destination

    Optional destination on which to copy the reduced value.

    returns

    Reduced value.

  10. abstract def unwrap[T](value: DistributedValue[T])(implicit arg0: Distributable[T], context: CrossTowerContext): Seq[T]

    Permalink

    Returns the list of all per-device values contained in value.

    Returns the list of all per-device values contained in value.

    value

    A value returned by forEachTower(), or a variable created in scope.

    returns

    Sequence of values contained in value.

  11. abstract def update[T, R](variable: MirroredVariable, fn: (variables.Variable, Seq[T]) ⇒ R, arguments: Seq[MirroredValue[T]])(implicit arg0: Distributable[T], arg1: Distributable[R], context: CrossTowerContext): MirroredValue[R]

    Permalink

    Runs fn to update variable using inputs mirrored to the same devices.

    Runs fn to update variable using inputs mirrored to the same devices.

    If variable is mirrored across multiple devices, then this method implements logic like:

    val results = variable.index.map {
      case (deviceSpec, variable) => tf.createWith(device = deviceSpec.toString) {
        fn(variable)
      }
    }
    merged(results)

    Otherwise this returns fn(variable) colocated with variable.

    variable

    Variable to update.

    fn

    Update function to use.

    arguments

    Mirrored arguments that should be passed to fn.

    returns

    Merged return value of fn across all towers.

  12. abstract def updateNonSlot[D, T, R](colocateWith: D, fn: (Seq[T]) ⇒ R, arguments: Seq[MirroredValue[T]])(implicit arg0: Destination[D], arg1: Distributable[T], arg2: Distributable[R], context: CrossTowerContext): MirroredValue[R]

    Permalink

    Runs fn on the devices specified by colocateWith, with the provided arguments.

    Runs fn on the devices specified by colocateWith, with the provided arguments.

    colocateWith

    Destination on which to execute fn.

    fn

    Function to use for the update.

    arguments

    Mirrored arguments that should be passed to fn.

    returns

    Merged return value of fn across all towers.

    Annotations
    @throws( ... )
    Exceptions thrown

    InvalidArgumentException If the provided colocateWith argument is invalid (e.g., too many devices).

  13. abstract def workerDeviceIndex(implicit context: CrossTowerContext): Map[DeviceSpecification, Int]

    Permalink

    Returns a map from worker devices to indices.

    Returns a map from worker devices to indices.

    TODO: [DISTRIBUTE] Settle on the interface of forEachTower() first. This map might be passed as an argument to forEachTower(), as in:

    distributionStrategy.scope {
      def fn(deviceIndex: Int): Unit = {
        // `fn` is being executed on device `distributionStrategy.workerDevices(deviceIndex)`.
      }
      distributionStrategy.forEachTower(fn, distributionStrategy.workerDeviceIndex)
    }
  14. abstract def workerDevices: Set[String]

    Permalink

    Returns the devices used to run forEachTower() calls.

Concrete Value Members

  1. final def !=(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  2. final def ##(): Int

    Permalink
    Definition Classes
    AnyRef → Any
  3. final def ==(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  4. final def asInstanceOf[T0]: T0

    Permalink
    Definition Classes
    Any
  5. def batchReduce[D](reduction: Reduction, valueDestinationPairs: Seq[(PerDeviceValue[OutputLike], Option[D])])(implicit arg0: Destination[D], context: CrossTowerContext): Seq[DistributedValue[OutputLike]]

    Permalink

    Combines multiple reduce calls into one for faster execution.

    Combines multiple reduce calls into one for faster execution.

    reduction

    Reduction method to use.

    valueDestinationPairs

    Sequence of values to reduce pairs with destinations to copy the reduced values to.

    returns

    Reduced values.

  6. def clone(): AnyRef

    Permalink
    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  7. def colocateVariablesWith[R](colocationOps: Set[Op])(block: ⇒ R)(implicit context: DistributionContext): R

    Permalink

    Executes block within a scope that controls which devices variables will be created on.

    Executes block within a scope that controls which devices variables will be created on.

    No operations should be added to the graph inside this scope; it should only be used when creating variables (some implementations work by changing variable creation and others work by using a colocateWith scope). This may only be used inside DistributionStrategy.scope.

    For example:

    distributionStrategy.scope {
      val variable1 = tf.variable(...)
      distributionStrategy.colocateVariablesWith(Set(variable1.op)) {
        // `variable2` and `variable3` will be created on the same device(s) as `variable1`.
        val variable2 = tf.variable(...)
        val variable3 = tf.variable(...)
      }
    
      def fn(v1: Variable, v2: Variable, v3: Variable): Unit = {
        // Operates on `v1` from `variable1`, `v2` from `variable2`, and `v3` from `variable3`.
      }
    
      // `fn` runs on every device `v1` is on, and `v2` and `v3` will be there too.
      distributionStrategy.update(variable1, fn, variable2, variable3)
    }
    colocationOps

    Variables created in block will be on the same set of devices as these ops.

    block

    Code block to execute in this scope.

    returns

    Value returned by block.

  8. def configure(sessionConfig: SessionConfig): Unit

    Permalink

    Finds and sets the best configuration for the provided TensorFlow session configuration.

  9. final def eq(arg0: AnyRef): Boolean

    Permalink
    Definition Classes
    AnyRef
  10. def equals(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  11. def finalize(): Unit

    Permalink
    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( classOf[java.lang.Throwable] )
  12. final def getClass(): Class[_]

    Permalink
    Definition Classes
    AnyRef → Any
  13. def group[T](value: DistributedValue[T], name: String = "Group")(implicit arg0: Distributable[T], context: CrossTowerContext): Op

    Permalink

    Acts as a shortcut for tf.group(distributionStrategy.unwrap(value)).

    Acts as a shortcut for tf.group(distributionStrategy.unwrap(value)).

    value

    A value returned by forEachTower(), or a variable created in scope.

    name

    Name for the created op.

    returns

    Grouped unwrapped value.

  14. def hashCode(): Int

    Permalink
    Definition Classes
    AnyRef → Any
  15. final def isInstanceOf[T0]: Boolean

    Permalink
    Definition Classes
    Any
  16. def mergeCall[R](mergeFn: (DistributionStrategy) ⇒ R)(implicit context: InTowerContext): R

    Permalink

    Merges arguments across towers and runs mergeFn in a cross-tower context.

    Merges arguments across towers and runs mergeFn in a cross-tower context.

    This allows communication and coordination when there are multiple calls to a model function triggered by a call to forEachTower(modelFn, ...). See MirroredDistribution.forEachTower() for an explanation.

    Otherwise, this is equivalent to:

    val strategy = tf.distribute.currentStrategy
    strategy.scope {
      mergeFn(strategy)
    }
    mergeFn

    Merge function to invoke from within a cross-tower context.

    returns

    Result of the mergeFn call, except for per-device values which are unpacked.

  17. final def ne(arg0: AnyRef): Boolean

    Permalink
    Definition Classes
    AnyRef
  18. final def notify(): Unit

    Permalink
    Definition Classes
    AnyRef
  19. final def notifyAll(): Unit

    Permalink
    Definition Classes
    AnyRef
  20. def scope[R](block: ⇒ R): R

    Permalink
  21. final def synchronized[T0](arg0: ⇒ T0): T0

    Permalink
    Definition Classes
    AnyRef
  22. def toString(): String

    Permalink
    Definition Classes
    AnyRef → Any
  23. def towerLocalVariableScope[R](reduction: Reduction)(block: ⇒ R)(implicit context: DistributionContext): R

    Permalink

    Executes block within a scope where new variables will not be mirrored.

    Executes block within a scope where new variables will not be mirrored.

    There will still be one component variable per tower, but there is no requirement that they stay in sync. Instead, when saving them or calling fetch(), we use the value that results when calling reduce() on all the towers' variables. Note that tower-local implies not trainable. Instead, it is expected that each tower will directly update (e.g., using assignAdd()) its local variable instance but only the aggregated value (accessible using fetch()) will be exported from the model. When it is acceptable to only aggregate on export, we greatly reduce communication overhead by using tower-local variables.

    Note that all component variables will be initialized to the same value, using the initialization expression from the first tower. The values will match even if the initialization expression uses random numbers.

    reduction

    Reduction method used to get the value to save when creating checkpoints.

    block

    Code block to execute in this scope.

    returns

    Value returned by block.

  24. final def wait(): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  25. final def wait(arg0: Long, arg1: Int): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  26. final def wait(arg0: Long): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )

Inherited from AnyRef

Inherited from Any

Ungrouped