Tokens Need Context
A language model generates tokens one at a time. At each step it must decide which earlier tokens are relevant — across potentially hundreds of thousands of positions. Attention is the mechanism that makes that comparison tractable.
Consider: "The trophy doesn't fit in the suitcase because it's too large."
What does it refer to? The answer requires tracking relationships between
distant words, not just reading left-to-right. Attention gives every token a way to
look at every other token and pull in what matters.
Queries, Keys, and Values
Inside an attention layer, each token's embedding is linearly projected into three vectors. Conceptually:
- Query (Q): what this token is searching for.
- Key (K): what this token advertises as searchable.
- Value (V): what this token contributes if selected.
The dot-product between a query and all keys produces relevance scores. A softmax normalizes them into weights. Those weights blend the values into the final output.
The 1/√d_k scaling keeps dot-product magnitudes stable as the head
dimension grows — without it, softmax saturates into near-one-hot distributions and
gradients vanish during training.
Why Multiple Heads?
A single attention head computes one kind of "relevance". Multi-head attention runs H independent attention operations in parallel, each with its own projection matrices. Different heads can specialize: one might track syntactic agreement, another coreference, another positional proximity. Their outputs are concatenated and projected back to the model dimension.
The KV Cache: Speed Optimization That Becomes a Memory Problem
During training, all tokens in a sequence are processed simultaneously — fully parallelized across the time dimension. During autoregressive generation it's different: the model emits one token per step, and each new token must attend to every prior token.
Without caching, generating token 512 would require recomputing the keys and values for all 511 preceding tokens. And then again for token 513. That's quadratic work in sequence length — obviously unusable in production.
The fix is the KV cache: store each token's K and V vectors once they are computed, and reuse them for all future decode steps. Only the newest token needs a fresh Q, K, and V. Memory bandwidth replaces redundant compute.
How Fast Does It Grow?
Plug in realistic numbers for a Llama-3-70B style model (80 layers, 8 KV heads, 128 head_dim, FP16, 128K context, 32 concurrent users):
Before model weights, runtime buffers, and scheduling headroom. For a model with full MHA (32 KV heads instead of 8), that 1.25 TiB balloons to roughly 5 TiB. This is why KV head count is one of the most consequential architectural choices in modern LLM design.
The key term in the formula is kv_heads. Standard MHA sets
kv_heads = query_heads. GQA and MQA reduce kv_heads
directly — everything else stays the same.
Multi-Head Attention: Expressive, Expensive
Standard Multi-Head Attention assigns each query head its own dedicated key head and value head. For a model with 32 query heads, that means 32 K heads and 32 V heads — all of which must be stored and loaded from the KV cache during every decode step.
The benefit is expressivity. Each head operates in its own learned subspace. One head may capture syntactic dependencies, another semantic similarity, another long-range coreference. MHA gives the model maximum freedom in how it routes information.
The cost is proportional to head count. A 32-head MHA model stores 32 K/V pairs per token per layer, per element of the batch. At long contexts this KV footprint becomes the dominant memory consumer, often eclipsing the model weights themselves.
Multi-Query Attention: One Shared K/V
Introduced by Noam Shazeer in 2019, Multi-Query Attention takes the aggressive position: keep many query heads, but collapse keys and values down to a single shared head. All query heads attend to the exact same K and V representations.
This collapses the KV cache by a factor of H (the number of query heads). For a 32-head model, that's a 32× reduction. Batch sizes can increase, throughput scales up, and long-context serving becomes dramatically cheaper.
Models trained from scratch with MQA tend to recover quality better than models where MQA is retrofitted after pre-training. The Ainslie et al. (2023) GQA paper demonstrated that MQA-from-scratch is competitive, but that GQA hits a better quality/efficiency point for most use cases.
Grouped-Query Attention: The Practical Middle Ground
Grouped-Query Attention (Ainslie et al., 2023) occupies the design space between MHA and MQA. Query heads are partitioned into G groups. Each group shares one key head and one value head, learned independently across groups.
If a model has 32 query heads and 8 KV heads, each KV head serves exactly 4 query heads (group size = 4). The KV cache is 4× smaller than full MHA, while preserving 8 distinct K/V representations across groups.
num_kv_heads = num_query_heads
num_kv_heads = 1
1 < num_kv_heads < num_query_heads
num_query_heads / num_kv_heads
KV Cache Reduction at a Glance
Why GQA Wins in Practice
Empirical results from the GQA paper show that, with enough training steps, models with a small number of KV groups (e.g., 8 groups on a 64-head model) reach quality within a fraction of a percent of full MHA — while offering massive inference savings. MQA, by contrast, can require more training compute to reach comparable perplexity, and still underperforms on tasks requiring fine-grained contextual discrimination.
GQA is now the dominant choice for frontier open-weight models. It hits the efficiency target without paying a meaningful quality tax.
What Real Models Actually Use
Abstract principles are clearer with concrete examples. Here are the attention configurations of widely-used models:
| Model | Query Heads | KV Heads | Type | GQA Ratio |
|---|---|---|---|---|
| GPT-2 (all sizes) | 12–25 | = Q heads | MHA | 1× |
| Llama 2 7B | 32 | 32 | MHA | 1× |
| Llama 2 70B | 64 | 8 | GQA | 8× |
| Llama 3 8B | 32 | 8 | GQA | 4× |
| Llama 3 70B | 64 | 8 | GQA | 8× |
| Mistral 7B v0.1 | 32 | 8 | GQA | 4× |
| Falcon 7B | 71 | 1 | MQA | 71× |
| Gemma 2 9B | 16 | 8 | GQA | 2× |
| Qwen2.5 72B | 64 | 8 | GQA | 8× |
The pattern is clear: models released from 2023 onwards almost universally adopt GQA, with KV head counts of 8 being extremely common. Pure MQA (1 KV head) still appears in models where decode throughput is the absolute priority. Legacy MHA remains in older models trained before GQA was standardized.
The Full Tradeoff Map
Choosing an attention variant means navigating four axes simultaneously: model quality, KV cache size, decode throughput, and training cost.
Most expressive
Each query head has dedicated K/V representations. Strongest quality ceiling, especially on tasks requiring fine-grained multi-perspective attention.
↳ Largest KV cache · Lowest decode throughput
Practical balance
Groups of query heads share K/V. Cache is 4–8× smaller with quality losses typically below 1% relative perplexity at equivalent training compute.
↳ Dominant choice for production LLMs since 2023
Maximum efficiency
Single shared K/V for all queries. Smallest possible cache, highest decode throughput, but requires more training to recover quality.
↳ Best when raw throughput > quality margin
| Dimension | MHA | GQA | MQA |
|---|---|---|---|
| KV heads | = Q heads | fraction of Q heads | 1 |
| KV cache size | Largest | Reduced (÷ group size) | Minimum (÷ Q heads) |
| Decode throughput | Lowest | High | Highest |
| Max batch size (fixed VRAM) | Smallest | Larger | Largest |
| Representational diversity | Highest | Good | Lowest |
| Training to match MHA quality | Baseline | ~same budget | Needs more steps |
| Production adoption (2024–25) | Legacy | Dominant | Niche |
A useful mental model: MHA is like giving every analyst their own private notes. MQA forces everyone to share a single whiteboard. GQA creates small teams — each team shares notes internally, but different teams have different notebooks.
Implementation Intuition
In code, GQA changes the tensor shapes for K and V. Queries retain the full head count; keys and values are smaller:
Each query head must be paired with the correct KV head. Logically, each KV head is
repeated group_size times. Production kernels avoid physically copying
data and instead broadcast in-place, but the conceptual mapping is straightforward:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Expand KV heads to match query head count.
x: (batch, seq_len, num_kv_heads, head_dim)
returns: (batch, seq_len, num_kv_heads * n_rep, head_dim)
"""
bs, seq_len, n_kv_heads, head_dim = x.shape
if n_rep == 1: # MHA or MQA with n_rep==1
return x
return (
x[:, :, :, None, :] # insert group dim
.expand(bs, seq_len, n_kv_heads, n_rep, head_dim) # broadcast
.reshape(bs, seq_len, n_kv_heads * n_rep, head_dim)
)
# Example: Llama-3-8B — 32 query heads, 8 kv heads → group_size 4
num_query_heads = 32
num_kv_heads = 8
group_size = num_query_heads // num_kv_heads # 4
k = repeat_kv(k, group_size) # (B, S, 32, head_dim)
v = repeat_kv(v, group_size) # (B, S, 32, head_dim)
# Now standard scaled dot-product attention works as-is
The KV cache only stores the original num_kv_heads tensors.
The repetition is applied at attention time, not stored. This is where the memory
saving actually comes from: you cache 8 heads worth of K/V instead of 32, then
expand on the fly during compute.
Flash Attention and GQA
Flash Attention 2 and 3 have native GQA support (the flash_attn_func with
different Q and KV head counts). The kernel handles the broadcast internally without
materializing the expanded tensor, so real deployments benefit from both the cache
savings and efficient fused attention. Using repeat_kv explicitly is
mainly useful for understanding or for non-Flash implementations.
Practical Gotchas
A few non-obvious constraints to know before reaching for GQA in a new architecture:
Group size must be an integer. If you want 32 query heads, valid KV head counts are 1, 2, 4, 8, 16, or 32. Choosing 6 or 10 KV heads would create unequal groups and is not standard.
When sharding attention across multiple GPUs (tensor parallelism degree T),
each GPU handles num_query_heads / T query heads. For this to work cleanly
with GQA, you also need num_kv_heads ≥ T so that each GPU has at least
one KV head. This is why 8 KV heads is common — it supports up to 8-way tensor
parallelism cleanly.
Speculative decoding uses a draft model to propose multiple tokens, which are then verified in parallel by the main model. With GQA, the verification step is cheaper per token — but multiple draft tokens mean multiple simultaneous KV appends. Profile the combined system rather than reasoning about each component in isolation.
The GQA paper showed a method to convert an MHA checkpoint to GQA or MQA by mean-pooling the original KV heads into groups. This allows adaptation with limited fine-tuning instead of full retraining, but the resulting model generally underperforms an equivalent model trained from scratch with GQA. The paper recommends treating conversion as a last resort.
Many serving frameworks (TensorRT-LLM, vLLM) quantize the KV cache to INT8 or FP8 to further reduce memory pressure. GQA and KV quantization are orthogonal optimizations and can be combined — but test on your target tasks, because quantizing shared KV heads (GQA) can accumulate error across all query heads in the same group.
The Short Version
Standard Multi-Head Attention (MHA) is expressive: every query head has private key and value projections. But storing all those KV vectors at long contexts and high concurrency consumes enormous GPU memory.
Multi-Query Attention (MQA) goes to the other extreme: one shared K/V head for all queries. Maximum cache savings, highest throughput, but real quality costs.
Grouped-Query Attention (GQA) finds the pragmatic midpoint: small groups of query heads share a K/V head. The cache shrinks by the group size factor (typically 4–8×). Quality stays close to MHA with equivalent training. Every major open-weight model family released since 2023 uses GQA by default.
As context windows grow from tens of thousands to millions of tokens, and as inference moves to higher concurrency and longer agentic chains, the KV cache becomes more dominant in system design. GQA — and increasingly hybrid architectures that combine it with KV cache quantization, paged attention, and CXL memory tiering — is how the field keeps memory pressure under control without sacrificing the quality of the models people actually want to run.
Sources & Further Reading
- Ainslie et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arxiv.org/abs/2305.13245 — the paper that popularized GQA and introduced the mean-pooling checkpoint conversion method.
- Shazeer (2019). Fast Transformer Decoding: One Write-Head is All You Need. arxiv.org/abs/1911.02150 — the original MQA paper.
- Vaswani et al. (2017). Attention Is All You Need. arxiv.org/abs/1706.03762 — the original Transformer paper introducing MHA.
- Meta AI (2024). Llama 3 Model Card — source for Llama 3 8B and 70B attention configs.
- Dao et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. arxiv.org/abs/2205.14135 — context for how GQA is implemented efficiently in practice.