LLM Architecture Survey: From Transformers to Modern Variants
A technical survey of large language model architectures covering backbone design, attention mechanisms, sparsity patterns, mixture-of-experts, normalization strategies, loss modifications, position encodings, and systems-level considerations for training and inference at scale.
1. Introduction
The transformer architecture introduced by Vaswani et al. [1] remains the dominant paradigm for large language models. However, nearly every component has been re-examined and modified to improve training stability, inference efficiency, and scaling behavior. Modern LLMs deviate significantly from the original encoder-decoder design.
This survey tracks the evolution from the vanilla transformer to current decoder-only dense and mixture-of-experts models. We focus on architectural choices that meaningfully impact model quality, throughput, and memory at the 7B to 500B+ parameter scale. The goal is to systematize the design space and connect implementation details to first-principles trade-offs.
2. Transformer Backbone
All modern LLMs use a decoder-only architecture with causal masking. The encoder-decoder structure was dropped early due to simplicity and the observation that decoder-only models trained with next-token prediction scale smoothly to zero-shot and few-shot tasks.
2.1 Standard Decoder Block
A single transformer block composes attention and feed-forward networks around residual connections. The exact ordering has major implications for gradient flow and training stability.
x flows through the block with only linear additions, preserving gradient magnitude across depth. Post-norm instead applies normalization after the sub-layer, which can attenuate gradients.
2.1 Pre-norm vs Post-norm Gradient Flow
The original transformer used post-norm, but this becomes unstable beyond ~12 layers without careful learning rate warmup. Pre-norm was adopted by GPT-2 and is now universal. The difference is critical for scaling to 100+ layers.
∂L/∂x_l ≈ 1 for all layers l. Post-norm applies LayerNorm to the sum x + F(x), injecting a scaling factor γ/σ per block. At depth L, gradients scale as (γ/σ)^L, causing vanishing gradients unless compensated by large warmup or layer-wise learning rates.
3. Attention Revisited
Self-attention remains O(N²D) in time and memory, where N is sequence length and D is model dimension. For N=32k, attention dominates compute and memory. Optimizations target both the activation footprint and the arithmetic intensity.
3.1 MHA, MQA, and GQA
Multi-Head Attention (MHA) uses H independent queries, keys, and values. During autoregressive decoding, the KV-cache stores 2NHD_k elements per layer, where D_k = D/H. For LLaMA-65B with H=64, N=2048, layers=80, this is ~130GB in FP16.
Multi-Query Attention (MQA) [4] shares a single K and V across all heads, reducing KV-cache to 2ND_k. This cuts memory Hx but degrades quality. Grouped-Query Attention (GQA) [5] interpolates by using G key-value heads for H query heads, with G < H. LLaMA 2 70B uses G=8, H=64.
| Method | Q Heads | KV Heads | KV-Cache Size | Quality |
|---|---|---|---|---|
| MHA | H | H | 2NHDk | Baseline |
| GQA | H | G | 2NGDk | ≈ MHA |
| MQA | H | 1 | 2NDk | -0.5 to -1.5% |
3.2 KV-Cache Arithmetic Intensity
Inference has two phases: prefill processes the prompt, and decode generates tokens one-by-one. Their computational profiles differ drastically. Prefill is compute-bound due to large batch matrix multiplies. Decode is memory-bound because it loads the entire KV-cache to compute a single vector.
B and sequence N has intensity ~B·D, saturating GPU compute. Decode with B=1 loads 2ND bytes from HBM for only 2ND FLOPs, giving intensity ~1. Since modern GPUs have O(1000) FLOPs/byte, decode is memory-bound. GQA/MQA reduces the N·D term by H/G, directly increasing throughput. FlashAttention [11] addresses this via tiling and recomputation.
4. Sparsity & Long Context
Full N² attention becomes prohibitive beyond 8k tokens. Sparse patterns reduce complexity while preserving long-range dependencies. The key insight is that most attention weights are near-zero; we only need to compute entries likely to be large.
4.1 Sliding Window + Full Attention
Mistral 7B [12] and Command A [13] use alternating sparsity: most layers use sliding window attention with size w=1024 to 4096, while every k-th layer uses full attention. This gives O(Nw) complexity with O(N) receptive field growth across layers.
w=1024. Layers 0, 4, 8 use full O(N²) attention, acting as global mixers. Intermediate layers use sliding window O(Nw) for local context. After k=4 layers, the receptive field covers 4w = 4096 tokens. Full attention every k layers ensures information propagates across the entire sequence in O(N/k) layers. This pattern enables 128k context at ~4x compute of full attention.
5. Mixture of Experts
MoE replaces the dense FFN with E expert networks, routing each token to K experts. This increases parameter count without increasing FLOPs: a 1.6T parameter MoE with K=2 active experts costs similar FLOPs to a 100B dense model.
The routing function is typically top_k(G·x) where G is a learned gating matrix. Load balancing loss prevents expert collapse. Recent models like Mixtral 8x7B use E=8, K=2.
6. Normalization & Residuals
RMSNorm [6] replaces LayerNorm by removing mean-centering: RMSNorm(x) = x / RMS(x) · γ. It saves ~7% compute and is now standard in LLaMA, PaLM, and others.
SwiGLU replaces the ReLU FFN. Instead of max(0, xW₁)W₂, it computes (Swish(xW₁) ⊙ xW₂)W₃. This increases parameters 1.5x for the same hidden size but improves quality per parameter [7].
7. Losses & Stability
Large logits in the output softmax cause numerical overflow and gradient spikes. Two techniques stabilize training: z-loss and query-key normalization.
7.1 Z-Loss Intuition
Z-loss adds a penalty on the log-partition function: L_z = α · (log Z)² where Z = Σ exp(logits_i). This encourages logits to stay small without changing the softmax output. PaLM uses α=10⁻⁴ [2, 15].
(log Σ exp(z_i))². Without it, logits grow uncontrollably to minimize cross-entropy, causing FP16 overflow when log Z > 89. Z-loss keeps log Z ≈ 1, ensuring softmax(z) is numerically stable. The gradient ∂L_z/∂z_i = 2α · log Z · p_i pulls all logits toward zero while preserving their relative differences. See [15] for derivation.
QK-Normalization applies LayerNorm to queries and keys before the dot product: Attn = softmax(LN(Q)LN(K)ᵀ/√d). This bounds attention scores and prevents entropy collapse at scale.
8. Positional Encodings
RoPE (Rotary Position Embedding) [8] encodes position by rotating query and key vectors by an angle proportional to their position. For 2D subspace [x_i, x_{i+1}], rotation by mθ_i where θ_i = base^{-2i/d}:
RoPE enables length extrapolation: a model trained on 2k tokens can run on 8k without retraining by scaling base or using NTK-aware interpolation [14]. ALiBi uses additive bias instead [9].
9. Systems View: Training and Inference
Architecture and systems are co-designed. FlashAttention [11] recomputes attention online in SRAM tiles, reducing HBM reads from O(N²) to O(N²/HBM_BW). This makes attention compute-bound instead of memory-bound.
Activation checkpointing saves memory by recomputing activations during backward. Combined with FSDP or 3D parallelism, 70B models train on 2k GPUs with high MFU.
For inference, paged attention manages KV-cache in non-contiguous blocks, enabling larger batch sizes and longer contexts without fragmentation. GQA reduces this cache by H/G, as shown in Figure 3.
10. Modern Recipes: Putting It Together
A representative 70B model combines these innovations:
| Component | Choice | Rationale |
|---|---|---|
| Backbone | Decoder-only, Pre-norm | Stability at 80 layers |
| Attention | GQA, H=64, G=8 | 8x KV-cache reduction |
| FFN | SwiGLU, 1.5x expansion | Quality per FLOP |
| Norm | RMSNorm | 7% faster than LayerNorm |
| Position | RoPE, base=10,000 | Length extrapolation |
| Sparsity | Alternating SW=4k + Full | 128k context at 4x cost |
| Stability | Z-loss, QK-norm | Prevent logit explosion |
This configuration trains stably at 8192 token sequences, runs inference at 32k with sliding window, and achieves >50% MFU on H100s with FlashAttention-2 and FSDP.
11. References
- Vaswani, A., et al. "Attention is all you need." NeurIPS 2017.
- Chowdhery, A., et al. "PaLM: Scaling language modeling with pathways." arXiv:2204.02311, 2022.
- Touvron, H., et al. "Llama 2: Open foundation and fine-tuned chat models." arXiv:2307.09288, 2023.
- Shazeer, N. "Fast transformer decoding: One write-head is all you need." arXiv:1911.02150, 2019.
- Ainslie, J., et al. "GQA: Training generalized multi-query transformer models from multi-head checkpoints." arXiv:2305.13245, 2023.
- Zhang, B. & Sennrich, R. "Root mean square layer normalization." NeurIPS 2019.
- Shazeer, N. "GLU variants improve transformer." arXiv:2002.05202, 2020.
- Su, J., et al. "RoFormer: Enhanced transformer with rotary position embedding." arXiv:2104.09864, 2021.
- Press, O., et al. "Train short, test long: Attention with linear biases enables input length extrapolation." ICLR 2022.
- Kaplan, J., et al. "Scaling laws for neural language models." arXiv:2001.08361, 2020.
- Dao, T., et al. "FlashAttention: Fast and memory-efficient exact attention with IO-awareness." NeurIPS 2022.
- Jiang, A. Q., et al. "Mistral 7B." arXiv:2310.06825, 2023.
- Cohere. "Command A technical report." 2024.
- Sun, Y., et al. "A length-extrapolatable transformer." ACL 2023.
- Develin, M. "Softmax temperature and z-loss." 2014. [Technique used in PaLM]