Attention
How tokens look at each other — the mechanism that made transformers possible.
The problem attention solves
After the embedding step, you have a stack of vectors — one per token — each carrying meaning but in isolation. The vector for "it" in "the cat sat on the mat because it was warm" just says "this is the pronoun 'it'". It says nothing about what "it" refers to. To resolve that, the vector for "it" needs to look at the other vectors and pull in information from the relevant ones.
That's the whole job of attention: let each token update itself based on a weighted blend of every other token in the sequence.
Step by step on a 4-token toy example
Take the sentence "the cat sat down". Four tokens. Each starts with its own embedding vector (let's pretend each is just 4 numbers for illustration):
the → [0.1, 0.5, 0.0, 0.2] cat → [0.4, 0.1, 0.7, 0.0] sat → [0.0, 0.6, 0.2, 0.4] down → [0.2, 0.0, 0.5, 0.3]
Attention will produce four new vectors of the same shape — refined versions where each one has incorporated information from the others.
1. Compute Q, K, V for every token
The model has three learned weight matrices: W_Q, W_K, W_V. For each token's embedding, multiply by each of these to get three new vectors:
- Query (Q) — "what am I looking for from the others?"
- Key (K) — "what am I advertising about myself?"
- Value (V) — "what will I actually share if you choose to listen?"
Imagine each token wears a name-tag (K) and is silently asking a question (Q). When two tokens look at each other, they check: how well does my question match your name-tag? If it's a strong match, I'll listen carefully to what you have to say (V). If it's a weak match, I'll mostly ignore you.
2. Compute attention scores
For each token A and each other token B, compute the dot product of A's Q with B's K. That's the raw "how much should A attend to B" score. Do this for every pair and you get a matrix:
the cat sat down the [ 0.6 0.3 0.1 0.0 ] cat [ 0.2 0.7 0.4 0.1 ] sat [ 0.1 0.6 0.5 0.3 ] down [ 0.0 0.4 0.5 0.4 ]
Read row 2 left to right: "cat looks at the with score 0.2, at itself with 0.7, at sat with 0.4, at down with 0.1."
3. Softmax → attention weights
Apply softmax across each row. This turns the raw scores into a probability distribution that sums to 1:
the cat sat down the [ 0.51 0.30 0.13 0.06 ] cat [ 0.18 0.43 0.27 0.12 ] sat [ 0.18 0.34 0.30 0.18 ] down [ 0.20 0.30 0.32 0.18 ]
Each row now is "what fraction of attention each token pays to each other token."
4. Weighted sum of V's
For each token, compute its new vector as the weighted sum of all the V vectors, weighted by that token's attention row. So the new vector for "cat" = 0.18 × V(the) + 0.43 × V(cat) + 0.27 × V(sat) + 0.12 × V(down).
Output: four new vectors, same shape, each having absorbed contextual information from the others according to learned attention patterns.
The pattern is context-sensitive — try it
Here's a famous test case (a "Winograd schema") where flipping one word changes which earlier noun the pronoun "it" refers to. In the first version, "too big" implies the trophy is the problem. Flip to "too small" and the suitcase is now the issue. Hover "it" in each variant — the attention weights flip:
This is the whole reason attention is powerful: it isn't a fixed lookup. The same word ("it") routes attention differently depending on the rest of the sentence, because Q, K, V are computed from the contextually-influenced embeddings flowing up through the layers.
Multi-head attention
A single attention computation captures one kind of relationship at a time — maybe "what does this pronoun refer to," maybe "what's the verb for this subject," maybe "which token came right before me." Real models run attention many times in parallel, with different learned Q/K/V projections. Each parallel run is called a head.
Picture the same meeting happening in eight rooms simultaneously, each room paying attention to a different aspect. Room 1 listens for grammatical agreement. Room 2 listens for "who is doing what to whom." Room 3 listens for proximity. After all rooms finish, their notes are collected and merged into a single update for each token.
Concretely: if your embedding dimension is 4096 and you use 32 heads, each head operates on a 128-dimensional slice of the embedding. The 32 heads' outputs (each 128-dim) are concatenated back into a 4096-dim vector and then linearly projected. This way the cost per head is small, but you get many parallel attention patterns.
The causal autoregressive property
In a decoder-only LLM, each token can only attend to itself and earlier tokens, never later ones. This isn't a quirky training shortcut — it's the structural property that makes the model a next-token predictor at all.
If you're asked "what comes next?", you can only read what's been written so far, not what hasn't been written yet. The causal property is the model living by the same rule. During training, this is enforced by a "causal mask" that zeros out attention scores looking forward; during inference, the autoregressive loop enforces it naturally because future tokens literally don't exist yet.
It's tempting to call the causal mask a "training trick" because that's where you see the matrix mask code most often. But during inference there's nothing to mask — you only have past tokens. The deeper truth is that the entire architecture is designed around the principle that each position predicts what comes next given only what came before. The mask is just how you efficiently train this in parallel without leaking the future.
Why attention is O(n²) — and what that costs
For a sequence of n tokens, attention computes a score for every pair. That's n × n = n² scores. Per layer. Per head. For modest contexts (say 4k tokens, 32 heads, 32 layers), that's already 16 million × 32 × 32 ≈ 16 billion attention scores. Double the context to 8k? You get four times the cost.
This quadratic scaling is why long context is the most expensive lever you can pull. It's also why a parade of approximate attention schemes exist (sliding window, sparse, linear, FlashAttention) — all trying to reduce the wall-clock impact of that n² without losing too much quality.
Multi-query attention (MQA) and grouped-query attention (GQA). In standard multi-head attention, every head has its own K and V. The KV cache (see inference.html) stores K and V for every past token, every head, every layer — and that scales with n_heads. MQA shares one K/V across all heads (cache shrinks dramatically; small quality cost). GQA groups heads — say, 32 heads share 8 K/V groups — for a middle ground. Most modern open models (Llama 3, Mistral, Mixtral) use GQA. This single change is the biggest reason their KV cache is small enough to ship on consumer hardware.