Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention by Katharopoulos, A., Vyas, A., Pappas, N., & Fleuret, F. (2020)

tags
Transformers, RNN
source
(Katharopoulos et al. 2020)

Summary

Transformers have traditionally been described as different models from RNNs. This is because instead of processing the sequence one token at a time, Transformers use attention to process all elements simultaneously.

The paper introduces an interesting new formulation, replacing the softmax attention with a feature map-based dot product.

This new formulation yields better time and memory complexity as well as a model that is casual and autoregressive (similar to RNNs).

A Transformer applied on sequence \(x\) is presented as a composition of multiple Transformer layers \(T_l\), with

\[ T_l(x) = f_l(A_l(x) + x) \]

Function \(f_l\) is applied to each component independently, while attention \(A_l\) is applied to the whole input sequence.

Softmax self-attention at layer \(l\) with queries, keys and values matrices is written

\[A_l(x) = V’ = \text{softmax}\left( \dfrac{QK^{T}}{\sqrt{D}} \right) V.\]

The equation above can be generalized to any similarity function \(\text{sim}\), and if \(V’_i\) designates the $i$-th row of \(V’\),

\[ V’_i = \dfrac{\sum_{j = 1}^N \text{sim}(Q_i, K_j) V_j}{\sum_{j = 1}^N \text{sim}(Q_i, K_j)} \]

Linearizing attention

In particular, all kernels \(k(x, y) = \langle\phi(x), \phi(y)\rangle_\mathcal{S} : \mathbb{R}^{2\times F} \rightarrow \mathbb{R}_+\) can be used as a similarity function, changing the equation above to

\[ V’_i = \dfrac{\phi(Q_i)^T \sum_{j = 1}^N \phi(K_j) V_j^T}{ \phi(Q_i)^T \sum_{j = 1}^N \phi(K_j)}.\]

Because the right term of the numerator and denominator above does not depend on \(i\), it can be computed once for all sequence, and time and memory complexity become \(\mathcal{O}(N)\).

Masking for autoregressive models

By replacing \(N\) by \(i\) in the expression above, one readily obtains a formulation of the Transformer function which only depends on previous tokens. This is used to train language models in particular, because the prediction of a token can only depend on the previous tokens.

Transformers are RNNs

By rewriting the main kernel formulation of a Transformer above, one sees how it can actually be seen as a RNN. Timesteps of the recurrence are denoted as subscripts.

\[ \begin{aligned} & s_0 = 0, z_0 = 0 \newline & s_i = s_{i-1} + \phi(x_i W_K) (x_i W_V)^T \newline & z_i = z_{i-1} + \phi(x_i W_K) \newline & y_i = f_l \left( \dfrac{\phi(x_i W_Q)^T s_i}{\phi(x_i W_Q)^T z_i} + x_i \right) \end{aligned} \]

The resulting RNN has two hidden states, namely the attention memory \(s\) and the normalizer memory \(z\).

Comments

The parallel between RNNs and Transformer models is clearly made in this paper. I believe this is significant because it give insights into why Transformers might be better at language modeling than RNN-based models.

It would seem from this new formulation that they aren’t better than RNNs but the choice of update function (in the equation above) they are equivalent to is superior.

Another possibility is that RNNs and Transformers have always had the same potential. The hype might have fuelled more effort into making Transformers models work better and have thus widened the performance gap between the two otherwise equivalent models. Recent research into RNN models also seems to have favored a few dominant models (standard RNN, LSTM and GRU) and might have slowed the discovery of other, more effective cells.

Experiments in the paper only demonstrate the performance of their new model on small tasks and I would like to see how this holds up for language modeling.

Bibliography

  1. . . "Transformers Are Rnns: Fast Autoregressive Transformers with Linear Attention". Arxiv:2006.16236 [cs, Stat]. http://arxiv.org/abs/2006.16236.
Last changed | authored by

Comments


← Back to Notes