Class/Object

org.platanios.tensorflow.api.learn.estimators

Estimator

Related Docs: object Estimator | package estimators

Permalink

abstract class Estimator[IT, IO, ID, IS, I, TT, TO, TD, TS, EI] extends AnyRef

Abstract class for estimators which are used to train, use, and evaluate TensorFlow models.

The Estimator class wraps a model which is specified by a modelFunction, which, given inputs and a number of other parameters, creates the ops necessary to perform training, evaluation, or predictions, and provides an interface for doing so.

All outputs (checkpoints, event files, etc.) are written to a working directory, provided by configurationBase, or a subdirectory thereof. If a working directory is not set in configurationBase, a temporary directory is used.

The configurationBase argument can be passed a Configuration object containing information about the execution environment. It is passed on to the modelFunction, if the modelFunction has an argument with Configuration type (and input functions in the same manner). If the configurationBase argument is not passed, it is instantiated by the Estimator. Not passing a configuration means that defaults useful for local execution are used. The Estimator class makes the configuration available to the model (for instance, to allow specialization based on the number of workers available), and also uses some of its fields to control internals, especially regarding saving checkpoints while training.

For models that have hyper-parameters it is recommended to incorporate them in modelFunction before instantiating an estimator. This is in contrast to the TensorFlow Python API, but the reason behind the divergence is that the estimator class never uses the provided hyper-parameters. The recommended way to deal with hyper-parameters in the Scala API is to create a model function with two parameter lists, the first one being the hyper-parameters and the second one being those supported by the model-generating function (i.e., optionally a Mode and a Configuration).

Linear Supertypes
Known Subclasses
Ordering
  1. Alphabetic
  2. By Inheritance
Inherited
  1. Estimator
  2. AnyRef
  3. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Abstract Value Members

  1. abstract def evaluate(data: () ⇒ Dataset[TT, TO, TD, TS], metrics: Seq[Metric[EI, ops.Output]], maxSteps: Long = 1L, saveSummaries: Boolean = true, name: String = null): Seq[tensors.Tensor[types.FLOAT32]]

    Permalink

    Evaluates the model managed by this estimator given the provided evaluation data, data.

    Evaluates the model managed by this estimator given the provided evaluation data, data.

    The evaluation process is iterative. In each step, a data batch is obtained from data and internal metric value accumulators are updated. The number of steps to perform is controlled through the maxSteps argument. If set to -1, then all batches from data will be processed.

    data

    Evaluation dataset. Each element is a tuple over input and training inputs (i.e., supervision labels).

    metrics

    Evaluation metrics to use.

    maxSteps

    Maximum number of evaluation steps to perform. If -1, the evaluation process will run until data is exhausted.

    saveSummaries

    Boolean indicator specifying whether to save the evaluation results as summaries in the working directory of this estimator.

    name

    Name for this evaluation. If provided, it will be used to generate an appropriate directory name for the resulting summaries. If saveSummaries is false, this argument has no effect. This is useful if the user needs to run multiple evaluations on different data sets, such as on training data vs test data. Metrics for different evaluations are saved in separate folders, and appear separately in TensorBoard.

    returns

    Evaluation metric values at the end of the evaluation process. The return sequence matches the ordering of metrics.

    Annotations
    @throws( ... )
    Exceptions thrown

    InvalidArgumentException If saveSummaries is true, but the estimator has no working directory specified.

  2. abstract def infer[InferInput, InferOutput, ModelInferenceOutput](input: () ⇒ InferInput)(implicit evFetchableIO: Aux[IO, IT], evFetchableI: Aux[I, ModelInferenceOutput], evFetchableIIO: Aux[(IO, I), (IT, ModelInferenceOutput)], ev: SupportedInferInput[InferInput, InferOutput, IT, IO, ID, IS, ModelInferenceOutput]): InferOutput

    Permalink

    Infers output (i.e., computes predictions) for input using the model managed by this estimator.

    Infers output (i.e., computes predictions) for input using the model managed by this estimator.

    input can be of one of the following types:

    • A Dataset, in which case this method returns an iterator over (input, output) tuples corresponding to each element in the dataset. Note that the predictions are computed lazily in this case, whenever an element is requested from the returned iterator.
    • A single input of type IT, in which case this method returns a prediction of type I.

    Note that, ModelInferenceOutput refers to the tensor type that corresponds to the symbolic type I. For example, if I is (Output, Output), then ModelInferenceOutput will be (Tensor, Tensor).

    input

    Input for the predictions.

    returns

    Either an iterator over (IT, ModelInferenceOutput) tuples, or a single element of type I, depending on the type of input.

  3. abstract def train(data: () ⇒ Dataset[TT, TO, TD, TS], stopCriteria: StopCriteria = StopCriteria()): Unit

    Permalink

    Trains the model managed by this estimator.

    Trains the model managed by this estimator.

    data

    Training dataset. Each element is a tuple over input and training inputs (i.e., supervision labels).

    stopCriteria

    Stop criteria to use for stopping the training iteration. For the default criteria please refer to the documentation of StopCriteria.

