Class TransferLearningHelper
- java.lang.Object
-
- org.deeplearning4j.nn.transferlearning.TransferLearningHelper
-
public class TransferLearningHelper extends Object
-
-
Constructor Summary
Constructors Constructor Description TransferLearningHelper(ComputationGraph orig)
Expects a computation graph where some vertices are frozenTransferLearningHelper(ComputationGraph orig, String... frozenOutputAt)
Will modify the given comp graph (in place!) to freeze vertices from input to the vertex specified.TransferLearningHelper(MultiLayerNetwork orig)
Expects a MLN where some layers are frozenTransferLearningHelper(MultiLayerNetwork orig, int frozenTill)
Will modify the given MLN (in place!) to freeze layers (hold params constant during training) specified and below
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description void
errorIfGraphIfMLN()
DataSet
featurize(DataSet input)
During training frozen vertices/layers can be treated as "featurizing" the input The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate quickly on the smaller unfrozen part of the model Currently does not support datasets with feature masksMultiDataSet
featurize(MultiDataSet input)
During training frozen vertices/layers can be treated as "featurizing" the input The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate quickly on the smaller unfrozen part of the model Currently does not support datasets with feature masksvoid
fitFeaturized(DataSetIterator iter)
void
fitFeaturized(MultiDataSetIterator iter)
Fit from a featurized dataset.void
fitFeaturized(DataSet input)
void
fitFeaturized(MultiDataSet input)
INDArray
outputFromFeaturized(INDArray input)
Use to get the output from a featurized inputINDArray[]
outputFromFeaturized(INDArray[] input)
Use to get the output from a featurized inputComputationGraph
unfrozenGraph()
Returns the unfrozen subset of the original computation graph as a computation graph Note that with each call to featurizedFit the parameters to the original computation graph are also updatedMultiLayerNetwork
unfrozenMLN()
Returns the unfrozen layers of the MultiLayerNetwork as a multilayernetwork Note that with each call to featurizedFit the parameters to the original MLN are also updated
-
-
-
Constructor Detail
-
TransferLearningHelper
public TransferLearningHelper(ComputationGraph orig, String... frozenOutputAt)
Will modify the given comp graph (in place!) to freeze vertices from input to the vertex specified.- Parameters:
orig
- Comp graphfrozenOutputAt
- vertex to freeze at (hold params constant during training)
-
TransferLearningHelper
public TransferLearningHelper(ComputationGraph orig)
Expects a computation graph where some vertices are frozen- Parameters:
orig
-
-
TransferLearningHelper
public TransferLearningHelper(MultiLayerNetwork orig, int frozenTill)
Will modify the given MLN (in place!) to freeze layers (hold params constant during training) specified and below- Parameters:
orig
- MLN to freezefrozenTill
- integer indicating the index of the layer and below to freeze
-
TransferLearningHelper
public TransferLearningHelper(MultiLayerNetwork orig)
Expects a MLN where some layers are frozen- Parameters:
orig
-
-
-
Method Detail
-
errorIfGraphIfMLN
public void errorIfGraphIfMLN()
-
unfrozenGraph
public ComputationGraph unfrozenGraph()
Returns the unfrozen subset of the original computation graph as a computation graph Note that with each call to featurizedFit the parameters to the original computation graph are also updated
-
unfrozenMLN
public MultiLayerNetwork unfrozenMLN()
Returns the unfrozen layers of the MultiLayerNetwork as a multilayernetwork Note that with each call to featurizedFit the parameters to the original MLN are also updated
-
outputFromFeaturized
public INDArray[] outputFromFeaturized(INDArray[] input)
Use to get the output from a featurized input- Parameters:
input
- featurized data- Returns:
- output
-
outputFromFeaturized
public INDArray outputFromFeaturized(INDArray input)
Use to get the output from a featurized input- Parameters:
input
- featurized data- Returns:
- output
-
featurize
public MultiDataSet featurize(MultiDataSet input)
During training frozen vertices/layers can be treated as "featurizing" the input The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate quickly on the smaller unfrozen part of the model Currently does not support datasets with feature masks- Parameters:
input
- multidataset to feed into the computation graph with frozen layer vertices- Returns:
- a multidataset with input features that are the outputs of the frozen layer vertices and the original labels.
-
featurize
public DataSet featurize(DataSet input)
During training frozen vertices/layers can be treated as "featurizing" the input The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate quickly on the smaller unfrozen part of the model Currently does not support datasets with feature masks- Parameters:
input
- multidataset to feed into the computation graph with frozen layer vertices- Returns:
- a multidataset with input features that are the outputs of the frozen layer vertices and the original labels.
-
fitFeaturized
public void fitFeaturized(MultiDataSetIterator iter)
Fit from a featurized dataset. The fit is conducted on an internally instantiated subset model that is representative of the unfrozen part of the original model. After each call on fit the parameters for the original model are updated- Parameters:
iter
-
-
fitFeaturized
public void fitFeaturized(MultiDataSet input)
-
fitFeaturized
public void fitFeaturized(DataSet input)
-
fitFeaturized
public void fitFeaturized(DataSetIterator iter)
-
-