MAN\SH AI / Writings

· Transformer Inference · 14 min read

Transformer Internals · From First Principles

GQA & MQA,
decoded.

Why modern LLMs stopped giving every attention head its own keys and values — and what that one change means for memory, latency, and serving economics at scale.

Multi-Head · Multi-Query · Grouped-Query KV cache mechanics Real model configs Implementation & gotchas
Query heads vs KV heads (8Q shown) ↓ KV cache pressure
Queries
Q1
Q2
Q3
Q4
Q5
Q6
Q7
Q8
MHA KV
K/V1
K/V2
K/V3
K/V4
K/V5
K/V6
K/V7
K/V8
GQA KV
K/V1
K/V2
MQA KV
K/V
MHA
8× heads
GQA
2× heads
MQA
1× head
Contents Attention 101 KV Cache MHA MQA GQA Real Models Tradeoffs Code Gotchas Summary

Tokens Need Context

A language model generates tokens one at a time. At each step it must decide which earlier tokens are relevant — across potentially hundreds of thousands of positions. Attention is the mechanism that makes that comparison tractable.

Consider: "The trophy doesn't fit in the suitcase because it's too large." What does it refer to? The answer requires tracking relationships between distant words, not just reading left-to-right. Attention gives every token a way to look at every other token and pull in what matters.

Queries, Keys, and Values

Inside an attention layer, each token's embedding is linearly projected into three vectors. Conceptually:

  • Query (Q): what this token is searching for.
  • Key (K): what this token advertises as searchable.
  • Value (V): what this token contributes if selected.

The dot-product between a query and all keys produces relevance scores. A softmax normalizes them into weights. Those weights blend the values into the final output.

-- Scaled dot-product attention -- Attention(Q, K, V) = softmax( Q·Kᵀ / √d_k ) · V

The 1/√d_k scaling keeps dot-product magnitudes stable as the head dimension grows — without it, softmax saturates into near-one-hot distributions and gradients vanish during training.

Why Multiple Heads?

A single attention head computes one kind of "relevance". Multi-head attention runs H independent attention operations in parallel, each with its own projection matrices. Different heads can specialize: one might track syntactic agreement, another coreference, another positional proximity. Their outputs are concatenated and projected back to the model dimension.

-- Multi-head attention -- MultiHead(Q, K, V) = Concat(head₁, …, headₕ) · W_O headᵢ = Attention(Q·W_Qᵢ, K·W_Kᵢ, V·W_Vᵢ)

The KV Cache: Speed Optimization That Becomes a Memory Problem

During training, all tokens in a sequence are processed simultaneously — fully parallelized across the time dimension. During autoregressive generation it's different: the model emits one token per step, and each new token must attend to every prior token.

Without caching, generating token 512 would require recomputing the keys and values for all 511 preceding tokens. And then again for token 513. That's quadratic work in sequence length — obviously unusable in production.

The fix is the KV cache: store each token's K and V vectors once they are computed, and reuse them for all future decode steps. Only the newest token needs a fresh Q, K, and V. Memory bandwidth replaces redundant compute.

The fundamental tension: the KV cache is the single most important inference optimization, and it is also the single largest consumer of GPU HBM for long-context, high-concurrency workloads.

How Fast Does It Grow?

-- KV cache size in bytes (per layer) -- size = 2 × batch_size × seq_len × kv_heads × head_dim × bytes_per_element -- The "2" accounts for K and V; repeated across all layers -- total = size × num_layers

Plug in realistic numbers for a Llama-3-70B style model (80 layers, 8 KV heads, 128 head_dim, FP16, 128K context, 32 concurrent users):

~40 GBKV cache per active sequence (128K ctx)
~1.25 TiBtotal at 32 concurrent users

Before model weights, runtime buffers, and scheduling headroom. For a model with full MHA (32 KV heads instead of 8), that 1.25 TiB balloons to roughly 5 TiB. This is why KV head count is one of the most consequential architectural choices in modern LLM design.

The key term in the formula is kv_heads. Standard MHA sets kv_heads = query_heads. GQA and MQA reduce kv_heads directly — everything else stays the same.

Multi-Head Attention: Expressive, Expensive

Standard Multi-Head Attention assigns each query head its own dedicated key head and value head. For a model with 32 query heads, that means 32 K heads and 32 V heads — all of which must be stored and loaded from the KV cache during every decode step.

MHA
Q1K/V1 · Q2K/V2 · Q3K/V3 · Q4K/V4 ···

The benefit is expressivity. Each head operates in its own learned subspace. One head may capture syntactic dependencies, another semantic similarity, another long-range coreference. MHA gives the model maximum freedom in how it routes information.

The cost is proportional to head count. A 32-head MHA model stores 32 K/V pairs per token per layer, per element of the batch. At long contexts this KV footprint becomes the dominant memory consumer, often eclipsing the model weights themselves.

Memory bandwidth, not compute, is often the bottleneck. During decode, the GPU reads K and V vectors from HBM for every token in the cache on every step. Fewer KV heads means fewer bytes moved — that's the direct link between head count and decode throughput.

Multi-Query Attention: One Shared K/V

