Class ComputationGraphUpdater
- java.lang.Object
-
- org.deeplearning4j.nn.updater.BaseMultiLayerUpdater<ComputationGraph>
-
- org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater
-
- All Implemented Interfaces:
Serializable
,Updater
public class ComputationGraphUpdater extends BaseMultiLayerUpdater<ComputationGraph>
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected Trainable[]
orderedLayers
-
Fields inherited from class org.deeplearning4j.nn.updater.BaseMultiLayerUpdater
gradientsForMinibatchDivision, initializedMinibatchDivision, layersByName, network, updaterBlocks, updaterStateViewArray
-
-
Constructor Summary
Constructors Constructor Description ComputationGraphUpdater(ComputationGraph graph)
ComputationGraphUpdater(ComputationGraph graph, INDArray updaterState)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description INDArray
getFlattenedGradientsView()
protected Trainable[]
getOrderedLayers()
protected INDArray
getParams()
protected boolean
isMiniBatch()
-
Methods inherited from class org.deeplearning4j.nn.updater.BaseMultiLayerUpdater
divideByMinibatch, equals, getMinibatchDivisionSubsets, getStateViewArray, getStateViewArrayCopy, hashCode, isSingleLayerUpdater, preApply, setStateViewArray, setStateViewArray, update, update
-
-
-
-
Field Detail
-
orderedLayers
protected Trainable[] orderedLayers
-
-
Constructor Detail
-
ComputationGraphUpdater
public ComputationGraphUpdater(ComputationGraph graph)
-
ComputationGraphUpdater
public ComputationGraphUpdater(ComputationGraph graph, INDArray updaterState)
-
-
Method Detail
-
getOrderedLayers
protected Trainable[] getOrderedLayers()
- Specified by:
getOrderedLayers
in classBaseMultiLayerUpdater<ComputationGraph>
- Returns:
- Array of layers, in the correct order (i.e., same order as the parameter/gradient/updater flattening order - input to output for MultiLayerNetwork, or topological order for ComputationGraph)
-
getFlattenedGradientsView
public INDArray getFlattenedGradientsView()
- Specified by:
getFlattenedGradientsView
in classBaseMultiLayerUpdater<ComputationGraph>
- Returns:
- The flattened gradient view array for the model
-
getParams
protected INDArray getParams()
- Specified by:
getParams
in classBaseMultiLayerUpdater<ComputationGraph>
- Returns:
- The flattened parameter array for the model
-
isMiniBatch
protected boolean isMiniBatch()
- Specified by:
isMiniBatch
in classBaseMultiLayerUpdater<ComputationGraph>
- Returns:
- True if the configuration for the model is set to minibatch (divide by minibatch size), false otherwise
-
-