Class BaseGraphVertex
- java.lang.Object
-
- org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
-
- All Implemented Interfaces:
Serializable,Trainable,GraphVertex
- Direct Known Subclasses:
DuplicateToTimeSeriesVertex,ElementWiseVertex,InputVertex,L2NormalizeVertex,L2Vertex,LastTimeStepVertex,LayerVertex,MergeVertex,PoolHelperVertex,PreprocessorVertex,ReshapeVertex,ReverseTimeSeriesVertex,SameDiffGraphVertex,ScaleVertex,ShiftVertex,StackVertex,SubsetVertex,UnstackVertex
public abstract class BaseGraphVertex extends Object implements GraphVertex
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected DataTypedataTypeprotected INDArrayepsilonprotected ComputationGraphgraphprotected INDArray[]inputsprotected VertexIndices[]inputVerticesA representation of the vertices that are inputs to this vertex (inputs during forward pass) Specifically, if inputVertices[X].getVertexIndex() = Y, and inputVertices[X].getVertexEdgeNumber() = Z then the Zth output of vertex Y is the Xth input to this vertexprotected booleanoutputVertexprotected VertexIndices[]outputVerticesA representation of the vertices that this vertex is connected to (outputs duing forward pass) Specifically, if outputVertices[X].getVertexIndex() = Y, and outputVertices[X].getVertexEdgeNumber() = Z then the output of this vertex (there is only one output) is connected to the Zth input of vertex Yprotected intvertexIndexThe index of this vertexprotected StringvertexName
-
Constructor Summary
Constructors Modifier Constructor Description protectedBaseGraphVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices, DataType dataType)
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description booleancanDoBackward()Whether the GraphVertex can do backward pass.booleancanDoForward()Whether the GraphVertex can do forward pass.voidclear()Clear the internal state (if any) of the GraphVertex.voidclearVertex()This method clears inpjut for this vertexTrainingConfiggetConfig()INDArraygetEpsilon()Get the epsilon/error (i.e., dL/dOutput) array previously set for this GraphVertexINDArraygetGradientsViewArray()VertexIndices[]getInputVertices()A representation of the vertices that are inputs to this vertex (inputs duing forward pass)
Specifically, if inputVertices[X].getVertexIndex() = Y, and inputVertices[X].getVertexEdgeNumber() = Z then the Zth output of vertex Y is the Xth input to this vertexintgetNumInputArrays()Get the number of input arrays.intgetNumOutputConnections()Get the number of outgoing connections from this GraphVertex.VertexIndices[]getOutputVertices()A representation of the vertices that this vertex is connected to (outputs duing forward pass) Specifically, if outputVertices[X].getVertexIndex() = Y, and outputVertices[X].getVertexEdgeNumber() = Z then the Xth output of this vertex is connected to the Zth input of vertex YintgetVertexIndex()Get the index of the GraphVertexStringgetVertexName()Get the name/label of the GraphVertexbooleanisInputVertex()Whether the GraphVertex is an input vertexlongnumParams()INDArrayparams()Map<String,INDArray>paramTable(boolean backpropOnly)Get the parameter table for the vertexvoidsetEpsilon(INDArray epsilon)Set the errors (epsilon - aka dL/dActivation) for this GraphVertexvoidsetInput(int inputNumber, INDArray input, LayerWorkspaceMgr workspaceMgr)Set the input activations.voidsetInputVertices(VertexIndices[] inputVertices)Sets the input vertices.voidsetLayerAsFrozen()Only applies to layer vertices.voidsetOutputVertices(VertexIndices[] outputVertices)set the output vertices.abstract StringtoString()booleanupdaterDivideByMinibatch(String paramName)DL4J layers typically produce the sum of the gradients during the backward pass for each layer, and if required (if minibatch=true) then divide by the minibatch size.
However, there are some exceptions, such as the batch norm mean/variance estimate parameters: these "gradients" are actually not gradients, but are updates to be applied directly to the parameter vector.-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface org.deeplearning4j.nn.graph.vertex.GraphVertex
doBackward, doForward, feedForwardMaskArrays, getInputs, getLayer, hasLayer, isOutputVertex, setBackpropGradientsViewArray, setInputs, setOutputVertex
-
-
-
-
Field Detail
-
graph
protected ComputationGraph graph
-
vertexName
protected String vertexName
-
vertexIndex
protected int vertexIndex
The index of this vertex
-
inputVertices
protected VertexIndices[] inputVertices
A representation of the vertices that are inputs to this vertex (inputs during forward pass) Specifically, if inputVertices[X].getVertexIndex() = Y, and inputVertices[X].getVertexEdgeNumber() = Z then the Zth output of vertex Y is the Xth input to this vertex
-
outputVertices
protected VertexIndices[] outputVertices
A representation of the vertices that this vertex is connected to (outputs duing forward pass) Specifically, if outputVertices[X].getVertexIndex() = Y, and outputVertices[X].getVertexEdgeNumber() = Z then the output of this vertex (there is only one output) is connected to the Zth input of vertex Y
-
inputs
protected INDArray[] inputs
-
epsilon
protected INDArray epsilon
-
outputVertex
protected boolean outputVertex
-
dataType
protected DataType dataType
-
-
Constructor Detail
-
BaseGraphVertex
protected BaseGraphVertex(ComputationGraph graph, String name, int vertexIndex, VertexIndices[] inputVertices, VertexIndices[] outputVertices, DataType dataType)
-
-
Method Detail
-
getVertexName
public String getVertexName()
Description copied from interface:GraphVertexGet the name/label of the GraphVertex- Specified by:
getVertexNamein interfaceGraphVertex
-
getVertexIndex
public int getVertexIndex()
Description copied from interface:GraphVertexGet the index of the GraphVertex- Specified by:
getVertexIndexin interfaceGraphVertex
-
getNumInputArrays
public int getNumInputArrays()
Description copied from interface:GraphVertexGet the number of input arrays. For example, a Layer may have only one input array, but in general a GraphVertex may have an arbtrary (>=1) number of input arrays (for example, from multiple other layers)- Specified by:
getNumInputArraysin interfaceGraphVertex
-
getNumOutputConnections
public int getNumOutputConnections()
Description copied from interface:GraphVertexGet the number of outgoing connections from this GraphVertex. A GraphVertex may only have a single output (for example, the activations out of a layer), but this output may be used as the input to an arbitrary number of other GraphVertex instances. This method returns the number of GraphVertex instances the output of this GraphVertex is input for.- Specified by:
getNumOutputConnectionsin interfaceGraphVertex
-
getInputVertices
public VertexIndices[] getInputVertices()
A representation of the vertices that are inputs to this vertex (inputs duing forward pass)
Specifically, if inputVertices[X].getVertexIndex() = Y, and inputVertices[X].getVertexEdgeNumber() = Z then the Zth output of vertex Y is the Xth input to this vertex- Specified by:
getInputVerticesin interfaceGraphVertex
-
setInputVertices
public void setInputVertices(VertexIndices[] inputVertices)
Description copied from interface:GraphVertexSets the input vertices.- Specified by:
setInputVerticesin interfaceGraphVertex- See Also:
GraphVertex.getInputVertices()
-
getOutputVertices
public VertexIndices[] getOutputVertices()
A representation of the vertices that this vertex is connected to (outputs duing forward pass) Specifically, if outputVertices[X].getVertexIndex() = Y, and outputVertices[X].getVertexEdgeNumber() = Z then the Xth output of this vertex is connected to the Zth input of vertex Y- Specified by:
getOutputVerticesin interfaceGraphVertex
-
setOutputVertices
public void setOutputVertices(VertexIndices[] outputVertices)
Description copied from interface:GraphVertexset the output vertices.- Specified by:
setOutputVerticesin interfaceGraphVertex- See Also:
GraphVertex.getOutputVertices()
-
isInputVertex
public boolean isInputVertex()
Description copied from interface:GraphVertexWhether the GraphVertex is an input vertex- Specified by:
isInputVertexin interfaceGraphVertex
-
setInput
public void setInput(int inputNumber, INDArray input, LayerWorkspaceMgr workspaceMgr)Description copied from interface:GraphVertexSet the input activations.- Specified by:
setInputin interfaceGraphVertex- Parameters:
inputNumber- Must be in range 0 toGraphVertex.getNumInputArrays()-1input- The input array
-
setEpsilon
public void setEpsilon(INDArray epsilon)
Description copied from interface:GraphVertexSet the errors (epsilon - aka dL/dActivation) for this GraphVertex- Specified by:
setEpsilonin interfaceGraphVertex
-
clear
public void clear()
Description copied from interface:GraphVertexClear the internal state (if any) of the GraphVertex. For example, any stored inputs/errors- Specified by:
clearin interfaceGraphVertex
-
canDoForward
public boolean canDoForward()
Description copied from interface:GraphVertexWhether the GraphVertex can do forward pass. Typically, this is just whether all inputs are set.- Specified by:
canDoForwardin interfaceGraphVertex
-
canDoBackward
public boolean canDoBackward()
Description copied from interface:GraphVertexWhether the GraphVertex can do backward pass. Typically, this is just whether all errors/epsilons are set- Specified by:
canDoBackwardin interfaceGraphVertex
-
getEpsilon
public INDArray getEpsilon()
Description copied from interface:GraphVertexGet the epsilon/error (i.e., dL/dOutput) array previously set for this GraphVertex- Specified by:
getEpsilonin interfaceGraphVertex
-
setLayerAsFrozen
public void setLayerAsFrozen()
Description copied from interface:GraphVertexOnly applies to layer vertices. Will throw exceptions on others. If applied to a layer vertex it will treat the parameters of the layer within it as constant. Activations through these will be calculated as they would as test time regardless of training mode- Specified by:
setLayerAsFrozenin interfaceGraphVertex
-
clearVertex
public void clearVertex()
Description copied from interface:GraphVertexThis method clears inpjut for this vertex- Specified by:
clearVertexin interfaceGraphVertex
-
paramTable
public Map<String,INDArray> paramTable(boolean backpropOnly)
Description copied from interface:GraphVertexGet the parameter table for the vertex- Specified by:
paramTablein interfaceGraphVertex- Specified by:
paramTablein interfaceTrainable- Parameters:
backpropOnly- If true: exclude unsupervised training parameters- Returns:
- Parameter table
-
numParams
public long numParams()
-
getConfig
public TrainingConfig getConfig()
-
params
public INDArray params()
-
getGradientsViewArray
public INDArray getGradientsViewArray()
- Specified by:
getGradientsViewArrayin interfaceTrainable- Returns:
- 1D gradients view array
-
updaterDivideByMinibatch
public boolean updaterDivideByMinibatch(String paramName)
Description copied from interface:TrainableDL4J layers typically produce the sum of the gradients during the backward pass for each layer, and if required (if minibatch=true) then divide by the minibatch size.
However, there are some exceptions, such as the batch norm mean/variance estimate parameters: these "gradients" are actually not gradients, but are updates to be applied directly to the parameter vector. Put another way, most gradients should be divided by the minibatch to get the average; some "gradients" are actually final updates already, and should not be divided by the minibatch size.- Specified by:
updaterDivideByMinibatchin interfaceTrainable- Parameters:
paramName- Name of the parameter- Returns:
- True if gradients should be divided by minibatch (most params); false otherwise (edge cases like batch norm mean/variance estimates)
-
-