public static class TrainingConfig.Builder extends Object
Constructor and Description |
---|
Builder() |
Modifier and Type | Method and Description |
---|---|
TrainingConfig.Builder |
addEvaluations(boolean validation,
@NonNull String variableName,
int labelIndex,
IEvaluation... evaluations)
Add requested evaluations for a parm/variable, for either training or validation.
|
TrainingConfig.Builder |
addRegularization(Regularization... regularizations)
Add regularization to all trainable parameters in the network
|
TrainingConfig |
build() |
TrainingConfig.Builder |
dataSetFeatureMapping(List<String> dataSetFeatureMapping)
Set the name of the placeholders/variables that should be set using the feature INDArray(s) from the
DataSet or MultiDataSet.
|
TrainingConfig.Builder |
dataSetFeatureMapping(String... dataSetFeatureMapping)
Set the name of the placeholders/variables that should be set using the feature INDArray(s) from the
DataSet or MultiDataSet.
|
TrainingConfig.Builder |
dataSetFeatureMaskMapping(List<String> dataSetFeatureMaskMapping)
Set the name of the placeholders/variables that should be set using the feature mask INDArray(s) from the
DataSet or MultiDataSet.
|
TrainingConfig.Builder |
dataSetFeatureMaskMapping(String... dataSetFeatureMaskMapping)
|
TrainingConfig.Builder |
dataSetLabelMapping(List<String> dataSetLabelMapping)
Set the name of the placeholders/variables that should be set using the labels INDArray(s) from the
DataSet or MultiDataSet.
|
TrainingConfig.Builder |
dataSetLabelMapping(String... dataSetLabelMapping)
Set the name of the placeholders/variables that should be set using the labels INDArray(s) from the
DataSet or MultiDataSet.
|
TrainingConfig.Builder |
dataSetLabelMaskMapping(List<String> dataSetLabelMaskMapping)
Set the name of the placeholders/variables that should be set using the label mask INDArray(s) from the
DataSet or MultiDataSet.
|
TrainingConfig.Builder |
dataSetLabelMaskMapping(String... dataSetLabelMaskMapping)
|
TrainingConfig.Builder |
l1(double l1)
Sets the L1 regularization coefficient for all trainable parameters.
|
TrainingConfig.Builder |
l2(double l2)
Sets the L2 regularization coefficient for all trainable parameters.
|
TrainingConfig.Builder |
markLabelsUnused()
Calling this method will mark the label as unused.
|
TrainingConfig.Builder |
minimize(boolean minimize)
Sets whether the loss function should be minimized (true) or maximized (false).
The loss function is usually minimized in SGD. Default: true. |
TrainingConfig.Builder |
minimize(String... lossVariables) |
TrainingConfig.Builder |
regularization(List<Regularization> regularization)
Set the regularization for all trainable parameters in the network.
|
TrainingConfig.Builder |
regularization(Regularization... regularization)
Set the regularization for all trainable parameters in the network.
|
TrainingConfig.Builder |
skipBuilderValidation(boolean skip) |
TrainingConfig.Builder |
trainEvaluation(@NonNull SDVariable variable,
int labelIndex,
IEvaluation... evaluations)
Add requested History training evaluations for a parm/variable.
|
TrainingConfig.Builder |
trainEvaluation(@NonNull String variableName,
int labelIndex,
IEvaluation... evaluations)
Add requested History training evaluations for a parm/variable.
|
TrainingConfig.Builder |
updater(IUpdater updater)
|
TrainingConfig.Builder |
validationEvaluation(@NonNull SDVariable variable,
int labelIndex,
IEvaluation... evaluations)
Add requested History validation evaluations for a parm/variable.
|
TrainingConfig.Builder |
validationEvaluation(@NonNull String variableName,
int labelIndex,
IEvaluation... evaluations)
Add requested History validation evaluations for a parm/variable.
|
TrainingConfig.Builder |
weightDecay(double coefficient,
boolean applyLR)
Add weight decay regularization for all trainable parameters.
|
public TrainingConfig.Builder updater(IUpdater updater)
Adam
, Nesterovs
etc. This is also how the learning rate (or learning rate schedule) is set.updater
- Updater to setpublic TrainingConfig.Builder l1(double l1)
L1Regularization
for more detailsl1
- L1 regularization coefficientpublic TrainingConfig.Builder l2(double l2)
WeightDecay
(set via weightDecay(double,boolean)
should be preferred to
L2 regularization. See WeightDecay
javadoc for further details.weightDecay(double, boolean)
public TrainingConfig.Builder weightDecay(double coefficient, boolean applyLR)
WeightDecay
for more details.coefficient
- Weight decay regularization coefficientapplyLR
- Whether the learning rate should be multiplied in when performing weight decay updates. See WeightDecay
for more details.public TrainingConfig.Builder addRegularization(Regularization... regularizations)
regularizations
- Regularization type(s) to addpublic TrainingConfig.Builder regularization(Regularization... regularization)
regularization
- Regularization type(s) to addpublic TrainingConfig.Builder regularization(List<Regularization> regularization)
regularization
- Regularization type(s) to addpublic TrainingConfig.Builder minimize(boolean minimize)
minimize
- True to minimize, false to maximizepublic TrainingConfig.Builder dataSetFeatureMapping(String... dataSetFeatureMapping)
MultiDataSet.getFeatures(0)->"input1"
and MultiDataSet.getFeatures(1)->"input2"
, then this should be set to List<>("input1", "input2")
.dataSetFeatureMapping
- Name of the variables/placeholders that the feature arrays should be mapped topublic TrainingConfig.Builder dataSetFeatureMapping(List<String> dataSetFeatureMapping)
MultiDataSet.getFeatures(0)->"input1"
and MultiDataSet.getFeatures(1)->"input2"
, then this should be set to "input1", "input2"
.dataSetFeatureMapping
- Name of the variables/placeholders that the feature arrays should be mapped topublic TrainingConfig.Builder dataSetLabelMapping(String... dataSetLabelMapping)
MultiDataSet.getLabel(0)->"label1"
and MultiDataSet.getLabels(1)->"label"
, then this should be set to "label1", "label2"
.dataSetLabelMapping
- Name of the variables/placeholders that the label arrays should be mapped topublic TrainingConfig.Builder dataSetLabelMapping(List<String> dataSetLabelMapping)
MultiDataSet.getLabel(0)->"label1"
and MultiDataSet.getLabels(1)->"label"
, then this should be set to "label1", "label2"
.dataSetLabelMapping
- Name of the variables/placeholders that the label arrays should be mapped topublic TrainingConfig.Builder markLabelsUnused()
dataSetLabelMapping(String...)
to set labels, this method
allows you to say that the DataSet/MultiDataSet labels aren't used in training.public TrainingConfig.Builder dataSetFeatureMaskMapping(String... dataSetFeatureMaskMapping)
public TrainingConfig.Builder dataSetFeatureMaskMapping(List<String> dataSetFeatureMaskMapping)
MultiDataSet.getFeatureMaskArray(0)->"mask1"
and MultiDataSet.getFeatureMaskArray(1)->"mask2"
, then this should be set to "mask1", "mask2"
.dataSetFeatureMaskMapping
- Name of the variables/placeholders that the feature arrays should be mapped topublic TrainingConfig.Builder dataSetLabelMaskMapping(String... dataSetLabelMaskMapping)
public TrainingConfig.Builder dataSetLabelMaskMapping(List<String> dataSetLabelMaskMapping)
MultiDataSet.getLabelMaskArray(0)->"mask1"
and MultiDataSet.getLabelMaskArray(1)->"mask2"
, then this should be set to "mask1", "mask2"
.dataSetLabelMaskMapping
- Name of the variables/placeholders that the feature arrays should be mapped topublic TrainingConfig.Builder skipBuilderValidation(boolean skip)
public TrainingConfig.Builder minimize(String... lossVariables)
public TrainingConfig.Builder trainEvaluation(@NonNull @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations)
History
object returned by fit.variableName
- The variable to evaluatelabelIndex
- The index of the label to evaluate againstevaluations
- The evaluations to runpublic TrainingConfig.Builder trainEvaluation(@NonNull @NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations)
History
object returned by fit.variable
- The variable to evaluatelabelIndex
- The index of the label to evaluate againstevaluations
- The evaluations to runpublic TrainingConfig.Builder validationEvaluation(@NonNull @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations)
History
object returned by fit.variableName
- The variable to evaluatelabelIndex
- The index of the label to evaluate againstevaluations
- The evaluations to runpublic TrainingConfig.Builder validationEvaluation(@NonNull @NonNull SDVariable variable, int labelIndex, @NonNull IEvaluation... evaluations)
History
object returned by fit.variable
- The variable to evaluatelabelIndex
- The index of the label to evaluate againstevaluations
- The evaluations to runpublic TrainingConfig.Builder addEvaluations(boolean validation, @NonNull @NonNull String variableName, int labelIndex, @NonNull IEvaluation... evaluations)
History
object returned by fit.validation
- Whether to add these evaluations as validation or trainingvariableName
- The variable to evaluatelabelIndex
- The index of the label to evaluate againstevaluations
- The evaluations to runpublic TrainingConfig build()
Copyright © 2020. All rights reserved.