Gradient Checkpointing Is a Memory Problem, Not a Training Trick
The inference memory essays in this series talk about KV cache, weight residency, and HBM pressure. Training has its own version of all three — and the dominant technique for managing training memory, gradient checkpointing, is widely taught as an "algorithm" when it is really a data movement policy. This essay derives it from first principles and shows what an optimal implementation looks like.
The Inference/Training Memory Gap Nobody Explains
This series has spent considerable space on inference memory: KV cache management, weight streaming, HBM residency, and the hierarchy from HBM to DRAM to CXL. All of that is about the memory required to run a model forward on new inputs.
Training has a different and in some ways harder memory problem. During inference, once a layer computes its output, the layer's input activations can be discarded. The model never looks backward. During training, the backward pass must compute gradients for every parameter in every layer — and computing those gradients requires knowing the activations that were present during the forward pass. The backward pass is, in a precise sense, a read-heavy workload on the intermediate results of the forward pass.
This creates the activation memory problem: to train a model with backpropagation, you must store all intermediate activations for the full forward pass until the corresponding backward pass completes. For large models at large sequence lengths, this activation footprint is larger — often much larger — than the weight footprint that training discussions usually focus on.
During inference: memory = weights + KV cache. Activations are transient and small. During training: memory = weights + gradients + optimizer states + all activations for the full forward pass. For a 70B model at sequence length 8192 with batch size 4, activation memory alone exceeds 200 GB — larger than the weights in many precision regimes.
Why the Backward Pass Needs the Forward Pass's Memory
To understand gradient checkpointing, you need to understand why the backward pass needs forward-pass activations. Let us be specific. For a single transformer layer computing y = F(x; W):
The forward pass produces y from input x and weights W. During backpropagation, the backward pass receives the upstream gradient ∂L/∂y and must compute:
- ∂L/∂W (weight gradient): Used to update the weights. Requires
x(the layer's input activation) and∂L/∂y. Formula for a linear layer:∂L/∂W = ∂L/∂y · xᵀ. - ∂L/∂x (input gradient): Passed to the previous layer for its backward pass. Requires
Wand∂L/∂y. Formula:∂L/∂x = Wᵀ · ∂L/∂y.
For the linear layer, the weight gradient only needs x, which the layer can retain cheaply. But for the attention mechanism, the situation is worse. Computing ∂L/∂Q, ∂L/∂K, ∂L/∂V requires the softmax output matrix P (shape [B, H, S, S] for standard attention), which is generated during the forward pass. That matrix occupies B × H × S² × dtype_bytes bytes of memory. At B=4, H=64, S=8192, FP16: 4 × 64 × 8192² × 2 = 34 GB. Per layer.
Across 80 layers: 2.7 TB of attention matrices that would need to be held in memory simultaneously. This is why naive full activation retention is impossible for large models at non-trivial sequence lengths.
What Gradient Checkpointing Actually Does
Gradient checkpointing (also called activation recomputation) solves the activation memory problem by trading compute for memory. The core idea: rather than storing all activations through the forward pass, discard most of them and recompute from a checkpoint when the backward pass needs them.
In its simplest form, only the layer inputs (checkpoints) are stored during the forward pass. The per-layer intermediate activations (attention matrices, hidden states at each sublayer) are discarded. When the backward pass reaches layer L, it re-runs the forward pass of layer L starting from the checkpoint to regenerate the discarded activations, then uses them for gradient computation, then discards them again.
The memory tradeoff is precise: full checkpointing (checkpoint every layer boundary) reduces peak activation memory from O(L × per_layer_act) to O(per_layer_act), because at any moment only one layer's intermediate activations exist. The cost is exactly one additional forward pass per layer during backpropagation — approximately 33% more FLOPs in total (one forward + two backward = 3 forward equivalents instead of 1+2=3... wait, baseline is 1 forward + 2 backward = 3 forward equivalents, checkpointing is 2 forward + 2 backward = 4 forward equivalents). Actually the overhead is exactly 33%.
The Optimal Checkpoint Interval
Full checkpointing (checkpoint every layer) minimizes memory but maximizes recomputation. There is a better answer: checkpoint every √L layers, which gives O(√L × layer_act) memory at O(√L) recomputation overhead per backward step — less memory than full checkpointing at lower recomputation cost than naive checkpointing.
# Optimal checkpoint strategy for L layers
# Memory: O(√L) checkpoints × O(1) per-layer activation = O(√L)
# Recompute: at most √L layers per backward step
def optimal_checkpoint_interval(L, activation_size_per_layer, hbm_budget):
# k = checkpoint every k layers
# Memory = L/k checkpoints × checkpoint_size + k × activation_size_per_layer
# Minimize over k: d/dk (L/k × chkpt_size + k × act_size) = 0
# → k_opt = √(L × chkpt_size / act_size)
# For chkpt_size ≈ act_size (typical): k_opt = √L
k_opt = (L * 1.0) ** 0.5
memory_at_opt = (L / k_opt + k_opt) * activation_size_per_layer
return k_opt, memory_at_opt
# Example: L=80, act_size=2.5 GB/layer (70B, S=8192)
k, mem = optimal_checkpoint_interval(80, 2.5) # k≈9, mem≈45 GB
# vs no checkpointing: 200 GB
# vs full checkpointing: ~2.5 GB (but +33% FLOPs)
How Big Is the Activation Footprint, Really?
Let us be concrete for a 70B-class model. Per transformer layer at sequence length S, batch size B:
| Activation Tensor | Shape | BF16 Size (B=4, S=8192) | Checkpointable? |
|---|---|---|---|
| Layer input | B × S × d | ~430 MB | Checkpoint (keep) |
| Attention scores Q·Kᵀ | B × H × S × S | ~34 GB (S=8192!) | Discard (recompute) |
| Attention weights (post-softmax) | B × H × S × S | ~34 GB | Discard (recompute) |
| Attention output | B × S × d | ~430 MB | Discard (recompute) |
| FFN intermediate (SwiGLU gate) | B × S × d_ff | ~1.5 GB | Discard (recompute) |
| FFN intermediate (SwiGLU up) | B × S × d_ff | ~1.5 GB | Discard (recompute) |
| Total per layer (naïve) | — | ~72 GB | — |
| Checkpointed per layer | — | ~430 MB | Keep only input |
The attention score matrices dominate. At sequence length 8192, they occupy 34 GB per layer × 2 (scores + weights) = 68 GB — out of ~72 GB total activation memory per layer. This is the deep reason why FlashAttention's contribution to training is at least as large as its contribution to inference: FlashAttention recomputes the attention matrices during the backward pass from the smaller Q, K, V tensors, effectively applying "free" gradient checkpointing to the largest activation tensors in the model.
FlashAttention does not store the S×S attention matrix during the forward pass. It computes attention in tiles, discards the per-tile matrices, and stores only the O, L (logsumexp), and m (running max) statistics needed for the backward pass recomputation. This is mathematically equivalent to checkpointing the attention sublayer, with zero recomputation cost because the backward-pass recomputation is fused into the kernel. For S=8192 at B=4, H=64: this saves ~68 GB per layer of activation memory, for free.
Choosing the Right Checkpointing Strategy
| Strategy | Peak Activation Memory | FLOPs Overhead | Practical Use Case |
|---|---|---|---|
| No checkpointing | O(L) — 200+ GB (70B, S=8K) | 0% | Small models, short sequences |
| Selective (attention only) | ~O(L × non-attn) | ~10–15% | FA2/FA3 + partial checkpointing |
| √L interval | O(√L) — ~45 GB (70B) | ~20–25% | Memory-constrained dense training |
| Full (every layer) | O(1) — ~430 MB/layer | ~33% | Maximum memory compression needed |
| Full + FA3 | O(1) — ~430 MB/layer | ~15–18% | Best practical choice for large S |
The modern standard for training large models at long sequence lengths is full gradient checkpointing combined with FlashAttention. FlashAttention eliminates the dominant activation tensors (S×S attention matrices) with zero recomputation overhead. Full checkpointing on the remaining activations (FFN intermediates, layer inputs) costs 15–18% additional FLOPs rather than 33%, because the attention forward pass recomputation is free. The resulting peak activation memory is on the order of 2–3 GB per concurrent batch item — manageable within HBM budgets that must also accommodate weights, gradients, and optimizer states.
Training Memory Is Activations + Weights + Gradients + Optimizer States
The activation analysis above is only part of the training memory picture. A complete accounting for a 70B BF16 training run includes:
| Component | Size (70B params) | Notes |
|---|---|---|
| Weights (BF16) | 140 GB | Model parameters |
| Gradients (BF16) | 140 GB | One gradient per weight, same precision |
| Optimizer: momentum (FP32) | 280 GB | Adam m₁, FP32 copy for numerical stability |
| Optimizer: variance (FP32) | 280 GB | Adam m₂ |
| Activations (full chkpt + FA3) | ~12 GB | B=4, S=8192, per layer only input |
| Total | ~852 GB | For B=4, S=8192, single GPU (theoretical) |
Activations, after checkpointing, are a modest fraction of total training memory. Optimizer states — particularly the FP32 Adam momentum and variance tensors — are the dominant consumer at 560 GB for a 70B model. This is why distributed training at 70B scale uses FSDP (Fully Sharded Data Parallelism) or ZeRO-3: these techniques shard the optimizer states across data-parallel replicas, reducing per-GPU optimizer state memory by the sharding degree.
ZeRO-3 with 64-way data parallelism shards weights + gradients + optimizer states across 64 GPUs, reducing per-GPU memory for all three from 840 GB to ~13 GB. The activations — which are not sharded by ZeRO (they are per-batch, not per-parameter) — become the dominant per-GPU memory consumer at high ZeRO sharding degrees. This is why very large-scale training runs combine ZeRO-3 with aggressive activation checkpointing and FlashAttention: each technique attacks a different part of the memory budget.
What This Means for Training Infrastructure Design
Gradient checkpointing is usually described as a model training technique. But its memory reduction behavior has direct implications for infrastructure design — specifically for how HBM capacity should be allocated and how DMA bandwidth should be budgeted during training.
The Recompute-vs-Offload Decision
For activation tensors that cannot be eliminated by FlashAttention (FFN intermediates, normalization outputs), there are two options: recompute from the checkpoint (costs FLOPs, no memory), or offload to DRAM between the forward and backward passes (costs DMA bandwidth, saves FLOPs). The right choice depends on the ratio of available compute to available DMA bandwidth:
- If HBM→DRAM bandwidth is underutilized during the forward pass window: offload activations to DRAM, free them from HBM immediately, reload before backward. This is a bandwidth cost, not a compute cost.
- If the training run is compute-bound (MFU near 100%): recomputation wastes compute cycles that could be used for useful work. Offload is preferable if bandwidth permits.
- If the training run is already bandwidth-bound (large batch, long sequence): adding DMA offload contends for the same HBM bus. Recomputation is preferable.
Production training frameworks are beginning to implement this decision as a runtime policy rather than a static flag. The decision should be made per-layer, per-phase, based on real-time measurements of compute utilization and DMA bandwidth consumption — the same principles as the KV cache memory scheduler in inference systems.
The Connection to Inference Memory Systems
The gradient checkpointing problem and the KV cache eviction problem are structurally identical in one important way: both require deciding, in real time, which tensors to keep in fast memory and which to spill or recompute, under a memory budget that is tighter than the full working set. The solutions share the same toolkit: cost-benefit analysis (memory savings vs recomputation FLOPs), access pattern prediction (which activations will be needed soon), and tiered storage with DMA-based movement.
The key difference is predictability. In inference, KV access patterns are partially unpredictable (request length, tool calls). In training, the backward pass always accesses layer activations in reverse layer order — a perfectly predictable pattern. This means training checkpointing can use an optimal static policy rather than the adaptive policies needed for inference. It also means the DMA prefetch problem is simpler: always prefetch the next-to-be-needed checkpoint L layers before the backward pass arrives at it.
Training and inference are the same memory problem in different clothes. The forward pass stores activations; the backward pass consumes them. The question — what to keep, what to recompute, what to offload — has the same structure as every KV cache eviction decision in inference. The toolkit is identical. The access patterns are just more predictable.