public class Trainer
extends java.lang.Object
implements java.lang.AutoCloseable
Trainer
interface provides a session for model training.
Trainer
provides an easy, and manageable interface for training. Trainer
is
not thread-safe.
See the tutorials on:
Constructor and Description |
---|
Trainer(Model model,
TrainingConfig trainingConfig)
|
Modifier and Type | Method and Description |
---|---|
void |
addMetric(java.lang.String metricName,
long begin)
Helper to add a metric for a time difference.
|
void |
close() |
NDList |
evaluate(NDList input)
Evaluates function of the model once on the given input
NDList . |
protected void |
finalize() |
NDList |
forward(NDList input)
Applies the forward function of the model once on the given input
NDList . |
NDList |
forward(NDList data,
NDList labels)
Applies the forward function of the model once with both data and labels.
|
DataManager |
getDataManager()
Returns the
DataManager . |
Device[] |
getDevices()
Returns the devices used for training.
|
java.util.List<Evaluator> |
getEvaluators()
Gets all
Evaluator s. |
Loss |
getLoss()
Gets the training
Loss function of the trainer. |
NDManager |
getManager()
Gets the
NDManager from the model. |
Metrics |
getMetrics()
Returns the Metrics param used for benchmarking.
|
Model |
getModel()
Returns the model used to create this trainer.
|
TrainingResult |
getTrainingResult()
Returns the
TrainingResult . |
void |
initialize(Shape... shapes)
Initializes the
Model that the Trainer is going to train. |
java.lang.Iterable<Batch> |
iterateDataset(Dataset dataset)
Fetches an iterator that can iterate through the given
Dataset . |
GradientCollector |
newGradientCollector()
Returns a new instance of
GradientCollector . |
void |
notifyListeners(java.util.function.Consumer<TrainingListener> listenerConsumer)
Executes a method on each of the
TrainingListener s. |
void |
setMetrics(Metrics metrics)
Attaches a Metrics param to use for benchmarking.
|
void |
step()
Updates all of the parameters of the model once.
|
public Trainer(Model model, TrainingConfig trainingConfig)
model
- the model the trainer will train ontrainingConfig
- the configuration used by the trainerpublic void initialize(Shape... shapes)
Model
that the Trainer
is going to train.shapes
- an array of Shape
of the inputspublic java.lang.Iterable<Batch> iterateDataset(Dataset dataset) throws java.io.IOException, TranslateException
Dataset
.dataset
- the dataset to iterate throughIterable
of Batch
that contains batches of data from the datasetjava.io.IOException
- for various exceptions depending on the datasetTranslateException
- if there is an error while processing inputpublic GradientCollector newGradientCollector()
GradientCollector
.GradientCollector
public NDList forward(NDList input)
NDList
.input
- the input NDList
public NDList forward(NDList data, NDList labels)
public NDList evaluate(NDList input)
NDList
.input
- the input NDList
public void step()
public Metrics getMetrics()
public void setMetrics(Metrics metrics)
metrics
- the Metrics classpublic Device[] getDevices()
public Loss getLoss()
Loss
function of the trainer.Loss
functionpublic Model getModel()
public DataManager getDataManager()
DataManager
.DataManager
public java.util.List<Evaluator> getEvaluators()
Evaluator
s.public void notifyListeners(java.util.function.Consumer<TrainingListener> listenerConsumer)
TrainingListener
s.listenerConsumer
- a consumer that executes the methodpublic TrainingResult getTrainingResult()
TrainingResult
.TrainingResult
protected void finalize() throws java.lang.Throwable
finalize
in class java.lang.Object
java.lang.Throwable
public void close()
close
in interface java.lang.AutoCloseable
public void addMetric(java.lang.String metricName, long begin)
metricName
- the metric namebegin
- the time difference start (this method is called at the time difference end)