Package org.deeplearning4j.nn.conf.graph
Class AttentionVertex
- java.lang.Object
-
- org.deeplearning4j.nn.conf.graph.GraphVertex
-
- org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
-
- org.deeplearning4j.nn.conf.graph.AttentionVertex
-
- All Implemented Interfaces:
Serializable
,Cloneable
,TrainingConfig
public class AttentionVertex extends SameDiffVertex
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
AttentionVertex.Builder
-
Field Summary
Fields Modifier and Type Field Description protected WeightInit
weightInit
-
Fields inherited from class org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
biasUpdater, dataType, gradientNormalization, gradientNormalizationThreshold, regularization, regularizationBias, updater
-
-
Constructor Summary
Constructors Modifier Constructor Description protected
AttentionVertex(AttentionVertex.Builder builder)
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description AttentionVertex
clone()
void
defineParametersAndInputs(SDVertexParams params)
Define the parameters - and inputs - for the network.SDVariable
defineVertex(SameDiff sameDiff, Map<String,SDVariable> layerInput, Map<String,SDVariable> paramTable, Map<String,SDVariable> maskVars)
Define the vertexPair<INDArray,MaskState>
feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
InputType
getOutputType(int layerIndex, InputType... vertexInputs)
Determine the type of output for this GraphVertex, given the specified inputs.void
initializeParameters(Map<String,INDArray> params)
Set the initial parameter values for this layer, if required-
Methods inherited from class org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex
applyGlobalConfig, applyGlobalConfigToLayer, getGradientNormalization, getGradientNormalizationThreshold, getLayerName, getMemoryReport, getRegularizationByParam, getUpdaterByParam, getVertexParams, instantiate, isPretrainParam, maxVertexInputs, minVertexInputs, numParams, paramReshapeOrder, setDataType, validateInput
-
Methods inherited from class org.deeplearning4j.nn.conf.graph.GraphVertex
equals, hashCode
-
-
-
-
Field Detail
-
weightInit
protected WeightInit weightInit
-
-
Constructor Detail
-
AttentionVertex
protected AttentionVertex(AttentionVertex.Builder builder)
-
-
Method Detail
-
clone
public AttentionVertex clone()
- Specified by:
clone
in classGraphVertex
-
getOutputType
public InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException
Description copied from class:GraphVertex
Determine the type of output for this GraphVertex, given the specified inputs. Given that a GraphVertex may do arbitrary processing or modifications of the inputs, the output types can be quite different to the input type(s).
This is generally used to determine when to add preprocessors, as well as the input sizes etc for layers- Overrides:
getOutputType
in classSameDiffVertex
- Parameters:
layerIndex
- The index of the layer (if appropriate/necessary).vertexInputs
- The inputs to this vertex- Returns:
- The type of output for this vertex
- Throws:
InvalidInputTypeException
- If the input type is invalid for this type of GraphVertex
-
defineParametersAndInputs
public void defineParametersAndInputs(SDVertexParams params)
Description copied from class:SameDiffVertex
Define the parameters - and inputs - for the network. UseSDLayerParams.addWeightParam(String, long...)
andSDLayerParams.addBiasParam(String, long...)
. Note also you must define (and optionally name) the inputs to the vertex. This is required so that DL4J knows how many inputs exists for the vertex.- Specified by:
defineParametersAndInputs
in classSameDiffVertex
- Parameters:
params
- Object used to set parameters for this layer
-
initializeParameters
public void initializeParameters(Map<String,INDArray> params)
Description copied from class:SameDiffVertex
Set the initial parameter values for this layer, if required- Specified by:
initializeParameters
in classSameDiffVertex
- Parameters:
params
- Parameter arrays that may be initialized
-
feedForwardMaskArrays
public Pair<INDArray,MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
- Overrides:
feedForwardMaskArrays
in classSameDiffVertex
-
defineVertex
public SDVariable defineVertex(SameDiff sameDiff, Map<String,SDVariable> layerInput, Map<String,SDVariable> paramTable, Map<String,SDVariable> maskVars)
Description copied from class:SameDiffVertex
Define the vertex- Specified by:
defineVertex
in classSameDiffVertex
- Parameters:
sameDiff
- SameDiff instancelayerInput
- Input to the layer - keys as defined bySameDiffVertex.defineParametersAndInputs(SDVertexParams)
paramTable
- Parameter table - keys as defined bySameDiffVertex.defineParametersAndInputs(SDVertexParams)
maskVars
- Masks of input, if available - keys as defined bySameDiffVertex.defineParametersAndInputs(SDVertexParams)
- Returns:
- The final layer variable corresponding to the activations/output from the forward pass
-
-