Class BaseWrapperVertex
- java.lang.Object
-
- org.deeplearning4j.nn.graph.vertex.BaseWrapperVertex
-
- All Implemented Interfaces:
Serializable
,Trainable
,GraphVertex
- Direct Known Subclasses:
FrozenVertex
public abstract class BaseWrapperVertex extends Object implements GraphVertex
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected GraphVertex
underlying
-
Constructor Summary
Constructors Modifier Constructor Description protected
BaseWrapperVertex(GraphVertex underlying)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description boolean
canDoBackward()
Whether the GraphVertex can do backward pass.boolean
canDoForward()
Whether the GraphVertex can do forward pass.void
clear()
Clear the internal state (if any) of the GraphVertex.void
clearVertex()
This method clears inpjut for this vertexPair<Gradient,INDArray[]>
doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr)
Do backward passINDArray
doForward(boolean training, LayerWorkspaceMgr workspaceMgr)
Do forward pass using the stored inputsPair<INDArray,MaskState>
feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
TrainingConfig
getConfig()
INDArray
getEpsilon()
Get the epsilon/error (i.e., dL/dOutput) array previously set for this GraphVertexINDArray
getGradientsViewArray()
INDArray[]
getInputs()
Get the array of inputs previously set for this GraphVertexVertexIndices[]
getInputVertices()
A representation of the vertices that are inputs to this vertex (inputs duing forward pass)
Specifically, if inputVertices[X].getVertexIndex() = Y, and inputVertices[X].getVertexEdgeNumber() = Z then the Zth output connection (seeGraphVertex.getNumOutputConnections()
of vertex Y is the Xth input to this vertexLayer
getLayer()
Get the Layer (if any).int
getNumInputArrays()
Get the number of input arrays.int
getNumOutputConnections()
Get the number of outgoing connections from this GraphVertex.VertexIndices[]
getOutputVertices()
A representation of the vertices that this vertex is connected to (outputs duing forward pass) Specifically, if outputVertices[X].getVertexIndex() = Y, and outputVertices[X].getVertexEdgeNumber() = Z then the Xth output of this vertex is connected to the Zth input of vertex Yint
getVertexIndex()
Get the index of the GraphVertexString
getVertexName()
Get the name/label of the GraphVertexboolean
hasLayer()
Whether the GraphVertex contains aLayer
object or notboolean
isInputVertex()
Whether the GraphVertex is an input vertexboolean
isOutputVertex()
Whether the GraphVertex is an output vertexlong
numParams()
INDArray
params()
Map<String,INDArray>
paramTable(boolean backpropOnly)
Get the parameter table for the vertexvoid
setBackpropGradientsViewArray(INDArray backpropGradientsViewArray)
void
setEpsilon(INDArray epsilon)
Set the errors (epsilon - aka dL/dActivation) for this GraphVertexvoid
setInput(int inputNumber, INDArray input, LayerWorkspaceMgr workspaceMgr)
Set the input activations.void
setInputs(INDArray... inputs)
Set all inputs for this GraphVertexvoid
setInputVertices(VertexIndices[] inputVertices)
Sets the input vertices.void
setLayerAsFrozen()
Only applies to layer vertices.void
setOutputVertex(boolean outputVertex)
Set the GraphVertex to be an output vertexvoid
setOutputVertices(VertexIndices[] outputVertices)
set the output vertices.boolean
updaterDivideByMinibatch(String paramName)
DL4J layers typically produce the sum of the gradients during the backward pass for each layer, and if required (if minibatch=true) then divide by the minibatch size.
However, there are some exceptions, such as the batch norm mean/variance estimate parameters: these "gradients" are actually not gradients, but are updates to be applied directly to the parameter vector.
-
-
-
Field Detail
-
underlying
protected GraphVertex underlying
-
-
Constructor Detail
-
BaseWrapperVertex
protected BaseWrapperVertex(GraphVertex underlying)
-
-
Method Detail
-
getVertexName
public String getVertexName()
Description copied from interface:GraphVertex
Get the name/label of the GraphVertex- Specified by:
getVertexName
in interfaceGraphVertex
-
getVertexIndex
public int getVertexIndex()
Description copied from interface:GraphVertex
Get the index of the GraphVertex- Specified by:
getVertexIndex
in interfaceGraphVertex
-
getNumInputArrays
public int getNumInputArrays()
Description copied from interface:GraphVertex
Get the number of input arrays. For example, a Layer may have only one input array, but in general a GraphVertex may have an arbtrary (>=1) number of input arrays (for example, from multiple other layers)- Specified by:
getNumInputArrays
in interfaceGraphVertex
-
getNumOutputConnections
public int getNumOutputConnections()
Description copied from interface:GraphVertex
Get the number of outgoing connections from this GraphVertex. A GraphVertex may only have a single output (for example, the activations out of a layer), but this output may be used as the input to an arbitrary number of other GraphVertex instances. This method returns the number of GraphVertex instances the output of this GraphVertex is input for.- Specified by:
getNumOutputConnections
in interfaceGraphVertex
-
getInputVertices
public VertexIndices[] getInputVertices()
Description copied from interface:GraphVertex
A representation of the vertices that are inputs to this vertex (inputs duing forward pass)
Specifically, if inputVertices[X].getVertexIndex() = Y, and inputVertices[X].getVertexEdgeNumber() = Z then the Zth output connection (seeGraphVertex.getNumOutputConnections()
of vertex Y is the Xth input to this vertex- Specified by:
getInputVertices
in interfaceGraphVertex
-
setInputVertices
public void setInputVertices(VertexIndices[] inputVertices)
Description copied from interface:GraphVertex
Sets the input vertices.- Specified by:
setInputVertices
in interfaceGraphVertex
- See Also:
GraphVertex.getInputVertices()
-
getOutputVertices
public VertexIndices[] getOutputVertices()
Description copied from interface:GraphVertex
A representation of the vertices that this vertex is connected to (outputs duing forward pass) Specifically, if outputVertices[X].getVertexIndex() = Y, and outputVertices[X].getVertexEdgeNumber() = Z then the Xth output of this vertex is connected to the Zth input of vertex Y- Specified by:
getOutputVertices
in interfaceGraphVertex
-
setOutputVertices
public void setOutputVertices(VertexIndices[] outputVertices)
Description copied from interface:GraphVertex
set the output vertices.- Specified by:
setOutputVertices
in interfaceGraphVertex
- See Also:
GraphVertex.getOutputVertices()
-
hasLayer
public boolean hasLayer()
Description copied from interface:GraphVertex
Whether the GraphVertex contains aLayer
object or not- Specified by:
hasLayer
in interfaceGraphVertex
-
isInputVertex
public boolean isInputVertex()
Description copied from interface:GraphVertex
Whether the GraphVertex is an input vertex- Specified by:
isInputVertex
in interfaceGraphVertex
-
isOutputVertex
public boolean isOutputVertex()
Description copied from interface:GraphVertex
Whether the GraphVertex is an output vertex- Specified by:
isOutputVertex
in interfaceGraphVertex
-
setOutputVertex
public void setOutputVertex(boolean outputVertex)
Description copied from interface:GraphVertex
Set the GraphVertex to be an output vertex- Specified by:
setOutputVertex
in interfaceGraphVertex
-
getLayer
public Layer getLayer()
Description copied from interface:GraphVertex
Get the Layer (if any). Returns null ifGraphVertex.hasLayer()
== false- Specified by:
getLayer
in interfaceGraphVertex
-
setInput
public void setInput(int inputNumber, INDArray input, LayerWorkspaceMgr workspaceMgr)
Description copied from interface:GraphVertex
Set the input activations.- Specified by:
setInput
in interfaceGraphVertex
- Parameters:
inputNumber
- Must be in range 0 toGraphVertex.getNumInputArrays()
-1input
- The input array
-
setEpsilon
public void setEpsilon(INDArray epsilon)
Description copied from interface:GraphVertex
Set the errors (epsilon - aka dL/dActivation) for this GraphVertex- Specified by:
setEpsilon
in interfaceGraphVertex
-
clear
public void clear()
Description copied from interface:GraphVertex
Clear the internal state (if any) of the GraphVertex. For example, any stored inputs/errors- Specified by:
clear
in interfaceGraphVertex
-
canDoForward
public boolean canDoForward()
Description copied from interface:GraphVertex
Whether the GraphVertex can do forward pass. Typically, this is just whether all inputs are set.- Specified by:
canDoForward
in interfaceGraphVertex
-
canDoBackward
public boolean canDoBackward()
Description copied from interface:GraphVertex
Whether the GraphVertex can do backward pass. Typically, this is just whether all errors/epsilons are set- Specified by:
canDoBackward
in interfaceGraphVertex
-
doForward
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr)
Description copied from interface:GraphVertex
Do forward pass using the stored inputs- Specified by:
doForward
in interfaceGraphVertex
- Parameters:
training
- if true: forward pass at training time. If false: forward pass at test time- Returns:
- The output (for example, activations) of the GraphVertex
-
doBackward
public Pair<Gradient,INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr)
Description copied from interface:GraphVertex
Do backward pass- Specified by:
doBackward
in interfaceGraphVertex
- Parameters:
tbptt
- If true: do backprop using truncated BPTT- Returns:
- The gradients (may be null), and the errors/epsilons for all inputs to this GraphVertex
-
getInputs
public INDArray[] getInputs()
Description copied from interface:GraphVertex
Get the array of inputs previously set for this GraphVertex- Specified by:
getInputs
in interfaceGraphVertex
-
getEpsilon
public INDArray getEpsilon()
Description copied from interface:GraphVertex
Get the epsilon/error (i.e., dL/dOutput) array previously set for this GraphVertex- Specified by:
getEpsilon
in interfaceGraphVertex
-
setInputs
public void setInputs(INDArray... inputs)
Description copied from interface:GraphVertex
Set all inputs for this GraphVertex- Specified by:
setInputs
in interfaceGraphVertex
- See Also:
GraphVertex.setInput(int, INDArray, LayerWorkspaceMgr)
-
getGradientsViewArray
public INDArray getGradientsViewArray()
- Specified by:
getGradientsViewArray
in interfaceTrainable
- Returns:
- 1D gradients view array
-
setBackpropGradientsViewArray
public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray)
Description copied from interface:GraphVertex
- Specified by:
setBackpropGradientsViewArray
in interfaceGraphVertex
-
feedForwardMaskArrays
public Pair<INDArray,MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
- Specified by:
feedForwardMaskArrays
in interfaceGraphVertex
-
setLayerAsFrozen
public void setLayerAsFrozen()
Description copied from interface:GraphVertex
Only applies to layer vertices. Will throw exceptions on others. If applied to a layer vertex it will treat the parameters of the layer within it as constant. Activations through these will be calculated as they would as test time regardless of training mode- Specified by:
setLayerAsFrozen
in interfaceGraphVertex
-
clearVertex
public void clearVertex()
Description copied from interface:GraphVertex
This method clears inpjut for this vertex- Specified by:
clearVertex
in interfaceGraphVertex
-
paramTable
public Map<String,INDArray> paramTable(boolean backpropOnly)
Description copied from interface:GraphVertex
Get the parameter table for the vertex- Specified by:
paramTable
in interfaceGraphVertex
- Specified by:
paramTable
in interfaceTrainable
- Parameters:
backpropOnly
- If true: exclude unsupervised training parameters- Returns:
- Parameter table
-
getConfig
public TrainingConfig getConfig()
-
params
public INDArray params()
-
numParams
public long numParams()
-
updaterDivideByMinibatch
public boolean updaterDivideByMinibatch(String paramName)
Description copied from interface:Trainable
DL4J layers typically produce the sum of the gradients during the backward pass for each layer, and if required (if minibatch=true) then divide by the minibatch size.
However, there are some exceptions, such as the batch norm mean/variance estimate parameters: these "gradients" are actually not gradients, but are updates to be applied directly to the parameter vector. Put another way, most gradients should be divided by the minibatch to get the average; some "gradients" are actually final updates already, and should not be divided by the minibatch size.- Specified by:
updaterDivideByMinibatch
in interfaceTrainable
- Parameters:
paramName
- Name of the parameter- Returns:
- True if gradients should be divided by minibatch (most params); false otherwise (edge cases like batch norm mean/variance estimates)
-
-