I was reading through A Mathematical Framework for Transformer Circuits recently, which is a study of attention-only Transformers (i.e., no MLP layers). Their goal is to identify the mechanisms Transformers are capable of expressing theoretically and that they tend to express empirically. To do so, they consider three levels of increasing complexity:
In this post, I’d like to summarize the differences in what each level can learn, since I didn’t come away with a strong understanding of this on my first few readthroughs. Notably, I won’t be getting into the math and empirical data that justifies these differences (there’s plenty of that in the original post!). Also, I don’t believe I’ve redefined all of the necessary terminology to make this post self-contained, so it’s more like supplementary material for the original post.
The task considered in the original post is next token prediction. Since there’s no attention pattern here, we can’t incorporate any context, so the best we can do to predict the next token is model bigram probabilities. That is, given a token $A$, we use the matrix $W_U W_E$ to model the probability of $AB$ for each $B$ (e.g., “cool beans” is more likely than “cool fire”).
<aside> 📖
Here’s an example where bigram probabilities go wrong, assuming “cool beans” is the most common bigram starting with “cool”.
Prompt: “He was going to say ‘cool beans’, but instead said ‘cool ___’”
Prediction: “He was going to say ‘cool beans’, but instead said ‘cool beans’”
Answer: “He was going to say ‘cool beans’, but instead said ‘cool cool’”
</aside>
Once we add an attention layer, we can now model what are called skip-trigrams. A skip-trigram is of the form $A \ldots BC$, where there’s some arbitrary number of tokens between $A$ and $B$. We call $A$ the source token, $B$ the destination token, and $C$ the output token. An attention layer can model skip-trigrams as follows. Suppose you have a destination token $B$. By computing attention scores, $B$ can “search” for instances of $A$, and if one is found, the layer then uses the probability of $A \ldots BC$, for each possible $C$, to influence the output logits. Now, attention is, of course, not binary, so it’s more likely that you attend over multiple skip-trigram prefixes $A\ldots B \$, $A’ \ldots B \$, etc., but this explanation is meant to be illustrative.
In the original post, they show some examples of the types of relationships Transformers use skip-trigrams to model. Here’s how to read the table below. In terms of attention, the destination token is the one that sends out the query vector, the source token sends the key vector, and the out token is the predicted next token after the destination token. This table was constructed by choosing a source token, finding the destination tokens that attend over it, and finding the out token probabilities the source token strengthens the most.
So in the first example above, they choose the source token “perfect”. They find that, for the attention head under study, the destination tokens “are”, “looks”, “is”, and “provides” all attend strongly over “perfect”. The effect that “perfect” then has on the output logits is mostly boosting the completion “perfect” and also somewhat boosting “super”, “absolute”, and “pure”.
One thing to note is that often the most likely output token is a direct copy of the source token, and the other weaker possibilities are of the same type as the source token, suggesting skip-trigrams can be used both for copying and for cuing what type of token should follow.
It will be useful to focus on the the copying form of skip-trigrams (i.e., $B \ldots A B$), to compare with two-layer Transformers.
This form of copying gives a way of “refining” the bigram statistics you get from zero-layer, attention-only Transformers. Instead of blindly choosing the most likely completion $AB’$ according to bigram statistics, attend over bigram instances, and boost the probability of the attended-over instances. Note that, when there are no bigram instances to attend over in the context, the residual stream ensures we can fall back on bigram statistics, because it will include the zero-layer $W_U W_E$ term.
<aside> 📖 Below are some examples of where skip-trigrams can fail to predict the correct token if they blindly copy. In the sentences with predicted completions, the italicized word is the first element in the skip-trigram.
Prompt: “Seeing the dog panting, we all agreed it was a hot ___.”
Prediction: “Seeing the dog panting, we all agreed it was a hot dog.”
Answer: “Seeing the dog panting, we all agreed it was a hot day.”
Prompt: “Computing the weather in his head, Chris determined there would be a lot of cloud ___ today.”
Prediction: “Computing the weather in his head, Chris determined there would be a lot of cloud computing today.”
Answer: “Computing the weather in his head, Chris determined there would be a lot of cloud cover today.”
While the network can now model context, it’s not modeling enough context to capture that having a bigram completion in context doesn’t mean it should be used. Of course, we’re looking at a single attention head, so it’s possible the zero-layer term and the other attention heads can prevent the degenerate predictions above.
</aside>
Now, a two-layer, attention-only Transformer gives us a slightly more sophisticated notion of copying. In particular, if we have a destination token $A$, we can search our context for instances $AB$ and promote this bigram “more confidently”. The difference is subtle. With a one-layer model, we searched backwards for instances of the second token in a bigram, without caring which token proceeded it, and we used this information to boost the probability of that token following the current token. With a two-layer model, we can actually search for instances of the entire bigram $AB$, and when we find it, we can copy the token $B$ to complete $A$.