Four Things Nobody Has
Written About FlashDecode
The decode-vs-prefill kernel argument is settled. What isn't settled: exactly when to switch, what kills FlashDecode at high partition count, why CXL breaks the tiling math, and how a runtime should decide between kernels on a live queue. This piece tries to close those gaps with derivations rather than vibes.
What Everyone Already Knows
The core FlashDecode argument is well-established. In autoregressive decode,
Q_len = 1 while KV_len = N. FlashAttention-2's outer loop
over query tiles collapses to a single iteration, leaving most SMs idle — occupancy
as low as 6% on an H100 at batch=1. FlashDecode fixes this by partitioning the KV
sequence across P SMs and reducing partial softmax outputs via a cheap
tree reduction.
The FlashDecoding paper (Dao et al., 2023) showed 2–4× throughput gains at long context. vLLM shipped it. FlashInfer built JIT variants. The mechanistic story is understood.
What isn't in any of those papers: a closed-form for when FA2 catches up, what happens to the reduction itself at high P, how CXL changes optimal partition size, or how a serving runtime should actually choose between kernels at runtime. That's what this piece is about.
The Known Gap, Quantified
Quick calibration. H100 has 132 SMs. In decode with batch=1, 8 heads,
FlashAttention-2 launches 8 threadblocks — one per (batch, head) pair.
Most SMs sit idle. HBM bandwidth utilisation collapses not because bandwidth is
insufficient but because there aren't enough inflight memory requests to saturate it.
This is the settled argument. Everything below is about the edges of that argument — where it breaks down, what kills it, and what it means for system design.
── Original Analysis Begins ──
The Crossover Formula: Exactly When FA2 Catches Up
originalEvery FlashDecode writeup says "the crossover is around 16K context" and leaves it there. Nobody derives why or gives you a formula to compute it for your own hardware and batch configuration. Here it is.
Setting up the occupancy model
Let S = number of SMs, B = batch size, H = heads per layer.
Under FlashAttention-2, the grid size in decode is exactly B × H threadblocks —
no more. Utilised fraction is:
// saturates at 1.0 when B×H ≥ S; equals B×H/S otherwise
Under FlashDecode with partition factor P, the grid is B × H × P:
// saturates when B×H×P ≥ S
Now, bandwidth-limited throughput scales with utilisation. The relative throughput of FA2 vs FlashDecode converges to 1 when their utilisations are equal. FA2 catches up when:
∴ B* = ⌈S / H⌉
// minimum batch size at which FA2 fully saturates SM occupancy
For H100 (S=132) with Llama-3-70B (H=64 heads using GQA):
// FA2 fully occupies H100 once batch ≥ 17 on GQA models
H100 · Llama-2-70B (MHA: H=64 heads) B* = ⌈132 / 64⌉ = 3
// MHA models recover much faster — batch=3 is enough
The full crossover curve
But batch size alone isn't the full story. Context length N matters too, because
higher N means more KV tiles per step, which increases the kernel launch overhead advantage
of FlashDecode. The complete condition, accounting for kernel launch overhead
L (roughly 3–5 µs per launch on H100):
AND
N × d × 2 / BW > P × L // compute time exceeds launch overhead
// N=context, d=head_dim, BW=HBM bandwidth, P=partition count, L=launch latency
Substituting H100 values (BW=3.35TB/s, L=4µs, d=128):
N > P × 4e-6 × 3.35e12 / (128 × 2) ≈ P × 52K tokens
// For P=16: FlashDecode only beats FA2 on launch overhead if N > 830K tokens // For P=1: already beneficial above N > 52K
Interactive: Compute Your Own B* and Optimal P
The Reduction Inversion: When FlashDecode Kills Itself
originalEvery FlashDecode analysis assumes the cross-SM reduction is "cheap." For small P it is. But nobody has derived the P at which the reduction itself becomes the bottleneck — the point where adding more partitions makes things worse.
Modelling the reduction cost
Each SM produces a partial result: local max m_p, normaliser l_p,
and partial output O_p (shape d × 1). The tree reduction
requires O(log P) rounds. Each round involves:
num_rounds = log₂(P)
Total reduction cost:
T_reduce(P) = (d + 2) × 4 × log₂(P) / BW_SMEM
// BW_SMEM ≈ 19 TB/s on H100 for intra-cluster transfers
The KV loading time for one partition:
// N tokens × d dims × K+V × FP16, split across P partitions
// assuming full HBM BW: 3.35 TB/s on H100
Total FlashDecode step time:
= N×d×4 / (P × BW_HBM) + (d+2)×4×log₂(P) / BW_SMEM
To find the optimal P*, differentiate with respect to P and set to zero:
Solving for P*:
P* = N × BW_SMEM × ln2 / (BW_HBM × (d+2)/d)
// ≈ N × BW_SMEM × ln2 / BW_HBM for large d
Substituting H100 numbers (BW_HBM = 3.35 TB/s, BW_SMEM ≈ 19 TB/s):
At N = 128K tokens: P* ≈ 503K // far beyond 132 SMs — reduction not yet a bottleneck
At N = 32K tokens: P* ≈ 126K // still not a bottleneck at hardware scale
// Constrained by hardware: P_max = S / (B×H) // For B=1, H=8: P_max = 132/8 = 16
Conclusion: at current GPU scale, reduction is never the bottleneck. You should always use max P = floor(S / (B×H)).
CXL Breaks the Tiling Math: Non-Uniform Partition Sizes
originalAt 10M context, KV cache for a 70B model in FP16 exceeds 1 TB. It cannot fit in HBM. CXL-attached DRAM becomes the overflow tier. Every existing FlashDecode analysis assumes a uniform memory access latency across all KV tiles. That assumption dies the moment any KV page lives on CXL.
The latency asymmetry
Why uniform partitions break
Standard FlashDecode assigns each SM a partition of size N/P. If N is split
across HBM (first N_HBM tokens) and CXL (remaining N_CXL tokens),
and all partitions are equal size, then some SMs load from HBM and finish early, while
others load from CXL and finish late. The reduction must wait for the slowest SM —
the stragglers dominate step latency.
// If P_HBM = P_CXL = P/2 and BW_CXL ≈ BW_HBM/8: // CXL partition takes ~8× longer — HBM SMs sit idle 87% of the step
The fix: bandwidth-proportional partitioning
Allocate partitions proportional to each tier's bandwidth so all SMs finish in roughly the same time:
Total partitions: P = P_HBM + P_CXL
P_HBM = P × BW_HBM / (BW_HBM + BW_CXL)
P_CXL = P × BW_CXL / (BW_HBM + BW_CXL)
H100 + CXL example (BW_HBM=3.35 TB/s, BW_CXL=400 GB/s, P=16):
P_HBM = 16 × 3350/(3350+400) = 14.3 → 14
P_CXL = 16 × 400/(3350+400) = 1.7 → 2
// 14 SMs handle HBM tokens, 2 SMs handle CXL tokens // Both finish in approximately equal time → no straggler penalty
Dynamic Kernel Switching: A Runtime Policy That Actually Works
originalExisting serving runtimes pick FA2 or FlashDecode at startup and keep it for the session. That's wrong. The optimal kernel changes per-step as batch composition shifts. Here's a lightweight policy that switches correctly.
Why static choice is wrong
A production serving system at any given moment holds a mix of requests:
some at 2K context (recently arrived), some at 128K (long-running agents),
some being prefilled. The effective B × H for the decode step changes
every iteration as requests complete and new ones start.
Static FA2 means agents at 256K context get bad latency indefinitely. Static FlashDecode means short requests pay unnecessary reduction overhead.
The switching policy
From Angle I, the crossover is simply whether FA2 can saturate SM occupancy. Combine with a minimum-N threshold to avoid unnecessary FlashDecode launch overhead at short context:
The scheduling-kernel feedback loop
This switching logic creates a second-order effect worth designing for: the scheduler can favour batching long-context requests together to keep FlashDecode's high-P configuration stable across multiple steps, rather than mixing long and short requests in a way that constantly crosses the threshold.
Expected gains from dynamic switching
In a realistic serving trace with 70% short requests (<8K) and 30% long requests (>64K), static FlashDecode wastes ~12 µs per short-context step on reduction overhead that buys nothing (FA2 would be equally fast and cheaper to launch). Static FA2 leaves ~3× throughput on the table for long-context steps.
What These Four Things Mean Together
The four analyses converge on a few actionable conclusions that differ from the current community consensus:
| Common belief | What the math shows |
|---|---|
| "Crossover is ~16K context" | Crossover is batch-dependent. For GQA models with H=8, FA2 needs batch ≥ 17 to recover occupancy on H100 — independent of context length. |
| "Higher P is always better" | Max achievable P = floor(S/BH). There's no benefit to P beyond that. The reduction is not the constraint at hardware scale; you should always use max P. |
| "CXL extends context, same kernel works" | Uniform partition sizes create 3–8× stragglers across HBM/CXL boundary. Bandwidth-proportional partitioning is required for CXL-extended KV serving. |
| "Pick a kernel at startup" | Static choice leaves significant performance on the table for mixed traffic. Dynamic switching is a scheduler change that takes <1 µs and can deliver 2.6× average improvement over static FA2. |
What to Build
In order of implementation difficulty and expected impact:
| # | Work | Difficulty | Expected gain | Where |
|---|---|---|---|---|
| 1 | Dynamic FA2/FD kernel switching in vLLM scheduler | ~1 week | 2–3× on mixed traffic p99 | vLLM scheduler |
| 2 | Expose P as a runtime parameter (not compile-time) | ~2 days | Allows crossover-optimal P per step | FlashInfer / vLLM |
| 3 | Tier-aware partition allocation for CXL KV | ~2–3 weeks | Eliminates straggler penalty at 10M ctx | Runtime + kernel |
| 4 | Validate P* formula with ncu profiling at N=1M | ~3 days | Confirms reduction cost model | Benchmarking |
| 5 | GQA-aware B* computation in autotune | ~1 day | Correct crossover for all model families | Triton autotune |