public abstract class BaseWrapperVertex extends Object implements GraphVertex
| Modifier and Type | Field and Description |
|---|---|
protected GraphVertex |
underlying |
| Modifier | Constructor and Description |
|---|---|
protected |
BaseWrapperVertex(GraphVertex underlying) |
| Modifier and Type | Method and Description |
|---|---|
boolean |
canDoBackward()
Whether the GraphVertex can do backward pass.
|
boolean |
canDoForward()
Whether the GraphVertex can do forward pass.
|
void |
clear()
Clear the internal state (if any) of the GraphVertex.
|
void |
clearVertex()
This method clears inpjut for this vertex
|
Pair<Gradient,INDArray[]> |
doBackward(boolean tbptt,
LayerWorkspaceMgr workspaceMgr)
Do backward pass
|
INDArray |
doForward(boolean training,
LayerWorkspaceMgr workspaceMgr)
Do forward pass using the stored inputs
|
Pair<INDArray,MaskState> |
feedForwardMaskArrays(INDArray[] maskArrays,
MaskState currentMaskState,
int minibatchSize) |
TrainingConfig |
getConfig() |
INDArray |
getEpsilon()
Get the epsilon/error (i.e., dL/dOutput) array previously set for this GraphVertex
|
INDArray |
getGradientsViewArray() |
INDArray[] |
getInputs()
Get the array of inputs previously set for this GraphVertex
|
VertexIndices[] |
getInputVertices()
A representation of the vertices that are inputs to this vertex (inputs duing forward pass)
Specifically, if inputVertices[X].getVertexIndex() = Y, and inputVertices[X].getVertexEdgeNumber() = Z then the Zth output connection (see GraphVertex.getNumOutputConnections() of vertex Y is the Xth input to this vertex |
Layer |
getLayer()
Get the Layer (if any).
|
int |
getNumInputArrays()
Get the number of input arrays.
|
int |
getNumOutputConnections()
Get the number of outgoing connections from this GraphVertex.
|
VertexIndices[] |
getOutputVertices()
A representation of the vertices that this vertex is connected to (outputs duing forward pass)
Specifically, if outputVertices[X].getVertexIndex() = Y, and outputVertices[X].getVertexEdgeNumber() = Z
then the Xth output of this vertex is connected to the Zth input of vertex Y
|
int |
getVertexIndex()
Get the index of the GraphVertex
|
String |
getVertexName()
Get the name/label of the GraphVertex
|
boolean |
hasLayer()
Whether the GraphVertex contains a
Layer object or not |
boolean |
isInputVertex()
Whether the GraphVertex is an input vertex
|
boolean |
isOutputVertex()
Whether the GraphVertex is an output vertex
|
long |
numParams() |
INDArray |
params() |
Map<String,INDArray> |
paramTable(boolean backpropOnly)
Get the parameter table for the vertex
|
void |
setBackpropGradientsViewArray(INDArray backpropGradientsViewArray)
|
void |
setEpsilon(INDArray epsilon)
Set the errors (epsilon - aka dL/dActivation) for this GraphVertex
|
void |
setInput(int inputNumber,
INDArray input,
LayerWorkspaceMgr workspaceMgr)
Set the input activations.
|
void |
setInputs(INDArray... inputs)
Set all inputs for this GraphVertex
|
void |
setInputVertices(VertexIndices[] inputVertices)
Sets the input vertices.
|
void |
setLayerAsFrozen()
Only applies to layer vertices.
|
void |
setOutputVertex(boolean outputVertex)
Set the GraphVertex to be an output vertex
|
void |
setOutputVertices(VertexIndices[] outputVertices)
set the output vertices.
|
boolean |
updaterDivideByMinibatch(String paramName)
DL4J layers typically produce the sum of the gradients during the backward pass for each layer, and if required
(if minibatch=true) then divide by the minibatch size.
However, there are some exceptions, such as the batch norm mean/variance estimate parameters: these "gradients" are actually not gradients, but are updates to be applied directly to the parameter vector. |
protected GraphVertex underlying
protected BaseWrapperVertex(GraphVertex underlying)
public String getVertexName()
GraphVertexgetVertexName in interface GraphVertexpublic int getVertexIndex()
GraphVertexgetVertexIndex in interface GraphVertexpublic int getNumInputArrays()
GraphVertexgetNumInputArrays in interface GraphVertexpublic int getNumOutputConnections()
GraphVertexgetNumOutputConnections in interface GraphVertexpublic VertexIndices[] getInputVertices()
GraphVertexGraphVertex.getNumOutputConnections() of vertex Y is the Xth input to this vertexgetInputVertices in interface GraphVertexpublic void setInputVertices(VertexIndices[] inputVertices)
GraphVertexsetInputVertices in interface GraphVertexGraphVertex.getInputVertices()public VertexIndices[] getOutputVertices()
GraphVertexgetOutputVertices in interface GraphVertexpublic void setOutputVertices(VertexIndices[] outputVertices)
GraphVertexsetOutputVertices in interface GraphVertexGraphVertex.getOutputVertices()public boolean hasLayer()
GraphVertexLayer object or nothasLayer in interface GraphVertexpublic boolean isInputVertex()
GraphVertexisInputVertex in interface GraphVertexpublic boolean isOutputVertex()
GraphVertexisOutputVertex in interface GraphVertexpublic void setOutputVertex(boolean outputVertex)
GraphVertexsetOutputVertex in interface GraphVertexpublic Layer getLayer()
GraphVertexGraphVertex.hasLayer() == falsegetLayer in interface GraphVertexpublic void setInput(int inputNumber,
INDArray input,
LayerWorkspaceMgr workspaceMgr)
GraphVertexsetInput in interface GraphVertexinputNumber - Must be in range 0 to GraphVertex.getNumInputArrays()-1input - The input arraypublic void setEpsilon(INDArray epsilon)
GraphVertexsetEpsilon in interface GraphVertexpublic void clear()
GraphVertexclear in interface GraphVertexpublic boolean canDoForward()
GraphVertexcanDoForward in interface GraphVertexpublic boolean canDoBackward()
GraphVertexcanDoBackward in interface GraphVertexpublic INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr)
GraphVertexdoForward in interface GraphVertextraining - if true: forward pass at training time. If false: forward pass at test timepublic Pair<Gradient,INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr)
GraphVertexdoBackward in interface GraphVertextbptt - If true: do backprop using truncated BPTTpublic INDArray[] getInputs()
GraphVertexgetInputs in interface GraphVertexpublic INDArray getEpsilon()
GraphVertexgetEpsilon in interface GraphVertexpublic void setInputs(INDArray... inputs)
GraphVertexsetInputs in interface GraphVertexGraphVertex.setInput(int, INDArray, LayerWorkspaceMgr)public INDArray getGradientsViewArray()
getGradientsViewArray in interface Trainablepublic void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray)
GraphVertexsetBackpropGradientsViewArray in interface GraphVertexpublic Pair<INDArray,MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
feedForwardMaskArrays in interface GraphVertexpublic void setLayerAsFrozen()
GraphVertexsetLayerAsFrozen in interface GraphVertexpublic void clearVertex()
GraphVertexclearVertex in interface GraphVertexpublic Map<String,INDArray> paramTable(boolean backpropOnly)
GraphVertexparamTable in interface TrainableparamTable in interface GraphVertexbackpropOnly - If true: exclude unsupervised training parameterspublic TrainingConfig getConfig()
public INDArray params()
public long numParams()
public boolean updaterDivideByMinibatch(String paramName)
TrainableupdaterDivideByMinibatch in interface TrainableparamName - Name of the parameterCopyright © 2020. All rights reserved.