Training Systems Gradient Checkpointing Memory Architecture

Gradient Checkpointing Is a Memory Problem, Not a Training Trick

Published · 12 min read

The inference memory essays in this series talk about KV cache, weight residency, and HBM pressure. Training has its own version of all three — and the dominant technique for managing training memory, gradient checkpointing, is widely taught as an "algorithm" when it is really a data movement policy. This essay derives it from first principles and shows what an optimal implementation looks like.

MANISH AI· April 2026· ~17 min read· Training · Memory · Backward Pass · Activations
3–5×Activation memory multiplier for a 70B dense model at seq=8192 vs. weights alone
33%Activation memory reduction from full checkpointing — at the cost of 33% more FLOPs
O(√L)Optimal checkpoint interval for L layers — reduces memory from O(L) to O(√L)
~80%Fraction of training memory that is activations, not weights, at large sequence length
01 — Why Training Memory Is Different

The Inference/Training Memory Gap Nobody Explains

This series has spent considerable space on inference memory: KV cache management, weight streaming, HBM residency, and the hierarchy from HBM to DRAM to CXL. All of that is about the memory required to run a model forward on new inputs.

Training has a different and in some ways harder memory problem. During inference, once a layer computes its output, the layer's input activations can be discarded. The model never looks backward. During training, the backward pass must compute gradients for every parameter in every layer — and computing those gradients requires knowing the activations that were present during the forward pass. The backward pass is, in a precise sense, a read-heavy workload on the intermediate results of the forward pass.

This creates the activation memory problem: to train a model with backpropagation, you must store all intermediate activations for the full forward pass until the corresponding backward pass completes. For large models at large sequence lengths, this activation footprint is larger — often much larger — than the weight footprint that training discussions usually focus on.

The Key Asymmetry

During inference: memory = weights + KV cache. Activations are transient and small. During training: memory = weights + gradients + optimizer states + all activations for the full forward pass. For a 70B model at sequence length 8192 with batch size 4, activation memory alone exceeds 200 GB — larger than the weights in many precision regimes.

02 — First Principles

Why the Backward Pass Needs the Forward Pass's Memory

To understand gradient checkpointing, you need to understand why the backward pass needs forward-pass activations. Let us be specific. For a single transformer layer computing y = F(x; W):

