Kernel Architecture · Unpublished Analysis

Four Things Nobody Has
Written About FlashDecode

The decode-vs-prefill kernel argument is settled. What isn't settled: exactly when to switch, what kills FlashDecode at high partition count, why CXL breaks the tiling math, and how a runtime should decide between kernels on a live queue. This piece tries to close those gaps with derivations rather than vibes.

4 original analytical contributions Derivations from first principles ~20 min read
00

What Everyone Already Knows

The core FlashDecode argument is well-established. In autoregressive decode, Q_len = 1 while KV_len = N. FlashAttention-2's outer loop over query tiles collapses to a single iteration, leaving most SMs idle — occupancy as low as 6% on an H100 at batch=1. FlashDecode fixes this by partitioning the KV sequence across P SMs and reducing partial softmax outputs via a cheap tree reduction.

The FlashDecoding paper (Dao et al., 2023) showed 2–4× throughput gains at long context. vLLM shipped it. FlashInfer built JIT variants. The mechanistic story is understood.

What isn't in any of those papers: a closed-form for when FA2 catches up, what happens to the reduction itself at high P, how CXL changes optimal partition size, or how a serving runtime should actually choose between kernels at runtime. That's what this piece is about.

01

The Known Gap, Quantified

Quick calibration. H100 has 132 SMs. In decode with batch=1, 8 heads, FlashAttention-2 launches 8 threadblocks — one per (batch, head) pair. Most SMs sit idle. HBM bandwidth utilisation collapses not because bandwidth is insufficient but because there aren't enough inflight memory requests to saturate it.

SM Utilisation — H100 (132 SMs) · Batch=1 · 8 Heads FA2 8 / 132 = 6% FD P=16 128 / 132 = 97% Each cell = 1 SM · 22 columns × 6 rows = 132 SMs total
Fig 0: The baseline problem. FA2 decode at batch=1, 8 heads leaves 94% of an H100 dark. P=16 FlashDecode partitions fill the grid.

This is the settled argument. Everything below is about the edges of that argument — where it breaks down, what kills it, and what it means for system design.


── Original Analysis Begins ──

I

The Crossover Formula: Exactly When FA2 Catches Up

original

Every FlashDecode writeup says "the crossover is around 16K context" and leaves it there. Nobody derives why or gives you a formula to compute it for your own hardware and batch configuration. Here it is.

Setting up the occupancy model

Let S = number of SMs, B = batch size, H = heads per layer. Under FlashAttention-2, the grid size in decode is exactly B × H threadblocks — no more. Utilised fraction is:

FA2 SM utilisation (decode) util_FA2 = min(1, B × H / S)
// saturates at 1.0 when B×H ≥ S; equals B×H/S otherwise

Under FlashDecode with partition factor P, the grid is B × H × P:

FlashDecode SM utilisation util_FD = min(1, B × H × P / S)
// saturates when B×H×P ≥ S

Now, bandwidth-limited throughput scales with utilisation. The relative throughput of FA2 vs FlashDecode converges to 1 when their utilisations are equal. FA2 catches up when:

Crossover condition B × HS

B* = ⌈S / H⌉
// minimum batch size at which FA2 fully saturates SM occupancy

For H100 (S=132) with Llama-3-70B (H=64 heads using GQA):

H100 · Llama-3-70B (GQA: H=8 KV heads) B* = ⌈132 / 8⌉ = 17
// FA2 fully occupies H100 once batch ≥ 17 on GQA models

H100 · Llama-2-70B (MHA: H=64 heads) B* = ⌈132 / 64⌉ = 3
// MHA models recover much faster — batch=3 is enough
Key insight
GQA (Grouped Query Attention) — used in Llama-3, Mistral, Gemma — dramatically raises B* by reducing H. A model with 8 KV heads needs batch=17 to saturate the H100 with FA2. Below that, FlashDecode wins. GQA doesn't just save KV memory; it inadvertently makes FA2 worse at low batch and FlashDecode more valuable.

The full crossover curve

