Package

org.platanios.tensorflow.api

learn

Permalink

package learn

Linear Supertypes
Ordering
  1. Alphabetic
  2. By Inheritance
Inherited
  1. learn
  2. AnyRef
  3. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Type Members

  1. case class BuiltSessionScaffold extends Product with Serializable

    Permalink

    Built session scaffold.

  2. case class ChiefSessionCreator(master: String = "", sessionScaffold: SessionScaffold = SessionScaffold(), sessionConfig: Option[SessionConfig] = None, checkpointPath: Option[Path] = None) extends SessionCreator with Product with Serializable

    Permalink

    Session factory for CHIEFs.

    Session factory for CHIEFs.

    master

    TensorFlow master to use.

    sessionScaffold

    Session scaffold used for gathering and/or building supportive ops. If not specified, a default one is created. The session scaffold is used to finalize the graph.

    sessionConfig

    Session configuration to be used for the new sessions.

    checkpointPath

    Path to either a checkpoint file to restore the model from, or a directory containing multiple checkpoint files, in which case the latest checkpoint in the directory will be used.

  3. trait ClipGradients extends AnyRef

    Permalink

    Represents a gradient-clipping method that can be used while training.

  4. case class ClipGradientsByAverageNorm(clipNorm: Float) extends ClipGradients with Product with Serializable

    Permalink

    Clips the gradients using the clipByAverageNorm op.

    Clips the gradients using the clipByAverageNorm op.

    $OpDocClipClipByAverageNorm

    clipNorm

    Maximum average norm clipping value (must be > 0).

  5. case class ClipGradientsByGlobalNorm(clipNorm: Float) extends ClipGradients with Product with Serializable

    Permalink

    Clips the gradients using the clipByGlobalNorm op.

    Clips the gradients using the clipByGlobalNorm op.

    $OpDocClipClipByGlobalNorm

    clipNorm

    Maximum norm clipping value (must be > 0).

  6. case class ClipGradientsByNorm(clipNorm: Float) extends ClipGradients with Product with Serializable

    Permalink

    Clips the gradients using the clipByNorm op.

    Clips the gradients using the clipByNorm op.

    $OpDocClipClipByNorm

    clipNorm

    Maximum norm clipping value (must be > 0).

  7. case class ClipGradientsByValue(clipValueMin: Float, clipValueMax: Float) extends ClipGradients with Product with Serializable

    Permalink

    Clips the gradients using the clipByValue op.

    Clips the gradients using the clipByValue op.

    $OpDocClipClipByValue

    clipValueMin

    Minimum value to clip by.

    clipValueMax

    Maximum value to clip by.

  8. case class Configuration(workingDir: Option[Path] = None, sessionConfig: Option[SessionConfig] = None, checkpointConfig: CheckpointConfig = TimeBasedCheckpoints(600, 5, 10000), randomSeed: Option[Int] = None) extends Product with Serializable

    Permalink

    Configuration for models in the learn API, to be used by estimators.

    Configuration for models in the learn API, to be used by estimators.

    If clusterConfig is not provided, then all distributed training related properties are set based on the TF_CONFIG environment variable, if the pertinent information is present. The TF_CONFIG environment variable is a JSON object with attributes: cluster and task.

    cluster is a JSON serialized version of ClusterConfig, mapping task types (usually one of the instances of TaskType) to a list of task addresses.

    task has two attributes: type and index, where type can be any of the task types in cluster. When TF_CONFIG contains said information, the following properties are set on this class:

    • clusterConfig is parsed from TF_CONFIG['cluster']. Defaults to None. If present, it must have one and only one node for the chief job (i.e., CHIEF task type).
    • taskType is set to TF_CONFIG['task']['type']. Must be set if clusterConfig is present; must be worker (the default value), if it is not.
    • taskIndex is set to TF_CONFIG['task']['index']. Must be set if clusterConfig is present; must be 0 (the default value), if it is not.
    • master is determined by looking up taskType and taskIndex in the clusterConfig. Defaults to "".
    • numParameterServers is set by counting the number of nodes listed in the ps job (i.e., PARAMETER_SERVER task type) of clusterConfig. Defaults to 0.
    • numWorkers is set by counting the number of nodes listed in the worker and chief jobs (i.e., WORKER and CHIEF task types) of clusterConfig. Defaults to 1.
    • isChief is determined based on taskType and TF_CONFIG['cluster'].

    There is a special node with taskType set as EVALUATOR, which is not part of the (training) clusterConfig. It handles the distributed evaluation job.

    Example for a non-chief node:

    // The TF_CONFIG environment variable contains:
    // {
    //   "cluster": {
    //     "chief": ["host0:2222"],
    //     "ps": ["host1:2222", "host2:2222"],
    //     "worker": ["host3:2222", "host4:2222", "host5:2222"]}
    //   "task": {
    //     "type": "worker",
    //     "index": 1}}
    // }
    val config = Configuration()
    assert(config.clusterConfig == Some(ClusterConfig(Map(
      "chief" -> JobConfig.fromAddresses("host0:2222"),
      "ps" -> JobConfig.fromAddresses("host1:2222", "host2:2222"),
      "worker" -> JobConfig.fromAddresses("host3:2222", "host4:2222", "host5:2222")))))
    assert(config.taskType == "worker")
    assert(config.taskIndex == 1)
    assert(config.master == "host4:2222")
    assert(config.numParameterServers == 2)
    assert(config.numWorkers == 4)
    assert(!config.isChief)

    Example for a chief node:

    // The TF_CONFIG environment variable contains:
    // {
    //   "cluster": {
    //     "chief": ["host0:2222"],
    //     "ps": ["host1:2222", "host2:2222"],
    //     "worker": ["host3:2222", "host4:2222", "host5:2222"]}
    //   "task": {
    //     "type": "chief",
    //     "index": 0}}
    // }
    val config = Configuration()
    assert(config.clusterConfig == Some(ClusterConfig(Map(
      "chief" -> JobConfig.fromAddresses("host0:2222"),
      "ps" -> JobConfig.fromAddresses("host1:2222", "host2:2222"),
      "worker" -> JobConfig.fromAddresses("host3:2222", "host4:2222", "host5:2222")))))
    assert(config.taskType == "chief")
    assert(config.taskIndex == 0)
    assert(config.master == "host0:2222")
    assert(config.numParameterServers == 2)
    assert(config.numWorkers == 4)
    assert(config.isChief)

    Example for an evaluator node (an evaluator is not part of the training cluster):

    // The TF_CONFIG environment variable contains:
    // {
    //   "cluster": {
    //     "chief": ["host0:2222"],
    //     "ps": ["host1:2222", "host2:2222"],
    //     "worker": ["host3:2222", "host4:2222", "host5:2222"]}
    //   "task": {
    //     "type": "evaluator",
    //     "index": 0}}
    // }
    val config = Configuration()
    assert(config.clusterConfig == None)
    assert(config.taskType == "evaluator")
    assert(config.taskIndex == 0)
    assert(config.master == "")
    assert(config.numParameterServers == 0)
    assert(config.numWorkers == 0)
    assert(!config.isChief)

    NOTE: If a checkpointConfig is set, maxCheckpointsToKeep might need to be adjusted accordingly, especially in distributed training. For example, using TimeBasedCheckpoints(60) without adjusting maxCheckpointsToKeep (which defaults to 5) leads to a situation that checkpoints would be garbage collected after 5 minutes. In distributed training, the evaluation job starts asynchronously and might fail to load or find the checkpoints due to a race condition.

    workingDir

    Directory used to save model parameters, graph, etc. It can also be used to load checkpoints for a previously saved model. If null, a temporary directory will be used.

    sessionConfig

    Configuration to use for the created sessions.

    checkpointConfig

    Configuration specifying when to save checkpoints.

    randomSeed

    Random seed value to be used by the TensorFlow initializers. Setting this value allows consistency between re-runs.

  9. trait InferenceModel[IT, IO, ID, IS, I] extends Model

    Permalink
  10. sealed trait Mode extends AnyRef

    Permalink

    Represents the mode that a model is on, while being used by a learner (e.g., training mode, evaluation mode, or prediction mode).

  11. trait Model extends AnyRef

    Permalink

  12. case class ModelInstance[IT, IO, ID, IS, I, TT, TO, TD, TS, EI](model: TrainableModel[IT, IO, ID, IS, I, TT, TO, TD, TS, EI], configuration: Configuration, trainInputIterator: Option[Iterator[TT, TO, TD, TS]] = None, trainInput: Option[TO] = None, output: Option[I] = None, loss: Option[ops.Output] = None, gradientsAndVariables: Option[Seq[(ops.OutputLike, ops.variables.Variable)]] = None, trainOp: Option[ops.Op] = None) extends Product with Serializable

    Permalink

    Represents an instance of a constructed model.

    Represents an instance of a constructed model. Such instances are constructed by estimators and passed on to model-dependent hooks.

  13. class MonitoredSession extends SessionWrapper

    Permalink

    Session wrapper that handles initialization, recovery, and hooks.

    Session wrapper that handles initialization, recovery, and hooks.

    Example usage:

    val stopAtStepHook = StopAtStepHook(5, true)
    val session = MonitoredSession(ChiefSessionCreator(...), Seq(stopAtStepHook))
    while (!session.shouldStop) {
      session.run(...)
    }

    Initialization: At creation time the monitored session does following things, in the presented order:

    • Invoke Hook.begin() for each hook.
    • Add any scaffolding ops and freeze the graph using SessionScaffold.build().
    • Create a session.
    • Initialize the model using the initialization ops provided by the session scaffold.
    • Restore variable values, if a checkpoint exists.
    • Invoke Hook.afterSessionCreation() for each hook.

    Run: When MonitoredSession.run() is called, the monitored session does the following things, in the presented order:

    • Invoke Hook.beforeSessionRun() for each hook.
    • Invoke Session.run() with the combined feeds, fetches, and targets (i.e., user-provided and hook-provided).
    • Invoke Hook.afterSessionRun() for each hook.
    • Return the result of run call that the user requested.
    • For certain types of acceptable exceptions (e.g., aborted or unavailable), recover or reinitialize the session before invoking the Session.run() call, again.

    Exit: When MonitoredSession.close(), the monitored session does following things, in the presented order:

    • Invoke Hook.end() for each hook, if no exception has been thrown (other than AbortedException, or UnavailableException).
    How to Create MonitoredSessions

    In most cases you can set the constructor arguments as follows:

    MonitoredSession(ChiefSessionCreator(master = ..., sessionConfig = ...))

    In a distributed setting for a non-chief worker, you can use the following:

    MonitoredSession(WorkerSessionCreator(master = ..., sessionConfig = ...))

    See MonitoredTrainingSession for an example usage based on chief or worker.

  14. case class RecoverableSession extends SessionWrapper with Product with Serializable

    Permalink

    Session wrapper that recreates a session upon certain kinds of errors.

    Session wrapper that recreates a session upon certain kinds of errors.

    The constructor is passed a SessionCreator object, not a Session.

    Calls to run() are delegated to the wrapped session. If a call throws an AbortedException or an UnavailableException, the wrapped session is closed, and a new one is created by invoking the session creator.

  15. trait SessionCreator extends AnyRef

    Permalink

    Factory for Sessions.

  16. case class SessionScaffold(readyOp: Option[ops.Output] = None, readyForLocalInitOp: Option[ops.Output] = None, initOp: Option[ops.Op] = None, initFeedMap: FeedMap = FeedMap.empty, initFunction: Option[(core.client.Session, BuiltSessionScaffold) ⇒ Unit] = None, localInitOp: Option[ops.Op] = None, localInitFunction: Option[(core.client.Session, BuiltSessionScaffold) ⇒ Unit] = None, summaryOp: Option[ops.Output] = None, saver: Option[Saver] = None) extends Product with Serializable

    Permalink

    Structure used to create or gather pieces commonly needed to train a model.

    Structure used to create or gather pieces commonly needed to train a model.

    When you build a model for training you usually need ops to initialize variables, a Saver to checkpoint them, an op to collect summaries for the visualizer, and so on.

    Various libraries built on top of the core TensorFlow library take care of creating some or all of these pieces and storing them in well known collections in the graph. The SessionScaffold class helps pick these pieces from graph collections, create them, and/or also add them to graph collections if needed.

    If you call the scaffold constructor without any arguments, it will pick pieces from the graph collections, creating default ones if needed, when SessionScaffold.build() is called. You can pass arguments to the constructor to provide your own pieces. Pieces that you pass to the constructor are not added to the graph collections.

    readyOp

    Output used to verify that the variables are initialized. Picked from and stored into the READY_OP graph collection by default.

    readyForLocalInitOp

    Output used to verify that global state has been initialized and it is fine to execute localInitOp. Picked from and stored into the READY_FOR_LOCAL_INIT_OP graph collection by default.

    initOp

    Op used to initialize the variables. Picked from and stored into the INIT_OP graph collection by default.

    initFeedMap

    Feed map that will be used when executing initOp.

    initFunction

    Function to run after the init op to perform additional initializations.

    localInitOp

    Op used to initialize the local variables. Picked from and stored into the LOCAL_INIT_OP graph collection by default.

    localInitFunction

    Function to run after the local init op to perform additional initializations.

    summaryOp

    Output used to merge the summaries in the graph. Picked from and stored into the SUMMARY_OP graph collection by default.

    saver

    Saver object taking care of saving the variables. Picked from and stored into the SAVERS graph collection by default.

  17. class SessionWrapper extends core.client.Session

    Permalink

    Wrapper around a Session that invokes Hook callbacks before and after calls to Session.run().

    Wrapper around a Session that invokes Hook callbacks before and after calls to Session.run().

    This wrapper is used as a base class for various session wrappers that provide additional functionality such as monitoring and recovery.

    In addition to the methods provided by Session the wrapper provides a method to check for requested stops and never throws any exceptions thrown by calls to Session.close.

    The list of hooks to call is passed in the constructor. Before each call to Session.run() the session calls the Hook.beforeSessionRun() method of each hook, which can return additional ops or tensors to run. These are added to the arguments of the call to Session.run().

    When the Session.run() call finishes, the session invokes the Hook.afterSessionRun() method of each hook, passing the values returned by the Session.run() call corresponding to the ops and tensors that each hook requested.

    If any call to the hooks requests a stop via the runContext, the session will be marked as needing to stop and its shouldStop() method will then return true.

  18. class StopCriteria extends AnyRef

    Permalink

    Criteria used to stop the training process iteration.

  19. trait SupervisedTrainableModel[IT, IO, ID, IS, I, TT, TO, TD, TS, T] extends TrainableModel[IT, IO, ID, IS, I, (IT, TT), (IO, TO), (ID, TD), (IS, TS), (I, T)]

    Permalink
  20. trait TrainableModel[IT, IO, ID, IS, I, TT, TO, TD, TS, EI] extends InferenceModel[IT, IO, ID, IS, I]

    Permalink
  21. trait UnsupervisedTrainableModel[IT, IO, ID, IS, I] extends TrainableModel[IT, IO, ID, IS, I, IT, IO, ID, IS, I]

    Permalink
  22. case class WorkerSessionCreator(master: String = "", sessionScaffold: SessionScaffold = SessionScaffold(), sessionConfig: Option[SessionConfig] = None) extends SessionCreator with Product with Serializable

    Permalink

    Session factory for WORKERs.

    Session factory for WORKERs.

    master

    TensorFlow master to use.

    sessionScaffold

    Session scaffold used for gathering and/or building supportive ops. If not specified, a default one is created. The session scaffold is used to finalize the graph.

    sessionConfig

    Session configuration to be used for the new sessions.

Value Members

  1. object Configuration extends Serializable

    Permalink

    Contains helper methods for dealing with Configurations.

  2. object Counter

    Permalink

    Contains helper methods for creating and obtaining counter variables (e.g., epoch or global iteration).

  3. object EVALUATION extends Mode with Product with Serializable

    Permalink
  4. object INFERENCE extends Mode with Product with Serializable

    Permalink
  5. object Model

    Permalink
  6. object MonitoredSession

    Permalink

    Contains helper methods for creating monitored sessions.

  7. object NoClipGradients extends ClipGradients with Product with Serializable

    Permalink

    Represents no clipping of the gradients (i.e., identity operation).

  8. object RecoverableSession extends Serializable

    Permalink

    Contains helper methods used internally by recoverable sessions.

  9. object StopCriteria

    Permalink
  10. object TRAINING extends Mode with Product with Serializable

    Permalink
  11. package estimators

    Permalink

  12. package hooks

    Permalink

  13. package layers

    Permalink

  14. package models

    Permalink

Inherited from AnyRef

Inherited from Any

Ungrouped