Basic sampling Recurrent Neural Network (RNN) decoder.
Recurrent Neural Network (RNN) that uses beam search to find the highest scoring sequence (i.e., perform decoding).
Recurrent Neural Network (RNN) decoder abstract interface.
Recurrent Neural Network (RNN) decoder abstract interface.
Concepts used by this interface:
input
: (structure of) tensors and tensor arrays that is passed as input to the RNN cell composing the
decoder, at each time step.state
: Sequence of tensors that is passed to the RNN cell instance as the state.finished
: Boolean tensor indicating whether each sequence in the batch has finished decoding.
Exponential length penalty function.
Exponential length penalty function. The penalty is equal to sequenceLengths ^ alpha
, where all operations a re
performed element-wise.
Length penalty weight (disabled if set to 0.0f
).
Google length penalty function described in
[Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation](https://arxiv.org/abs/1609.08144.)
The penalty is equal to ((5 + sequenceLengths) / 6) ^ alpha
, where all operations are performed element-wise.
Google length penalty function described in
[Google's Neural Machine Translation System: Bridging the Gap between Human and Machine Translation](https://arxiv.org/abs/1609.08144.)
The penalty is equal to ((5 + sequenceLengths) / 6) ^ alpha
, where all operations are performed element-wise.
Length penalty weight (disabled if set to 0.0f
).
Length penalty function to be used while decoding.
No length penalty.