But batch size alone isn't the full story. Context length N matters too, because higher N means more KV tiles per step, which increases the kernel launch overhead advantage of FlashDecode. The complete condition, accounting for kernel launch overhead L (roughly 3–5 µs per launch on H100):

Full crossover: FA2 wins when both conditions hold B × HS // occupancy recovered
AND
N × d × 2 / BW > P × L // compute time exceeds launch overhead

// N=context, d=head_dim, BW=HBM bandwidth, P=partition count, L=launch latency

Substituting H100 values (BW=3.35TB/s, L=4µs, d=128):
N > P × 4e-6 × 3.35e12 / (128 × 2) ≈ P × 52K tokens

// For P=16: FlashDecode only beats FA2 on launch overhead if N > 830K tokens // For P=1: already beneficial above N > 52K
Practical implication
At moderate context (32K–128K), the benefit is occupancy-driven, not launch-overhead-driven. The right P for moderate context is small (4–8), not the P=64 often cited. P=64 is for million-token contexts only. Using P=64 at 32K context adds reduction overhead that partially negates the occupancy gain.

Interactive: Compute Your Own B* and Optimal P

132
8
128K
1
II

The Reduction Inversion: When FlashDecode Kills Itself

original

Every FlashDecode analysis assumes the cross-SM reduction is "cheap." For small P it is. But nobody has derived the P at which the reduction itself becomes the bottleneck — the point where adding more partitions makes things worse.

Modelling the reduction cost

Each SM produces a partial result: local max m_p, normaliser l_p, and partial output O_p (shape d × 1). The tree reduction requires O(log P) rounds. Each round involves:

Cost per reduction round cost_round = (d + 2) × 4 bytes // d outputs + scalar m + scalar l, FP32
num_rounds = log₂(P)

Total reduction cost:
T_reduce(P) = (d + 2) × 4 × log₂(P) / BW_SMEM
// BW_SMEM ≈ 19 TB/s on H100 for intra-cluster transfers

The KV loading time for one partition:

KV load time per partition T_load(P) = N × d × 2 × 2 bytes / (P × BW_HBM)
// N tokens × d dims × K+V × FP16, split across P partitions
// assuming full HBM BW: 3.35 TB/s on H100

Total FlashDecode step time:

T_FD(P) = T_load(P) + T_reduce(P)
= N×d×4 / (P × BW_HBM) + (d+2)×4×log₂(P) / BW_SMEM

To find the optimal P*, differentiate with respect to P and set to zero:

∂T_FD/∂P = 0 -N×d×4 / (P² × BW_HBM) + (d+2)×4 / (P × ln2 × BW_SMEM) = 0

Solving for P*:
P* = N × BW_SMEM × ln2 / (BW_HBM × (d+2)/d)
// ≈ N × BW_SMEM × ln2 / BW_HBM for large d

Substituting H100 numbers (BW_HBM = 3.35 TB/s, BW_SMEM ≈ 19 TB/s):

H100 optimal P* formula P* ≈ N × 19e12 × 0.693 / 3.35e12 = N × 3.93

At N = 128K tokens: P* ≈ 503K // far beyond 132 SMs — reduction not yet a bottleneck
At N = 32K tokens: P* ≈ 126K // still not a bottleneck at hardware scale

// Constrained by hardware: P_max = S / (B×H) // For B=1, H=8: P_max = 132/8 = 16
Conclusion: at current GPU scale, reduction is never the bottleneck. You should always use max P = floor(S / (B×H)).
T_FD(P) — Load vs Reduction at N=128K · H100 Partition count P → Time (µs) 4 8 16 32 64 T_load (HBM) T_reduce (SMEM) T_total P_max=16 (B=1,H=8)
Fig I: T_load falls hyperbolically with P (more SMs share the work). T_reduce grows as log₂(P) but its magnitude is negligible at hardware-achievable P values. The reduction never becomes the bottleneck — you should always use max P.
When this changes
The analysis flips on future architectures with massive SM counts but modest HBM bandwidth improvements. A hypothetical 1000-SM GPU with 5 TB/s HBM would push P* to within hardware range — at which point reduction cost starts mattering. The formula gives you the threshold. For current H100/A100/H200: always max P.
III

