Class/Object

org.platanios.tensorflow.api.learn.hooks

Hook

Related Docs: object Hook | package hooks

Permalink

abstract class Hook extends AnyRef

Hook to extend calls to MonitoredSession.run().

Hooks are useful to track training, report progress, request early stopping and more. They use the observer pattern and notify at the following points:

A Hook encapsulates a piece of reusable/composable computation that can piggyback a call to MonitoredSession.run(). A hook can add any feeds/fetches/targets to the run call, and when the run call finishes executing with success, the hook gets the fetches it requested. Hooks are allowed to add ops to the graph in the begin() method. The graph is finalized after the begin() method is called.

There are a few pre-defined hooks that can be used without modification:

For more specific needs you can create custom hooks. For example:

class ExampleHook extends Hook[Output, Unit, Tensor] {
  private[this] val logger: Logger = Logger(LoggerFactory.getLogger("Example Hook"))
  private[this] var exampleTensor: Output

  override def begin(): Unit = {
    // You can add ops to the graph here.
    logger.info("Starting the session.")
    exampleTensor = ...
  }

  override def afterSessionCreation(session: Session): Unit = {
    // When this is called, the graph is finalized and ops can no longer be added to it.
    logger.info("Session created.")
  }

  override def beforeSessionRun[F, E, R](runContext: SessionRunContext[F, E, R])(implicit
    executableEv: Executable[E],
    fetchableEv: Fetchable.Aux[F, R]
   ): Hook.SessionRunArgs[Output, Unit, Tensor] = {
     logger.info("Before calling `Session.run()`.")
     Hook.SessionRunArgs(fetches = exampleTensor)
   }

   override def afterSessionRun[F, E, R](runContext: SessionRunContext[F, E, R], runValues: Tensor)(implicit
     executableEv: Executable[E],
     fetchableEv: Fetchable.Aux[F, R]
   ): Unit = {
     logger.info("Done running one step. The value of the tensor is: ${runValues.summarize()}")
     if (needToStop)
       runContext.requestStop()
   }

   override def end(session: Session): Unit = {
     logger.info("Done with the session.")
   }
}

To understand how hooks interact with calls to MonitoredSession.run(), look at following code:

val session = Estimator.monitoredTrainingSession(hooks = someHook, ...)
while (!session.shouldStop)
  session.run(...)
session.close()

The above user code loosely leads to the following execution:

someHook.begin()
val session = tf.Session()
someHook.afterSessionCreation()
while (!stopRequested) {
  someHook.beforeSessionRun(...)
  try {
    val result = session.run(mergedSessionRunArgs)
    someHook.afterSessionRun(..., result)
  } catch {
    case _: OutOfRangeException => stopRequested = true
  }
}
someHook.end()
session.close()

Note that if session.run() throws an OutOfRangeException then someHook.afterSessionRun() will not be called, but someHook.end() will still be called. On the other hand, if session.run() throws any other exception, then neither someHook.afterSessionRun() nor someHook.end() will be called.

Linear Supertypes
Known Subclasses
Ordering
  1. Alphabetic
  2. By Inheritance
Inherited
  1. Hook
  2. AnyRef
  3. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Instance Constructors

  1. new Hook()

    Permalink

