Package ai.djl.training.hyperparameter
Class EasyHpo
java.lang.Object
ai.djl.training.hyperparameter.EasyHpo
Helper for easy training with hyperparameters.
-
Constructor Summary
Constructors -
Method Summary
Modifier and TypeMethodDescriptionprotected abstract Model
buildModel
(HpSet hpVals) ai.djl.util.Pair<Model,
TrainingResult> fit()
Fits the model given the implemented abstract methods.protected abstract RandomAccessDataset
getDataset
(Dataset.Usage usage) Returns the dataset to train with.protected abstract Shape
inputShape
(HpSet hpVals) Returns the input shape for the model.protected abstract int
Returns the number of epochs to train for the current hyperparameter set.protected abstract int
Returns the number of hyperparameter sets to train with.protected void
saveModel
(Model model, TrainingResult result) Saves the best hyperparameter set.protected abstract HpSet
Returns the initial hyperparameters.protected abstract TrainingConfig
setupTrainingConfig
(HpSet hpVals) Returns theTrainingConfig
to use to train each hyperparameter set.
-
Constructor Details
-
EasyHpo
public EasyHpo()
-
-
Method Details
-
fit
Fits the model given the implemented abstract methods.- Returns:
- the best model and training results
- Throws:
IOException
- for various exceptions depending on the datasetTranslateException
- if there is an error while processing input
-
setupHyperParams
Returns the initial hyperparameters.- Returns:
- the initial hyperparameters
-
getDataset
Returns the dataset to train with.- Parameters:
usage
- the usage of the dataset- Returns:
- the dataset to train with
- Throws:
IOException
- if the dataset could not be loaded
-
setupTrainingConfig
Returns theTrainingConfig
to use to train each hyperparameter set.- Parameters:
hpVals
- the hyperparameters to train with- Returns:
- the
TrainingConfig
to use to train each hyperparameter set
-
buildModel
- Parameters:
hpVals
- the hyperparameter values to use for the model- Returns:
- the model to train
-
inputShape
Returns the input shape for the model.- Parameters:
hpVals
- the hyperparameter values for the model- Returns:
- returns the model input shape
-
numEpochs
Returns the number of epochs to train for the current hyperparameter set.- Parameters:
hpVals
- the current hyperparameter set- Returns:
- the number of epochs
-
numHyperParameterTests
protected abstract int numHyperParameterTests()Returns the number of hyperparameter sets to train with.- Returns:
- the number of hyperparameter sets to train with
-
saveModel
Saves the best hyperparameter set.- Parameters:
model
- the model to saveresult
- the training result for training with this model's hyperparameters- Throws:
IOException
- if the model could not be saved
-