Introduced by Noam Shazeer in 2019, Multi-Query Attention takes the aggressive position: keep many query heads, but collapse keys and values down to a single shared head. All query heads attend to the exact same K and V representations.

MQA
Q1 Q2 Q3 Q4 Q5 ··· 1× K/V

This collapses the KV cache by a factor of H (the number of query heads). For a 32-head model, that's a 32× reduction. Batch sizes can increase, throughput scales up, and long-context serving becomes dramatically cheaper.

The quality tradeoff is real. All query heads share a single representational view of keys and values. Some heads may be able to compensate by learning more specialized query projections, but the model has less capacity to simultaneously track diverse relationship types. Empirically, MQA often works well at moderate context lengths but can lose ground to GQA on complex, long-range reasoning tasks.

Models trained from scratch with MQA tend to recover quality better than models where MQA is retrofitted after pre-training. The Ainslie et al. (2023) GQA paper demonstrated that MQA-from-scratch is competitive, but that GQA hits a better quality/efficiency point for most use cases.

Grouped-Query Attention: The Practical Middle Ground

Grouped-Query Attention (Ainslie et al., 2023) occupies the design space between MHA and MQA. Query heads are partitioned into G groups. Each group shares one key head and one value head, learned independently across groups.

GQA
Q1Q2Q3Q4 K/V1
Q5Q6Q7Q8 K/V2

If a model has 32 query heads and 8 KV heads, each KV head serves exactly 4 query heads (group size = 4). The KV cache is 4× smaller than full MHA, while preserving 8 distinct K/V representations across groups.

MHA identity num_kv_heads = num_query_heads
MQA extreme num_kv_heads = 1
GQA general 1 < num_kv_heads < num_query_heads
Group size num_query_heads / num_kv_heads

KV Cache Reduction at a Glance

MHA
32 KV heads
100%
GQA ×4
8 KV heads
25%
MQA
1
3.1%

Why GQA Wins in Practice

Empirical results from the GQA paper show that, with enough training steps, models with a small number of KV groups (e.g., 8 groups on a 64-head model) reach quality within a fraction of a percent of full MHA — while offering massive inference savings. MQA, by contrast, can require more training compute to reach comparable perplexity, and still underperforms on tasks requiring fine-grained contextual discrimination.

GQA is now the dominant choice for frontier open-weight models. It hits the efficiency target without paying a meaningful quality tax.

What Real Models Actually Use

Abstract principles are clearer with concrete examples. Here are the attention configurations of widely-used models:

Configurations from model cards and technical reports. GQA ratio = query_heads / kv_heads.
Model Query Heads KV Heads Type GQA Ratio
GPT-2 (all sizes) 12–25 = Q heads MHA
Llama 2 7B 32 32 MHA
Llama 2 70B 64 8 GQA
Llama 3 8B 32 8 GQA
Llama 3 70B 64 8 GQA
Mistral 7B v0.1 32 8 GQA
Falcon 7B 71 1 MQA 71×
Gemma 2 9B 16 8 GQA
Qwen2.5 72B 64 8 GQA

The pattern is clear: models released from 2023 onwards almost universally adopt GQA, with KV head counts of 8 being extremely common. Pure MQA (1 KV head) still appears in models where decode throughput is the absolute priority. Legacy MHA remains in older models trained before GQA was standardized.

The Full Tradeoff Map

Choosing an attention variant means navigating four axes simultaneously: model quality, KV cache size, decode throughput, and training cost.

MHA

Most expressive

Each query head has dedicated K/V representations. Strongest quality ceiling, especially on tasks requiring fine-grained multi-perspective attention.

↳ Largest KV cache · Lowest decode throughput

GQA

Practical balance

Groups of query heads share K/V. Cache is 4–8× smaller with quality losses typically below 1% relative perplexity at equivalent training compute.

↳ Dominant choice for production LLMs since 2023

MQA

Maximum efficiency

Single shared K/V for all queries. Smallest possible cache, highest decode throughput, but requires more training to recover quality.

↳ Best when raw throughput > quality margin

Dimension MHA GQA MQA
KV heads = Q heads fraction of Q heads 1
KV cache size Largest Reduced (÷ group size) Minimum (÷ Q heads)
Decode throughput Lowest High Highest
Max batch size (fixed VRAM) Smallest Larger Largest
Representational diversity Highest Good Lowest
Training to match MHA quality Baseline ~same budget Needs more steps
Production adoption (2024–25) Legacy Dominant Niche

A useful mental model: MHA is like giving every analyst their own private notes. MQA forces everyone to share a single whiteboard. GQA creates small teams — each team shares notes internally, but different teams have different notebooks.

Implementation Intuition

In code, GQA changes the tensor shapes for K and V. Queries retain the full head count; keys and values are smaller:

-- Tensor shapes in a GQA layer -- Q: [batch, seq_len, num_query_heads, head_dim] K: [batch, seq_len, num_kv_heads, head_dim] V: [batch, seq_len, num_kv_heads, head_dim] -- group_size = num_query_heads // num_kv_heads --