Value Members

  1. final def !=(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  2. final def ##(): Int

    Permalink
    Definition Classes
    AnyRef → Any
  3. final def ==(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  4. def afterSessionCreation(session: core.client.Session): Unit

    Permalink

    Called after a new session is created.

    Called after a new session is created. This is called to signal the hooks that a new session has been created. This callback has two essential differences with the situation in which begin() is called:

    • When this is called, the graph is finalized and ops can no longer be added to it.
    • This method will also be called as a result of recovering a wrapped session (i.e., not only at the beginning of the overall session).
    session

    The session that has been created.

    Attributes
    protected
  5. def afterSessionRun[F, E, R](runContext: SessionRunContext[F, E, R], runResult: SessionRunResult[Seq[ops.Output], Seq[tensors.Tensor[types.DataType]]])(implicit executableEv: Executable[E], fetchableEv: Aux[F, R]): Unit

    Permalink

    Called after each call to Session.run().

    Called after each call to Session.run().

    The runContext argument is the same one passed to beforeSessionRun(). runContext.requestStop() can be called to stop the iteration.

    The runResult argument contains fetched values for the tensors requested by beforeSessionRun().

    If Session.run() throws any exception, then afterSessionRun() will not be called. Note the difference between the end() and the afterSessionRun() behavior when Session.run() throws an OutOfRangeException. In that case, end() is called but afterSessionRun() is not called.

    runContext

    Provides information about the run call (i.e., the originally requested ops/tensors, the session, etc.). Same value as that passed to beforeSessionRun.

    runResult

    Result of the Session.run() call that includes the fetched values for the tensors requested by beforeSessionRun().

    Attributes
    protected
  6. final def asInstanceOf[T0]: T0

    Permalink
    Definition Classes
    Any
  7. def beforeSessionRun[F, E, R](runContext: SessionRunContext[F, E, R])(implicit executableEv: Executable[E], fetchableEv: Aux[F, R]): Option[SessionRunArgs[Seq[ops.Output], Traversable[ops.Op], Seq[tensors.Tensor[types.DataType]]]]

    Permalink

    Called before each call to Session.run().

    Called before each call to Session.run(). You can return from this call a Hook.SessionRunArgs object indicating ops or tensors to add to the upcoming run call. These ops/tensors will be run together with the ops/tensors originally passed to the original run call. The run arguments you return can also contain feeds to be added to the run call.

    The runContext argument is a Hook.SessionRunContext that provides information about the upcoming run call (i.e., the originally requested ops/tensors, the session, etc.).

    At this point the graph is finalized and you should not add any new ops.

    runContext

    Provides information about the upcoming run call (i.e., the originally requested ops/tensors, the session, etc.).

    Attributes
    protected
  8. def begin(): Unit

    Permalink

    Called once before creating the session.

    Called once before creating the session. When called, the default graph is the one that will be launched in the session. The hook can modify the graph by adding new operations to it. After the begin call the graph will be finalized and the other callbacks will not be able to modify the graph anymore. A second begin call on the same graph, should not change that graph.

    Attributes
    protected
  9. def clone(): AnyRef

    Permalink
    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  10. def end(session: core.client.Session): Unit

    Permalink

    Called at the end of the session usage (i.e., Session.run() will not be invoked again after this call).

    Called at the end of the session usage (i.e., Session.run() will not be invoked again after this call).

    The session argument can be used in case the hook wants to execute any final ops, such as saving a last checkpoint.

    If Session.run() throws any exception other than OutOfRangeException then end() will not be called. Note the difference between the end() and the afterSessionRun() behavior when Session.run() throws an OutOfRangeException. In that case, end() is called but afterSessionRun() is not called.

    session

    Session that will not be used again after this call.

    Attributes
    protected
  11. final def eq(arg0: AnyRef): Boolean

    Permalink
    Definition Classes
    AnyRef
  12. def equals(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  13. def finalize(): Unit

    Permalink
    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( classOf[java.lang.Throwable] )
  14. final def getClass(): Class[_]

    Permalink
    Definition Classes
    AnyRef → Any
  15. def hashCode(): Int

    Permalink
    Definition Classes
    AnyRef → Any
  16. final def isInstanceOf[T0]: Boolean

    Permalink
    Definition Classes
    Any
  17. final def ne(arg0: AnyRef): Boolean

    Permalink
    Definition Classes
    AnyRef
  18. final def notify(): Unit

    Permalink
    Definition Classes
    AnyRef
  19. final def notifyAll(): Unit

    Permalink
    Definition Classes
    AnyRef
  20. final def synchronized[T0](arg0: ⇒ T0): T0

    Permalink
    Definition Classes
    AnyRef
  21. def toString(): String

    Permalink
    Definition Classes
    AnyRef → Any
  22. final def wait(): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  23. final def wait(arg0: Long, arg1: Int): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  24. final def wait(arg0: Long): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )

Inherited from AnyRef

Inherited from Any

Ungrouped