Visualizing equivariances in transformer neural networks
Transformer neural networks have become the dominant architecture within many subfields of deep learning. Their success is partly owed due to the fact that self-attention is a very generic operation in terms of the geometric priors it uses[1]. The following blogpost interactively visualizes some geometric priors in transformers. The target audience for this blogpost are those who are already (at least vaguely) familiar with self-attention and want to see some simple visualizations of what positional encodings do to them[2].
As a quick prerequisite, let us recap the self-attention formula via a simple example. Consider the sentence: "Love conquers all". Each word in this sentence can be assigned an embedding vector, which may look like the following:
These three vectors gives us an input matrix
Note that I have chosen simple word embeddings and projections to simplify following along with computations. In practice, these weights are learned.
After projection, self-attention is performed as follows:
With
Via this equation, one sees that, intuitively, self-attention is nothing more than three projections of
(Hover over the elements highlighted in blue to see computation)
For the purpose of the visualizations, I've conveniently ignored the multiple heads that are typically used with self-attention.
Permutation equivariance in self-attention #
The words in the previous example sentence "Love conquers all" can be scrambled in a number of ways, and still form a correct sentence, e.g. "All love conquers". We can play around with the previous visualizations of self-attention by adding a shuffle button:
If you play around with the shuffling, you will notice that, if elements of
Permutation equivariance is useful for any kind of data modality where the inputs are not really a sequence, but can rather be described as a set (i.e. ordering does not matter). In language, however, it does, which is why transformers were originally proposed along with positional encodings.
Positional embeddings break permutation equivariance #
Positional encodings are most-simply introduced by adding a position-dependent signal to the input. For example:
You will see that permutation equivariance is broken by the positional encodings. I.e., a shuffled input will not return the same output - albeit shuffled the same way - anymore. By communicating positional indices, we do not operate on an unordered set. Rather, the model becomes a true sequence model.
Note that this example features positional encodings that simply contain the positional indices. In practice, positional encodings may be sinusoidal (which has a nice decaying similarity effect on the dot product attention matrix), as in the original transformer publication[3]. Given enough data, one may also choose to learn the positional embeddings from scratch, as in the BERT model[4].
Time-shift (translation) equivariance through relative positional encodings #
In many data domains, the absolute positioning of elements in the sequence does not matter. In these domains, how signals co-occur relative to eachother may be more relevant. For example, in images, the absolute location of a cat's ear is inconsequential to its detection. Rather, the fact that a cat's ear should be located on top of its head is a relevant signal. Convolutions have this built-in, as they are translation equivariant: given a shift in pixels, a convolution will return the same activation map, albeit shifted by the same amount. In language, this might also be an attractive feature. Consider that the triplet of words "Love conquers all", may occur anywhere within a larger paragraph of text:
Irregardless of its location within a paragraph, "Love" will always be the grammatical subject of the clause, and "conquers" its verb. How the words interact within the three word clause remains the same, no matter where in the paragraph it appears. A beneficial property of a language model could, hence, be, to be robust against translations or time-shifts in words. One can build such translation equivariance - or in the case of sequence models, also called time-shift equivariance - into transformer models using relative positional encodings. One example of such a relative positional encoding scheme is rotary embeddings[5], which are applied in favor of absolute encodings in many of the recent LLMs.
To visualize rotary embeddings in action, let us add a slider to the previous example that lets you control where you place the words:
Sentence position:
Embedding the three-word clause in the same way:
With rotary embeddings, the queries and key matrices are rotated according to their position index:
Using the same attention operations:
One sees that dot products of index-rotated queries and keys are preserved when said index changes. This, in turn, gives us a translation-equivariant self-attention mechanism.
Note that in this example, the visualized outputs do not change if the inputs are shifted, suggesting invariance rather than equivariance. It is important to keep in mind that this visualization only shows the three example tokens in the larger sequence. In the broader context of the paragraph, the embeddings of these three example tokens would be similarly shifted according to its indices.
Exploiting other symmetries #
Recently, SE(3)-equivariant self-attention variants have been described[6]. These operations are agnostic to rotations and translations of inputs. It's a useful property to have when operating on 3D coordinates as inputs. For example, for an input molecule, a neural network should deliver the same output if said molecule is inputted with slightly different coordinates for its atoms. An interactive D3.js visualization of this mechanism is for a next post.
Addendum: What about invariance? #
Equivariances are nice to have for your model layers internally.
In the end, however, the final representation is still dependent on the original ordering.
Imagine concatenating all token representations across a sequence, and linearly projecting those to make final a final prediction.
In that case, different orderings of data will still result in different predictions.
What we want at the end of the model is, hence, often invariance.
A simple way to achieve this with a transformer is either through pre-pending a classification (
References and footnotes
For more info on what I mean with geometric priors, refer to the geometric deep learning book. ↩︎
If you are not familiar with transformers, consider reading The Illustrated Transformer by Jay Alammar ↩︎
Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017). ↩︎
Devlin, Jacob, et al. "Bert: Pre-training of deep bidirectional transformers for language understanding." arXiv preprint arXiv:1810.04805 (2018). ↩︎
Su, Jianlin, et al. "Roformer: Enhanced transformer with rotary position embedding." Neurocomputing 568 (2024): 127063. ↩︎
Fuchs, Fabian, et al. "Se (3)-transformers: 3d roto-translation equivariant attention networks." Advances in neural information processing systems 33 (2020): 1970-1981. ↩︎