Class SameDiffVertex
- java.lang.Object
-
- org.deeplearning4j.nn.conf.graph.GraphVertex
-
- org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
-
- All Implemented Interfaces:
Serializable
,Cloneable
,TrainingConfig
- Direct Known Subclasses:
AttentionVertex
,SameDiffLambdaVertex
public abstract class SameDiffVertex extends GraphVertex implements TrainingConfig
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected IUpdater
biasUpdater
protected DataType
dataType
protected GradientNormalization
gradientNormalization
protected double
gradientNormalizationThreshold
protected List<Regularization>
regularization
protected List<Regularization>
regularizationBias
protected IUpdater
updater
-
Constructor Summary
Constructors Constructor Description SameDiffVertex()
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description void
applyGlobalConfig(NeuralNetConfiguration.Builder b)
void
applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig)
abstract void
defineParametersAndInputs(SDVertexParams params)
Define the parameters - and inputs - for the network.abstract SDVariable
defineVertex(SameDiff sameDiff, Map<String,SDVariable> layerInput, Map<String,SDVariable> paramTable, Map<String,SDVariable> maskVars)
Define the vertexPair<INDArray,MaskState>
feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
GradientNormalization
getGradientNormalization()
double
getGradientNormalizationThreshold()
String
getLayerName()
MemoryReport
getMemoryReport(InputType... inputTypes)
This is a report of the estimated memory consumption for the given vertexInputType
getOutputType(int layerIndex, InputType... vertexInputs)
Determine the type of output for this GraphVertex, given the specified inputs.List<Regularization>
getRegularizationByParam(String paramName)
Get the regularization types (l1/l2/weight decay) for the given parameter.IUpdater
getUpdaterByParam(String paramName)
Get the updater for the given parameter.SDVertexParams
getVertexParams()
abstract void
initializeParameters(Map<String,INDArray> params)
Set the initial parameter values for this layer, if requiredGraphVertex
instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype)
Create aGraphVertex
instance, for the given computation graph, given the configuration instance.boolean
isPretrainParam(String paramName)
Is the specified parameter a layerwise pretraining only parameter?
For example, visible bias params in an autoencoder (or, decoder params in a variational autoencoder) aren't used during supervised backprop.
Layers (like DenseLayer, etc) with no pretrainable parameters will return false for all (valid) inputs.int
maxVertexInputs()
int
minVertexInputs()
long
numParams(boolean backprop)
char
paramReshapeOrder(String paramName)
void
setDataType(DataType dataType)
void
validateInput(INDArray[] input)
Validate input arrays to confirm that they fulfill the assumptions of the layer.-
Methods inherited from class org.deeplearning4j.nn.conf.graph.GraphVertex
clone, equals, hashCode
-
-
-
-
Field Detail
-
regularization
protected List<Regularization> regularization
-
regularizationBias
protected List<Regularization> regularizationBias
-
updater
protected IUpdater updater
-
biasUpdater
protected IUpdater biasUpdater
-
gradientNormalization
protected GradientNormalization gradientNormalization
-
gradientNormalizationThreshold
protected double gradientNormalizationThreshold
-
dataType
protected DataType dataType
-
-
Method Detail
-
defineVertex
public abstract SDVariable defineVertex(SameDiff sameDiff, Map<String,SDVariable> layerInput, Map<String,SDVariable> paramTable, Map<String,SDVariable> maskVars)
Define the vertex- Parameters:
sameDiff
- SameDiff instancelayerInput
- Input to the layer - keys as defined bydefineParametersAndInputs(SDVertexParams)
paramTable
- Parameter table - keys as defined bydefineParametersAndInputs(SDVertexParams)
maskVars
- Masks of input, if available - keys as defined bydefineParametersAndInputs(SDVertexParams)
- Returns:
- The final layer variable corresponding to the activations/output from the forward pass
-
defineParametersAndInputs
public abstract void defineParametersAndInputs(SDVertexParams params)
Define the parameters - and inputs - for the network. UseSDLayerParams.addWeightParam(String, long...)
andSDLayerParams.addBiasParam(String, long...)
. Note also you must define (and optionally name) the inputs to the vertex. This is required so that DL4J knows how many inputs exists for the vertex.- Parameters:
params
- Object used to set parameters for this layer
-
initializeParameters
public abstract void initializeParameters(Map<String,INDArray> params)
Set the initial parameter values for this layer, if required- Parameters:
params
- Parameter arrays that may be initialized
-
getVertexParams
public SDVertexParams getVertexParams()
-
numParams
public long numParams(boolean backprop)
- Specified by:
numParams
in classGraphVertex
-
minVertexInputs
public int minVertexInputs()
- Specified by:
minVertexInputs
in classGraphVertex
- Returns:
- The Smallest valid number of inputs to this vertex
-
maxVertexInputs
public int maxVertexInputs()
- Specified by:
maxVertexInputs
in classGraphVertex
- Returns:
- The largest valid number of inputs to this vertex
-
instantiate
public GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype)
Description copied from class:GraphVertex
Create aGraphVertex
instance, for the given computation graph, given the configuration instance.- Specified by:
instantiate
in classGraphVertex
- Parameters:
graph
- The computation graph that this GraphVertex is to be part ofname
- The name of the GraphVertex objectidx
- The index of the GraphVertexparamsView
- A view of the full parameters arrayinitializeParams
- If true: initialize the parameters. If false: make no change to the values in the paramsView array- Returns:
- The implementation GraphVertex object (i.e., implementation, no the configuration)
-
getOutputType
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException
Description copied from class:GraphVertex
Determine the type of output for this GraphVertex, given the specified inputs. Given that a GraphVertex may do arbitrary processing or modifications of the inputs, the output types can be quite different to the input type(s).
This is generally used to determine when to add preprocessors, as well as the input sizes etc for layers- Specified by:
getOutputType
in classGraphVertex
- Parameters:
layerIndex
- The index of the layer (if appropriate/necessary).vertexInputs
- The inputs to this vertex- Returns:
- The type of output for this vertex
- Throws:
InvalidInputTypeException
- If the input type is invalid for this type of GraphVertex
-
feedForwardMaskArrays
public Pair<INDArray,MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
-
validateInput
public void validateInput(INDArray[] input)
Validate input arrays to confirm that they fulfill the assumptions of the layer. If they don't, throw an exception.- Parameters:
input
- inputs to the layer
-
getMemoryReport
public MemoryReport getMemoryReport(InputType... inputTypes)
Description copied from class:GraphVertex
This is a report of the estimated memory consumption for the given vertex- Specified by:
getMemoryReport
in classGraphVertex
- Parameters:
inputTypes
- Input types to the vertex. Memory consumption is often a function of the input type- Returns:
- Memory report for the vertex
-
paramReshapeOrder
public char paramReshapeOrder(String paramName)
-
applyGlobalConfig
public void applyGlobalConfig(NeuralNetConfiguration.Builder b)
-
applyGlobalConfigToLayer
public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig)
-
getLayerName
public String getLayerName()
- Specified by:
getLayerName
in interfaceTrainingConfig
- Returns:
- Name of the layer
-
getRegularizationByParam
public List<Regularization> getRegularizationByParam(String paramName)
Description copied from interface:TrainingConfig
Get the regularization types (l1/l2/weight decay) for the given parameter. Different parameters may have different regularization types.- Specified by:
getRegularizationByParam
in interfaceTrainingConfig
- Parameters:
paramName
- Parameter name ("W", "b" etc)- Returns:
- Regularization types (if any) for the specified parameter
-
isPretrainParam
public boolean isPretrainParam(String paramName)
Description copied from interface:TrainingConfig
Is the specified parameter a layerwise pretraining only parameter?
For example, visible bias params in an autoencoder (or, decoder params in a variational autoencoder) aren't used during supervised backprop.
Layers (like DenseLayer, etc) with no pretrainable parameters will return false for all (valid) inputs.- Specified by:
isPretrainParam
in interfaceTrainingConfig
- Parameters:
paramName
- Parameter name/key- Returns:
- True if the parameter is for layerwise pretraining only, false otherwise
-
getUpdaterByParam
public IUpdater getUpdaterByParam(String paramName)
Description copied from interface:TrainingConfig
Get the updater for the given parameter. Typically the same updater will be used for all updaters, but this is not necessarily the case- Specified by:
getUpdaterByParam
in interfaceTrainingConfig
- Parameters:
paramName
- Parameter name- Returns:
- IUpdater for the parameter
-
getGradientNormalization
public GradientNormalization getGradientNormalization()
- Specified by:
getGradientNormalization
in interfaceTrainingConfig
- Returns:
- The gradient normalization configuration
-
getGradientNormalizationThreshold
public double getGradientNormalizationThreshold()
- Specified by:
getGradientNormalizationThreshold
in interfaceTrainingConfig
- Returns:
- The gradient normalization threshold
-
setDataType
public void setDataType(DataType dataType)
- Specified by:
setDataType
in interfaceTrainingConfig
- Overrides:
setDataType
in classGraphVertex
-
-