public class TrainingConfig extends Object
SameDiff
instance.IUpdater
to use (i.e., Adam
, Nesterovs
etc.
The IUpdater instance is also how the learning rate (or learning rate schedule) is set.Modifier and Type | Class and Description |
---|---|
static class |
TrainingConfig.Builder |
Modifier | Constructor and Description |
---|---|
|
TrainingConfig(IUpdater updater,
List<Regularization> regularization,
boolean minimize,
List<String> dataSetFeatureMapping,
List<String> dataSetLabelMapping,
List<String> dataSetFeatureMaskMapping,
List<String> dataSetLabelMaskMapping,
List<String> lossVariables)
Create a training configuration suitable for training both single input/output and multi input/output networks.
See also the TrainingConfig.Builder for creating a TrainingConfig |
protected |
TrainingConfig(IUpdater updater,
List<Regularization> regularization,
boolean minimize,
List<String> dataSetFeatureMapping,
List<String> dataSetLabelMapping,
List<String> dataSetFeatureMaskMapping,
List<String> dataSetLabelMaskMapping,
List<String> lossVariables,
Map<String,List<IEvaluation>> trainEvaluations,
Map<String,Integer> trainEvaluationLabels,
Map<String,List<IEvaluation>> validationEvaluations,
Map<String,Integer> validationEvaluationLabels) |
|
TrainingConfig(IUpdater updater,
List<Regularization> regularization,
String dataSetFeatureMapping,
String dataSetLabelMapping)
Create a training configuration suitable for training a single input, single output network.
See also the TrainingConfig.Builder for creating a TrainingConfig |
Modifier and Type | Method and Description |
---|---|
static TrainingConfig.Builder |
builder() |
static TrainingConfig |
fromJson(@NonNull String json) |
void |
incrementEpochCount()
Increment the epoch count by 1
|
void |
incrementIterationCount()
Increment the iteration count by 1
|
int |
labelIdx(String s)
Get the index of the label array that the specified variable is associated with
|
static void |
removeInstances(List<?> list,
Class<?> remove)
Remove any instances of the specified type from the list.
|
static void |
removeInstancesWithWarning(List<?> list,
Class<?> remove,
String warning) |
String |
toJson() |
public TrainingConfig(IUpdater updater, List<Regularization> regularization, String dataSetFeatureMapping, String dataSetLabelMapping)
TrainingConfig.Builder
for creating a TrainingConfigupdater
- The updater configuration to usedataSetFeatureMapping
- The name of the placeholder/variable that should be set using the feature INDArray from the DataSet
(or the first/only feature from a MultiDataSet). For example, if the network input placeholder was
called "input" then this should be set to "input"dataSetLabelMapping
- The name of the placeholder/variable that should be set using the label INDArray from the DataSet
(or the first/only feature from a MultiDataSet). For example, if the network input placeholder was
called "input" then this should be set to "input"public TrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping, List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables)
TrainingConfig.Builder
for creating a TrainingConfigupdater
- The updater configuration to useregularization
- Regularization for all trainable parameters;\minimize
- Set to true if the loss function should be minimized (usually true). False to maximizedataSetFeatureMapping
- The name of the placeholders/variables that should be set using the feature INDArray(s) from the
DataSet or MultiDataSet. For example, if the network had 2 inputs called "input1" and "input2"
and the MultiDataSet features should be mapped with MultiDataSet.getFeatures(0)->"input1"
and MultiDataSet.getFeatures(1)->"input2"
, then this should be set to List<>("input1", "input2")
.dataSetLabelMapping
- As per dataSetFeatureMapping, but for the DataSet/MultiDataSet labelsdataSetFeatureMaskMapping
- May be null. If non-null, the variables that the MultiDataSet feature mask arrays should be associated with.dataSetLabelMaskMapping
- May be null. If non-null, the variables that the MultiDataSet label mask arrays should be associated with.protected TrainingConfig(IUpdater updater, List<Regularization> regularization, boolean minimize, List<String> dataSetFeatureMapping, List<String> dataSetLabelMapping, List<String> dataSetFeatureMaskMapping, List<String> dataSetLabelMaskMapping, List<String> lossVariables, Map<String,List<IEvaluation>> trainEvaluations, Map<String,Integer> trainEvaluationLabels, Map<String,List<IEvaluation>> validationEvaluations, Map<String,Integer> validationEvaluationLabels)
public void incrementIterationCount()
public void incrementEpochCount()
public static TrainingConfig.Builder builder()
public int labelIdx(String s)
s
- Name of the variablepublic static void removeInstances(List<?> list, Class<?> remove)
list
- List. May be nullremove
- Type of objects to removepublic static void removeInstancesWithWarning(List<?> list, Class<?> remove, String warning)
public String toJson()
public static TrainingConfig fromJson(@NonNull @NonNull String json)
Copyright © 2020. All rights reserved.