org.platanios.tensorflow.api.ops.rnn.attention
Attention layer weights to use for projecting the computed attention.
Attention mechanisms to use.
RNN cell being wrapped.
Function that takes the original cell input tensor and the attention tensor as inputs and returns the mixed cell input to use.
Function that takes the original cell input tensor and the attention tensor as inputs and returns the mixed cell input to use. Defaults to concatenating the two tensors across their last axis.
Performs a step using this attention-wrapped RNN cell.
Performs a step using this attention-wrapped RNN cell.
inputs
and the previous step's attention
output via cellInputFn
.cell
with the mixed input and its previous state.attentionMechanism
.normalizer
.attentionLayerWeights.shape(-1)
outputs).
Input tuple to the attention wrapper cell.
Next tuple.
Returns an initial state for this attention cell wrapper.
Returns an initial state for this attention cell wrapper.
Initial state for the wrapped cell.
Optional data type which defaults to the data type of the last tensor in
initialCellState
.
Initial state for this attention cell wrapper.
Name prefix used for all new ops.
If true
(the default), the output of this cell at each step is the attention value.
If true
(the default), the output of this cell at each step is the attention value.
This is the behavior of Luong-style attention mechanisms. If false
, the output at
each step is the output of cell
. This is the behavior of Bhadanau-style attention
mechanisms. In both cases, the attention
tensor is propagated to the next time step
via the state and is used there. This flag only controls whether the attention
mechanism is propagated up to the next cell in an RNN stack or to the top RNN output.
If true
, the alignments history from all steps is stored in the final output state
(currently stored as a time major TensorArray
on which you must call stack()
).
If true
, the alignments history from all steps is stored in the final output state
(currently stored as a time major TensorArray
on which you must call stack()
).
Defaults to false
.
RNN cell that wraps another RNN cell and adds support for attention to it.