Memorizing Transformers by Wu, Y., Rabe, M. N., Hutchins, D., & Szegedy, C. (2022)

source
(Wu et al. 2022)
tags
Transformers, Memory in neural networks

Summary

This paper introduces a method to extend the classical Transformer neural network model with an addressable memory that can be queried and updated at inference time.

This memory is addressed using an attention mechanism. It is a set of cached attention (key, value) vector pairs. At some arbitrary depth of the attention “stack” the memory mechanism is inserted.

A query \(\bm{Q}\) is used both for the self-attention over the local context (the other tokens in the input) as well as the (key, value) pairs stored in the memory. We write \(\bm{K}_l\), \(\bm{V}_l\) the key value matrices of the local context. The local self-attention output (denoted \(\bm{V}_c\) in the paper) is computed as follows:

\[\bm{V}_l’ = \text{softmax}\left( \dfrac{\bm{Q}\bm{K}_l^{T}}{\sqrt{D}} \right) \bm{V}_l.\]

For each query (i.e. each input token position), a set of \(k\) keys is retrieved from memory. They correspond to the \(k\) nearest neighbor keys to the query.

Combining the memory with the query is done similarly to standard attention, except that each query has a different set of (key, value) pairs to attend to. With \(\bm{K}_m\) and \(\bm{V}_m\) the matrix of (key, values) retrieved from the memory for query \(\bm{q}_i\) we have: \[\bm{v}_{m, i} ’ = \text{softmax}\left(\dfrac{\bm{q}_i \bm{K}_m^{T}}{\sqrt{D}} \right) \bm{V}_m.\]

Then the final value matrix is computed by combining \(\bm{V}_l’\) and \(\bm{V}_{m}’\):

\[ \bm{V} = \bm{V}_m \odot g + \bm{V}_l \odot (1 - g) \]

with \(g = \sigma(b_g)\) is the sigmoid of a per-head scalar parameter that determines how much importance is given to the external memory.

Bibliography

  1. . . "Memorizing Transformers". arXiv. DOI.

Comments


← Back to Notes