Class SameDiffGraphVertex
- java.lang.Object
-
- org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
-
- org.deeplearning4j.nn.layers.samediff.SameDiffGraphVertex
-
- All Implemented Interfaces:
Serializable,Trainable,GraphVertex
public class SameDiffGraphVertex extends BaseGraphVertex
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected SameDiffVertexconfigprotected ExternalErrorsFunctionfnprotected INDArraygradientsprotected Map<String,INDArray>gradTableprotected Map<String,SDVariable>inputVarsprotected INDArray[]maskArraysprotected StringoutputKeyprotected SDVariableoutputVarprotected INDArrayparamsprotected Map<String,INDArray>paramTableprotected SameDiffsameDiff-
Fields inherited from class org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
dataType, epsilon, graph, inputs, inputVertices, outputVertex, outputVertices, vertexIndex, vertexName
-
-
Constructor Summary
Constructors Constructor Description SameDiffGraphVertex(SameDiffVertex config, ComputationGraph graph, String name, int vertexIndex, INDArray paramsView, boolean initParams, DataType dataType)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description voidclearVertex()This method clears inpjut for this vertexPair<Gradient,INDArray[]>doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr)Do backward passINDArraydoForward(boolean training, LayerWorkspaceMgr workspaceMgr)Do forward pass using the stored inputsprotected voiddoInit()Pair<INDArray,MaskState>feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)TrainingConfiggetConfig()INDArraygetGradientsViewArray()LayergetLayer()Get the Layer (if any).booleanhasLayer()Whether the GraphVertex contains aLayerobject or notINDArrayparams()Map<String,INDArray>paramTable(boolean backpropOnly)Get the parameter table for the vertexvoidsetBackpropGradientsViewArray(INDArray backpropGradientsViewArray)StringtoString()-
Methods inherited from class org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
canDoBackward, canDoForward, clear, getEpsilon, getInputVertices, getNumInputArrays, getNumOutputConnections, getOutputVertices, getVertexIndex, getVertexName, isInputVertex, numParams, setEpsilon, setInput, setInputVertices, setLayerAsFrozen, setOutputVertices, updaterDivideByMinibatch
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface org.deeplearning4j.nn.graph.vertex.GraphVertex
getInputs, isOutputVertex, setInputs, setOutputVertex
-
-
-
-
Field Detail
-
config
protected SameDiffVertex config
-
sameDiff
protected SameDiff sameDiff
-
outputVar
protected SDVariable outputVar
-
fn
protected ExternalErrorsFunction fn
-
outputKey
protected String outputKey
-
inputVars
protected Map<String,SDVariable> inputVars
-
maskArrays
protected INDArray[] maskArrays
-
params
protected INDArray params
-
gradients
protected INDArray gradients
-
-
Constructor Detail
-
SameDiffGraphVertex
public SameDiffGraphVertex(SameDiffVertex config, ComputationGraph graph, String name, int vertexIndex, INDArray paramsView, boolean initParams, DataType dataType)
-
-
Method Detail
-
toString
public String toString()
- Specified by:
toStringin classBaseGraphVertex
-
hasLayer
public boolean hasLayer()
Description copied from interface:GraphVertexWhether the GraphVertex contains aLayerobject or not
-
getLayer
public Layer getLayer()
Description copied from interface:GraphVertexGet the Layer (if any). Returns null ifGraphVertex.hasLayer()== false
-
doForward
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr)
Description copied from interface:GraphVertexDo forward pass using the stored inputs- Parameters:
training- if true: forward pass at training time. If false: forward pass at test time- Returns:
- The output (for example, activations) of the GraphVertex
-
doBackward
public Pair<Gradient,INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr)
Description copied from interface:GraphVertexDo backward pass- Parameters:
tbptt- If true: do backprop using truncated BPTT- Returns:
- The gradients (may be null), and the errors/epsilons for all inputs to this GraphVertex
-
setBackpropGradientsViewArray
public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray)
Description copied from interface:GraphVertex
-
feedForwardMaskArrays
public Pair<INDArray,MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
-
doInit
protected void doInit()
-
clearVertex
public void clearVertex()
Description copied from interface:GraphVertexThis method clears inpjut for this vertex- Specified by:
clearVertexin interfaceGraphVertex- Overrides:
clearVertexin classBaseGraphVertex
-
paramTable
public Map<String,INDArray> paramTable(boolean backpropOnly)
Description copied from interface:GraphVertexGet the parameter table for the vertex- Specified by:
paramTablein interfaceGraphVertex- Specified by:
paramTablein interfaceTrainable- Overrides:
paramTablein classBaseGraphVertex- Parameters:
backpropOnly- If true: exclude unsupervised training parameters- Returns:
- Parameter table
-
getConfig
public TrainingConfig getConfig()
- Specified by:
getConfigin interfaceTrainable- Overrides:
getConfigin classBaseGraphVertex- Returns:
- Training configuration
-
params
public INDArray params()
- Specified by:
paramsin interfaceTrainable- Overrides:
paramsin classBaseGraphVertex- Returns:
- 1d parameter vector
-
getGradientsViewArray
public INDArray getGradientsViewArray()
- Specified by:
getGradientsViewArrayin interfaceTrainable- Overrides:
getGradientsViewArrayin classBaseGraphVertex- Returns:
- 1D gradients view array
-
-