Package

org.platanios.tensorflow.api.learn

hooks

Permalink

package hooks

Linear Supertypes
Ordering
  1. Alphabetic
  2. By Inheritance
Inherited
  1. hooks
  2. AnyRef
  3. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Type Members

  1. class CheckpointSaver extends TriggeredHook

    Permalink

    Saves checkpoints to files based on a HookTrigger.

    Saves checkpoints to files based on a HookTrigger. Checkpoints include the current graph, as well as the trained values of all variables, so far.

  2. class Evaluator[IT, IO, ID, IS, I, TT, TO, TD, TS, EI] extends TriggeredHook with ModelDependentHook[IT, IO, ID, IS, I, TT, TO, TD, TS, EI]

    Permalink

    Hooks that can be used to evaluate the performance of an estimator for a separate dataset, while training.

    Hooks that can be used to evaluate the performance of an estimator for a separate dataset, while training. This hook creates a new session whenever invoked that loads the latest saved checkpoint and evaluates performance using the provided set of evaluation metrics.

  3. abstract class Hook extends AnyRef

    Permalink

    Hook to extend calls to MonitoredSession.run().

    Hook to extend calls to MonitoredSession.run().

    Hooks are useful to track training, report progress, request early stopping and more. They use the observer pattern and notify at the following points:

    • When a session starts being used,
    • Before a call to Session.run(),
    • After a call to Session.run(),
    • When the session stops being used.

    A Hook encapsulates a piece of reusable/composable computation that can piggyback a call to MonitoredSession.run(). A hook can add any feeds/fetches/targets to the run call, and when the run call finishes executing with success, the hook gets the fetches it requested. Hooks are allowed to add ops to the graph in the begin() method. The graph is finalized after the begin() method is called.

    There are a few pre-defined hooks that can be used without modification:

    • StopHook: Requests to stop iterating based on the provided stopping criteria.
    • StepRateHook: Logs and/or saves summaries with the number of steps executed per second.
    • TensorLoggingHook: Logs the values of one or more tensors.
    • SummarySaverHook: Saves summaries to the provided summary writer.
    • CheckpointSaverHook: Saves checkpoints (i.e., copy of the graph along with the trained variable values).
    • TensorNaNHook: Requests to stop iterating if the provided tensor contains NaN values.

    For more specific needs you can create custom hooks. For example:

    class ExampleHook extends Hook[Output, Unit, Tensor] {
      private[this] val logger: Logger = Logger(LoggerFactory.getLogger("Example Hook"))
      private[this] var exampleTensor: Output
    
      override def begin(): Unit = {
        // You can add ops to the graph here.
        logger.info("Starting the session.")
        exampleTensor = ...
      }
    
      override def afterSessionCreation(session: Session): Unit = {
        // When this is called, the graph is finalized and ops can no longer be added to it.
        logger.info("Session created.")
      }
    
      override def beforeSessionRun[F, E, R](runContext: SessionRunContext[F, E, R])(implicit
        executableEv: Executable[E],
        fetchableEv: Fetchable.Aux[F, R]
       ): Hook.SessionRunArgs[Output, Unit, Tensor] = {
         logger.info("Before calling `Session.run()`.")
         Hook.SessionRunArgs(fetches = exampleTensor)
       }
    
       override def afterSessionRun[F, E, R](runContext: SessionRunContext[F, E, R], runValues: Tensor)(implicit
         executableEv: Executable[E],
         fetchableEv: Fetchable.Aux[F, R]
       ): Unit = {
         logger.info("Done running one step. The value of the tensor is: ${runValues.summarize()}")
         if (needToStop)
           runContext.requestStop()
       }
    
       override def end(session: Session): Unit = {
         logger.info("Done with the session.")
       }
    }

    To understand how hooks interact with calls to MonitoredSession.run(), look at following code:

    val session = Estimator.monitoredTrainingSession(hooks = someHook, ...)
    while (!session.shouldStop)
      session.run(...)
    session.close()

    The above user code loosely leads to the following execution:

    someHook.begin()
    val session = tf.Session()
    someHook.afterSessionCreation()
    while (!stopRequested) {
      someHook.beforeSessionRun(...)
      try {
        val result = session.run(mergedSessionRunArgs)
        someHook.afterSessionRun(..., result)
      } catch {
        case _: OutOfRangeException => stopRequested = true
      }
    }
    someHook.end()
    session.close()

    Note that if session.run() throws an OutOfRangeException then someHook.afterSessionRun() will not be called, but someHook.end() will still be called. On the other hand, if session.run() throws any other exception, then neither someHook.afterSessionRun() nor someHook.end() will be called.

  4. trait HookTrigger extends AnyRef

    Permalink

    Determines when hooks should be triggered.

  5. class LossLogger extends TriggeredHook with ModelDependentHook[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with SummaryWriterHookAddOn

    Permalink

    Hook that logs the loss function value.

  6. trait ModelDependentHook[IT, IO, ID, IS, I, TT, TO, TD, TS, EI] extends Hook

    Permalink

    Represents hooks that may dependent on the constructed model.

    Represents hooks that may dependent on the constructed model.

    This class offers the modelInstance field that sub-classes can access and that contains information specific to the created model. It is only updated when the model graph is constructed (i.e., it is not updated while recovering failed sessions).

    For example, a hook that logs the loss function value depends on the created loss op, or an evaluation hook may depends on multiple ops created as part of the model.

  7. class NaNChecker extends Hook

    Permalink

    Monitors the provided tensors and stops training if, at any point, any of them contains any NaN values.

    Monitors the provided tensors and stops training if, at any point, any of them contains any NaN values.

    This hook can either fail with an exception or just stop training.

  8. case class StepHookTrigger(numSteps: Int, startStep: Int = 0) extends HookTrigger with Product with Serializable

    Permalink

    Hook trigger that triggers at most once every numSteps steps.

    Hook trigger that triggers at most once every numSteps steps.

    numSteps

    Triggering step frequency.

    startStep

    Step after which to start triggering.

  9. class StepRateLogger extends TriggeredHook with SummaryWriterHookAddOn

    Permalink

    Saves summaries to files based on a HookTrigger.

  10. class SummarySaver extends TriggeredHook

    Permalink

    Saves summaries to files based on a HookTrigger.

  11. trait SummaryWriterHookAddOn extends Hook

    Permalink

    Add-on trait for hooks that provides convenience methods for using a summary writer.

  12. class TensorLogger extends TriggeredHook

    Permalink

    Logs the values of the provided tensors based on a HookTrigger, or at the end of a run (i.e., end of a Session's usage.

    Logs the values of the provided tensors based on a HookTrigger, or at the end of a run (i.e., end of a Session's usage. The tensors will be printed using INFO logging level/severity. If you are not seeing the logs, you might want to changing the logging level in your logging configuration file.

    Note that if logAtEnd is true, tensors should not include any tensor whose evaluation produces a side effect, such as consuming additional inputs.

  13. case class TimeHookTrigger(numSeconds: Double, startStep: Int = 0) extends HookTrigger with Product with Serializable

    Permalink

    Hook trigger that triggers at most once every numSeconds seconds.

    Hook trigger that triggers at most once every numSeconds seconds.

    numSeconds

    Triggering time frequency.

    startStep

    Step after which to start triggering.

  14. class TimelineHook extends TriggeredHook

    Permalink

    Hook that saves Chrome trace files for visualizing execution timelines of TensorFlow steps.

  15. abstract class TriggeredHook extends Hook

    Permalink

    Hook that may be triggered at certain steps or time points.

Value Members

  1. object CheckpointSaver

    Permalink
  2. object Evaluator

    Permalink
  3. object Hook

    Permalink

    Contains helper classes for the Hook class.

  4. object HookTrigger

    Permalink
  5. object LossLogger

    Permalink
  6. object NaNChecker

    Permalink
  7. object NoHookTrigger extends HookTrigger with Product with Serializable

    Permalink

    Hook trigger that never actually triggers.

  8. object StepRateLogger

    Permalink
  9. object Stopper

    Permalink
  10. object SummarySaver

    Permalink
  11. object TensorLogger

    Permalink
  12. object TimelineHook

    Permalink

Inherited from AnyRef

Inherited from Any

Ungrouped