The forward pass produces y from input x and weights W. During backpropagation, the backward pass receives the upstream gradient ∂L/∂y and must compute:

  • ∂L/∂W (weight gradient): Used to update the weights. Requires x (the layer's input activation) and ∂L/∂y. Formula for a linear layer: ∂L/∂W = ∂L/∂y · xᵀ.
  • ∂L/∂x (input gradient): Passed to the previous layer for its backward pass. Requires W and ∂L/∂y. Formula: ∂L/∂x = Wᵀ · ∂L/∂y.

For the linear layer, the weight gradient only needs x, which the layer can retain cheaply. But for the attention mechanism, the situation is worse. Computing ∂L/∂Q, ∂L/∂K, ∂L/∂V requires the softmax output matrix P (shape [B, H, S, S] for standard attention), which is generated during the forward pass. That matrix occupies B × H × S² × dtype_bytes bytes of memory. At B=4, H=64, S=8192, FP16: 4 × 64 × 8192² × 2 = 34 GB. Per layer.

Across 80 layers: 2.7 TB of attention matrices that would need to be held in memory simultaneously. This is why naive full activation retention is impossible for large models at non-trivial sequence lengths.

Fig 1 — Memory During Forward and Backward Pass (No Checkpointing)
time start forward done backward done Forward Pass — Activations Accumulate L1 L2 L3 · · · L80 PEAK: all 80 layers in memory ~200 GB activations for 70B @ S=8K Backward Pass — Activations Free L80↑ · · · L1↑ Forward activation (held in memory) Backward pass (free as it goes)
Without checkpointing, all 80 layers' activations accumulate during the forward pass and are freed one-by-one during the backward pass. Peak memory occurs at the end of the forward pass, when all activations are in memory simultaneously. For a 70B model at S=8192, this peak can reach 200+ GB of activation memory alone.
03 — The Checkpointing Solution

What Gradient Checkpointing Actually Does

Gradient checkpointing (also called activation recomputation) solves the activation memory problem by trading compute for memory. The core idea: rather than storing all activations through the forward pass, discard most of them and recompute from a checkpoint when the backward pass needs them.

In its simplest form, only the layer inputs (checkpoints) are stored during the forward pass. The per-layer intermediate activations (attention matrices, hidden states at each sublayer) are discarded. When the backward pass reaches layer L, it re-runs the forward pass of layer L starting from the checkpoint to regenerate the discarded activations, then uses them for gradient computation, then discards them again.

The memory tradeoff is precise: full checkpointing (checkpoint every layer boundary) reduces peak activation memory from O(L × per_layer_act) to O(per_layer_act), because at any moment only one layer's intermediate activations exist. The cost is exactly one additional forward pass per layer during backpropagation — approximately 33% more FLOPs in total (one forward + two backward = 3 forward equivalents instead of 1+2=3... wait, baseline is 1 forward + 2 backward = 3 forward equivalents, checkpointing is 2 forward + 2 backward = 4 forward equivalents). Actually the overhead is exactly 33%.

The Optimal Checkpoint Interval

Full checkpointing (checkpoint every layer) minimizes memory but maximizes recomputation. There is a better answer: checkpoint every √L layers, which gives O(√L × layer_act) memory at O(√L) recomputation overhead per backward step — less memory than full checkpointing at lower recomputation cost than naive checkpointing.

# Optimal checkpoint strategy for L layers
# Memory: O(√L) checkpoints × O(1) per-layer activation = O(√L)
# Recompute: at most √L layers per backward step

def optimal_checkpoint_interval(L, activation_size_per_layer, hbm_budget):
    # k = checkpoint every k layers
    # Memory = L/k checkpoints × checkpoint_size + k × activation_size_per_layer
    # Minimize over k: d/dk (L/k × chkpt_size + k × act_size) = 0
    # → k_opt = √(L × chkpt_size / act_size)
    # For chkpt_size ≈ act_size (typical): k_opt = √L
    k_opt = (L * 1.0) ** 0.5
    memory_at_opt = (L / k_opt + k_opt) * activation_size_per_layer
    return k_opt, memory_at_opt

# Example: L=80, act_size=2.5 GB/layer (70B, S=8192)
k, mem = optimal_checkpoint_interval(80, 2.5)  # k≈9, mem≈45 GB
# vs no checkpointing: 200 GB
# vs full checkpointing: ~2.5 GB (but +33% FLOPs)
04 — Quantifying Activation Memory

How Big Is the Activation Footprint, Really?

Let us be concrete for a 70B-class model. Per transformer layer at sequence length S, batch size B:

Activation TensorShapeBF16 Size (B=4, S=8192)Checkpointable?
Layer inputB × S × d~430 MBCheckpoint (keep)
Attention scores Q·KᵀB × H × S × S~34 GB (S=8192!)Discard (recompute)
Attention weights (post-softmax)B × H × S × S~34 GBDiscard (recompute)
Attention outputB × S × d~430 MBDiscard (recompute)
FFN intermediate (SwiGLU gate)B × S × d_ff~1.5 GBDiscard (recompute)
FFN intermediate (SwiGLU up)B × S × d_ff~1.5 GBDiscard (recompute)
Total per layer (naïve)~72 GB
Checkpointed per layer~430 MBKeep only input

The attention score matrices dominate. At sequence length 8192, they occupy 34 GB per layer × 2 (scores + weights) = 68 GB — out of ~72 GB total activation memory per layer. This is the deep reason why FlashAttention's contribution to training is at least as large as its contribution to inference: FlashAttention recomputes the attention matrices during the backward pass from the smaller Q, K, V tensors, effectively applying "free" gradient checkpointing to the largest activation tensors in the model.

FlashAttention Is Gradient Checkpointing for Attention

FlashAttention does not store the S×S attention matrix during the forward pass. It computes attention in tiles, discards the per-tile matrices, and stores only the O, L (logsumexp), and m (running max) statistics needed for the backward pass recomputation. This is mathematically equivalent to checkpointing the attention sublayer, with zero recomputation cost because the backward-pass recomputation is fused into the kernel. For S=8192 at B=4, H=64: this saves ~68 GB per layer of activation memory, for free.

05 — The Memory/Compute Tradeoff Landscape

Choosing the Right Checkpointing Strategy

StrategyPeak Activation MemoryFLOPs OverheadPractical Use Case
No checkpointingO(L) — 200+ GB (70B, S=8K)0%Small models, short sequences
Selective (attention only)~O(L × non-attn)~10–15%FA2/FA3 + partial checkpointing
√L intervalO(√L) — ~45 GB (70B)~20–25%Memory-constrained dense training
Full (every layer)O(1) — ~430 MB/layer~33%Maximum memory compression needed
Full + FA3O(1) — ~430 MB/layer~15–18%Best practical choice for large S

The modern standard for training large models at long sequence lengths is full gradient checkpointing combined with FlashAttention. FlashAttention eliminates the dominant activation tensors (S×S attention matrices) with zero recomputation overhead. Full checkpointing on the remaining activations (FFN intermediates, layer inputs) costs 15–18% additional FLOPs rather than 33%, because the attention forward pass recomputation is free. The resulting peak activation memory is on the order of 2–3 GB per concurrent batch item — manageable within HBM budgets that must also accommodate weights, gradients, and optimizer states.

06 — Optimizer State: The Forgotten Memory Consumer

Training Memory Is Activations + Weights + Gradients + Optimizer States

The activation analysis above is only part of the training memory picture. A complete accounting for a 70B BF16 training run includes:

ComponentSize (70B params)Notes
Weights (BF16)140 GBModel parameters
Gradients (BF16)140 GBOne gradient per weight, same precision
Optimizer: momentum (FP32)280 GBAdam m₁, FP32 copy for numerical stability
Optimizer: variance (FP32)280 GBAdam m₂
Activations (full chkpt + FA3)~12 GBB=4, S=8192, per layer only input
Total~852 GBFor B=4, S=8192, single GPU (theoretical)

Activations, after checkpointing, are a modest fraction of total training memory. Optimizer states — particularly the FP32 Adam momentum and variance tensors — are the dominant consumer at 560 GB for a 70B model. This is why distributed training at 70B scale uses FSDP (Fully Sharded Data Parallelism) or ZeRO-3: these techniques shard the optimizer states across data-parallel replicas, reducing per-GPU optimizer state memory by the sharding degree.

The ZeRO Connection

ZeRO-3 with 64-way data parallelism shards weights + gradients + optimizer states across 64 GPUs, reducing per-GPU memory for all three from 840 GB to ~13 GB. The activations — which are not sharded by ZeRO (they are per-batch, not per-parameter) — become the dominant per-GPU memory consumer at high ZeRO sharding degrees. This is why very large-scale training runs combine ZeRO-3 with aggressive activation checkpointing and FlashAttention: each technique attacks a different part of the memory budget.

07 — Systems Implications

What This Means for Training Infrastructure Design

Gradient checkpointing is usually described as a model training technique. But its memory reduction behavior has direct implications for infrastructure design — specifically for how HBM capacity should be allocated and how DMA bandwidth should be budgeted during training.

The Recompute-vs-Offload Decision

For activation tensors that cannot be eliminated by FlashAttention (FFN intermediates, normalization outputs), there are two options: recompute from the checkpoint (costs FLOPs, no memory), or offload to DRAM between the forward and backward passes (costs DMA bandwidth, saves FLOPs). The right choice depends on the ratio of available compute to available DMA bandwidth:

  • If HBM→DRAM bandwidth is underutilized during the forward pass window: offload activations to DRAM, free them from HBM immediately, reload before backward. This is a bandwidth cost, not a compute cost.
  • If the training run is compute-bound (MFU near 100%): recomputation wastes compute cycles that could be used for useful work. Offload is preferable if bandwidth permits.
  • If the training run is already bandwidth-bound (large batch, long sequence): adding DMA offload contends for the same HBM bus. Recomputation is preferable.

Production training frameworks are beginning to implement this decision as a runtime policy rather than a static flag. The decision should be made per-layer, per-phase, based on real-time measurements of compute utilization and DMA bandwidth consumption — the same principles as the KV cache memory scheduler in inference systems.

The Connection to Inference Memory Systems

The gradient checkpointing problem and the KV cache eviction problem are structurally identical in one important way: both require deciding, in real time, which tensors to keep in fast memory and which to spill or recompute, under a memory budget that is tighter than the full working set. The solutions share the same toolkit: cost-benefit analysis (memory savings vs recomputation FLOPs), access pattern prediction (which activations will be needed soon), and tiered storage with DMA-based movement.

The key difference is predictability. In inference, KV access patterns are partially unpredictable (request length, tool calls). In training, the backward pass always accesses layer activations in reverse layer order — a perfectly predictable pattern. This means training checkpointing can use an optimal static policy rather than the adaptive policies needed for inference. It also means the DMA prefetch problem is simpler: always prefetch the next-to-be-needed checkpoint L layers before the backward pass arrives at it.

Training and inference are the same memory problem in different clothes. The forward pass stores activations; the backward pass consumes them. The question — what to keep, what to recompute, what to offload — has the same structure as every KV cache eviction decision in inference. The toolkit is identical. The access patterns are just more predictable.