Class BaseMultiLayerUpdater<T extends Model>

    • Field Detail

      • network

        protected final T extends Model network
      • updaterStateViewArray

        protected INDArray updaterStateViewArray
      • initializedMinibatchDivision

        protected boolean initializedMinibatchDivision
      • gradientsForMinibatchDivision

        protected List<INDArray> gradientsForMinibatchDivision
    • Constructor Detail

      • BaseMultiLayerUpdater

        public BaseMultiLayerUpdater​(T network)
      • BaseMultiLayerUpdater

        public BaseMultiLayerUpdater​(T network,
                                     INDArray updaterState)
        Parameters:
        network - Network to create the updater for
        updaterState - The updater state to use. Note: This array is used *directly* and isn't copied/cloned
    • Method Detail

      • getOrderedLayers

        protected abstract Trainable[] getOrderedLayers()
        Returns:
        Array of layers, in the correct order (i.e., same order as the parameter/gradient/updater flattening order - input to output for MultiLayerNetwork, or topological order for ComputationGraph)
      • getFlattenedGradientsView

        public abstract INDArray getFlattenedGradientsView()
        Returns:
        The flattened gradient view array for the model
      • getParams

        protected abstract INDArray getParams()
        Returns:
        The flattened parameter array for the model
      • isMiniBatch

        protected abstract boolean isMiniBatch()
        Returns:
        True if the configuration for the model is set to minibatch (divide by minibatch size), false otherwise
      • setStateViewArray

        public void setStateViewArray​(INDArray viewArray)
        Set the view array. Note that this does an assign operation - the provided array is not stored internally.
        Parameters:
        viewArray - The new updater state
      • setStateViewArray

        public void setStateViewArray​(Trainable layer,
                                      INDArray viewArray,
                                      boolean initialize)
        Description copied from interface: Updater
        Set the internal (historical) state view array for this updater
        Specified by:
        setStateViewArray in interface Updater
        Parameters:
        layer - Layer that this updater belongs to
        viewArray - View array
        initialize - Whether to initialize the array or not
      • getStateViewArrayCopy

        public INDArray getStateViewArrayCopy()
        A synchronized version of getStateViewArray() that duplicates the view array internally. This should be used in preference to getStateViewArray() when the updater state is accessed in one thread while another thread is using the updater for training.
        Returns:
        A copy (duplicate) of the updater state
      • update

        public void update​(Gradient gradient,
                           int iteration,
                           int epoch,
                           int batchSize,
                           LayerWorkspaceMgr workspaceMgr)
        Update the gradient for the model. This operates in 3 steps: 1. Pre-apply: gradient clipping, etc on a per-layer basis 2. Execute the updater (Adam, Nesterov momentum, etc) - in blocks of layers at a time 3. Divide by minibatch size
        Parameters:
        gradient - Gradient to updater
        iteration - The current iteration (i.e., number of parameter updates so far)
        batchSize - The current minibatch size (number of examples)
      • divideByMinibatch

        protected void divideByMinibatch​(boolean isExternal,
                                         Gradient gradient,
                                         int batchSize)
      • getMinibatchDivisionSubsets

        protected List<INDArray> getMinibatchDivisionSubsets​(INDArray from)
      • isSingleLayerUpdater

        protected boolean isSingleLayerUpdater()
      • preApply

        public void preApply​(Trainable layer,
                             Gradient gradient,
                             int iteration)
        Pre-apply: Apply gradient normalization/clipping
        Parameters:
        layer - Layer to apply gradient normalization/clipping for
        gradient - Gradient to update
        iteration - The current iteration (i.e., number of parameter updates so far)
      • hashCode

        public int hashCode()
        Overrides:
        hashCode in class Object