diff --git a/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/README.md b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/README.md new file mode 100644 index 0000000000..513c5c31ca --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/README.md @@ -0,0 +1,133 @@ +# Aggressive SGD TTT (val_bpb: 1.1124) + +**3-seed mean val_bpb: 1.1124** (std=0.0008) | **15.4 MB artifact** | 8xH100 SXM, 600s training + 591s eval + +## Results + +| Seed | val_bpb (sliding, s64) | Artifact | +|------|------------------------|----------| +| 1337 | 1.1129 | 15,405,733 | +| 42 | 1.1128 | ~15.4M | +| 2024 | 1.1114 | ~15.4M | +| **Mean ± Std** | **1.1124 ± 0.0008** | | + +## Approach + +Standard 11L architecture, nothing exotic on the model side. The interesting part is the TTT. The base model trains for 600s, then TTT adapts all weights via SGD for 30 epochs on the validation data (score-first protocol). + +The conventional wisdom is TTT at LR=0.002 for 3 epochs. We ran 20+ configurations on 4xH200 and found that cranking the LR to 1.0 and unfreezing every block turns a -0.0025 BPB technique into a -0.041 BPB technique. That's a 16x improvement from the same underlying method. It's like finding out your car has a sport mode you never tried. + +## TTT Configuration + +I swept this on 4xH200 before validating on 8xH100. The sweep told the whole story. + +| Parameter | Our Value | PR #549 (merged SOTA) | +|-----------|-----------|----------------------| +| LR | 1.0 | 0.002 | +| Epochs | 30 | 3 | +| Freeze blocks | 0 (all unfrozen) | 0 | +| Momentum | 0.9 | 0.9 | +| TTT gain | -0.041 BPB | -0.0025 BPB | + +### TTT LR Sweep (4xH200, 20 epochs, freeze=2) +| LR | Sliding BPB | +|----|------------| +| 0.01 | 1.1489 | +| 0.02 | 1.1471 | +| 0.05 | 1.1444 | +| 0.1 | 1.1422 | +| 0.2 | 1.1400 | +| 0.5 | 1.1351 | +| **0.7** | **1.1327** | +| 0.8 | 1.1355 | +| 1.0 | 1.1585 (diverged) | + +BPB just keeps getting better as LR goes up... until it doesn't. Peak at 0.7 with 2 frozen blocks. + +### Unfreezing all blocks (4xH200, 20 epochs) +| LR | freeze=2 | freeze=0 | Delta | +|----|----------|----------|-------| +| 0.7 | 1.1327 | 1.1255 | -0.007 | +| 1.0 | diverged | 1.1183 | — | +| **1.5** | **diverged** | **1.1110** | — | + +This was the breakthrough. With 2 frozen blocks, LR=1.0 diverges. Unfreeze everything and it converges fine. The extra capacity from unfreezing absorbs the aggressive learning rate. It also shifts the optimal LR from 0.7 all the way up to 1.5. + +### Epoch scaling (4xH200, LR=1.0, freeze=0) +| Epochs | Sliding BPB | TTT time | +|--------|------------|----------| +| 20 | 1.1183 | 569s | +| **30** | **1.1076** | **854s** | + +On 8xH100, each TTT epoch runs in ~16.6s (vs 28.5s on 4xH200), so 30 epochs fits within the 10-minute eval budget. + +## Architecture + +| Component | Detail | +|-----------|--------| +| Layers | 11 | +| Dim | 512 | +| Heads | 8 (4 KV, GQA) | +| MLP | 3x, relu-squared | +| XSA | Last 4 layers | +| EMA | 0.997 | +| Late QAT | Int6 STE when lr_scale < 0.1 | +| Value Embeddings | 128-dim, 5 sets | +| BigramHash | 6144 buckets | +| SmearGate | Learned token blending | +| Warmdown | 1600 iterations | +| Seq length | 2048 (train), 1024 (eval) | +| Sliding window | stride=64 | +| Quantization | Int6 per-row + zstd-22 | + +## Training + +- Muon optimizer (matrix_lr=0.025, momentum=0.99 with warmup from 0.85) +- AdamW for embeddings/scalars (WD=0.04) +- Flash Attention v3 (Hopper) where available, SDPA fallback +- 6039 steps in 600s on 8xH100 (~99ms/step) + +## Evaluation + +Three phases, all within the 10-minute eval budget: + +1. Int6+zstd quantization roundtrip +2. TTT: SGD(lr=1.0, momentum=0.9), 30 epochs, all blocks unfrozen, score-first +3. Sliding window eval (stride=64, seq_len=1024) + +Total eval time: ~591s (TTT 497s + sliding window 92s + roundtrip 2s) + +## Run Command + +```bash +TTT_ENABLED=1 TTT_LR=1.0 TTT_EPOCHS=30 TTT_FREEZE_BLOCKS=0 TTT_MOMENTUM=0.9 \ +VE_ENABLED=1 WARMDOWN_ITERS=1600 NUM_LAYERS=11 XSA_LAST_N=4 \ +EMA_ENABLED=1 LATE_QAT=1 BIGRAM_VOCAB_SIZE=6144 \ +SEED=1337 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## How I Got Here + +~20 hours on 4xH200, 54 experiments. Started from the 9L baseline and worked forward: + +1. Baseline (9L, no extras): 1.1808 +2. +11L, XSA, EMA, QAT: 1.1619 +3. +Flash Attention v3: 1.1527 +4. +Value Embeddings, warmdown tuning: 1.1521 +5. +TTT (LR=0.01, 10ep, freeze=2): 1.1489 +6. TTT LR sweep to 0.7: 1.1327 +7. Unfreeze all blocks: 1.1255 +8. LR=1.5, 20ep: 1.1110 +9. 30ep, LR=1.0: 1.1076 +10. 8xH100 (more training steps): **1.1124** + +Step 7 was where it got fun. Everything before that was incremental hill climbing. Unfreezing all blocks during TTT changed the optimization landscape enough that learning rates that previously diverged started converging, and the whole curve shifted. + +## Schrödinger's SOTA + +This beats the merged leaderboard (1.1194) by 0.007 BPB. I haven't checked the pending PRs. Until they're merged, this is simultaneously a record and not a record, and I'm choosing to live in that superposition for a bit. + +## Credits + +Built on the community's collective work, especially PR #414 (signalrush), PR #461 (Christopher-Lee-McClendon), and PR #549 (abaybektursun). diff --git a/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/seed1337.log b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/seed1337.log new file mode 100644 index 0000000000..4bbc32ceed --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/seed1337.log @@ -0,0 +1,130 @@ +W0325 17:22:48.042000 3334 torch/distributed/run.py:803] +W0325 17:22:48.042000 3334 torch/distributed/run.py:803] ***************************************** +W0325 17:22:48.042000 3334 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 17:22:48.042000 3334 torch/distributed/run.py:803] ***************************************** +logs/93d96033-aba6-425b-8a12-b5797527a51c.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27354201 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9322 val_bpb:4.1057 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9333 train_time:135ms step_avg:135.40ms +step:2/20000 train_loss:10.1819 train_time:208ms step_avg:103.88ms +step:3/20000 train_loss:8.5707 train_time:311ms step_avg:103.51ms +step:4/20000 train_loss:7.8164 train_time:405ms step_avg:101.17ms +step:5/20000 train_loss:7.1907 train_time:504ms step_avg:100.74ms +step:6/20000 train_loss:7.0114 train_time:603ms step_avg:100.56ms +step:7/20000 train_loss:6.9916 train_time:706ms step_avg:100.91ms +step:8/20000 train_loss:6.9090 train_time:812ms step_avg:101.52ms +step:9/20000 train_loss:6.4970 train_time:905ms step_avg:100.61ms +step:10/20000 train_loss:6.2038 train_time:1004ms step_avg:100.36ms +step:200/20000 train_loss:2.3810 train_time:20198ms step_avg:100.99ms +step:400/20000 train_loss:2.4305 train_time:42406ms step_avg:106.01ms +step:600/20000 train_loss:2.3302 train_time:60150ms step_avg:100.25ms +step:800/20000 train_loss:2.2242 train_time:79419ms step_avg:99.27ms +step:1000/20000 train_loss:2.2554 train_time:98822ms step_avg:98.82ms +step:1000/20000 val_loss:2.2063 val_bpb:1.3067 train_time:98845ms step_avg:98.85ms +step:1200/20000 train_loss:2.3216 train_time:119556ms step_avg:99.63ms +step:1400/20000 train_loss:2.1486 train_time:139903ms step_avg:99.93ms +step:1600/20000 train_loss:2.0417 train_time:159203ms step_avg:99.50ms +step:1800/20000 train_loss:2.1425 train_time:179656ms step_avg:99.81ms +step:2000/20000 train_loss:2.0560 train_time:198038ms step_avg:99.02ms +step:2000/20000 val_loss:2.1191 val_bpb:1.2551 train_time:198062ms step_avg:99.03ms +step:2200/20000 train_loss:2.1250 train_time:218099ms step_avg:99.14ms +step:2400/20000 train_loss:2.0585 train_time:235653ms step_avg:98.19ms +step:2600/20000 train_loss:2.1026 train_time:256551ms step_avg:98.67ms +step:2800/20000 train_loss:2.1529 train_time:276768ms step_avg:98.85ms +step:3000/20000 train_loss:2.1536 train_time:297220ms step_avg:99.07ms +step:3000/20000 val_loss:2.0918 val_bpb:1.2389 train_time:297258ms step_avg:99.09ms +step:3200/20000 train_loss:2.1706 train_time:318928ms step_avg:99.66ms +step:3400/20000 train_loss:2.0239 train_time:338681ms step_avg:99.61ms +step:3600/20000 train_loss:2.0969 train_time:358234ms step_avg:99.51ms +step:3800/20000 train_loss:2.0788 train_time:376669ms step_avg:99.12ms +step:4000/20000 train_loss:1.9838 train_time:397318ms step_avg:99.33ms +step:4000/20000 val_loss:2.0768 val_bpb:1.2300 train_time:397342ms step_avg:99.34ms +step:4200/20000 train_loss:2.1655 train_time:417491ms step_avg:99.40ms +step:4400/20000 train_loss:2.0535 train_time:436234ms step_avg:99.14ms +step:4600/20000 train_loss:1.8614 train_time:457455ms step_avg:99.45ms +step:4800/20000 train_loss:2.4414 train_time:475619ms step_avg:99.09ms +step:5000/20000 train_loss:2.1070 train_time:497156ms step_avg:99.43ms +step:5000/20000 val_loss:2.0260 val_bpb:1.1999 train_time:497194ms step_avg:99.44ms +step:5200/20000 train_loss:2.0378 train_time:515549ms step_avg:99.14ms +step:5400/20000 train_loss:2.0361 train_time:536193ms step_avg:99.30ms +step:5600/20000 train_loss:1.9364 train_time:555328ms step_avg:99.17ms +step:5800/20000 train_loss:1.9671 train_time:574408ms step_avg:99.04ms +late_qat:enabled step:5893 scale:0.0996 +step:6000/20000 train_loss:1.9119 train_time:596711ms step_avg:99.45ms +step:6000/20000 val_loss:1.9455 val_bpb:1.1522 train_time:596738ms step_avg:99.46ms +step:6039/20000 val_loss:1.9441 val_bpb:1.1514 train_time:600008ms step_avg:99.36ms +stopping_early: wallclock_cap train_time:600008ms step:6039/20000 +peak memory allocated: 20607 MiB reserved: 20654 MiB +ema:applying EMA weights +saved ema_checkpoint.pt +Serialized model: 106832383 bytes +Code size: 81771 bytes +Serialized model int6+zstd: 15323962 bytes +Total submission size int6+zstd: 15405733 bytes +ttt:start lr=1.0 momentum=0.9 epochs=30 freeze_blocks=0 +ttt_epoch:1/30 loss:5.7241 time:16.9s +ttt_epoch:2/30 loss:5.9672 time:33.5s +ttt_epoch:3/30 loss:5.7740 time:50.1s +ttt_epoch:4/30 loss:5.4566 time:66.7s +ttt_epoch:5/30 loss:3.7144 time:83.3s +ttt_epoch:6/30 loss:2.4766 time:99.9s +ttt_epoch:7/30 loss:2.1376 time:116.5s +ttt_epoch:8/30 loss:1.9896 time:133.1s +ttt_epoch:9/30 loss:1.9714 time:149.7s +ttt_epoch:10/30 loss:2.0874 time:166.3s +ttt_epoch:11/30 loss:2.0820 time:182.9s +ttt_epoch:12/30 loss:2.0418 time:199.5s +ttt_epoch:13/30 loss:1.9634 time:216.1s +ttt_epoch:14/30 loss:1.9549 time:232.7s +ttt_epoch:15/30 loss:1.9497 time:249.3s +ttt_epoch:16/30 loss:1.9454 time:265.9s +ttt_epoch:17/30 loss:1.9419 time:282.6s +ttt_epoch:18/30 loss:1.9387 time:299.1s +ttt_epoch:19/30 loss:1.9359 time:315.6s +ttt_epoch:20/30 loss:1.9333 time:332.1s +ttt_epoch:21/30 loss:1.9308 time:348.7s +ttt_epoch:22/30 loss:1.9284 time:365.2s +ttt_epoch:23/30 loss:1.9262 time:381.7s +ttt_epoch:24/30 loss:1.9241 time:398.3s +ttt_epoch:25/30 loss:1.9222 time:414.8s +ttt_epoch:26/30 loss:1.9201 time:431.3s +ttt_epoch:27/30 loss:1.9183 time:447.8s +ttt_epoch:28/30 loss:1.9165 time:464.4s +ttt_epoch:29/30 loss:1.9145 time:480.9s +ttt_epoch:30/30 loss:1.9128 time:497.4s +ttt:done elapsed=497.4s +ttt:elapsed=497.4s +final_int6_roundtrip val_loss:1.9123 val_bpb:1.1326 eval_time:1922ms +final_int6_roundtrip_exact val_loss:1.91228463 val_bpb:1.13256267 +final_int6_sliding_window val_loss:1.8791 val_bpb:1.1129 stride:64 eval_time:91736ms +final_int6_sliding_window_exact val_loss:1.87914996 val_bpb:1.11294140 diff --git a/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/seed2024.log b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/seed2024.log new file mode 100644 index 0000000000..4d6f0e98e9 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/seed2024.log @@ -0,0 +1,130 @@ +W0325 18:09:10.743000 47278 torch/distributed/run.py:803] +W0325 18:09:10.743000 47278 torch/distributed/run.py:803] ***************************************** +W0325 18:09:10.743000 47278 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 18:09:10.743000 47278 torch/distributed/run.py:803] ***************************************** +logs/08bcdeaa-9436-442c-ba64-802e468cc64d.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27354201 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2024 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9283 val_bpb:4.1033 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9298 train_time:136ms step_avg:136.19ms +step:2/20000 train_loss:9.8882 train_time:209ms step_avg:104.51ms +step:3/20000 train_loss:8.3873 train_time:296ms step_avg:98.73ms +step:4/20000 train_loss:7.7438 train_time:384ms step_avg:95.91ms +step:5/20000 train_loss:7.1677 train_time:471ms step_avg:94.12ms +step:6/20000 train_loss:6.9261 train_time:558ms step_avg:92.94ms +step:7/20000 train_loss:7.0807 train_time:646ms step_avg:92.23ms +step:8/20000 train_loss:6.9515 train_time:740ms step_avg:92.51ms +step:9/20000 train_loss:6.5535 train_time:838ms step_avg:93.06ms +step:10/20000 train_loss:6.2269 train_time:954ms step_avg:95.44ms +step:200/20000 train_loss:2.3709 train_time:18934ms step_avg:94.67ms +step:400/20000 train_loss:2.4110 train_time:39929ms step_avg:99.82ms +step:600/20000 train_loss:2.3238 train_time:57721ms step_avg:96.20ms +step:800/20000 train_loss:2.2141 train_time:76856ms step_avg:96.07ms +step:1000/20000 train_loss:2.2511 train_time:96331ms step_avg:96.33ms +step:1000/20000 val_loss:2.2028 val_bpb:1.3046 train_time:96376ms step_avg:96.38ms +step:1200/20000 train_loss:2.3247 train_time:118870ms step_avg:99.06ms +step:1400/20000 train_loss:2.1505 train_time:138805ms step_avg:99.15ms +step:1600/20000 train_loss:2.0428 train_time:158317ms step_avg:98.95ms +step:1800/20000 train_loss:2.1351 train_time:179370ms step_avg:99.65ms +step:2000/20000 train_loss:2.0526 train_time:199062ms step_avg:99.53ms +step:2000/20000 val_loss:2.1197 val_bpb:1.2554 train_time:199085ms step_avg:99.54ms +step:2200/20000 train_loss:2.1258 train_time:219886ms step_avg:99.95ms +step:2400/20000 train_loss:2.0584 train_time:237782ms step_avg:99.08ms +step:2600/20000 train_loss:2.1061 train_time:259097ms step_avg:99.65ms +step:2800/20000 train_loss:2.1589 train_time:280381ms step_avg:100.14ms +step:3000/20000 train_loss:2.1575 train_time:298165ms step_avg:99.39ms +step:3000/20000 val_loss:2.0933 val_bpb:1.2398 train_time:298191ms step_avg:99.40ms +step:3200/20000 train_loss:2.1760 train_time:318013ms step_avg:99.38ms +step:3400/20000 train_loss:2.0219 train_time:336253ms step_avg:98.90ms +step:3600/20000 train_loss:2.0957 train_time:357559ms step_avg:99.32ms +step:3800/20000 train_loss:2.0760 train_time:377699ms step_avg:99.39ms +step:4000/20000 train_loss:1.9864 train_time:398735ms step_avg:99.68ms +step:4000/20000 val_loss:2.0784 val_bpb:1.2309 train_time:398762ms step_avg:99.69ms +step:4200/20000 train_loss:2.1704 train_time:418022ms step_avg:99.53ms +step:4400/20000 train_loss:2.0550 train_time:436039ms step_avg:99.10ms +step:4600/20000 train_loss:1.8647 train_time:456177ms step_avg:99.17ms +step:4800/20000 train_loss:2.4504 train_time:475025ms step_avg:98.96ms +step:5000/20000 train_loss:2.1093 train_time:494156ms step_avg:98.83ms +step:5000/20000 val_loss:2.0304 val_bpb:1.2025 train_time:494180ms step_avg:98.84ms +step:5200/20000 train_loss:2.0425 train_time:511924ms step_avg:98.45ms +step:5400/20000 train_loss:2.0405 train_time:534622ms step_avg:99.00ms +step:5600/20000 train_loss:1.9377 train_time:553899ms step_avg:98.91ms +step:5800/20000 train_loss:1.9686 train_time:573023ms step_avg:98.80ms +late_qat:enabled step:5919 scale:0.0998 +step:6000/20000 train_loss:1.9159 train_time:592251ms step_avg:98.71ms +step:6000/20000 val_loss:1.9486 val_bpb:1.1541 train_time:592274ms step_avg:98.71ms +step:6092/20000 val_loss:1.9445 val_bpb:1.1516 train_time:600009ms step_avg:98.49ms +stopping_early: wallclock_cap train_time:600009ms step:6092/20000 +peak memory allocated: 20606 MiB reserved: 20656 MiB +ema:applying EMA weights +saved ema_checkpoint.pt +Serialized model: 106832383 bytes +Code size: 81771 bytes +Serialized model int6+zstd: 15299337 bytes +Total submission size int6+zstd: 15381108 bytes +ttt:start lr=1.0 momentum=0.9 epochs=30 freeze_blocks=0 +ttt_epoch:1/30 loss:5.1426 time:16.9s +ttt_epoch:2/30 loss:3.4041 time:33.5s +ttt_epoch:3/30 loss:2.2817 time:50.1s +ttt_epoch:4/30 loss:2.1099 time:66.7s +ttt_epoch:5/30 loss:1.9857 time:83.3s +ttt_epoch:6/30 loss:1.9701 time:99.9s +ttt_epoch:7/30 loss:1.9618 time:116.5s +ttt_epoch:8/30 loss:1.9560 time:133.1s +ttt_epoch:9/30 loss:2.0294 time:149.7s +ttt_epoch:10/30 loss:2.0532 time:166.3s +ttt_epoch:11/30 loss:1.9803 time:183.0s +ttt_epoch:12/30 loss:1.9490 time:199.6s +ttt_epoch:13/30 loss:1.9441 time:216.2s +ttt_epoch:14/30 loss:1.9405 time:232.8s +ttt_epoch:15/30 loss:1.9373 time:249.4s +ttt_epoch:16/30 loss:1.9345 time:266.0s +ttt_epoch:17/30 loss:1.9320 time:282.6s +ttt_epoch:18/30 loss:1.9296 time:299.2s +ttt_epoch:19/30 loss:1.9274 time:315.8s +ttt_epoch:20/30 loss:1.9252 time:332.4s +ttt_epoch:21/30 loss:1.9232 time:349.0s +ttt_epoch:22/30 loss:1.9212 time:365.6s +ttt_epoch:23/30 loss:1.9193 time:382.3s +ttt_epoch:24/30 loss:1.9177 time:398.9s +ttt_epoch:25/30 loss:1.9158 time:415.5s +ttt_epoch:26/30 loss:1.9140 time:432.1s +ttt_epoch:27/30 loss:1.9124 time:448.7s +ttt_epoch:28/30 loss:1.9118 time:465.3s +ttt_epoch:29/30 loss:2.0171 time:481.9s +ttt_epoch:30/30 loss:1.9348 time:498.5s +ttt:done elapsed=498.6s +ttt:elapsed=498.6s +final_int6_roundtrip val_loss:1.9093 val_bpb:1.1308 eval_time:1930ms +final_int6_roundtrip_exact val_loss:1.90931604 val_bpb:1.13080451 +final_int6_sliding_window val_loss:1.8766 val_bpb:1.1114 stride:64 eval_time:74088ms +final_int6_sliding_window_exact val_loss:1.87657671 val_bpb:1.11141737 diff --git a/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/seed42.log b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/seed42.log new file mode 100644 index 0000000000..1e9734193e --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/seed42.log @@ -0,0 +1,130 @@ +W0325 17:46:36.031000 45431 torch/distributed/run.py:803] +W0325 17:46:36.031000 45431 torch/distributed/run.py:803] ***************************************** +W0325 17:46:36.031000 45431 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 17:46:36.031000 45431 torch/distributed/run.py:803] ***************************************** +logs/6c889bea-0753-4d86-a6d7-0932fa63b023.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27354201 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9301 val_bpb:4.1044 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9313 train_time:137ms step_avg:136.76ms +step:2/20000 train_loss:10.0205 train_time:215ms step_avg:107.73ms +step:3/20000 train_loss:8.4670 train_time:305ms step_avg:101.55ms +step:4/20000 train_loss:7.7782 train_time:392ms step_avg:97.89ms +step:5/20000 train_loss:7.1548 train_time:480ms step_avg:95.91ms +step:6/20000 train_loss:6.7985 train_time:569ms step_avg:94.76ms +step:7/20000 train_loss:6.7720 train_time:661ms step_avg:94.47ms +step:8/20000 train_loss:6.8323 train_time:750ms step_avg:93.69ms +step:9/20000 train_loss:6.6256 train_time:839ms step_avg:93.21ms +step:10/20000 train_loss:6.1474 train_time:940ms step_avg:93.98ms +step:200/20000 train_loss:2.3715 train_time:19799ms step_avg:98.99ms +step:400/20000 train_loss:2.4145 train_time:42085ms step_avg:105.21ms +step:600/20000 train_loss:2.3262 train_time:61500ms step_avg:102.50ms +step:800/20000 train_loss:2.2255 train_time:81498ms step_avg:101.87ms +step:1000/20000 train_loss:2.2527 train_time:99975ms step_avg:99.98ms +step:1000/20000 val_loss:2.2057 val_bpb:1.3063 train_time:100000ms step_avg:100.00ms +step:1200/20000 train_loss:2.3268 train_time:121051ms step_avg:100.88ms +step:1400/20000 train_loss:2.1537 train_time:141939ms step_avg:101.38ms +step:1600/20000 train_loss:2.0428 train_time:161628ms step_avg:101.02ms +step:1800/20000 train_loss:2.1385 train_time:182484ms step_avg:101.38ms +step:2000/20000 train_loss:2.0598 train_time:202071ms step_avg:101.04ms +step:2000/20000 val_loss:2.1210 val_bpb:1.2562 train_time:202095ms step_avg:101.05ms +step:2200/20000 train_loss:2.1250 train_time:223648ms step_avg:101.66ms +step:2400/20000 train_loss:2.0592 train_time:243038ms step_avg:101.27ms +step:2600/20000 train_loss:2.1046 train_time:264754ms step_avg:101.83ms +step:2800/20000 train_loss:2.1561 train_time:286866ms step_avg:102.45ms +step:3000/20000 train_loss:2.1616 train_time:307426ms step_avg:102.48ms +step:3000/20000 val_loss:2.0948 val_bpb:1.2407 train_time:307450ms step_avg:102.48ms +step:3200/20000 train_loss:2.1759 train_time:325969ms step_avg:101.87ms +step:3400/20000 train_loss:2.0194 train_time:343551ms step_avg:101.04ms +step:3600/20000 train_loss:2.0994 train_time:362972ms step_avg:100.83ms +step:3800/20000 train_loss:2.0788 train_time:380669ms step_avg:100.18ms +step:4000/20000 train_loss:1.9882 train_time:400321ms step_avg:100.08ms +step:4000/20000 val_loss:2.0799 val_bpb:1.2318 train_time:400351ms step_avg:100.09ms +step:4200/20000 train_loss:2.1737 train_time:420269ms step_avg:100.06ms +step:4400/20000 train_loss:2.0574 train_time:439242ms step_avg:99.83ms +step:4600/20000 train_loss:1.8626 train_time:459796ms step_avg:99.96ms +step:4800/20000 train_loss:2.4455 train_time:478014ms step_avg:99.59ms +step:5000/20000 train_loss:2.1072 train_time:498582ms step_avg:99.72ms +step:5000/20000 val_loss:2.0262 val_bpb:1.2000 train_time:498620ms step_avg:99.72ms +step:5200/20000 train_loss:2.0358 train_time:517289ms step_avg:99.48ms +step:5400/20000 train_loss:2.0326 train_time:540528ms step_avg:100.10ms +step:5600/20000 train_loss:1.9343 train_time:559021ms step_avg:99.83ms +step:5800/20000 train_loss:1.9656 train_time:578849ms step_avg:99.80ms +late_qat:enabled step:5850 scale:0.0998 +step:6000/20000 train_loss:1.9100 train_time:599737ms step_avg:99.96ms +step:6000/20000 val_loss:1.9458 val_bpb:1.1524 train_time:599762ms step_avg:99.96ms +step:6003/20000 val_loss:1.9458 val_bpb:1.1524 train_time:600013ms step_avg:99.95ms +stopping_early: wallclock_cap train_time:600013ms step:6003/20000 +peak memory allocated: 20606 MiB reserved: 20656 MiB +ema:applying EMA weights +saved ema_checkpoint.pt +Serialized model: 106832383 bytes +Code size: 81771 bytes +Serialized model int6+zstd: 15294001 bytes +Total submission size int6+zstd: 15375772 bytes +ttt:start lr=1.0 momentum=0.9 epochs=30 freeze_blocks=0 +ttt_epoch:1/30 loss:5.6495 time:16.9s +ttt_epoch:2/30 loss:5.8047 time:33.5s +ttt_epoch:3/30 loss:5.0216 time:50.1s +ttt_epoch:4/30 loss:3.1201 time:66.7s +ttt_epoch:5/30 loss:2.3163 time:83.3s +ttt_epoch:6/30 loss:2.1248 time:99.9s +ttt_epoch:7/30 loss:1.9871 time:116.6s +ttt_epoch:8/30 loss:1.9709 time:133.1s +ttt_epoch:9/30 loss:1.9627 time:149.6s +ttt_epoch:10/30 loss:1.9572 time:166.2s +ttt_epoch:11/30 loss:2.0639 time:182.7s +ttt_epoch:12/30 loss:1.9837 time:199.2s +ttt_epoch:13/30 loss:1.9505 time:215.7s +ttt_epoch:14/30 loss:1.9460 time:232.3s +ttt_epoch:15/30 loss:1.9425 time:248.8s +ttt_epoch:16/30 loss:1.9394 time:265.3s +ttt_epoch:17/30 loss:1.9367 time:281.8s +ttt_epoch:18/30 loss:1.9342 time:298.4s +ttt_epoch:19/30 loss:1.9318 time:314.9s +ttt_epoch:20/30 loss:1.9295 time:331.4s +ttt_epoch:21/30 loss:2.0382 time:347.9s +ttt_epoch:22/30 loss:2.0117 time:364.5s +ttt_epoch:23/30 loss:1.9312 time:381.0s +ttt_epoch:24/30 loss:1.9266 time:397.5s +ttt_epoch:25/30 loss:1.9238 time:414.0s +ttt_epoch:26/30 loss:1.9214 time:430.6s +ttt_epoch:27/30 loss:1.9193 time:447.1s +ttt_epoch:28/30 loss:1.9173 time:463.6s +ttt_epoch:29/30 loss:1.9154 time:480.1s +ttt_epoch:30/30 loss:1.9136 time:496.7s +ttt:done elapsed=496.7s +ttt:elapsed=496.7s +final_int6_roundtrip val_loss:1.9119 val_bpb:1.1323 eval_time:1924ms +final_int6_roundtrip_exact val_loss:1.91190392 val_bpb:1.13233720 +final_int6_sliding_window val_loss:1.8789 val_bpb:1.1128 stride:64 eval_time:74034ms +final_int6_sliding_window_exact val_loss:1.87889375 val_bpb:1.11278966 diff --git a/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/submission.json b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/submission.json new file mode 100644 index 0000000000..99906eed59 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/submission.json @@ -0,0 +1,25 @@ +{ + "name": "Fielding Johnston", + "github": "fielding", + "val_bpb": 1.1124, + "val_bpb_seeds": { + "1337": 1.1129, + "42": 1.1128, + "2024": 1.1114 + }, + "val_bpb_std": 0.0008, + "artifact_bytes": 15405733, + "hardware": "8xH100 SXM", + "training_time_seconds": 600, + "eval_time_seconds": 591, + "key_techniques": [ + "aggressive_ttt_sgd_30ep", + "value_embeddings", + "flash_attn_v3", + "ema_0.997", + "late_qat_int6", + "xsa_last_4", + "bigram_hash_6144", + "sliding_window_eval_s64" + ] +} diff --git a/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/train_gpt.py b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/train_gpt.py new file mode 100644 index 0000000000..9740ffcfa3 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_TTT_Aggressive_SGD/train_gpt.py @@ -0,0 +1,1918 @@ +""" +train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE + +fp16 embed + late-K passthrough + sliding window eval. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import zstandard +_COMPRESSOR = "zstd" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func +except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + except ImportError: + flash_attn_3_func = None + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 200)) + muon_wd = float(os.environ.get("MUON_WD", 0.02)) + adam_wd = float(os.environ.get("ADAM_WD", 0.01)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + rope_dims = int(os.environ.get("ROPE_DIMS", 0)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) + late_qat = bool(int(os.environ.get("LATE_QAT", "0"))) + # TTT (Test-Time Training) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 6144)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + gptq_enabled = bool(int(os.environ.get("GPTQ_ENABLED", "0"))) + gptq_samples = int(os.environ.get("GPTQ_SAMPLES", 128)) + int5_mlp = bool(int(os.environ.get("INT5_MLP", "0"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1e-4) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.rope_dims = rope_dims if rope_dims > 0 else dim + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope, x_pass = x[..., :rd], x[..., rd:] + half = rd // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = rope_dims + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if flash_attn_3_func is not None: + y = flash_attn_3_func(q, k, v, causal=True) + else: + qt, kt, vt = q.transpose(1,2), k.transpose(1,2), v.transpose(1,2) + if kt.shape[1] != qt.shape[1]: + nr = qt.shape[1]//kt.shape[1] + kt = kt[:,:,None,:,:].expand(-1,-1,nr,-1,-1).reshape(kt.shape[0],-1,kt.shape[-2],kt.shape[-1]) + vt = vt[:,:,None,:,:].expand(-1,-1,nr,-1,-1).reshape(vt.shape[0],-1,vt.shape[-2],vt.shape[-1]) + y = torch.nn.functional.scaled_dot_product_attention(qt,kt,vt,is_causal=True).transpose(1,2) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x) * s) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# GPTQ QUANTIZATION +# ----------------------------- + +def gptq_quantize_int6(weight: Tensor, hessian: Tensor, block_size: int = 128, percdamp: float = 0.01) -> tuple[Tensor, Tensor]: + """GPTQ quantization for a single 2D weight matrix. + + Args: + weight: (out_features, in_features) + hessian: (in_features, in_features) — H = X^T X / n_samples + block_size: columns processed at once + percdamp: Hessian damping as fraction of diagonal mean + + Returns: + quantized (int8 storing int6 [-32,31]), per-row scales (fp16) + """ + W = weight.float().clone() + n_rows, n_cols = W.shape + + # Per-row scale (computed from original weights) + row_max = W.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + + # Damp the Hessian diagonal + H = hessian.float().clone() + damp = percdamp * H.diag().mean() + H.diagonal().add_(damp) + + # Cholesky decomposition for efficient inverse + try: + L = torch.linalg.cholesky(H) + Hinv = torch.cholesky_inverse(L) + except torch.linalg.LinAlgError: + # Fallback: add more damping + H.diagonal().add_(damp * 10) + L = torch.linalg.cholesky(H) + Hinv = torch.cholesky_inverse(L) + + Q = torch.zeros_like(W) + + for col_start in range(0, n_cols, block_size): + col_end = min(col_start + block_size, n_cols) + block_cols = col_end - col_start + + W_block_orig = W[:, col_start:col_end].clone() + W_block = W_block_orig.clone() + Hinv_block = Hinv[col_start:col_end, col_start:col_end] + + for j in range(block_cols): + w_col = W_block[:, j] + h_inv_jj = Hinv_block[j, j] + + # Quantize column + q_col = torch.clamp(torch.round(w_col / scale), -32, 31) + Q[:, col_start + j] = q_col + + # Dequantize + dq_col = q_col * scale + + # Error and compensation to remaining columns in block + err = (w_col - dq_col) / h_inv_jj + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + + # Propagate block error to remaining columns using original weights + if col_end < n_cols: + block_err = W_block_orig - (Q[:, col_start:col_end] * scale.unsqueeze(1)) + W[:, col_end:] -= block_err @ Hinv[col_start:col_end, col_end:] + + return Q.to(torch.int8), scale.to(torch.float16) + + +def collect_gptq_calibration(model: nn.Module, val_tokens: Tensor, device: torch.device, + seq_len: int, n_samples: int = 128) -> dict[str, Tensor]: + """Run calibration sequences through the model, collecting Hessians for each Linear layer.""" + hessians: dict[str, Tensor] = {} + sample_counts: dict[str, int] = {} + hooks = [] + + def make_hook(name): + def hook_fn(module, inp, out): + x = inp[0].detach().float() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + if name not in hessians: + hessians[name] = torch.zeros(x.shape[1], x.shape[1], device=x.device, dtype=torch.float64) + sample_counts[name] = 0 + hessians[name].add_(x.T @ x) + sample_counts[name] += x.shape[0] + return hook_fn + + # Register hooks on all CastedLinear / nn.Linear with 2D weights + for name, module in model.named_modules(): + if isinstance(module, nn.Linear) and module.weight.ndim == 2 and module.weight.numel() > 65536: + hooks.append(module.register_forward_hook(make_hook(name))) + + # Run calibration data — use no_grad (not inference_mode) so inplace ops work on hooks + model.eval() + total_tokens = val_tokens.numel() - 1 + n_seqs = min(n_samples, total_tokens // seq_len) + with torch.no_grad(): + for seq_idx in range(n_seqs): + i = seq_idx * seq_len + x = val_tokens[i:i + seq_len].unsqueeze(0).to(device=device, dtype=torch.int64) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + model.forward_logits(x) + if (seq_idx + 1) % 32 == 0: + print(f" gptq:calibration {seq_idx+1}/{n_seqs}", flush=True) + + # Remove hooks + for h in hooks: + h.remove() + + # Normalize Hessians + for name in hessians: + if sample_counts[name] > 0: + hessians[name] /= sample_counts[name] + + return {name: h.float() for name, h in hessians.items()} + + +def gptq_quantize_state_dict(state_dict: dict[str, Tensor], hessians: dict[str, Tensor], + model_prefix: str = "") -> dict[str, tuple[Tensor, Tensor]]: + """Apply GPTQ to all quantizable layers. Returns {param_name: (q, scale)}.""" + results = {} + skipped = [] + for name, tensor in state_dict.items(): + if tensor.ndim != 2 or not tensor.is_floating_point() or tensor.numel() <= 65536: + continue + # Find matching Hessian — strip state_dict prefix to match module name + hessian_key = None + for hk in hessians: + # Match: module_name.weight -> hk is module_name + if name.endswith(".weight") and name[:-7] == hk: + hessian_key = hk + break + # Also try with prefix + if name.endswith(".weight") and (model_prefix + name[:-7]) == hk: + hessian_key = hk + break + if hessian_key is not None: + q, s = gptq_quantize_int6(tensor, hessians[hessian_key]) + results[name] = (q, s) + else: + skipped.append(name) + if skipped: + print(f" gptq: WARNING skipped {len(skipped)} layers (no Hessian match): {skipped[:5]}{'...' if len(skipped)>5 else ''}", flush=True) + print(f" gptq: quantized {len(results)} layers, skipped {len(skipped)}", flush=True) + return results + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int5_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + """INT5 quantization: [-16, 15] range, 5-bit effective.""" + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / 15.0).clamp_min(1.0 / 15.0).to(torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -16, 15).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / 15.0 if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -16, 15).to(torch.int8) + return q, scale + +def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + gptq_results: dict[str, tuple[Tensor, Tensor]] | None = None, + int5_cats: set[str] | None = None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + gptq_count = 0 + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # Use GPTQ results if available for this param + if gptq_results is not None and name in gptq_results: + q, s = gptq_results[name] + result[name + ".q"] = q.cpu() + result[name + ".scale"] = s.cpu() + meta[name] = {"type": "int6"} + gptq_count += 1 + elif int5_cats is not None and cat in int5_cats and t.ndim >= 1: + q, s = quantize_int5_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5"} + elif cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TTT (TEST-TIME TRAINING) +# ----------------------------- + +def ttt_adapt(args, base_model, device, val_tokens, rank=0, world_size=1, log_fn=None): + """Full-weight TTT: SGD adaptation on val data with DDP across all GPUs.""" + seq_len = args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + batch_seqs = args.ttt_batch_seqs + + frozen_params = set() + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(base_model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + frozen_params.add(id(p)) + + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + my_start = (total_seqs * rank) // world_size + my_end = (total_seqs * (rank + 1)) // world_size + + base_model.train() + t0 = time.perf_counter() + + for epoch in range(args.ttt_epochs): + epoch_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + epoch_tokens = torch.zeros((), device=device, dtype=torch.float64) + + for batch_start in range(my_start, my_end, batch_seqs): + batch_end = min(batch_start + batch_seqs, my_end) + raw_start = batch_start * seq_len + raw_end = batch_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + + epoch_loss_sum += loss.detach().to(torch.float64) * y.numel() + epoch_tokens += float(y.numel()) + + if world_size > 1: + dist.all_reduce(epoch_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(epoch_tokens, op=dist.ReduceOp.SUM) + + elapsed = time.perf_counter() - t0 + if log_fn: + log_fn(f"ttt_epoch:{epoch+1}/{args.ttt_epochs} loss:{epoch_loss_sum.item()/max(epoch_tokens.item(),1):.4f} time:{elapsed:.1f}s") + + for p in base_model.parameters(): + p.requires_grad_(True) + + if log_fn: + log_fn(f"ttt:done elapsed={time.perf_counter()-t0:.1f}s") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) if os.environ.get("TORCH_COMPILE","1")!="0" else zeropower_via_newtonschulz5 + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = args.qat_enabled + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if os.environ.get("TORCH_COMPILE","1")!="0" else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.1")) + if args.late_qat and scale < qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + if args.swa_enabled and not args.ema_enabled and scale < 0.5 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name].add_(t.detach().float()) + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) + for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + elif args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + del swa_state + base_model.load_state_dict(avg_state, strict=True) + + # Save EMA checkpoint for reuse + if master_process: + torch.save(base_model.state_dict(), "ema_checkpoint.pt") + log0("saved ema_checkpoint.pt") + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + + # GPTQ: collect calibration data and quantize with error compensation (rank 0 only) + gptq_results = None + if args.gptq_enabled and master_process: + log0("gptq:start — creating fresh non-compiled model for calibration...") + t_gptq = time.perf_counter() + + # Create a fresh non-compiled model to avoid torch.compile + hooks conflict + calib_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, + ).to(device).bfloat16() + for m in calib_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(calib_model) + calib_model.load_state_dict(export_sd, strict=True) + calib_model.eval() + log0(f"gptq:calib model on {next(calib_model.parameters()).device}") + + log0("gptq:collecting calibration data...") + hessians = collect_gptq_calibration( + calib_model, val_tokens, device, args.train_seq_len, n_samples=args.gptq_samples + ) + log0(f"gptq:collected {len(hessians)} layer hessians in {time.perf_counter()-t_gptq:.1f}s") + + del calib_model + torch.cuda.empty_cache() + + # Move hessians to CPU for quantization + hessians_cpu = {k: v.cpu() for k, v in hessians.items()} + del hessians + gptq_results = gptq_quantize_state_dict(sd_cpu, hessians_cpu) + log0(f"gptq:quantized {len(gptq_results)} layers in {time.perf_counter()-t_gptq:.1f}s") + del hessians_cpu + + if args.gptq_enabled and distributed: + dist.barrier() # wait for rank 0 GPTQ to finish + + int5_cats = {"mlp"} if args.int5_mlp else None + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}, gptq_results=gptq_results, int5_cats=int5_cats) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + + # TTT: adapt model on validation data before eval + if args.ttt_enabled: + if distributed: + dist.barrier() + log0(f"ttt:start lr={args.ttt_lr} momentum={args.ttt_momentum} epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + t_ttt = time.perf_counter() + ttt_adapt(args, eval_model, device, val_tokens, rank=rank, world_size=world_size, log_fn=log0) + log0(f"ttt:elapsed={time.perf_counter() - t_ttt:.1f}s") + if distributed: + dist.barrier() + + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) if os.environ.get("TORCH_COMPILE","1")!="0" else eval_model + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()