Class Trainer

    • Constructor Detail

      • Trainer

        public Trainer​(Model model,
                       TrainingConfig trainingConfig)
        Creates an instance of Trainer with the given Model and TrainingConfig.
        Parameters:
        model - the model the trainer will train on
        trainingConfig - the configuration used by the trainer
    • Method Detail

      • initialize

        public void initialize​(Shape... shapes)
        Initializes the Model that the Trainer is going to train.
        Parameters:
        shapes - an array of Shape of the inputs
      • iterateDataset

        public java.lang.Iterable<Batch> iterateDataset​(Dataset dataset)
                                                 throws java.io.IOException,
                                                        TranslateException
        Fetches an iterator that can iterate through the given Dataset.
        Parameters:
        dataset - the dataset to iterate through
        Returns:
        an Iterable of Batch that contains batches of data from the dataset
        Throws:
        java.io.IOException - for various exceptions depending on the dataset
        TranslateException - if there is an error while processing input
      • forward

        public NDList forward​(NDList input)
        Applies the forward function of the model once on the given input NDList.
        Parameters:
        input - the input NDList
        Returns:
        the output of the forward function
      • forward

        public NDList forward​(NDList data,
                              NDList labels)
        Applies the forward function of the model once with both data and labels.
        Parameters:
        data - the input data NDList
        labels - the input labels NDList
        Returns:
        the output of the forward function
      • evaluate

        public NDList evaluate​(NDList input)
        Evaluates function of the model once on the given input NDList.
        Parameters:
        input - the input NDList
        Returns:
        the output of the predict function
      • step

        public void step()
        Updates all of the parameters of the model once.
      • getMetrics

        public Metrics getMetrics()
        Returns the Metrics param used for benchmarking.
        Returns:
        the the Metrics param used for benchmarking
      • setMetrics

        public void setMetrics​(Metrics metrics)
        Attaches a Metrics param to use for benchmarking.
        Parameters:
        metrics - the Metrics class
      • getDevices

        public Device[] getDevices()
        Returns the devices used for training.
        Returns:
        the devices used for training
      • getLoss

        public Loss getLoss()
        Gets the training Loss function of the trainer.
        Returns:
        the Loss function
      • getModel

        public Model getModel()
        Returns the model used to create this trainer.
        Returns:
        the model associated with this trainer
      • getExecutorService

        public java.util.Optional<java.util.concurrent.ExecutorService> getExecutorService()
        Returns the ExecutorService.
        Returns:
        the ExecutorService
      • getEvaluators

        public java.util.List<Evaluator> getEvaluators()
        Gets all Evaluators.
        Returns:
        the evaluators used during training
      • notifyListeners

        public void notifyListeners​(java.util.function.Consumer<TrainingListener> listenerConsumer)
        Executes a method on each of the TrainingListeners.
        Parameters:
        listenerConsumer - a consumer that executes the method
      • finalize

        protected void finalize()
                         throws java.lang.Throwable
        Overrides:
        finalize in class java.lang.Object
        Throws:
        java.lang.Throwable
      • close

        public void close()
        Specified by:
        close in interface java.lang.AutoCloseable
      • addMetric

        public void addMetric​(java.lang.String metricName,
                              long begin)
        Helper to add a metric for a time difference.
        Parameters:
        metricName - the metric name
        begin - the time difference start (this method is called at the time difference end)