Concrete Value Members

  1. final def !=(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  2. final def ##(): Int

    Permalink
    Definition Classes
    AnyRef → Any
  3. final def ==(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  4. final def asInstanceOf[T0]: T0

    Permalink
    Definition Classes
    Any
  5. def checkpointConfig: CheckpointConfig

    Permalink

    Checkpoint configuration used by this estimator.

  6. def clone(): AnyRef

    Permalink
    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  7. val configuration: Configuration

    Permalink

    Run configuration used for this estimator.

  8. val configurationBase: Configuration

    Permalink

    Configuration base for this estimator.

    Configuration base for this estimator. This allows for setting up distributed training environments, for example. Note that this is a *base* for a configuration because the estimator might modify it and set some missing fields to appropriate default values, in order to obtain its final configuration that can be obtain through its configuration field.

    Attributes
    protected
  9. val deviceFunction: Option[(OpSpecification) ⇒ String]

    Permalink

    Device function used by this estimator for managing replica device placement when using distributed training.

  10. final def eq(arg0: AnyRef): Boolean

    Permalink
    Definition Classes
    AnyRef
  11. def equals(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  12. def finalize(): Unit

    Permalink
    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( classOf[java.lang.Throwable] )
  13. final def getClass(): Class[_]

    Permalink
    Definition Classes
    AnyRef → Any
  14. def getOrCreateSaver(): Option[Saver]

    Permalink

    Gets an existing saver from the current graph, or creates a new one if none exists.

    Gets an existing saver from the current graph, or creates a new one if none exists.

    Attributes
    protected
  15. def hashCode(): Int

    Permalink
    Definition Classes
    AnyRef → Any
  16. final def isInstanceOf[T0]: Boolean

    Permalink
    Definition Classes
    Any
  17. val modelFunction: ModelFunction[IT, IO, ID, IS, I, TT, TO, TD, TS, EI]

    Permalink

    Model-generating function that can optionally have a Configuration argument which will be used to pass the estimator's configuration to the model and allows customizing the model based on the execution environment.

    Model-generating function that can optionally have a Configuration argument which will be used to pass the estimator's configuration to the model and allows customizing the model based on the execution environment.

    Attributes
    protected
  18. final def ne(arg0: AnyRef): Boolean

    Permalink
    Definition Classes
    AnyRef
  19. final def notify(): Unit

    Permalink
    Definition Classes
    AnyRef
  20. final def notifyAll(): Unit

    Permalink
    Definition Classes
    AnyRef
  21. def randomSeed: Option[Int]

    Permalink

    Random seed value to be used by the TensorFlow initializers in this estimator.

  22. def saveEvaluationSummaries(step: Long, metrics: Seq[Metric[EI, ops.Output]], metricValues: Seq[tensors.Tensor[types.FLOAT32]], name: String = null): Unit

    Permalink
    Attributes
    protected
  23. def sessionConfig: Option[SessionConfig]

    Permalink

    Session configuration used by this estimator.

  24. final def synchronized[T0](arg0: ⇒ T0): T0

    Permalink
    Definition Classes
    AnyRef
  25. def toString(): String

    Permalink
    Definition Classes
    AnyRef → Any
  26. final def wait(): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  27. final def wait(arg0: Long, arg1: Int): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  28. final def wait(arg0: Long): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  29. def workingDir: Option[Path]

    Permalink

    Working directory used by this estimator, used to save model parameters, graph, etc.

    Working directory used by this estimator, used to save model parameters, graph, etc. It can also be used to load checkpoints for a previously saved model.

Inherited from AnyRef

Inherited from Any

Ungrouped