public interface Trainer
extends java.lang.AutoCloseable
Trainer
interface provides a session for model training.
Trainer
provides an easy, and manageable interface for training. Trainer
is
not thread-safe.
See the tutorials on:
Modifier and Type | Method and Description |
---|---|
void |
close() |
void |
endEpoch()
Runs the end epoch actions.
|
NDList |
forward(NDList input)
Applies the forward function of the model once on the given input
NDList . |
java.util.List<Device> |
getDevices()
Returns the devices used for training.
|
<T extends Evaluator> |
getEvaluator(java.lang.Class<T> clazz)
Gets the
Evaluator that is an instance of the given Class . |
java.util.List<Evaluator> |
getEvaluators()
Gets all
Evaluator s. |
Loss |
getLoss()
Gets the training
Loss function of the trainer. |
NDManager |
getManager()
Gets the
NDManager from the model. |
Metrics |
getMetrics()
Returns the Metrics param used for benchmarking.
|
Model |
getModel()
Returns the model used to create this trainer.
|
void |
initialize(Shape... shapes)
Initializes the
Model that the Trainer is going to train. |
default java.lang.Iterable<Batch> |
iterateDataset(Dataset dataset)
Fetches an iterator that can iterate through the given
Dataset . |
GradientCollector |
newGradientCollector()
Returns a new instance of
GradientCollector . |
void |
setMetrics(Metrics metrics)
Attaches a Metrics param to use for benchmarking.
|
void |
step()
Updates all of the parameters of the model once.
|
void |
trainBatch(Batch batch)
Trains the model with one iteration of the given
Batch of data. |
void |
validateBatch(Batch batch)
Validates the given batch of data.
|
void initialize(Shape... shapes)
Model
that the Trainer
is going to train.shapes
- an array of Shape
of the inputsdefault java.lang.Iterable<Batch> iterateDataset(Dataset dataset)
Dataset
.dataset
- the dataset to iterate throughIterable
of Batch
that contains batches of data from the datasetGradientCollector newGradientCollector()
GradientCollector
.GradientCollector
void trainBatch(Batch batch)
Batch
of data.batch
- a Batch
that contains data, and its respective labelsjava.lang.IllegalArgumentException
- if the batch engine does not match the trainer engineNDList forward(NDList input)
NDList
.input
- the input NDList
void validateBatch(Batch batch)
During validation, the evaluators and losses are computed, but gradients aren't computed, and parameters aren't updated.
batch
- a Batch
of datajava.lang.IllegalArgumentException
- if the batch engine does not match the trainer enginevoid step()
Metrics getMetrics()
void setMetrics(Metrics metrics)
metrics
- the Metrics classjava.util.List<Device> getDevices()
void endEpoch()
Model getModel()
java.util.List<Evaluator> getEvaluators()
Evaluator
s.<T extends Evaluator> T getEvaluator(java.lang.Class<T> clazz)
Evaluator
that is an instance of the given Class
.T
- the type of the training evaluatorclazz
- the Class
of the Evaluator
soughtvoid close()
close
in interface java.lang.AutoCloseable