Class EasyHpo

java.lang.Object
ai.djl.training.hyperparameter.EasyHpo

public abstract class EasyHpo extends Object
Helper for easy training with hyperparameters.
  • Constructor Details

    • EasyHpo

      public EasyHpo()
  • Method Details

    • fit

      public ai.djl.util.Pair<Model,TrainingResult> fit() throws IOException, TranslateException
      Fits the model given the implemented abstract methods.
      Returns:
      the best model and training results
      Throws:
      IOException - for various exceptions depending on the dataset
      TranslateException - if there is an error while processing input
    • setupHyperParams

      protected abstract HpSet setupHyperParams()
      Returns the initial hyperparameters.
      Returns:
      the initial hyperparameters
    • getDataset

      protected abstract RandomAccessDataset getDataset(Dataset.Usage usage) throws IOException
      Returns the dataset to train with.
      Parameters:
      usage - the usage of the dataset
      Returns:
      the dataset to train with
      Throws:
      IOException - if the dataset could not be loaded
    • setupTrainingConfig

      protected abstract TrainingConfig setupTrainingConfig(HpSet hpVals)
      Returns the TrainingConfig to use to train each hyperparameter set.
      Parameters:
      hpVals - the hyperparameters to train with
      Returns:
      the TrainingConfig to use to train each hyperparameter set
    • buildModel

      protected abstract Model buildModel(HpSet hpVals)
      Builds the Model and Block to train.
      Parameters:
      hpVals - the hyperparameter values to use for the model
      Returns:
      the model to train
    • inputShape

      protected abstract Shape inputShape(HpSet hpVals)
      Returns the input shape for the model.
      Parameters:
      hpVals - the hyperparameter values for the model
      Returns:
      returns the model input shape
    • numEpochs

      protected abstract int numEpochs(HpSet hpVals)
      Returns the number of epochs to train for the current hyperparameter set.
      Parameters:
      hpVals - the current hyperparameter set
      Returns:
      the number of epochs
    • numHyperParameterTests

      protected abstract int numHyperParameterTests()
      Returns the number of hyperparameter sets to train with.
      Returns:
      the number of hyperparameter sets to train with
    • saveModel

      protected void saveModel(Model model, TrainingResult result) throws IOException
      Saves the best hyperparameter set.
      Parameters:
      model - the model to save
      result - the training result for training with this model's hyperparameters
      Throws:
      IOException - if the model could not be saved