Attention Mechanism
Video (best)
- 3Blue1Brown — “Attention in transformers, visually explained | Chapter 6, Deep Learning”
- Watch: YouTube
- Why: Exceptional visual intuition for how attention scores are computed, how queries/keys/values interact, and why the mechanism works. Grant Sanderson’s geometric framing makes abstract matrix operations concrete. Part of a coherent series so learners have scaffolding.
- Level: beginner/intermediate
Blog / Written explainer (best)
- Jay Alammar — “The Illustrated Transformer”
- Link: https://jalammar.github.io/illustrated-transformer/
- Why: The gold standard written explainer for attention. Step-by-step diagrams show exactly how Q, K, V matrices are formed and combined, multi-head attention is visualized clearly, and the encoder-decoder attention is distinguished from self-attention. Widely cited in courses precisely because it bridges intuition and math without losing either.
- Level: beginner/intermediate
Deep dive
- Lilian Weng — “Attention? Attention!”
- Link: https://lilianweng.github.io/posts/2018-06-24-attention/
- Why: Comprehensive taxonomy of attention variants (soft vs. hard, self-attention, global vs. local, additive vs. dot-product). Covers the historical progression from Bahdanau through Transformer attention with mathematical rigor. Excellent reference when learners need to understand why design choices were made, not just what they are.
- Level: intermediate/advanced
Original paper
- Vaswani et al., 2017 — “Attention Is All You Need”
- Link: https://arxiv.org/abs/1706.03762
- Why: The seminal paper that crystallized scaled dot-product attention and multi-head attention as the dominant paradigm. Unusually readable for a landmark paper — the architecture description is self-contained and the ablations are instructive. The clear notation has become the field’s standard vocabulary.
- Level: intermediate/advanced
Code walkthrough
- Andrej Karpathy — “Let’s build GPT: from scratch, in code, spelled out.”
- Watch: YouTube
- Why: Karpathy builds self-attention from a blank Python file, deriving each line from first principles (starting from the “mathematical trick” of masked self-attention). Learners see exactly how the Q/K/V projections, scaled dot-product, softmax, and multi-head assembly translate to ~50 lines of PyTorch. The incremental build-up makes debugging intuitions explicit.
- Level: intermediate/advanced
Coverage notes
- Strong: Visual/conceptual explanation (3B1B video + Jay Alammar blog form a near-perfect beginner ramp); mathematical formalism (Lilian Weng); hands-on implementation (Karpathy); seminal theory (Vaswani et al.)
- Weak: Attention variants beyond the Transformer (e.g., linear attention, sparse attention, cross-attention in diffusion models) are not well covered by any single beginner-friendly resource
- Gap: No excellent standalone resource specifically covers cross-attention (encoder-decoder attention) in isolation with worked code examples — most resources treat it as a footnote to self-attention. Learners building seq2seq systems may need to supplement with the original Bahdanau paper (arxiv.org/abs/1409.0473) [VERIFY current URL stability] and the older Jay Alammar post “Visualizing A Neural Machine Translation Model.”
Additional Resources for Tutor Depth
8 sources — papers, official docs, working code, benchmarks, and deep explainers that give the AI tutor precision on this topic.
📄 FlashAttention-2 (exact IO-aware attention via tiling + online softmax)
Paper · source
Block/tiling scheme for exact attention; online softmax with running ((m,\ell)) / logsumexp (L); recomputation to avoid materializing (N\times N) attention.
Key content
- Standard attention (Sec. 2.2): Given (Q,K,V\in\mathbb{R}^{N\times d}):
(S=QK^\top\in\mathbb{R}^{N\times N}), (P=\mathrm{softmax}(S)) (row-wise), (O=PV\in\mathbb{R}^{N\times d}).
Backward: (dV=P^\top dO); (dP=dOV^\top); (dS=\mathrm{dsoftmax}(dP)); (dQ=dSK); (dK=QdS^\top). - FlashAttention-2 forward algorithm (Alg. 1, Sec. 3.1.1): Tile rows/cols: (T_r=\lceil N/B_r\rceil), (T_c=\lceil N/B_c\rceil). For each row block (Q_i): init on-chip (O_i^{(0)}=0), (\ell_i^{(0)}=0), (m_i^{(0)}=-\infty). For each col block (K_j,V_j):
(S_{ij}=Q_iK_j^\top). Update running row-wise max and exp-sum:
(m_i^{(j)}=\max(m_i^{(j-1)}, \mathrm{rowmax}(S_{ij}))); (\tilde P_{ij}=\exp(S_{ij}-m_i^{(j)}));
(\ell_i^{(j)}=\exp(m_i^{(j-1)}-m_i^{(j)})\ell_i^{(j-1)}+\mathrm{rowsum}(\tilde P_{ij})).
Update unnormalized output: (O_i^{(j)}=\mathrm{diag}(\exp(m_i^{(j-1)}-m_i^{(j)}))O_i^{(j-1)}+\tilde P_{ij}V_j).
Final: (O_i=\mathrm{diag}((\ell_i^{(T_c)})^{-1})O_i^{(T_c)}); store (L_i=m_i^{(T_c)}+\log(\ell_i^{(T_c)})) (logsumexp). - Rationale (Sec. 3.1): Reduce expensive non-matmul FLOPs (A100: 312 TFLOPs/s FP16/BF16 matmul vs 19.5 TFLOPs/s FP32 non-matmul; ~16× gap). Keep (O) unscaled until end; store only (L) (not both (m,\ell)).
- Memory/compute: Exact output (no approximation); avoids materializing (S,P) in HBM; extra memory (O(N)) (store (L)); FLOPs (O(N^2 d)) (Sec. 3.1.1).
- Causal mask (Sec. 3.1.1): Skip blocks entirely above diagonal (~half blocks) → ~1.7–1.8× speedup vs non-causal; per row apply mask to only ~1 block (square blocks).
- Parallelism/work partitioning (Secs. 3.2–3.3): Forward parallelize over row blocks (sequence length) + batch + heads; backward parallelize over column blocks; atomic adds for (dQ). Avoid “split-K”: FlashAttn-2 splits Q across warps (K,V shared) to reduce shared-memory traffic.
- Empirical (A100 80GB, Sec. 4.1): FlashAttention-2 1.7–3.0× faster than FlashAttention; 3–10× faster than PyTorch attention; forward reaches up to 73% of theoretical peak; end-to-end training up to 225 TFLOPs/s per A100 (72% MFU) (Table 1).
📄 FlashAttention-3 (Hopper): asynchrony + FP8 attention
Paper · source
Consolidated H100 benchmarks + kernel design changes beyond FlashAttention-2 (asynchrony/pipelining, FP8)
Key content
- Attention formula (Sec. 2.1): For one head with sequence length (n), head dim (d):
[ O=\mathrm{softmax}(S)V,\quad S=\alpha QK^\top,\ \alpha=\frac{1}{\sqrt d} ] Softmax applied row-wise; subtract row max from (S) for numerical stability. - Why FA-3 (Intro): FlashAttention-2 achieves ~35% utilization on H100 vs 80–90% for optimized GEMM; FA-3 redesigns for Hopper asynchrony (Tensor Cores + TMA) and low precision (FP8).
- 3 main techniques (Intro, Sec. 3):
- Producer–consumer warp specialization: separate warps issue TMA loads vs WGMMA compute; use setmaxnreg to reallocate registers to compute warps.
- Overlap softmax under async GEMMs: 2-stage pipeline across iterations; “pingpong scheduling” uses barriers so one warpgroup does softmax while another runs GEMMs. Example gain: 570 → 620–640 TFLOPs/s (FP16 fwd, headdim 128, seqlen 8192).
- FP8 support: FP8 WGMMA requires k-major operands; FA-3 uses in-kernel transpose (LDSM/STSM) + register byte-permute to satisfy layout for back-to-back GEMMs.
- Empirical performance (Abstract/Sec. 4):
- FP16: 1.5–2.0× faster than FA-2 forward; up to 740 TFLOPs/s (~75% H100 peak). Backward: 1.5–1.75× faster.
- FP8: near 1.2 PFLOPs/s.
- Ablation (Table 2, non-causal FP16): FA-3 3.538 ms, 661 TFLOPs/s; removing pipelining 4.021 ms, 582; removing warp-specialization 4.105 ms, 570.
- Numerical error (Table 3): RMSE vs FP64 reference with outlier (Q) (0.1% entries add (\mathcal N(0,10)) to (\mathcal N(0,1))):
- Baseline FP16 3.2e-4; FA-2 FP16 1.9e-4; FA-3 FP16 1.9e-4 (softmax intermediates kept FP32).
- Baseline FP8 (per-tensor scaling) 2.4e-2; FA-3 FP8 9.1e-3 (~2.6× lower error). “No incoherent processing” returns to 2.4e-2.
- Benchmark defaults (Sec. 4.1/C.1): H100 80GB SXM5; CUDA 12.3, cuDNN 9.1.1.17, CUTLASS 3.5, FA2 2.5.8, Triton nightly 3.0, PyTorch 2.3; GPU clock fixed 1830 MHz; 100 repeats avg. Seqlen 512–16k; total tokens 16k; hidden dim 2048; headdim 64/128/256 (32/16/8 heads). Forward FLOPs: (4n^2d+2n^2); causal halves FLOPs; backward ≈ 2.5× forward.
📄 Multi-Query Attention (MQA) for Faster Autoregressive Decoding
Paper · source
MQA rationale + incremental decoding bottleneck: KV-cache memory-bandwidth; sharing K/V across heads shrinks cache/loads.
Key content
- Dot-Product Attention (Sec. 2.1):
Given query (q\in\mathbb{R}^{k}), keys (K\in\mathbb{R}^{m\times k}), values (V\in\mathbb{R}^{m\times v}):
(\text{logits}= qK^\top\in\mathbb{R}^{m}); (\alpha=\text{softmax}(\text{logits})); output (y=\alpha^\top V\in\mathbb{R}^{v}). - Multi-Head Attention (Sec. 2.2–2.3):
For (h) heads, projections (P_q\in\mathbb{R}^{h\times d\times k}), (P_k\in\mathbb{R}^{h\times d\times k}), (P_v\in\mathbb{R}^{h\times d\times v}), (P_o\in\mathbb{R}^{h\times d\times v}).
(Q=\text{einsum}(X,P_q)\in\mathbb{R}^{b\times h\times n\times k}); (K=\text{einsum}(M,P_k)\in\mathbb{R}^{b\times h\times m\times k}); (V=\text{einsum}(M,P_v)\in\mathbb{R}^{b\times h\times m\times v}).
logits (\in\mathbb{R}^{b\times h\times n\times m}); (O=\text{softmax}(\cdot),V); (Y=\text{einsum}(O,P_o)\in\mathbb{R}^{b\times n\times d}). - Incremental decoding bottleneck (Sec. 2.4.1): autoregressive inference can’t parallelize over time; repeatedly loading cached (K,V) dominates due to memory bandwidth. Cost term tied to reloading (K,V) of size (\approx bhn^2) (under simplifying assumptions (m=n, k=v=d/h)).
- MQA definition (Sec. 3): keep multi-head queries but share keys/values across heads (remove head dim from (K,V,P_k,P_v)):
(K\in\mathbb{R}^{b\times m\times k}), (V\in\mathbb{R}^{b\times m\times v}); logits (=\text{einsum}(Q,K)\in\mathbb{R}^{b\times h\times n\times m}).
KV-cache size/load reduced by factor (\approx h) vs multi-head. - Empirical speed (Sec. 4.3, TPUv2): incremental greedy inference, batch 1024, src=128, tgt=128.
Baseline decoder: 47 ms/step ⇒ 46 µs/token; MQA decoder: 3.9 ms/step ⇒ 3.8 µs/token (~12× faster). Encoder: 222 ms (1.7 µs/token) baseline vs 195 ms (1.5 µs/token) MQA. - Training setup (Sec. 4.1): WMT14 En-De, 6-layer encoder-decoder, (d_{model}=1024), (d_{ff}=4096) baseline, (h=8), (d_k=d_v=128), learned positional embeddings, embed/output weight sharing; 100k steps, batch 128 examples, each 256 src + 256 tgt tokens; TPUv3 32-core. MQA widens FFN to 5440 to match params (211M).
📄 PagedAttention & vLLM KV-cache paging for LLM serving
Paper · source
PagedAttention + vLLM system design: paged KV-cache layout, block allocation/eviction to avoid fragmentation, throughput/latency under dynamic batching
Key content
- Problem (Sec. 1, 3): KV cache dominates serving memory + is dynamic. Example memory split on A100-40GB with 13B model: ~65% weights, ~30% KV cache, small activations. KV cache grows/shrinks per request; inefficient management limits batch size → throughput.
- KV cache size formula (Sec. 3): For OPT-13B, per-token KV cache ≈ 800 KB, computed as
2 (K,V) × 5120 (hidden size) × 40 (layers) × 2 bytes (FP16). - Why existing systems waste memory (Sec. 3): contiguous pre-allocation to max length causes:
- Reserved slots for future tokens (held for entire request lifetime)
- Internal fragmentation (unknown output length)
- External fragmentation (allocator/buddy system). Profiling: only 20–40% of KV cache space utilized; “effective memory … as low as 20%”.
- PagedAttention algorithm (Sec. 4.1): partition each sequence’s KV cache into fixed-size KV blocks (pages). Logical blocks are contiguous; physical blocks can be non-contiguous. Attention kernel uses a block table (logical→physical) to fetch blocks during attention.
- Block-size tradeoff (Sec. 4.3, 7.2 mention): larger blocks improve parallelism/utilization (lower latency) but increase fragmentation; block size 16 “generally works well” (slides/transcript).
- Write path optimization (Sec. 4.3): fused reshape + block write per layer: split new KV into blocks, reshape to block-read-friendly layout, write to physical blocks via block table.
- Sharing (Sec. 4.4): block-level copy-on-write enables sharing prompt KV across parallel sampling/beam search; reported savings: ~30% (parallel sampling) and >60% (beam search); prompt itself noted as ~12% of total KV in one experiment.
- Preemption/recovery (talk): when out of blocks, either swap to CPU or recompute KV; recomputation can be faster for small blocks; vLLM uses recomputation “whenever possible”.
- Empirical result (Abstract): vLLM improves throughput 2–4× vs FasterTransformer/Orca at similar latency; memory utilization >96% average (talk) and 2.5–5× KV efficiency improvement.
📊 FlashAttention-2 benchmarks & efficiency claims
Benchmark · source
Benchmark comparisons (throughput/latency, causal vs non-causal, training-relevant fwd+bwd) and end-to-end GPT training TFLOPs/s.
Key content
- Attention equations (Section 2):
- Scores: (S = QK^\top \in \mathbb{R}^{N\times N})
- Probabilities: (P=\mathrm{softmax}(S)) (row-wise)
- Output: (O = PV \in \mathbb{R}^{N\times d})
- Variables: (N)=sequence length, (d)=head dimension; computed per head and batch in MHA.
- Core FlashAttention idea (Section 2.3): tile blocks of (Q,K,V) from HBM→SRAM, compute block attention with online softmax rescaling, and avoid materializing (S) and (P) in HBM → memory drops from (O(N^2)) to (O(N)) (stores row-wise logsumexp (L)).
- FlashAttention-2 design rationale (Sections 3.1–3.3):
- Reduce non-matmul FLOPs because A100 peak FP16/BF16 matmul is 312 TFLOPs/s vs 19.5 TFLOPs/s non-matmul FP32 (~16× gap).
- Increase occupancy by parallelizing over sequence length (not just batch×heads).
- Reduce shared-memory traffic by switching warp partitioning from split-K (FlashAttention) to split-Q (FlashAttention-2), avoiding inter-warp reductions in forward pass.
- Empirical speed/efficiency (benchmarks, Section 4 + abstract):
- FlashAttention-2: ~2× faster than FlashAttention; reaches 50–73% of theoretical max FLOPs/s on A100; up to 230 TFLOPs/s (A100, FP16/BF16).
- Compared to standard PyTorch attention: up to 9× faster (forward+backward benchmarks mentioned).
- End-to-end GPT training throughput (table in text):
- GPT3-1.3B, 2k: Baseline 142, FA 189, FA-2 196 TFLOPs/s
- GPT3-1.3B, 8k: Baseline 72, FA 170, FA-2 220 TFLOPs/s
- GPT3-2.7B, 2k: Baseline 149, FA 189, FA-2 205 TFLOPs/s
- GPT3-2.7B, 8k: Baseline 80, FA 175, FA-2 225 TFLOPs/s (72% model FLOPs utilization); ~1.3× end-to-end speedup over FlashAttention.
- Causal masking optimization (Section 3.1.1): skip blocks where column index > row index → ~1.7–1.8× speedup vs non-causal (since ~half entries computed).
- Supported configs/features: head dim up to 256; supports MQA/GQA (reduces KV-cache size, improves inference throughput).
📖 PyTorch scaled_dot_product_attention (SDPA) semantics
Reference Doc · source
Authoritative parameter semantics/defaults, tensor shapes, mask behavior, kernel selection (Flash/MemEff/Math), GQA constraints.
Key content
- API + defaults:
scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, enable_gqa=False) -> Tensor.scaleis keyword-only. - Eq. 1 (scores + scaling):
LetL = query.size(-2),S = key.size(-2),d = query.size(-1).
scale_factor = 1/sqrt(d)ifscale is Noneelsescale.
attn_weight = (query @ key.transpose(-2, -1)) * scale_factorwith shape(..., L, S). - Eq. 2 (mask/bias + softmax + dropout + output):
attn_biasinitialized zeros(L,S).
Ifis_causal=True: requiresattn_mask is None; apply lower-triangular allow-mask, fill others with-inf.
Ifattn_maskprovided (broadcastable to(..., L, S)):- bool mask:
True= participates/allowed;Falsefilled with-inf. - float mask (same dtype as q/k/v): added to scores (
attn_bias = attn_mask + attn_bias).
Then:softmax(dim=-1),dropout(attn_weight, dropout_p, train=True), outputattn_weight @ value(shape likequery).
- bool mask:
- Dropout behavior: always applied per
dropout_p; to disable in eval, pass0.0whennot self.training. - Mask semantics note: SDPA bool mask is inverse of
MultiheadAttention.key_padding_mask(MHA:True= masked out). Invert when migrating (~mask/logical_not()). - Backends: auto-select among FlashAttention-2, Memory-Efficient, and PyTorch C++ “math”; control via
torch.nn.attention.sdpa_kernel()(preferred) or global CUDA toggles. - GQA (
enable_gqa=True) constraints: works only for Flash + math on CUDA; no NestedTensor. Requiresnum_heads_q % num_heads_kv == 0andheads_key == heads_value; implemented via repeating K/V along head dim.
📋 # Source: https://aclanthology.org/2023.emnlp-main.298/
Source ·
🔍 PyTorch Scaled Dot Product Attention (SDPA) usage + masking
Explainer · source
Concrete invocation patterns for F.scaled_dot_product_attention, backend dispatch control, causal masking (is_causal vs bias tensors), and training vs inference knobs (dropout_p).
Key content
- SDPA call signature (usage pattern):
torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=..., is_causal=...)
Example (dense):F.scaled_dot_product_attention(q, k, v)whereq,k,vshaped like(B, ..., L, D); tutorial uses(2,3,8)and multihead(B, H, L, D). - Backend dispatch (CUDA): chooses among FlashAttention, Memory-Efficient Attention, or C++ math implementation.
Explicit control via context manager:
from torch.nn.attention import SDPBackend, sdpa_kernelthenwith sdpa_kernel(SDPBackend.FLASH_ATTENTION): ...(alsoMATH,EFFICIENT_ATTENTION). - Benchmark setup + empirical timings (example):
B=32, L=1024, H=32, D=32, dtype=float16.
Reported: default 2333.687 µs, math 87407.322 µs, flash 2316.913 µs, efficient 4577.936 µs. - Causal self-attention module procedure: project
xwithLinear(embed_dim, 3*embed_dim),chunk(3)intoq,k,v, reshape to(B,H,L,head_dim)via.view(...).transpose(1,2), then SDPA.
Training vs eval knobs: ifself.training:dropout_p=self.dropout,is_causal=self.is_causal; elsedropout_p=0.0,is_causal=False. - Causal bias tensors (PyTorch ≥2.3):
from torch.nn.attention.bias import causal_upper_left, causal_lower_right
is_causal=Trueis equivalent tocausal_upper_left(Lq, Lkv); differs fromcausal_lower_rightwhen attention score matrix is non-square (common in decoding). - NestedTensor note: SDPA supports NestedTensor + dense; fused implementations currently don’t support NestedTensor for training; example eval benchmark: Random NT 599.388 µs vs Random Dense 964.192 µs (flash backend).