Class ComputationGraphUpdater
- java.lang.Object
-
- org.deeplearning4j.nn.updater.BaseMultiLayerUpdater<ComputationGraph>
-
- org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater
-
- All Implemented Interfaces:
Serializable,Updater
public class ComputationGraphUpdater extends BaseMultiLayerUpdater<ComputationGraph>
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected Trainable[]orderedLayers-
Fields inherited from class org.deeplearning4j.nn.updater.BaseMultiLayerUpdater
gradientsForMinibatchDivision, initializedMinibatchDivision, layersByName, network, updaterBlocks, updaterStateViewArray
-
-
Constructor Summary
Constructors Constructor Description ComputationGraphUpdater(ComputationGraph graph)ComputationGraphUpdater(ComputationGraph graph, INDArray updaterState)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description INDArraygetFlattenedGradientsView()protected Trainable[]getOrderedLayers()protected INDArraygetParams()protected booleanisMiniBatch()-
Methods inherited from class org.deeplearning4j.nn.updater.BaseMultiLayerUpdater
divideByMinibatch, equals, getMinibatchDivisionSubsets, getStateViewArray, getStateViewArrayCopy, hashCode, isSingleLayerUpdater, preApply, setStateViewArray, setStateViewArray, update, update
-
-
-
-
Field Detail
-
orderedLayers
protected Trainable[] orderedLayers
-
-
Constructor Detail
-
ComputationGraphUpdater
public ComputationGraphUpdater(ComputationGraph graph)
-
ComputationGraphUpdater
public ComputationGraphUpdater(ComputationGraph graph, INDArray updaterState)
-
-
Method Detail
-
getOrderedLayers
protected Trainable[] getOrderedLayers()
- Specified by:
getOrderedLayersin classBaseMultiLayerUpdater<ComputationGraph>- Returns:
- Array of layers, in the correct order (i.e., same order as the parameter/gradient/updater flattening order - input to output for MultiLayerNetwork, or topological order for ComputationGraph)
-
getFlattenedGradientsView
public INDArray getFlattenedGradientsView()
- Specified by:
getFlattenedGradientsViewin classBaseMultiLayerUpdater<ComputationGraph>- Returns:
- The flattened gradient view array for the model
-
getParams
protected INDArray getParams()
- Specified by:
getParamsin classBaseMultiLayerUpdater<ComputationGraph>- Returns:
- The flattened parameter array for the model
-
isMiniBatch
protected boolean isMiniBatch()
- Specified by:
isMiniBatchin classBaseMultiLayerUpdater<ComputationGraph>- Returns:
- True if the configuration for the model is set to minibatch (divide by minibatch size), false otherwise
-
-