public abstract class SameDiffVertex extends GraphVertex implements TrainingConfig
SameDiffLayer
,
SameDiffOutputLayer
,
Serialized FormModifier and Type | Field and Description |
---|---|
protected IUpdater |
biasUpdater |
protected DataType |
dataType |
protected GradientNormalization |
gradientNormalization |
protected double |
gradientNormalizationThreshold |
protected List<Regularization> |
regularization |
protected List<Regularization> |
regularizationBias |
protected IUpdater |
updater |
Constructor and Description |
---|
SameDiffVertex() |
Modifier and Type | Method and Description |
---|---|
void |
applyGlobalConfig(NeuralNetConfiguration.Builder b) |
void |
applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) |
abstract void |
defineParametersAndInputs(SDVertexParams params)
Define the parameters - and inputs - for the network.
|
abstract SDVariable |
defineVertex(SameDiff sameDiff,
Map<String,SDVariable> layerInput,
Map<String,SDVariable> paramTable,
Map<String,SDVariable> maskVars)
Define the vertex
|
Pair<INDArray,MaskState> |
feedForwardMaskArrays(INDArray[] maskArrays,
MaskState currentMaskState,
int minibatchSize) |
GradientNormalization |
getGradientNormalization() |
double |
getGradientNormalizationThreshold() |
String |
getLayerName() |
MemoryReport |
getMemoryReport(InputType... inputTypes)
This is a report of the estimated memory consumption for the given vertex
|
InputType |
getOutputType(int layerIndex,
InputType... vertexInputs)
Determine the type of output for this GraphVertex, given the specified inputs.
|
List<Regularization> |
getRegularizationByParam(String paramName)
Get the regularization types (l1/l2/weight decay) for the given parameter.
|
IUpdater |
getUpdaterByParam(String paramName)
Get the updater for the given parameter.
|
SDVertexParams |
getVertexParams() |
abstract void |
initializeParameters(Map<String,INDArray> params)
Set the initial parameter values for this layer, if required
|
GraphVertex |
instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
DataType networkDatatype)
Create a
GraphVertex instance, for the given computation graph,
given the configuration instance. |
boolean |
isPretrainParam(String paramName)
Is the specified parameter a layerwise pretraining only parameter?
For example, visible bias params in an autoencoder (or, decoder params in a variational autoencoder) aren't used during supervised backprop. Layers (like DenseLayer, etc) with no pretrainable parameters will return false for all (valid) inputs. |
int |
maxVertexInputs() |
int |
minVertexInputs() |
long |
numParams(boolean backprop) |
char |
paramReshapeOrder(String paramName) |
void |
setDataType(DataType dataType) |
void |
validateInput(INDArray[] input)
Validate input arrays to confirm that they fulfill the assumptions of the layer.
|
clone, equals, hashCode
protected List<Regularization> regularization
protected List<Regularization> regularizationBias
protected IUpdater updater
protected IUpdater biasUpdater
protected GradientNormalization gradientNormalization
protected double gradientNormalizationThreshold
protected DataType dataType
public abstract SDVariable defineVertex(SameDiff sameDiff, Map<String,SDVariable> layerInput, Map<String,SDVariable> paramTable, Map<String,SDVariable> maskVars)
sameDiff
- SameDiff instancelayerInput
- Input to the layer - keys as defined by defineParametersAndInputs(SDVertexParams)
paramTable
- Parameter table - keys as defined by defineParametersAndInputs(SDVertexParams)
maskVars
- Masks of input, if available - keys as defined by defineParametersAndInputs(SDVertexParams)
public abstract void defineParametersAndInputs(SDVertexParams params)
SDLayerParams.addWeightParam(String, long...)
and
SDLayerParams.addBiasParam(String, long...)
.
Note also you must define (and optionally name) the inputs to the vertex. This is required so that
DL4J knows how many inputs exists for the vertex.params
- Object used to set parameters for this layerpublic abstract void initializeParameters(Map<String,INDArray> params)
params
- Parameter arrays that may be initializedpublic SDVertexParams getVertexParams()
public long numParams(boolean backprop)
numParams
in class GraphVertex
public int minVertexInputs()
minVertexInputs
in class GraphVertex
public int maxVertexInputs()
maxVertexInputs
in class GraphVertex
public GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype)
GraphVertex
GraphVertex
instance, for the given computation graph,
given the configuration instance.instantiate
in class GraphVertex
graph
- The computation graph that this GraphVertex is to be part ofname
- The name of the GraphVertex objectidx
- The index of the GraphVertexparamsView
- A view of the full parameters arrayinitializeParams
- If true: initialize the parameters. If false: make no change to the values in the paramsView arraypublic InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException
GraphVertex
getOutputType
in class GraphVertex
layerIndex
- The index of the layer (if appropriate/necessary).vertexInputs
- The inputs to this vertexInvalidInputTypeException
- If the input type is invalid for this type of GraphVertexpublic Pair<INDArray,MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
public void validateInput(INDArray[] input)
input
- inputs to the layerpublic MemoryReport getMemoryReport(InputType... inputTypes)
GraphVertex
getMemoryReport
in class GraphVertex
inputTypes
- Input types to the vertex. Memory consumption is often a function of the input typepublic char paramReshapeOrder(String paramName)
public void applyGlobalConfig(NeuralNetConfiguration.Builder b)
public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig)
public String getLayerName()
getLayerName
in interface TrainingConfig
public List<Regularization> getRegularizationByParam(String paramName)
TrainingConfig
getRegularizationByParam
in interface TrainingConfig
paramName
- Parameter name ("W", "b" etc)public boolean isPretrainParam(String paramName)
TrainingConfig
isPretrainParam
in interface TrainingConfig
paramName
- Parameter name/keypublic IUpdater getUpdaterByParam(String paramName)
TrainingConfig
getUpdaterByParam
in interface TrainingConfig
paramName
- Parameter namepublic GradientNormalization getGradientNormalization()
getGradientNormalization
in interface TrainingConfig
public double getGradientNormalizationThreshold()
getGradientNormalizationThreshold
in interface TrainingConfig
public void setDataType(DataType dataType)
setDataType
in interface TrainingConfig
setDataType
in class GraphVertex
Copyright © 2020. All rights reserved.