Skip to content

Commit 33d159b

Browse files
authored
Merge pull request openai#315 from jfprincz/submission/11l-partialrope-lateqat-1.1248
Record: 11L Partial RoPE + LN Scale + EMA + XSA4 (val_bpb: 1.1248)
2 parents d947625 + 123120e commit 33d159b

7 files changed

Lines changed: 8642 additions & 0 deletions

File tree

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
## Record: 11L Partial RoPE + LN Scale + EMA + XSA4 (val_bpb: 1.1248)
2+
3+
**val_bpb = 1.1248** (sliding window, stride=64) | **15.6 MB** artifact | 8xH100 SXM, 600s
4+
5+
Previous: [PR #70](https://github.com/openai/parameter-golf/pull/70) (9L, 1.1659) → [PR #164](https://github.com/openai/parameter-golf/pull/164) (9L, 1.1524) → [PR #198](https://github.com/openai/parameter-golf/pull/198) (11L, 1.1318) → [PR #287](https://github.com/openai/parameter-golf/pull/287) (11L, 1.1271) → this
6+
7+
### Changes from PR #287
8+
9+
| | [PR #287](https://github.com/openai/parameter-golf/pull/287) | This |
10+
|---|---|---|
11+
| val_bpb (sliding s64) | 1.1271 | **1.1248** |
12+
| Partial RoPE | None (full 64d) | 16 of 64 dims |
13+
| LN Scale | None | 1/sqrt(layer_idx+1) |
14+
| Artifact | 15.5 MB | 15.6 MB |
15+
| Everything else | Same | Same |
16+
17+
### What's new
18+
19+
1. **Partial RoPE (16 of 64 dims)**. Rotary position embeddings applied to only the first 16 of 64 head dimensions (25%). The remaining 48 dims attend without positional bias, allowing the model to learn position-invariant patterns. Zero new parameters.
20+
21+
2. **LN Scale**. RMSNorm outputs are scaled by 1/sqrt(layer_idx+1), damping deeper layers' contributions. Stabilizes training and improves convergence in deep models. Zero new parameters.
22+
23+
### Carried from PR #287
24+
25+
- 11 transformer layers with U-Net skip connections
26+
- Exclusive Self Attention (XSA) on last 4 layers
27+
- EMA weight averaging (decay=0.997, every step)
28+
- Orthogonal + muP-scaled init on all large matrices
29+
- 3x MLP (hidden=1536), relu² activation
30+
- Int6 mixed quantization + zstd-22 (int6 on MLP+attention, int8 on embeddings)
31+
- Weight decay 0.04 (Muon + AdamW)
32+
- SmearGate (learned token blending gate, ~512 params)
33+
- Bigram Hash Embedding (2048-bucket, dim=128, projected to 512)
34+
- FlashAttention 3 (direct flash_attn_func calls)
35+
- Sequence length 2048 with NTK-aware RoPE
36+
- Muon optimizer, momentum 0.99 with warmup, warmdown 3000 iters, grad clip 0.3
37+
38+
### Configuration
39+
40+
```bash
41+
NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=2048 XSA_LAST_N=4 \
42+
EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=0 \
43+
ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 QAT_THRESHOLD=0.1 \
44+
MUON_WD=0.04 ADAM_WD=0.04 \
45+
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
46+
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \
47+
MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3000 \
48+
ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \
49+
torchrun --standalone --nproc_per_node=8 train_gpt.py
50+
```
51+
52+
### Key Metrics
53+
54+
- 7,051 steps in 600s (85ms/step)
55+
- ~5.5B train tokens (7,051 steps x 786,432 tokens/step)
56+
- Peak memory: ~20,600 MiB per GPU
57+
58+
| Metric | Value |
59+
|--------|-------|
60+
| Pre-quant val_bpb | 1.1418 |
61+
| Int6 roundtrip val_bpb | 1.1485 |
62+
| **Int6 sliding val_bpb (s64)** | **1.1248** |
63+
| Compressed artifact (int6+zstd) | 15,544,691 bytes |
64+
| Code size | 67,617 bytes |
65+
| **Total submission size** | **15,612,308 bytes** |
66+
67+
### Reproducibility
68+
69+
| Seed | Steps | Sliding s64 | Artifact |
70+
|------|-------|-------------|----------|
71+
| **2025** | **7,051** | **1.1248** | **15,612,308** |
72+
| 42 | 7,061 | 1.1250 | 15,528,666 |
73+
| 1337 | 7,063 | 1.1253 | 15,639,340 |
74+
75+
Mean val_bpb: **1.1250**. Submitted: seed 2025 (best). Inter-seed variance: 0.0005.
76+
77+
### Included files
78+
79+
- `train_gpt.py` — full training + quantization + evaluation script
80+
- `train.log` — training log from best seed (2025)
81+
- `train_seed2025.log`, `train_seed42.log`, `train_seed1337.log` — all seed logs
82+
- `submission.json` — leaderboard metadata
83+
84+
### Note on Late QAT
85+
86+
The submitted code includes a Late QAT flag (`LATE_QAT=1`) intended to enable STE int6 fake-quantization in the final 4% of training. Post-submission analysis (credit: @152334H) revealed that `torch.compile` constant-folds the `CastedLinear._qat_enabled` class attribute at first trace, so the STE branch is dead-code-eliminated and never activates during training. Late QAT had no effect on the results. The score is driven entirely by Partial RoPE and LN Scale.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"author": "Jack Princz",
3+
"github_id": "jfprincz",
4+
"name": "Record: 11L Partial RoPE + LN Scale + EMA + XSA4",
5+
"blurb": "11 layers with Partial RoPE (16 of 64 dims), LN Scale (1/sqrt(l+1)), EMA weight averaging (decay=0.997), Exclusive Self Attention (XSA) on last 4 layers. Int6 per-row on all MLP+attention weights, int8 tok_emb, zstd-22. Weight decay 0.04 (Muon+AdamW). OrthoInit + muP scaling. SmearGate + BigramHash(2048x128). FA3. Sliding window eval stride=64, seq=2048. Note: Late QAT flag is present in the code but inactive due to torch.compile constant-folding.",
6+
"date": "2026-03-21T06:00:00Z",
7+
"val_loss": 1.89924867,
8+
"val_bpb": 1.12484502,
9+
"pre_quant_val_loss": 1.9279,
10+
"pre_quant_val_bpb": 1.1418,
11+
"int6_zstd_val_loss": 1.93912126,
12+
"int6_zstd_val_bpb": 1.14845684,
13+
"bytes_total": 15612308,
14+
"bytes_model_int6_zstd": 15544691,
15+
"bytes_code": 67617
16+
}

0 commit comments

Comments
 (0)