CXL Breaks the Tiling Math: Non-Uniform Partition Sizes

original

At 10M context, KV cache for a 70B model in FP16 exceeds 1 TB. It cannot fit in HBM. CXL-attached DRAM becomes the overflow tier. Every existing FlashDecode analysis assumes a uniform memory access latency across all KV tiles. That assumption dies the moment any KV page lives on CXL.

The latency asymmetry

Memory Tier Latencies — Decode Access Pattern HBM — 80 GB · 3.35 TB/s · ~100 ns CXL DRAM — 1–2 TB · ~400 GB/s · ~300–500 ns CPU DRAM (pinned) — ~200 GB/s · ~1 µs NVMe SSD — ~14 GB/s · ~100 µs (too slow for decode) 0 ns ← latency (not to scale) →
Fig II: HBM latency is ~100 ns; CXL latency is 3–5× higher at ~300–500 ns. Uniform partition sizes assume equal latency — CXL violates that.

Why uniform partitions break

Standard FlashDecode assigns each SM a partition of size N/P. If N is split across HBM (first N_HBM tokens) and CXL (remaining N_CXL tokens), and all partitions are equal size, then some SMs load from HBM and finish early, while others load from CXL and finish late. The reduction must wait for the slowest SM — the stragglers dominate step latency.

Straggler-dominated latency with uniform partitions T_uniform = max(T_HBM_partition, T_CXL_partition) = max(N_HBM×d×4 / (P_HBM × BW_HBM), N_CXL×d×4 / (P_CXL × BW_CXL))

// If P_HBM = P_CXL = P/2 and BW_CXL ≈ BW_HBM/8: // CXL partition takes ~8× longer — HBM SMs sit idle 87% of the step

The fix: bandwidth-proportional partitioning

Allocate partitions proportional to each tier's bandwidth so all SMs finish in roughly the same time:

Optimal bandwidth-proportional split P_HBM / P_CXL = BW_HBM / BW_CXL

Total partitions: P = P_HBM + P_CXL
P_HBM = P × BW_HBM / (BW_HBM + BW_CXL)
P_CXL = P × BW_CXL / (BW_HBM + BW_CXL)

H100 + CXL example (BW_HBM=3.35 TB/s, BW_CXL=400 GB/s, P=16):
P_HBM = 16 × 3350/(3350+400) = 14.3 → 14
P_CXL = 16 × 400/(3350+400) = 1.7 → 2

// 14 SMs handle HBM tokens, 2 SMs handle CXL tokens // Both finish in approximately equal time → no straggler penalty
Uniform vs Bandwidth-Proportional Partitioning across HBM+CXL Uniform 8×HBM partitions — done at T=1 8×CXL partitions — still loading at T=1 ← straggler BW-prop 14 HBM + 2 CXL proportional partitions → all SMs finish at same time → no straggler
Fig III: Uniform partitions create stragglers when memory tiers have different bandwidths. Bandwidth-proportional allocation assigns more tokens-per-SM to HBM (fast) and fewer to CXL (slow), equalising completion time across all partitions.
Why this isn't implemented yet
Bandwidth-proportional partitioning requires the runtime to know which KV pages live on which tier at kernel launch time. Current PagedAttention runtimes don't expose tier affinity to the attention kernel — the kernel just sees physical addresses. This requires a runtime change, not just a kernel change. It's a 2–3 week engineering project, not a research problem.
IV

Dynamic Kernel Switching: A Runtime Policy That Actually Works

original

Existing serving runtimes pick FA2 or FlashDecode at startup and keep it for the session. That's wrong. The optimal kernel changes per-step as batch composition shifts. Here's a lightweight policy that switches correctly.

Why static choice is wrong

A production serving system at any given moment holds a mix of requests: some at 2K context (recently arrived), some at 128K (long-running agents), some being prefilled. The effective B × H for the decode step changes every iteration as requests complete and new ones start.

