Class TransferLearningHelper


  • public class TransferLearningHelper
    extends Object
    • Constructor Detail

      • TransferLearningHelper

        public TransferLearningHelper​(ComputationGraph orig,
                                      String... frozenOutputAt)
        Will modify the given comp graph (in place!) to freeze vertices from input to the vertex specified.
        Parameters:
        orig - Comp graph
        frozenOutputAt - vertex to freeze at (hold params constant during training)
      • TransferLearningHelper

        public TransferLearningHelper​(ComputationGraph orig)
        Expects a computation graph where some vertices are frozen
        Parameters:
        orig -
      • TransferLearningHelper

        public TransferLearningHelper​(MultiLayerNetwork orig,
                                      int frozenTill)
        Will modify the given MLN (in place!) to freeze layers (hold params constant during training) specified and below
        Parameters:
        orig - MLN to freeze
        frozenTill - integer indicating the index of the layer and below to freeze
      • TransferLearningHelper

        public TransferLearningHelper​(MultiLayerNetwork orig)
        Expects a MLN where some layers are frozen
        Parameters:
        orig -
    • Method Detail

      • errorIfGraphIfMLN

        public void errorIfGraphIfMLN()
      • unfrozenGraph

        public ComputationGraph unfrozenGraph()
        Returns the unfrozen subset of the original computation graph as a computation graph Note that with each call to featurizedFit the parameters to the original computation graph are also updated
      • unfrozenMLN

        public MultiLayerNetwork unfrozenMLN()
        Returns the unfrozen layers of the MultiLayerNetwork as a multilayernetwork Note that with each call to featurizedFit the parameters to the original MLN are also updated
      • outputFromFeaturized

        public INDArray[] outputFromFeaturized​(INDArray[] input)
        Use to get the output from a featurized input
        Parameters:
        input - featurized data
        Returns:
        output
      • outputFromFeaturized

        public INDArray outputFromFeaturized​(INDArray input)
        Use to get the output from a featurized input
        Parameters:
        input - featurized data
        Returns:
        output
      • featurize

        public MultiDataSet featurize​(MultiDataSet input)
        During training frozen vertices/layers can be treated as "featurizing" the input The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate quickly on the smaller unfrozen part of the model Currently does not support datasets with feature masks
        Parameters:
        input - multidataset to feed into the computation graph with frozen layer vertices
        Returns:
        a multidataset with input features that are the outputs of the frozen layer vertices and the original labels.
      • featurize

        public DataSet featurize​(DataSet input)
        During training frozen vertices/layers can be treated as "featurizing" the input The forward pass through these frozen layer/vertices can be done in advance and the dataset saved to disk to iterate quickly on the smaller unfrozen part of the model Currently does not support datasets with feature masks
        Parameters:
        input - multidataset to feed into the computation graph with frozen layer vertices
        Returns:
        a multidataset with input features that are the outputs of the frozen layer vertices and the original labels.
      • fitFeaturized

        public void fitFeaturized​(MultiDataSetIterator iter)
        Fit from a featurized dataset. The fit is conducted on an internally instantiated subset model that is representative of the unfrozen part of the original model. After each call on fit the parameters for the original model are updated
        Parameters:
        iter -
      • fitFeaturized

        public void fitFeaturized​(MultiDataSet input)
      • fitFeaturized

        public void fitFeaturized​(DataSet input)