Class SameDiffGraphVertex
- java.lang.Object
-
- org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
-
- org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex
-
- All Implemented Interfaces:
Serializable
,Trainable
,GraphVertex
public class SameDiffGraphVertex extends BaseGraphVertex
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected SameDiffVertex
config
protected ExternalErrorsFunction
fn
protected INDArray
gradients
protected Map<String,INDArray>
gradTable
protected Map<String,SDVariable>
inputVars
protected INDArray[]
maskArrays
protected String
outputKey
protected SDVariable
outputVar
protected INDArray
params
protected Map<String,INDArray>
paramTable
protected SameDiff
sameDiff
-
Fields inherited from class org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
dataType, epsilon, graph, inputs, inputVertices, outputVertex, outputVertices, vertexIndex, vertexName
-
-
Constructor Summary
Constructors Constructor Description SameDiffGraphVertex(SameDiffVertex config, ComputationGraph graph, String name, int vertexIndex, INDArray paramsView, boolean initParams, DataType dataType)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description void
clearVertex()
This method clears inpjut for this vertexPair<Gradient,INDArray[]>
doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr)
Do backward passINDArray
doForward(boolean training, LayerWorkspaceMgr workspaceMgr)
Do forward pass using the stored inputsprotected void
doInit()
Pair<INDArray,MaskState>
feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
TrainingConfig
getConfig()
INDArray
getGradientsViewArray()
Layer
getLayer()
Get the Layer (if any).boolean
hasLayer()
Whether the GraphVertex contains aLayer
object or notINDArray
params()
Map<String,INDArray>
paramTable(boolean backpropOnly)
Get the parameter table for the vertexvoid
setBackpropGradientsViewArray(INDArray backpropGradientsViewArray)
String
toString()
-
Methods inherited from class org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
canDoBackward, canDoForward, clear, getEpsilon, getInputVertices, getNumInputArrays, getNumOutputConnections, getOutputVertices, getVertexIndex, getVertexName, isInputVertex, numParams, setEpsilon, setInput, setInputVertices, setLayerAsFrozen, setOutputVertices, updaterDivideByMinibatch
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface org.deeplearning4j.nn.graph.vertex.GraphVertex
getInputs, isOutputVertex, setInputs, setOutputVertex
-
-
-
-
Field Detail
-
config
protected SameDiffVertex config
-
sameDiff
protected SameDiff sameDiff
-
outputVar
protected SDVariable outputVar
-
fn
protected ExternalErrorsFunction fn
-
outputKey
protected String outputKey
-
inputVars
protected Map<String,SDVariable> inputVars
-
maskArrays
protected INDArray[] maskArrays
-
params
protected INDArray params
-
gradients
protected INDArray gradients
-
-
Constructor Detail
-
SameDiffGraphVertex
public SameDiffGraphVertex(SameDiffVertex config, ComputationGraph graph, String name, int vertexIndex, INDArray paramsView, boolean initParams, DataType dataType)
-
-
Method Detail
-
toString
public String toString()
- Specified by:
toString
in classBaseGraphVertex
-
hasLayer
public boolean hasLayer()
Description copied from interface:GraphVertex
Whether the GraphVertex contains aLayer
object or not
-
getLayer
public Layer getLayer()
Description copied from interface:GraphVertex
Get the Layer (if any). Returns null ifGraphVertex.hasLayer()
== false
-
doForward
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr)
Description copied from interface:GraphVertex
Do forward pass using the stored inputs- Parameters:
training
- if true: forward pass at training time. If false: forward pass at test time- Returns:
- The output (for example, activations) of the GraphVertex
-
doBackward
public Pair<Gradient,INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr)
Description copied from interface:GraphVertex
Do backward pass- Parameters:
tbptt
- If true: do backprop using truncated BPTT- Returns:
- The gradients (may be null), and the errors/epsilons for all inputs to this GraphVertex
-
setBackpropGradientsViewArray
public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray)
Description copied from interface:GraphVertex
-
feedForwardMaskArrays
public Pair<INDArray,MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
-
doInit
protected void doInit()
-
clearVertex
public void clearVertex()
Description copied from interface:GraphVertex
This method clears inpjut for this vertex- Specified by:
clearVertex
in interfaceGraphVertex
- Overrides:
clearVertex
in classBaseGraphVertex
-
paramTable
public Map<String,INDArray> paramTable(boolean backpropOnly)
Description copied from interface:GraphVertex
Get the parameter table for the vertex- Specified by:
paramTable
in interfaceGraphVertex
- Specified by:
paramTable
in interfaceTrainable
- Overrides:
paramTable
in classBaseGraphVertex
- Parameters:
backpropOnly
- If true: exclude unsupervised training parameters- Returns:
- Parameter table
-
getConfig
public TrainingConfig getConfig()
- Specified by:
getConfig
in interfaceTrainable
- Overrides:
getConfig
in classBaseGraphVertex
- Returns:
- Training configuration
-
params
public INDArray params()
- Specified by:
params
in interfaceTrainable
- Overrides:
params
in classBaseGraphVertex
- Returns:
- 1d parameter vector
-
getGradientsViewArray
public INDArray getGradientsViewArray()
- Specified by:
getGradientsViewArray
in interfaceTrainable
- Overrides:
getGradientsViewArray
in classBaseGraphVertex
- Returns:
- 1D gradients view array
-
-