1. Lecture 3
CS336
Tatsu H
EV ERYTHI N G YO U D I D N ’T WAN T TO K N OW AB O UT
LM ARC HI TECTUR E AN D TRA IN I N G
2. Logistics
❖ Join the slack!
❖ Check to make sure you have the latest version of the assignment!
3. Outline and goals
❖ Quick recap of the ‘standard’ transformer (what you implement)
❖ What do most of the large LMs have in common?
❖ What are common variations to the architecture / training process?
Today’s theme: the best way to learn is hands-on experience
the second best way is to try to learn from others’ experience
4. Starting point: the ‘original’ transformer
Review: choices in the standard transformer
Position embedding: sines and cosines
FFN: ReLU
Norm type: post-norm, LayerNorm
5. What you implemented – simple, modern variant
Differences:
• LayerNorm is in front of the block
• Rotary position embeddings (RoPE)
• FF layers use SwiGLU, not ReLU
• Linear layers (and layernorm) have no
bias (constant) terms
Why did we pick these?
What should you pick?
6. How should we think about architectures?
Lots of architecture. Just in the last year since last 336..
Over 19 new dense model releases, many of them with minor architecture tweaks..
7. Let’s look at the data (on dense architectures)
Learn from the many other models (and papers) out there
We will talk through many major
architecture and hyperparameter variants.
• What do all these models have in common?
• What parts vary?
• What can we learn from this?
8. What are we going to cover?
Common architecture variations
• Activations, FFN
• Attention variants
• Position embeddings
Hyperparameters that (do or don’t) matter
• What is ff_dim? Do multi_head dims always sum to model_dim?
• How many vocab elements?
Stability tricks
9. Architecture variations..
Let’s think about the core architecture piece
High level view:
• Low consensus
(except pre-norm)
• Trends toward ‘LLaMA-
like’ architectures
10. Pre-vs-post norm
The one thing everyone agrees on (in 2024)
Figure from Xiong 2020
Set up LayerNorm so that it doesn’t affect the
mainresidual signal path (onthe left)
Almost all modern LMs use pre-norm(but BERT was post-norm)
(One somewhat funny exception – OPT350M. I don’t know why this is post-norm)
12. Pre-vs-post norm, explanations?
Gradient attenuation [Xiong 2020] Gradient spikes [Salazar and Ngyuen]
Original stated advantage– removing warmup.
Today – stability and larger LRs for large networks
13. New things – ‘double’ norm.
If putting LayerNorms in residual streams is bad.. Why not post-norm outside the stream?
Recent models: Grok, Gemma 2. Olmo 2 only does non-residual post norm
14. LayerNorm vs RMSNorm
Original transformer: LayerNorm – normalizes
the mean and variance across 𝑑𝑚𝑜𝑑𝑒𝑙
Many modern LMs: RMSNorm – does not
subtract mean or add a bias term
𝑦 =
𝑥
𝑥 2
2
+ 𝜀
∗ 𝛾
Notable models:
GPT3/2/1, OPT, GPT-J, BLOOM
Notable models:
LLaMA-family, PaLM, Chinchilla, T5
15. Why RMSNorm?
Modern explanation – it’s faster (and just as good).
• Fewer operations (no mean calculation)
• Fewer parameters (no bias term to store)
Does this explanation make sense?
Matrix multiplies are the vast majority of FLOPs (and memory)
[Ivanov et al 2023]
16. Why RMSNorm (2)
Important lesson: FLOPS are not runtime! (we will discuss this in far more detail later)
[Ivanov et al 2023]
Left top (”43G”) is FLOPS
Right top (“153”) is the FLOP-to-memory ratio
RMSNorm can still matter due to
the importance of data movement
18. More generally: dropping bias terms
Most modern transformers don’t have bias terms.
Original Transformer:
Most implementations (if they’re not gated):
𝐹𝐹𝑁 𝑥 = 𝜎 𝑥𝑊1 𝑊2
Reasons: memory (similar to RMSnorm) and optimization stability
19. LayerNorm: recap
• Basically everyone does pre-norm.
• Intuition – keep the good parts of residual connections
• Observations – nicer gradient propagation, fewer spike
• Some people add a second norm outside the residual stream (NOT post-norm)
• Most people do RMSnorm
• In practice, works as well as LayerNorm
• But, has fewer parameters to move around, which saves on wallclock time
• People more generally drop bias terms since the compute/param tradeoffs are not
great.
20. Activations
A whole zoo of activations ..
ReLU, GeLU, Swish, ELU, GLU, GeGLU, ReGLU, SeLU, SwiGLU, LiGLU
What are these things? What do people use? Does it matter?
21. A few of the common activations
ReLU
𝐹𝐹 𝑥 = max 0, 𝑥𝑊1 𝑊2
GeLU
𝐹𝐹 𝑥 = GELU 𝑥𝑊1 𝑊2
𝐺𝐸𝐿𝑈 𝑥 ≔ 𝑥Φ(𝑥)
SwiGLU / GeGLU (next slide..)
Notable models:
Original transformer, T5,
Gopher, Chinchilla, OPT
Notable models:
GPT1/2/3, GPTJ, GPT-Neox,
BLOOM
Notable models:
Llama, PaLM,T5 v1.1, most
models post 2023
22. Gated activations (*GLU)
GLUs modify the ‘first part’ of a FF layer
𝐹𝐹 𝑥 = max 0, 𝑥𝑊1 𝑊2
Instead of a linear + ReLU, augment the above with an (entrywise) linear term
max 0, 𝑥𝑊1 → max 0, 𝑥𝑊1 ⊗ (𝑥𝑉)
This gives the gated variant (ReGLU) – note that we have an extra parameter (V)
FFReGLU 𝑥 = (max 0, 𝑥𝑊1 ⊗ 𝑥𝑉) 𝑊2
23. Gated variants of standard FF layers
GeGLU
SwiGLU (swish is 𝑥 ∗ sigmoid(𝑥))
Note: Gated models use smaller dimensions for the 𝑑𝑓𝑓 by 2/3
Notable models:
T5 v1.1, mT5, LaMDA, Phi3,
Gemma 2, Gemma 3
Notable models:
LLaMa 1/2/3, PaLM, Mistral,
OlMo, most models post 2023
24. Do gated linear units work?
Yes, fairly consistently so.
Shazeer 2020
25. Do gated linear units work (2)?
Yes, with other works corroborating Shazeer 2020
Narang et al 2020
26. Gating, activations
• Many variations (ReLU, GeLU, *GLU) across models.
• *GLU isn’t necessary for a good model (see GPT3), but it’s probably helpful
Otuer, recent outlier models..
Nemotron 340B (Squared ReLU), Falcon 2 11b (ReLU)
• But evidence points towards somewhat consistent gains from Swi/GeGLU
27. Serial vs Parallel layers
Normal transformer blocks are serial – they compute attention, then the MLP
Could we parallelize the transformer block?
28. Parallel layers
A few models (GPTJ, PaLM, GPT-NeoX) do parallel layers. Originally in GPT-J
If implemented right, LayerNorm can be shared, and matrix multiplies can be fused
Recent Models: Cohere Command A, Falcon 2 11B, Command R+
29. Summary: architectures
Pre-vs-post norm:
• Everyone does pre-norm (except
OPT350M), likely with good reason.
Layer vs RMSnorm:
• RMSnorm has clear compute wins,
sometimes even performance
Gating:
• GLUs seem generally better, though
differences are small
Serial vs parallel layers:
• No extremely serious ablations, but has a
compute win.
30. Many variations in position embeddings
Sine embeddings: add sines and cosines that enable localization
𝐸𝑚𝑏𝑒𝑑 𝑥, 𝑖 = 𝑣𝑥 + 𝑃𝐸𝑝𝑜𝑠
Absolute embeddings: add a position vector to the embedding
𝐸𝑚𝑏𝑒𝑑 𝑥, 𝑖 = 𝑣𝑥 + 𝑢𝑖
Relative embeddings: add a vector to the attention computation
Notable models:
Original transformer
Notable models:
GPT1/2/3, OPT
Notable models:
T5, Gopher, Chinchilla
Rope embeddings (next slides..)
Notable models:
GPTJ, PaLM, LLaMA
Most 2024+ models
31. RoPE: rotary position embeddings
High level thought process: a relative position embedding should be some 𝑓(𝑥, 𝑖) s.t.
𝑓 𝑥, 𝑖 , 𝑓 𝑦, 𝑗 = 𝑔(𝑥, 𝑦, 𝑖 − 𝑗)
That is, the attention function only gets to depend on the relative position (i-j). How do
existing embeddings not fulfill this goal?
• Sine: Has various cross-terms that are not relative
𝐸𝑚𝑏𝑒𝑑 𝑥, 𝑖 , 𝐸𝑚𝑏𝑒𝑑 𝑦, 𝑖 = 𝑣𝑥, 𝑣𝑦 + 𝑃𝐸𝑖, 𝑣𝑦 …
• Absolute: obviously not relative
• Relative embeddings: is not an inner product
32. RoPE: rotary position embeddings
How can we solve this problem?
• We want our embeddings to be invariant to absolute position
• We know that inner products are invariant to arbitrary rotation.
we
know
Position independent
embedding
we
know
Embedding
“of course we know”
Rotate we by ‘2 positions’
Rotate we by ‘0 positions’
we
know
Embedding
“we know that”
know by ‘1 positions’ Rotate know by ‘3 positions’
33. RoPE: rotary position embeddings
There are many rotations, which one do you pick?
Just pair up the coordinates and rotate them in 2d (motivation: complex numbers)
[Su et al 2021]
34. The actual RoPE math
Multiply with sines and cosines
Difference with sine embeddings – not additive, no cross terms
35. Implementation and code for RoPE
…
Same stuff as the usual multi-head self attention below
Get the RoPE
matrix cos/sin
Multiply
query/key inputs
Usual
attention stuff
Note: embedding at each attention operation to enforce position invariance
36. Hyperparameters
Transformer hyperparameter questions you might have had in 224n..
• How much bigger should the feedforward size be compared to hidden size?
• How many heads, and should num_heads always divide hidden size?
• What should my vocab size be?
And other model setting questions
• Do people even regularize these huge LMs?
• How do people scale these models - very deep or very wide?
37. Surprising (?) consensus hyperparameter 1
Feedforward – model dimension ratio.
There are two dimensions that are relevant – the feedforward dim (𝑑𝑓𝑓) and model dim
(𝑑𝑚𝑜𝑑𝑒𝑙). What should their relationship be?
𝒅𝒇𝒇 = 𝟒 𝒅𝒎𝒐𝒅𝒆𝒍
This is almost always true. There’s just a few exceptions.
38. Exception #1 – GLU variants
Remember that GLU variants scale down by 2/3rd. This means most GLU variants have
𝑑𝑓𝑓 =
8
3
𝑑𝑚𝑜𝑑𝑒𝑙. This is mostly what happens. Some notable such examples.
Model 𝒅𝒇𝒇/𝒅𝒎𝒐𝒅𝒆𝒍
PaLM 4
Mistral 7B 3.5
LLaMA-2 70B 3.5
LLaMA 70B 2.68
Qwen 14B 2.67
DeepSeek 67B 2.68
Yi 34B 2.85
T5 v1.1 2.5
Models are roughly in this range, though PaLM, LLaMA2 and Mistral are slightly larger
39. Exception #2 – T5
As we have (and will) see, most LMs are have boring, conservative hyperparameters.
One exception is T5 [Raffel et al 2020] which has some very bold settings.
In particular, for the 11B model, they set
𝑑𝑓𝑓 = 65,536
𝑑𝑚𝑜𝑑𝑒𝑙 = 1024
For an astounding 64-times multiplier.
Other, recent exceptions – Gemma 2 (8x), SmolLM/Gemma 3 (4x, GLU)
40. Why this range of multipliers?
Empirically, there’s a basin between 1-10 where this hyperparameter is near-optimal
[Kaplan+ 2020]
41. What can we learn from the model-dim hyperparam?
• The ‘default’ choices of 𝑑𝑓𝑓 = 4𝑑𝑚𝑜𝑑𝑒𝑙 and 𝑑𝑓𝑓 = 2.66𝑑𝑚𝑜𝑑𝑒𝑙 have worked well for nearly
all modern LLMs.
• But T5 does show that even radical choices of 𝑑𝑓𝑓 = 64𝑑𝑚𝑜𝑑𝑒𝑙 can work. This
hyperparameter choice isn’t written in stone.
• That said, T5 has a follow-up model (T5 v1.1) that is ‘improved’ and uses a much more
standard 2.5 multiplier on GeGLU, so the 64-times multiplier is likely suboptimal.
42. Surprising (?) consensus hyperparameter 2
Head-dim*num-heads to model-dim ratio. As a reminder, slide from 224n.
This doesn’t have to be true: we can have head-dimensions > model-dim / num-heads.
But most models do follow this guideline
43. How many heads, whats the model dim?
Some examples of this hyperparameter
Num heads Head dim Model dim Ratio
GPT3 96 128 12288 1
T5 128 128 1024 16
T5 v1.1 64 64 4096 1
LaMDA 128 128 8192 2
PaLM 48 258 18432 1.48
LLaMA2 64 128 8192 1
Most models have ratios around 1 – notable exceptions by some google models.
44. Evidence for 1-1 ratio?
There have been papers written against the 1-1 ratio [Bhojanapalli et al 2020]
But we don’t seem to be seeing significant ‘low rank bottlenecks’ in practice..
45. Aspect ratios
Should my model be deep or wide? How deep and how wide?
Most models are surprisingly consistent on this one too!
Model 𝒅𝒎𝒐𝒅𝒆𝒍/𝒏𝒍𝒂𝒚𝒆𝒓
BLOOM 205
T5 v1.1 171
PaLM (540B) 156
GPT3/OPT/Mistral/Qwen 128
LLaMA / LLaMA2 /
Chinchila
102
T5 (11B) 43
GPT2 33
Sweet spot?
46. Considerations about aspect ratio
Extremely deep models are harder to parallelize and have higher latency
[Tay et al 2021]
48. What are typical vocabulary sizes?
Monolingual models – 30-50k vocab
Model Token count
Original
transformer
37000
GPT 40257
GPT2/3 50257
T5/T5v1.1 32128
LLaMA 32000
Model Token count
mT5 250000
PaLM 256000
GPT4 100276
Command A 255000
DeepSeek 100000
Qwen 15B 152064
Yi 64000
Multilingual / production systems 100-250k
Monolingual vocabs don’t need to be huge, but multilingual ones do
49. Dropout and other regularization
Do we need regularization during pretraining?
Arguments against:
• There is a lot of data (trillions of tokens), more than parameters.
• SGD only does a single pass on a corpus (hard to memorize)
This is all quite reasonable.. but what do people do in practice?
50. Dropout and weight decay in practice
* Most of the times papers just don’t discuss dropout. On open models, this closely matches not doing dropout.
This may not be true of closed models.
Model Dropout* Weight decay
Original transformer 0.1 0
GPT2 0.1 0.1
T5 0.1 0
GPT3 0.1 0.1
T5 v1.1 0 0
PaLM 0 (variable)
OPT 0.1 0.1
LLaMA 0 0.1
Qwen 14B 0.1 0.1
Many older models used
dropout during pretraining
Newer models (except Qwen) rely
only on weight decay
51. Why weight decay LLMs?
[Andriushchenko et al 2023] has interesting observations about LLM weight decay
It’s not to control overfitting Weight decay interacts with learning rates (cosine schedule)
52. Summary: hyperparameters
Feedforward
• Factor-of-4 rule of thumb (8/3 for GLUs) is
standard (with some evidence)
Head dim
• Head dim*Num head = D model is standard
– but low to no validation
Aspect ratio
• Wide range of ‘good’ values (100-200).
Systems concerns dictate the value
Regularization
• You still ‘regularize’ LMs but its effects are
primarily on optimization dynamics
54. Where do the issues arise? Beware of softmaxes!
Softmaxes – can be ill-behaved due to exponentials / divison by zero
55. Output softmax stability – the ‘z-loss’
Recall the softmax calculation
[From Devlin 2014]
This is useful for stability! PaLM pioneered this ‘z loss’ trick.
Other examples: Baichuan 2 (2023), DCLM (2024), OLMo 2 (2025)
56. Attention softmax stability – the ‘QK norm’
The query and keys are Layer (RMS) normed before going into the softmax operation.
Other examples: DCLM, OLMo2, Gemma 2
Originally from vision and multimodal models [Dehgani 2023, Idefcs, Chameleon]
Norms
58. Attention heads
Most models don’t touch the attention heads much at all with a few minor exceptions..
GQA / MQA : Saving inference costs by reducing the number of heads
Sparse or sliding window attention (GPT4/Mistral): restricting the attention pattern
to reduce compute cost
Exotic SSM stuff (Jamba, Falcon 3, etc): not covered (sorry!)
59. GQA/MQA – Reducing attention head cost
Let’s think about the compute involved for attention
Total arithmetric operations (𝑏𝑛𝑑2), total memory accesses (𝑏𝑛𝑑 + 𝑏ℎ𝑛2 + 𝑑2)
Arithmetic intensity is high 𝑂
1
𝑘
+
1
𝑏𝑛
−1
- we can keep our GPUs running
X softmax projection
60. GQA/MQA – Reducing attention head cost
What about the incremental case when we generate text?
Key difference: can’t parallelize the generation process – needs to be step by step
In this case – we need to incrementaly re-compute/update attention via the ‘KV cache’
[Animation from https://medium.com/@joaolages/kv-caching-explained-276520203249]
61. GQA/MQA – Reducing attention head cost
What’s the incremental arithmetic intensity?
Total arithmetric operations (𝑏𝑛𝑑2
), total memory accesses (𝑏𝑛2
𝑑 + 𝑛𝑑2
)
Arithmetic intensity is not good 𝑂
𝑛
𝑑
+
1
𝑏
−1
- need large batches + short seq length
(n) or big model dimensions (d)
X projection
Is there some way around this? The n/d term is difficult to reduce.
62. MQA – just have fewer key dimensions.
Key idea – have multiple queries, but just one dimension for keys and values
We have much fewer items to move in and out of memory (KV Cache)
Total memory access (𝑏𝑛𝑑 + 𝑏𝑛2
𝑘 + 𝑛𝑑2
), Arithmetic intensity 𝑂
1
𝑑
+
𝑛
𝑑ℎ
+
1
𝑏
−1
[figure from https://blog.fireworks.ai/multi-query-attention-is-all-you-need-db072e758055]
63. Recent extension – GQA
Don’t go all the way to one dimension of KV – have fewer dims
Simple knob to control expressiveness (key-query ratio) and inference efficiency
64. Does MQA hurt? Sometimes..
Small PPL hit w/ MQA [Shazeer 2019] Low/no hit w/ GQA [Ainslie 2023]
65. Sparse / sliding window attention
Attending to the entire context can be expensive (quadratic).
Build sparse / structured attention that trades off expressiveness vs runtime (GPT3)
[Child et al 2019]
66. Sliding window attention
Another variation on this idea – sliding window attention
Just use the main part of the strided pattern – let depth extend effective context (Mistral)
67. Current standard trick – interleave ‘full’ and ‘LR’ attention
From Cohere Command A – Every 4th layer is a full attention
Long-range info via NoPE, short-range info via RoPE + SWA.
Other models – LLaMA 4, Gemma does SWA+Full RoPE.
68. Recap, conclusion, etc.
Many aspects (arch, hparams) of transformers are in common across the big LMs
Major differences? Position embeddings, activations, tokenization