Each query head must be paired with the correct KV head. Logically, each KV head is repeated group_size times. Production kernels avoid physically copying data and instead broadcast in-place, but the conceptual mapping is straightforward:

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """Expand KV heads to match query head count.

    x: (batch, seq_len, num_kv_heads, head_dim)
    returns: (batch, seq_len, num_kv_heads * n_rep, head_dim)
    """
    bs, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:          # MHA or MQA with n_rep==1
        return x
    return (
        x[:, :, :, None, :]                                # insert group dim
         .expand(bs, seq_len, n_kv_heads, n_rep, head_dim) # broadcast
         .reshape(bs, seq_len, n_kv_heads * n_rep, head_dim)
    )

# Example: Llama-3-8B — 32 query heads, 8 kv heads → group_size 4
num_query_heads = 32
num_kv_heads    = 8
group_size      = num_query_heads // num_kv_heads  # 4

k = repeat_kv(k, group_size)  # (B, S, 32, head_dim)
v = repeat_kv(v, group_size)  # (B, S, 32, head_dim)

# Now standard scaled dot-product attention works as-is

The KV cache only stores the original num_kv_heads tensors. The repetition is applied at attention time, not stored. This is where the memory saving actually comes from: you cache 8 heads worth of K/V instead of 32, then expand on the fly during compute.

Flash Attention and GQA

Flash Attention 2 and 3 have native GQA support (the flash_attn_func with different Q and KV head counts). The kernel handles the broadcast internally without materializing the expanded tensor, so real deployments benefit from both the cache savings and efficient fused attention. Using repeat_kv explicitly is mainly useful for understanding or for non-Flash implementations.

Practical Gotchas

A few non-obvious constraints to know before reaching for GQA in a new architecture:

⚠️
num_query_heads must be divisible by num_kv_heads

Group size must be an integer. If you want 32 query heads, valid KV head counts are 1, 2, 4, 8, 16, or 32. Choosing 6 or 10 KV heads would create unequal groups and is not standard.

⚠️
Tensor parallelism interacts with KV head count

When sharding attention across multiple GPUs (tensor parallelism degree T), each GPU handles num_query_heads / T query heads. For this to work cleanly with GQA, you also need num_kv_heads ≥ T so that each GPU has at least one KV head. This is why 8 KV heads is common — it supports up to 8-way tensor parallelism cleanly.

⚠️
Speculative decoding changes the bandwidth math

Speculative decoding uses a draft model to propose multiple tokens, which are then verified in parallel by the main model. With GQA, the verification step is cheaper per token — but multiple draft tokens mean multiple simultaneous KV appends. Profile the combined system rather than reasoning about each component in isolation.

⚠️
MQA retrofit from MHA is imperfect

The GQA paper showed a method to convert an MHA checkpoint to GQA or MQA by mean-pooling the original KV heads into groups. This allows adaptation with limited fine-tuning instead of full retraining, but the resulting model generally underperforms an equivalent model trained from scratch with GQA. The paper recommends treating conversion as a last resort.

⚠️
KV cache quantization compounds with GQA

Many serving frameworks (TensorRT-LLM, vLLM) quantize the KV cache to INT8 or FP8 to further reduce memory pressure. GQA and KV quantization are orthogonal optimizations and can be combined — but test on your target tasks, because quantizing shared KV heads (GQA) can accumulate error across all query heads in the same group.

The Short Version

Standard Multi-Head Attention (MHA) is expressive: every query head has private key and value projections. But storing all those KV vectors at long contexts and high concurrency consumes enormous GPU memory.

Multi-Query Attention (MQA) goes to the other extreme: one shared K/V head for all queries. Maximum cache savings, highest throughput, but real quality costs.

Grouped-Query Attention (GQA) finds the pragmatic midpoint: small groups of query heads share a K/V head. The cache shrinks by the group size factor (typically 4–8×). Quality stays close to MHA with equivalent training. Every major open-weight model family released since 2023 uses GQA by default.

-- The design space, in one line -- MHA: kv_heads == query_heads -- maximum quality GQA: 1 < kv_heads < query_heads -- practical balance ← frontier default MQA: kv_heads == 1 -- maximum throughput

As context windows grow from tens of thousands to millions of tokens, and as inference moves to higher concurrency and longer agentic chains, the KV cache becomes more dominant in system design. GQA — and increasingly hybrid architectures that combine it with KV cache quantization, paged attention, and CXL memory tiering — is how the field keeps memory pressure under control without sacrificing the quality of the models people actually want to run.

Sources & Further Reading

  • Ainslie et al. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arxiv.org/abs/2305.13245 — the paper that popularized GQA and introduced the mean-pooling checkpoint conversion method.
  • Shazeer (2019). Fast Transformer Decoding: One Write-Head is All You Need. arxiv.org/abs/1911.02150 — the original MQA paper.
  • Vaswani et al. (2017). Attention Is All You Need. arxiv.org/abs/1706.03762 — the original Transformer paper introducing MHA.
  • Meta AI (2024). Llama 3 Model Card — source for Llama 3 8B and 70B attention configs.
  • Dao et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. arxiv.org/abs/2205.14135 — context for how GQA is implemented efficiently in practice.