Static FA2 means agents at 256K context get bad latency indefinitely. Static FlashDecode means short requests pay unnecessary reduction overhead.

The switching policy

From Angle I, the crossover is simply whether FA2 can saturate SM occupancy. Combine with a minimum-N threshold to avoid unnecessary FlashDecode launch overhead at short context:

Per-step kernel selection policy if (B × H ≥ S) and (N_avg < N_threshold): use FA2 // fully occupied, low context → FA2 wins elif (B × H < S) or (N_avg ≥ N_threshold): P = floor(S / (B × H)) use FlashDecode(P) // under-occupied or long context → FD wins else: use FlashDecode(P=2) // borderline: small P, modest overhead // N_threshold ≈ 16K tokens (empirically safe crossover) // Re-evaluate every scheduling step (~every 1–5ms in real systems)

The scheduling-kernel feedback loop

This switching logic creates a second-order effect worth designing for: the scheduler can favour batching long-context requests together to keep FlashDecode's high-P configuration stable across multiple steps, rather than mixing long and short requests in a way that constantly crosses the threshold.

Runtime Kernel State Machine Observe Sample B, H, N_avg each step Decide B×H vs S, N vs N_thr B×H≥S, N<N_thr FlashAttention-2 full occupancy, low ctx B×H<S or N≥N_thr FlashDecode(P) P = floor(S / (B×H)) queue state
Fig IV: Per-step kernel selection policy. The runtime samples B, H, and N_avg each scheduling step (microseconds). Decision cost is negligible vs kernel latency. Switching eliminates the static-choice penalty for heterogeneous traffic.

Expected gains from dynamic switching

In a realistic serving trace with 70% short requests (<8K) and 30% long requests (>64K), static FlashDecode wastes ~12 µs per short-context step on reduction overhead that buys nothing (FA2 would be equally fast and cheaper to launch). Static FA2 leaves ~3× throughput on the table for long-context steps.

Estimated TPOP improvement — mixed traffic · 70% short / 30% long
Static FA2
baseline
Static FD
+1.35× avg
Dynamic
+2.6× avg · p99 -40%
Implementation note
The switching decision adds <1 µs of overhead per step (it's a few comparisons and a conditional kernel dispatch). The kernel itself already exists in vLLM and FlashInfer. This is a scheduler change, not a kernel change. It can be shipped without touching CUDA code.

05

What These Four Things Mean Together

The four analyses converge on a few actionable conclusions that differ from the current community consensus:

Common beliefWhat the math shows
"Crossover is ~16K context" Crossover is batch-dependent. For GQA models with H=8, FA2 needs batch ≥ 17 to recover occupancy on H100 — independent of context length.
"Higher P is always better" Max achievable P = floor(S/BH). There's no benefit to P beyond that. The reduction is not the constraint at hardware scale; you should always use max P.
"CXL extends context, same kernel works" Uniform partition sizes create 3–8× stragglers across HBM/CXL boundary. Bandwidth-proportional partitioning is required for CXL-extended KV serving.
"Pick a kernel at startup" Static choice leaves significant performance on the table for mixed traffic. Dynamic switching is a scheduler change that takes <1 µs and can deliver 2.6× average improvement over static FA2.
06

What to Build

In order of implementation difficulty and expected impact:

#WorkDifficultyExpected gainWhere
1Dynamic FA2/FD kernel switching in vLLM scheduler~1 week2–3× on mixed traffic p99vLLM scheduler
2Expose P as a runtime parameter (not compile-time)~2 daysAllows crossover-optimal P per stepFlashInfer / vLLM
3Tier-aware partition allocation for CXL KV~2–3 weeksEliminates straggler penalty at 10M ctxRuntime + kernel
4Validate P* formula with ncu profiling at N=1M~3 daysConfirms reduction cost modelBenchmarking
5GQA-aware B* computation in autotune~1 dayCorrect crossover for all model familiesTriton autotune
The kernel exists. The math is clear. What's missing is a runtime that uses it correctly — dynamically, per-step, with awareness of batch composition and memory tier topology. That's a scheduling problem, not a research problem.