public abstract class EasyHpo
extends java.lang.Object
Constructor and Description |
---|
EasyHpo() |
Modifier and Type | Method and Description |
---|---|
protected abstract Model |
buildModel(HpSet hpVals)
|
ai.djl.util.Pair<Model,TrainingResult> |
fit()
Fits the model given the implemented abstract methods.
|
protected abstract RandomAccessDataset |
getDataset(Dataset.Usage usage)
Returns the dataset to train with.
|
protected abstract Shape |
inputShape(HpSet hpVals)
Returns the input shape for the model.
|
protected abstract int |
numEpochs(HpSet hpVals)
Returns the number of epochs to train for the current hyperparameter set.
|
protected abstract int |
numHyperParameterTests()
Returns the number of hyperparameter sets to train with.
|
protected void |
saveModel(Model model,
TrainingResult result)
Saves the best hyperparameter set.
|
protected abstract HpSet |
setupHyperParams()
Returns the initial hyperparameters.
|
protected abstract TrainingConfig |
setupTrainingConfig(HpSet hpVals)
Returns the
TrainingConfig to use to train each hyperparameter set. |
public ai.djl.util.Pair<Model,TrainingResult> fit() throws java.io.IOException, TranslateException
java.io.IOException
- for various exceptions depending on the datasetTranslateException
- if there is an error while processing inputprotected abstract HpSet setupHyperParams()
protected abstract RandomAccessDataset getDataset(Dataset.Usage usage) throws java.io.IOException
usage
- the usage of the datasetjava.io.IOException
- if the dataset could not be loadedprotected abstract TrainingConfig setupTrainingConfig(HpSet hpVals)
TrainingConfig
to use to train each hyperparameter set.hpVals
- the hyperparameters to train withTrainingConfig
to use to train each hyperparameter setprotected abstract Model buildModel(HpSet hpVals)
hpVals
- the hyperparameter values to use for the modelprotected abstract Shape inputShape(HpSet hpVals)
hpVals
- the hyperparameter values for the modelprotected abstract int numEpochs(HpSet hpVals)
hpVals
- the current hyperparameter setprotected abstract int numHyperParameterTests()
protected void saveModel(Model model, TrainingResult result) throws java.io.IOException
model
- the model to saveresult
- the training result for training with this model's hyperparametersjava.io.IOException
- if the model could not be saved