diff --git a/1x H100 SXM5 Logs/10L_qk4_wd1500_20260401_094047.log b/1x H100 SXM5 Logs/10L_qk4_wd1500_20260401_094047.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/10L_qk4_wd2000_20260401_081938.log b/1x H100 SXM5 Logs/10L_qk4_wd2000_20260401_081938.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_mlp35_20260401_074042.log b/1x H100 SXM5 Logs/11L_mlp35_20260401_074042.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_20260401_070337.log b/1x H100 SXM5 Logs/11L_qk4_20260401_070337.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_20260401_070342.log b/1x H100 SXM5 Logs/11L_qk4_20260401_070342.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1000_20260401_101532.log b/1x H100 SXM5 Logs/11L_qk4_wd1000_20260401_101532.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1000_tied005_20260401_154931.log b/1x H100 SXM5 Logs/11L_qk4_wd1000_tied005_20260401_154931.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1200_fa3_20260401_172016.log b/1x H100 SXM5 Logs/11L_qk4_wd1200_fa3_20260401_172016.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1200_fa3_20260401_183913.log b/1x H100 SXM5 Logs/11L_qk4_wd1200_fa3_20260401_183913.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1500_20260401_085426.log b/1x H100 SXM5 Logs/11L_qk4_wd1500_20260401_085426.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1500_fa3_20260401_171958.log b/1x H100 SXM5 Logs/11L_qk4_wd1500_fa3_20260401_171958.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1500_fa3_20260401_180210.log b/1x H100 SXM5 Logs/11L_qk4_wd1500_fa3_20260401_180210.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1500_fa3_20260401_180428.log b/1x H100 SXM5 Logs/11L_qk4_wd1500_fa3_20260401_180428.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1500_mlr025_20260401_105230.log b/1x H100 SXM5 Logs/11L_qk4_wd1500_mlr025_20260401_105230.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1500_swa25_20260401_143529.log b/1x H100 SXM5 Logs/11L_qk4_wd1500_swa25_20260401_143529.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1500_tied005_20260401_124435.log b/1x H100 SXM5 Logs/11L_qk4_wd1500_tied005_20260401_124435.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk4_wd1500_tied005_mlr025_20260401_151230.log b/1x H100 SXM5 Logs/11L_qk4_wd1500_tied005_mlr025_20260401_151230.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk6_wd1500_20260401_132133.log b/1x H100 SXM5 Logs/11L_qk6_wd1500_20260401_132133.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_qk8_wd1500_20260401_135834.log b/1x H100 SXM5 Logs/11L_qk8_wd1500_20260401_135834.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11L_wd3500_20260401_081928.log b/1x H100 SXM5 Logs/11L_wd3500_20260401_081928.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/11layers_20260331_233309.log b/1x H100 SXM5 Logs/11layers_20260331_233309.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/12L_qk4_wd1200_20260401_120458.log b/1x H100 SXM5 Logs/12L_qk4_wd1200_20260401_120458.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/baseline_10L_20260331_214036.log b/1x H100 SXM5 Logs/baseline_10L_20260331_214036.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/baseline_786k_20260331_222116.log b/1x H100 SXM5 Logs/baseline_786k_20260331_222116.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/depth_recurrence.log b/1x H100 SXM5 Logs/depth_recurrence.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/free_wins.log b/1x H100 SXM5 Logs/free_wins.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/int6_qat.log b/1x H100 SXM5 Logs/int6_qat.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/mega_xsa_ema_fa3_20260401_162629.log b/1x H100 SXM5 Logs/mega_xsa_ema_fa3_20260401_162629.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/mega_xsa_ema_fa3_20260401_164231.log b/1x H100 SXM5 Logs/mega_xsa_ema_fa3_20260401_164231.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/mega_xsa_ema_fa3_20260401_173311.log b/1x H100 SXM5 Logs/mega_xsa_ema_fa3_20260401_173311.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/mlp35_20260401_033845.log b/1x H100 SXM5 Logs/mlp35_20260401_033845.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/naive_baseline_9L_mlp2_seq1024_20260401_113000.log b/1x H100 SXM5 Logs/naive_baseline_9L_mlp2_seq1024_20260401_113000.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/ngram_cache.log b/1x H100 SXM5 Logs/ngram_cache.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/parallel_residuals.log b/1x H100 SXM5 Logs/parallel_residuals.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/qk_gain_20260401_001427.log b/1x H100 SXM5 Logs/qk_gain_20260401_001427.log new file mode 100644 index 0000000000..0a19ee082b --- /dev/null +++ b/1x H100 SXM5 Logs/qk_gain_20260401_001427.log @@ -0,0 +1,709 @@ +logs/phase2_qk_gain_20260401_001427.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/workspace/runpod-testing/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/workspace/runpod-testing/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:25517137 +world_size:1 grad_accum_steps:8 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:262144 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9323 val_bpb:4.1057 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9338 train_time:313ms step_avg:313.18ms +step:2/20000 train_loss:8.2629 train_time:577ms step_avg:288.48ms +step:3/20000 train_loss:7.8358 train_time:841ms step_avg:280.26ms +step:4/20000 train_loss:7.0947 train_time:1104ms step_avg:276.06ms +step:5/20000 train_loss:6.5743 train_time:1368ms step_avg:273.65ms +step:6/20000 train_loss:6.3174 train_time:1632ms step_avg:271.99ms +step:7/20000 train_loss:6.2934 train_time:1896ms step_avg:270.81ms +step:8/20000 train_loss:6.1892 train_time:2159ms step_avg:269.92ms +step:9/20000 train_loss:6.1419 train_time:2423ms step_avg:269.22ms +step:10/20000 train_loss:6.0544 train_time:2687ms step_avg:268.67ms +step:50/20000 train_loss:3.7730 train_time:13243ms step_avg:264.85ms +step:100/20000 train_loss:3.3028 train_time:26475ms step_avg:264.75ms +step:150/20000 train_loss:3.1008 train_time:39666ms step_avg:264.44ms +step:200/20000 train_loss:2.8570 train_time:52851ms step_avg:264.26ms +step:250/20000 train_loss:2.7286 train_time:66037ms step_avg:264.15ms +step:300/20000 train_loss:2.6685 train_time:79242ms step_avg:264.14ms +step:350/20000 train_loss:2.4599 train_time:92491ms step_avg:264.26ms +step:400/20000 train_loss:2.4586 train_time:105701ms step_avg:264.25ms +step:450/20000 train_loss:2.4659 train_time:118888ms step_avg:264.20ms +step:500/20000 train_loss:2.4882 train_time:132071ms step_avg:264.14ms +step:500/20000 val_loss:2.4696 val_bpb:1.4627 train_time:132071ms step_avg:264.14ms +step:550/20000 train_loss:2.4288 train_time:145264ms step_avg:264.12ms +step:600/20000 train_loss:2.4215 train_time:158448ms step_avg:264.08ms +step:650/20000 train_loss:2.4073 train_time:171631ms step_avg:264.05ms +step:700/20000 train_loss:2.3756 train_time:184866ms step_avg:264.09ms +step:750/20000 train_loss:2.4568 train_time:198063ms step_avg:264.08ms +step:800/20000 train_loss:2.1979 train_time:211325ms step_avg:264.16ms +step:850/20000 train_loss:2.3230 train_time:224576ms step_avg:264.21ms +step:900/20000 train_loss:2.3208 train_time:237767ms step_avg:264.19ms +step:950/20000 train_loss:2.3669 train_time:250958ms step_avg:264.17ms +step:1000/20000 train_loss:2.2878 train_time:264193ms step_avg:264.19ms +step:1000/20000 val_loss:2.3192 val_bpb:1.3736 train_time:264194ms step_avg:264.19ms +step:1050/20000 train_loss:2.4122 train_time:277374ms step_avg:264.17ms +swa:start step:1100 +step:1100/20000 train_loss:2.2196 train_time:290572ms step_avg:264.16ms +step:1150/20000 train_loss:2.2358 train_time:303890ms step_avg:264.25ms +step:1200/20000 train_loss:2.3712 train_time:317084ms step_avg:264.24ms +step:1250/20000 train_loss:2.2891 train_time:330328ms step_avg:264.26ms +step:1300/20000 train_loss:2.1527 train_time:343529ms step_avg:264.25ms +step:1350/20000 train_loss:2.1327 train_time:356810ms step_avg:264.30ms +step:1400/20000 train_loss:2.3299 train_time:370039ms step_avg:264.31ms +step:1450/20000 train_loss:2.3405 train_time:383316ms step_avg:264.36ms +step:1500/20000 train_loss:2.2303 train_time:396571ms step_avg:264.38ms +step:1500/20000 val_loss:2.2487 val_bpb:1.3318 train_time:396587ms step_avg:264.39ms +step:1550/20000 train_loss:2.3584 train_time:409963ms step_avg:264.49ms +step:1600/20000 train_loss:2.1952 train_time:423248ms step_avg:264.53ms +step:1650/20000 train_loss:2.2934 train_time:436519ms step_avg:264.56ms +step:1700/20000 train_loss:2.2806 train_time:449799ms step_avg:264.59ms +step:1750/20000 train_loss:2.2113 train_time:463067ms step_avg:264.61ms +step:1800/20000 train_loss:2.2273 train_time:476285ms step_avg:264.60ms +step:1850/20000 train_loss:2.1953 train_time:489503ms step_avg:264.60ms +step:1900/20000 train_loss:2.2019 train_time:502744ms step_avg:264.60ms +step:1950/20000 train_loss:2.1589 train_time:516025ms step_avg:264.63ms +step:2000/20000 train_loss:2.2230 train_time:529242ms step_avg:264.62ms +step:2000/20000 val_loss:2.1696 val_bpb:1.2850 train_time:529258ms step_avg:264.63ms +step:2050/20000 train_loss:2.3380 train_time:542465ms step_avg:264.62ms +step:2100/20000 train_loss:2.1885 train_time:555727ms step_avg:264.63ms +step:2150/20000 train_loss:2.1286 train_time:569021ms step_avg:264.66ms +step:2200/20000 train_loss:2.1141 train_time:582243ms step_avg:264.66ms +step:2250/20000 train_loss:2.1736 train_time:595466ms step_avg:264.65ms +step:2268/20000 val_loss:2.1398 val_bpb:1.2673 train_time:600242ms step_avg:264.66ms +stopping_early: wallclock_cap train_time:600242ms step:2268/20000 +peak memory allocated: 6687 MiB reserved: 6954 MiB +swa:applying averaged 24 checkpoints +Serialized model: 98437419 bytes +Code size: 52930 bytes +Total submission size: 98490349 bytes +Serialized model int6+zstd: 15681053 bytes +Total submission size int8+zlib: 15733983 bytes +final_eval_mode:sliding_window stride:64 batch_seqs:32 + sliding_eval [ 0.0%] 32/969088 windows running_bpb=1.327616 + sliding_eval [ 0.2%] 1632/969088 windows running_bpb=1.261600 + sliding_eval [ 0.3%] 3232/969088 windows running_bpb=1.256557 + sliding_eval [ 0.5%] 4832/969088 windows running_bpb=1.252669 + sliding_eval [ 0.7%] 6432/969088 windows running_bpb=1.265404 + sliding_eval [ 0.8%] 8032/969088 windows running_bpb=1.267272 + sliding_eval [ 1.0%] 9632/969088 windows running_bpb=1.269369 + sliding_eval [ 1.2%] 11232/969088 windows running_bpb=1.264131 + sliding_eval [ 1.3%] 12832/969088 windows running_bpb=1.261458 + sliding_eval [ 1.5%] 14432/969088 windows running_bpb=1.263414 + sliding_eval [ 1.7%] 16032/969088 windows running_bpb=1.272222 + sliding_eval [ 1.8%] 17632/969088 windows running_bpb=1.270399 + sliding_eval [ 2.0%] 19232/969088 windows running_bpb=1.271390 + sliding_eval [ 2.1%] 20832/969088 windows running_bpb=1.270020 + sliding_eval [ 2.3%] 22432/969088 windows running_bpb=1.268801 + sliding_eval [ 2.5%] 24032/969088 windows running_bpb=1.269308 + sliding_eval [ 2.6%] 25632/969088 windows running_bpb=1.270880 + sliding_eval [ 2.8%] 27232/969088 windows running_bpb=1.271293 + sliding_eval [ 3.0%] 28832/969088 windows running_bpb=1.277273 + sliding_eval [ 3.1%] 30432/969088 windows running_bpb=1.274588 + sliding_eval [ 3.3%] 32032/969088 windows running_bpb=1.275893 + sliding_eval [ 3.5%] 33632/969088 windows running_bpb=1.274255 + sliding_eval [ 3.6%] 35232/969088 windows running_bpb=1.273579 + sliding_eval [ 3.8%] 36832/969088 windows running_bpb=1.272868 + sliding_eval [ 4.0%] 38432/969088 windows running_bpb=1.273439 + sliding_eval [ 4.1%] 40032/969088 windows running_bpb=1.270970 + sliding_eval [ 4.3%] 41632/969088 windows running_bpb=1.270063 + sliding_eval [ 4.5%] 43232/969088 windows running_bpb=1.270502 + sliding_eval [ 4.6%] 44832/969088 windows running_bpb=1.269600 + sliding_eval [ 4.8%] 46432/969088 windows running_bpb=1.269449 + sliding_eval [ 5.0%] 48032/969088 windows running_bpb=1.268561 + sliding_eval [ 5.1%] 49632/969088 windows running_bpb=1.269933 + sliding_eval [ 5.3%] 51232/969088 windows running_bpb=1.271127 + sliding_eval [ 5.5%] 52832/969088 windows running_bpb=1.271732 + sliding_eval [ 5.6%] 54432/969088 windows running_bpb=1.271367 + sliding_eval [ 5.8%] 56032/969088 windows running_bpb=1.271841 + sliding_eval [ 5.9%] 57632/969088 windows running_bpb=1.271210 + sliding_eval [ 6.1%] 59232/969088 windows running_bpb=1.267078 + sliding_eval [ 6.3%] 60832/969088 windows running_bpb=1.267244 + sliding_eval [ 6.4%] 62432/969088 windows running_bpb=1.267915 + sliding_eval [ 6.6%] 64032/969088 windows running_bpb=1.268265 + sliding_eval [ 6.8%] 65632/969088 windows running_bpb=1.268133 + sliding_eval [ 6.9%] 67232/969088 windows running_bpb=1.266928 + sliding_eval [ 7.1%] 68832/969088 windows running_bpb=1.266619 + sliding_eval [ 7.3%] 70432/969088 windows running_bpb=1.265738 + sliding_eval [ 7.4%] 72032/969088 windows running_bpb=1.265808 + sliding_eval [ 7.6%] 73632/969088 windows running_bpb=1.265901 + sliding_eval [ 7.8%] 75232/969088 windows running_bpb=1.266207 + sliding_eval [ 7.9%] 76832/969088 windows running_bpb=1.265891 + sliding_eval [ 8.1%] 78432/969088 windows running_bpb=1.266307 + sliding_eval [ 8.3%] 80032/969088 windows running_bpb=1.266662 + sliding_eval [ 8.4%] 81632/969088 windows running_bpb=1.266539 + sliding_eval [ 8.6%] 83232/969088 windows running_bpb=1.267783 + sliding_eval [ 8.8%] 84832/969088 windows running_bpb=1.269720 + sliding_eval [ 8.9%] 86432/969088 windows running_bpb=1.268847 + sliding_eval [ 9.1%] 88032/969088 windows running_bpb=1.269542 + sliding_eval [ 9.2%] 89632/969088 windows running_bpb=1.269922 + sliding_eval [ 9.4%] 91232/969088 windows running_bpb=1.269789 + sliding_eval [ 9.6%] 92832/969088 windows running_bpb=1.269301 + sliding_eval [ 9.7%] 94432/969088 windows running_bpb=1.269579 + sliding_eval [ 9.9%] 96032/969088 windows running_bpb=1.269098 + sliding_eval [ 10.1%] 97632/969088 windows running_bpb=1.271978 + sliding_eval [ 10.2%] 99232/969088 windows running_bpb=1.272048 + sliding_eval [ 10.4%] 100832/969088 windows running_bpb=1.272048 + sliding_eval [ 10.6%] 102432/969088 windows running_bpb=1.271684 + sliding_eval [ 10.7%] 104032/969088 windows running_bpb=1.271232 + sliding_eval [ 10.9%] 105632/969088 windows running_bpb=1.270531 + sliding_eval [ 11.1%] 107232/969088 windows running_bpb=1.270574 + sliding_eval [ 11.2%] 108832/969088 windows running_bpb=1.271209 + sliding_eval [ 11.4%] 110432/969088 windows running_bpb=1.271339 + sliding_eval [ 11.6%] 112032/969088 windows running_bpb=1.271263 + sliding_eval [ 11.7%] 113632/969088 windows running_bpb=1.271749 + sliding_eval [ 11.9%] 115232/969088 windows running_bpb=1.271577 + sliding_eval [ 12.1%] 116832/969088 windows running_bpb=1.271252 + sliding_eval [ 12.2%] 118432/969088 windows running_bpb=1.271533 + sliding_eval [ 12.4%] 120032/969088 windows running_bpb=1.271649 + sliding_eval [ 12.6%] 121632/969088 windows running_bpb=1.271855 + sliding_eval [ 12.7%] 123232/969088 windows running_bpb=1.271852 + sliding_eval [ 12.9%] 124832/969088 windows running_bpb=1.271382 + sliding_eval [ 13.0%] 126432/969088 windows running_bpb=1.271333 + sliding_eval [ 13.2%] 128032/969088 windows running_bpb=1.271278 + sliding_eval [ 13.4%] 129632/969088 windows running_bpb=1.271425 + sliding_eval [ 13.5%] 131232/969088 windows running_bpb=1.271651 + sliding_eval [ 13.7%] 132832/969088 windows running_bpb=1.271228 + sliding_eval [ 13.9%] 134432/969088 windows running_bpb=1.270717 + sliding_eval [ 14.0%] 136032/969088 windows running_bpb=1.269474 + sliding_eval [ 14.2%] 137632/969088 windows running_bpb=1.270013 + sliding_eval [ 14.4%] 139232/969088 windows running_bpb=1.269831 + sliding_eval [ 14.5%] 140832/969088 windows running_bpb=1.270564 + sliding_eval [ 14.7%] 142432/969088 windows running_bpb=1.270960 + sliding_eval [ 14.9%] 144032/969088 windows running_bpb=1.271369 + sliding_eval [ 15.0%] 145632/969088 windows running_bpb=1.271222 + sliding_eval [ 15.2%] 147232/969088 windows running_bpb=1.271078 + sliding_eval [ 15.4%] 148832/969088 windows running_bpb=1.270822 + sliding_eval [ 15.5%] 150432/969088 windows running_bpb=1.270562 + sliding_eval [ 15.7%] 152032/969088 windows running_bpb=1.270250 + sliding_eval [ 15.9%] 153632/969088 windows running_bpb=1.271080 + sliding_eval [ 16.0%] 155232/969088 windows running_bpb=1.271040 + sliding_eval [ 16.2%] 156832/969088 windows running_bpb=1.271511 + sliding_eval [ 16.3%] 158432/969088 windows running_bpb=1.271365 + sliding_eval [ 16.5%] 160032/969088 windows running_bpb=1.271826 + sliding_eval [ 16.7%] 161632/969088 windows running_bpb=1.271952 + sliding_eval [ 16.8%] 163232/969088 windows running_bpb=1.271903 + sliding_eval [ 17.0%] 164832/969088 windows running_bpb=1.271854 + sliding_eval [ 17.2%] 166432/969088 windows running_bpb=1.271962 + sliding_eval [ 17.3%] 168032/969088 windows running_bpb=1.271422 + sliding_eval [ 17.5%] 169632/969088 windows running_bpb=1.271394 + sliding_eval [ 17.7%] 171232/969088 windows running_bpb=1.271111 + sliding_eval [ 17.8%] 172832/969088 windows running_bpb=1.270950 + sliding_eval [ 18.0%] 174432/969088 windows running_bpb=1.270949 + sliding_eval [ 18.2%] 176032/969088 windows running_bpb=1.270765 + sliding_eval [ 18.3%] 177632/969088 windows running_bpb=1.271008 + sliding_eval [ 18.5%] 179232/969088 windows running_bpb=1.271340 + sliding_eval [ 18.7%] 180832/969088 windows running_bpb=1.271754 + sliding_eval [ 18.8%] 182432/969088 windows running_bpb=1.272114 + sliding_eval [ 19.0%] 184032/969088 windows running_bpb=1.272663 + sliding_eval [ 19.2%] 185632/969088 windows running_bpb=1.272292 + sliding_eval [ 19.3%] 187232/969088 windows running_bpb=1.272264 + sliding_eval [ 19.5%] 188832/969088 windows running_bpb=1.272443 + sliding_eval [ 19.7%] 190432/969088 windows running_bpb=1.272334 + sliding_eval [ 19.8%] 192032/969088 windows running_bpb=1.272513 + sliding_eval [ 20.0%] 193632/969088 windows running_bpb=1.272583 + sliding_eval [ 20.1%] 195232/969088 windows running_bpb=1.272061 + sliding_eval [ 20.3%] 196832/969088 windows running_bpb=1.271873 + sliding_eval [ 20.5%] 198432/969088 windows running_bpb=1.272143 + sliding_eval [ 20.6%] 200032/969088 windows running_bpb=1.272345 + sliding_eval [ 20.8%] 201632/969088 windows running_bpb=1.272318 + sliding_eval [ 21.0%] 203232/969088 windows running_bpb=1.272211 + sliding_eval [ 21.1%] 204832/969088 windows running_bpb=1.272028 + sliding_eval [ 21.3%] 206432/969088 windows running_bpb=1.271766 + sliding_eval [ 21.5%] 208032/969088 windows running_bpb=1.271303 + sliding_eval [ 21.6%] 209632/969088 windows running_bpb=1.271095 + sliding_eval [ 21.8%] 211232/969088 windows running_bpb=1.270817 + sliding_eval [ 22.0%] 212832/969088 windows running_bpb=1.271227 + sliding_eval [ 22.1%] 214432/969088 windows running_bpb=1.270967 + sliding_eval [ 22.3%] 216032/969088 windows running_bpb=1.271272 + sliding_eval [ 22.5%] 217632/969088 windows running_bpb=1.271531 + sliding_eval [ 22.6%] 219232/969088 windows running_bpb=1.271631 + sliding_eval [ 22.8%] 220832/969088 windows running_bpb=1.271423 + sliding_eval [ 23.0%] 222432/969088 windows running_bpb=1.271042 + sliding_eval [ 23.1%] 224032/969088 windows running_bpb=1.271145 + sliding_eval [ 23.3%] 225632/969088 windows running_bpb=1.270691 + sliding_eval [ 23.4%] 227232/969088 windows running_bpb=1.270395 + sliding_eval [ 23.6%] 228832/969088 windows running_bpb=1.270886 + sliding_eval [ 23.8%] 230432/969088 windows running_bpb=1.270672 + sliding_eval [ 23.9%] 232032/969088 windows running_bpb=1.270485 + sliding_eval [ 24.1%] 233632/969088 windows running_bpb=1.270197 + sliding_eval [ 24.3%] 235232/969088 windows running_bpb=1.270241 + sliding_eval [ 24.4%] 236832/969088 windows running_bpb=1.270439 + sliding_eval [ 24.6%] 238432/969088 windows running_bpb=1.270469 + sliding_eval [ 24.8%] 240032/969088 windows running_bpb=1.270108 + sliding_eval [ 24.9%] 241632/969088 windows running_bpb=1.269781 + sliding_eval [ 25.1%] 243232/969088 windows running_bpb=1.269641 + sliding_eval [ 25.3%] 244832/969088 windows running_bpb=1.269534 + sliding_eval [ 25.4%] 246432/969088 windows running_bpb=1.269605 + sliding_eval [ 25.6%] 248032/969088 windows running_bpb=1.269170 + sliding_eval [ 25.8%] 249632/969088 windows running_bpb=1.269712 + sliding_eval [ 25.9%] 251232/969088 windows running_bpb=1.269651 + sliding_eval [ 26.1%] 252832/969088 windows running_bpb=1.269824 + sliding_eval [ 26.3%] 254432/969088 windows running_bpb=1.269653 + sliding_eval [ 26.4%] 256032/969088 windows running_bpb=1.269329 + sliding_eval [ 26.6%] 257632/969088 windows running_bpb=1.269210 + sliding_eval [ 26.8%] 259232/969088 windows running_bpb=1.268952 + sliding_eval [ 26.9%] 260832/969088 windows running_bpb=1.268704 + sliding_eval [ 27.1%] 262432/969088 windows running_bpb=1.268630 + sliding_eval [ 27.2%] 264032/969088 windows running_bpb=1.268460 + sliding_eval [ 27.4%] 265632/969088 windows running_bpb=1.268535 + sliding_eval [ 27.6%] 267232/969088 windows running_bpb=1.268241 + sliding_eval [ 27.7%] 268832/969088 windows running_bpb=1.268238 + sliding_eval [ 27.9%] 270432/969088 windows running_bpb=1.268634 + sliding_eval [ 28.1%] 272032/969088 windows running_bpb=1.269013 + sliding_eval [ 28.2%] 273632/969088 windows running_bpb=1.268796 + sliding_eval [ 28.4%] 275232/969088 windows running_bpb=1.268599 + sliding_eval [ 28.6%] 276832/969088 windows running_bpb=1.268903 + sliding_eval [ 28.7%] 278432/969088 windows running_bpb=1.268666 + sliding_eval [ 28.9%] 280032/969088 windows running_bpb=1.268596 + sliding_eval [ 29.1%] 281632/969088 windows running_bpb=1.268302 + sliding_eval [ 29.2%] 283232/969088 windows running_bpb=1.268368 + sliding_eval [ 29.4%] 284832/969088 windows running_bpb=1.268234 + sliding_eval [ 29.6%] 286432/969088 windows running_bpb=1.268099 + sliding_eval [ 29.7%] 288032/969088 windows running_bpb=1.268075 + sliding_eval [ 29.9%] 289632/969088 windows running_bpb=1.267881 + sliding_eval [ 30.1%] 291232/969088 windows running_bpb=1.267570 + sliding_eval [ 30.2%] 292832/969088 windows running_bpb=1.267587 + sliding_eval [ 30.4%] 294432/969088 windows running_bpb=1.267450 + sliding_eval [ 30.5%] 296032/969088 windows running_bpb=1.267525 + sliding_eval [ 30.7%] 297632/969088 windows running_bpb=1.267252 + sliding_eval [ 30.9%] 299232/969088 windows running_bpb=1.267328 + sliding_eval [ 31.0%] 300832/969088 windows running_bpb=1.267040 + sliding_eval [ 31.2%] 302432/969088 windows running_bpb=1.266652 + sliding_eval [ 31.4%] 304032/969088 windows running_bpb=1.266765 + sliding_eval [ 31.5%] 305632/969088 windows running_bpb=1.266762 + sliding_eval [ 31.7%] 307232/969088 windows running_bpb=1.266703 + sliding_eval [ 31.9%] 308832/969088 windows running_bpb=1.266470 + sliding_eval [ 32.0%] 310432/969088 windows running_bpb=1.266428 + sliding_eval [ 32.2%] 312032/969088 windows running_bpb=1.266326 + sliding_eval [ 32.4%] 313632/969088 windows running_bpb=1.266206 + sliding_eval [ 32.5%] 315232/969088 windows running_bpb=1.266215 + sliding_eval [ 32.7%] 316832/969088 windows running_bpb=1.266327 + sliding_eval [ 32.9%] 318432/969088 windows running_bpb=1.266007 + sliding_eval [ 33.0%] 320032/969088 windows running_bpb=1.265925 + sliding_eval [ 33.2%] 321632/969088 windows running_bpb=1.265878 + sliding_eval [ 33.4%] 323232/969088 windows running_bpb=1.265597 + sliding_eval [ 33.5%] 324832/969088 windows running_bpb=1.265278 + sliding_eval [ 33.7%] 326432/969088 windows running_bpb=1.265078 + sliding_eval [ 33.8%] 328032/969088 windows running_bpb=1.265165 + sliding_eval [ 34.0%] 329632/969088 windows running_bpb=1.265265 + sliding_eval [ 34.2%] 331232/969088 windows running_bpb=1.264920 + sliding_eval [ 34.3%] 332832/969088 windows running_bpb=1.264635 + sliding_eval [ 34.5%] 334432/969088 windows running_bpb=1.264501 + sliding_eval [ 34.7%] 336032/969088 windows running_bpb=1.264476 + sliding_eval [ 34.8%] 337632/969088 windows running_bpb=1.264350 + sliding_eval [ 35.0%] 339232/969088 windows running_bpb=1.264462 + sliding_eval [ 35.2%] 340832/969088 windows running_bpb=1.264235 + sliding_eval [ 35.3%] 342432/969088 windows running_bpb=1.264160 + sliding_eval [ 35.5%] 344032/969088 windows running_bpb=1.263768 + sliding_eval [ 35.7%] 345632/969088 windows running_bpb=1.263487 + sliding_eval [ 35.8%] 347232/969088 windows running_bpb=1.263377 + sliding_eval [ 36.0%] 348832/969088 windows running_bpb=1.263193 + sliding_eval [ 36.2%] 350432/969088 windows running_bpb=1.263019 + sliding_eval [ 36.3%] 352032/969088 windows running_bpb=1.263175 + sliding_eval [ 36.5%] 353632/969088 windows running_bpb=1.263404 + sliding_eval [ 36.7%] 355232/969088 windows running_bpb=1.263090 + sliding_eval [ 36.8%] 356832/969088 windows running_bpb=1.263001 + sliding_eval [ 37.0%] 358432/969088 windows running_bpb=1.262676 + sliding_eval [ 37.2%] 360032/969088 windows running_bpb=1.262326 + sliding_eval [ 37.3%] 361632/969088 windows running_bpb=1.262160 + sliding_eval [ 37.5%] 363232/969088 windows running_bpb=1.262465 + sliding_eval [ 37.6%] 364832/969088 windows running_bpb=1.262464 + sliding_eval [ 37.8%] 366432/969088 windows running_bpb=1.262292 + sliding_eval [ 38.0%] 368032/969088 windows running_bpb=1.262200 + sliding_eval [ 38.1%] 369632/969088 windows running_bpb=1.262211 + sliding_eval [ 38.3%] 371232/969088 windows running_bpb=1.262190 + sliding_eval [ 38.5%] 372832/969088 windows running_bpb=1.262254 + sliding_eval [ 38.6%] 374432/969088 windows running_bpb=1.262580 + sliding_eval [ 38.8%] 376032/969088 windows running_bpb=1.262475 + sliding_eval [ 39.0%] 377632/969088 windows running_bpb=1.262592 + sliding_eval [ 39.1%] 379232/969088 windows running_bpb=1.262508 + sliding_eval [ 39.3%] 380832/969088 windows running_bpb=1.262255 + sliding_eval [ 39.5%] 382432/969088 windows running_bpb=1.262257 + sliding_eval [ 39.6%] 384032/969088 windows running_bpb=1.262043 + sliding_eval [ 39.8%] 385632/969088 windows running_bpb=1.262182 + sliding_eval [ 40.0%] 387232/969088 windows running_bpb=1.262201 + sliding_eval [ 40.1%] 388832/969088 windows running_bpb=1.262286 + sliding_eval [ 40.3%] 390432/969088 windows running_bpb=1.262194 + sliding_eval [ 40.5%] 392032/969088 windows running_bpb=1.262177 + sliding_eval [ 40.6%] 393632/969088 windows running_bpb=1.262183 + sliding_eval [ 40.8%] 395232/969088 windows running_bpb=1.262009 + sliding_eval [ 40.9%] 396832/969088 windows running_bpb=1.262231 + sliding_eval [ 41.1%] 398432/969088 windows running_bpb=1.262281 + sliding_eval [ 41.3%] 400032/969088 windows running_bpb=1.262257 + sliding_eval [ 41.4%] 401632/969088 windows running_bpb=1.262213 + sliding_eval [ 41.6%] 403232/969088 windows running_bpb=1.262109 + sliding_eval [ 41.8%] 404832/969088 windows running_bpb=1.262173 + sliding_eval [ 41.9%] 406432/969088 windows running_bpb=1.261979 + sliding_eval [ 42.1%] 408032/969088 windows running_bpb=1.262050 + sliding_eval [ 42.3%] 409632/969088 windows running_bpb=1.262077 + sliding_eval [ 42.4%] 411232/969088 windows running_bpb=1.261955 + sliding_eval [ 42.6%] 412832/969088 windows running_bpb=1.262052 + sliding_eval [ 42.8%] 414432/969088 windows running_bpb=1.262088 + sliding_eval [ 42.9%] 416032/969088 windows running_bpb=1.262073 + sliding_eval [ 43.1%] 417632/969088 windows running_bpb=1.261927 + sliding_eval [ 43.3%] 419232/969088 windows running_bpb=1.261883 + sliding_eval [ 43.4%] 420832/969088 windows running_bpb=1.262109 + sliding_eval [ 43.6%] 422432/969088 windows running_bpb=1.262078 + sliding_eval [ 43.8%] 424032/969088 windows running_bpb=1.261874 + sliding_eval [ 43.9%] 425632/969088 windows running_bpb=1.261815 + sliding_eval [ 44.1%] 427232/969088 windows running_bpb=1.261651 + sliding_eval [ 44.3%] 428832/969088 windows running_bpb=1.261653 + sliding_eval [ 44.4%] 430432/969088 windows running_bpb=1.261610 + sliding_eval [ 44.6%] 432032/969088 windows running_bpb=1.261759 + sliding_eval [ 44.7%] 433632/969088 windows running_bpb=1.261769 + sliding_eval [ 44.9%] 435232/969088 windows running_bpb=1.261667 + sliding_eval [ 45.1%] 436832/969088 windows running_bpb=1.261853 + sliding_eval [ 45.2%] 438432/969088 windows running_bpb=1.261848 + sliding_eval [ 45.4%] 440032/969088 windows running_bpb=1.261822 + sliding_eval [ 45.6%] 441632/969088 windows running_bpb=1.261969 + sliding_eval [ 45.7%] 443232/969088 windows running_bpb=1.261944 + sliding_eval [ 45.9%] 444832/969088 windows running_bpb=1.262014 + sliding_eval [ 46.1%] 446432/969088 windows running_bpb=1.262152 + sliding_eval [ 46.2%] 448032/969088 windows running_bpb=1.262126 + sliding_eval [ 46.4%] 449632/969088 windows running_bpb=1.262140 + sliding_eval [ 46.6%] 451232/969088 windows running_bpb=1.262234 + sliding_eval [ 46.7%] 452832/969088 windows running_bpb=1.262298 + sliding_eval [ 46.9%] 454432/969088 windows running_bpb=1.262055 + sliding_eval [ 47.1%] 456032/969088 windows running_bpb=1.261852 + sliding_eval [ 47.2%] 457632/969088 windows running_bpb=1.262055 + sliding_eval [ 47.4%] 459232/969088 windows running_bpb=1.261977 + sliding_eval [ 47.6%] 460832/969088 windows running_bpb=1.261978 + sliding_eval [ 47.7%] 462432/969088 windows running_bpb=1.261829 + sliding_eval [ 47.9%] 464032/969088 windows running_bpb=1.261724 + sliding_eval [ 48.0%] 465632/969088 windows running_bpb=1.261794 + sliding_eval [ 48.2%] 467232/969088 windows running_bpb=1.261777 + sliding_eval [ 48.4%] 468832/969088 windows running_bpb=1.261847 + sliding_eval [ 48.5%] 470432/969088 windows running_bpb=1.261857 + sliding_eval [ 48.7%] 472032/969088 windows running_bpb=1.261974 + sliding_eval [ 48.9%] 473632/969088 windows running_bpb=1.261856 + sliding_eval [ 49.0%] 475232/969088 windows running_bpb=1.261868 + sliding_eval [ 49.2%] 476832/969088 windows running_bpb=1.261894 + sliding_eval [ 49.4%] 478432/969088 windows running_bpb=1.261816 + sliding_eval [ 49.5%] 480032/969088 windows running_bpb=1.262246 + sliding_eval [ 49.7%] 481632/969088 windows running_bpb=1.262176 + sliding_eval [ 49.9%] 483232/969088 windows running_bpb=1.262277 + sliding_eval [ 50.0%] 484832/969088 windows running_bpb=1.262588 + sliding_eval [ 50.2%] 486432/969088 windows running_bpb=1.262634 + sliding_eval [ 50.4%] 488032/969088 windows running_bpb=1.262520 + sliding_eval [ 50.5%] 489632/969088 windows running_bpb=1.262679 + sliding_eval [ 50.7%] 491232/969088 windows running_bpb=1.262610 + sliding_eval [ 50.9%] 492832/969088 windows running_bpb=1.262802 + sliding_eval [ 51.0%] 494432/969088 windows running_bpb=1.262954 + sliding_eval [ 51.2%] 496032/969088 windows running_bpb=1.263181 + sliding_eval [ 51.4%] 497632/969088 windows running_bpb=1.263216 + sliding_eval [ 51.5%] 499232/969088 windows running_bpb=1.263289 + sliding_eval [ 51.7%] 500832/969088 windows running_bpb=1.263310 + sliding_eval [ 51.8%] 502432/969088 windows running_bpb=1.263358 + sliding_eval [ 52.0%] 504032/969088 windows running_bpb=1.263434 + sliding_eval [ 52.2%] 505632/969088 windows running_bpb=1.263406 + sliding_eval [ 52.3%] 507232/969088 windows running_bpb=1.263272 + sliding_eval [ 52.5%] 508832/969088 windows running_bpb=1.263375 + sliding_eval [ 52.7%] 510432/969088 windows running_bpb=1.263439 + sliding_eval [ 52.8%] 512032/969088 windows running_bpb=1.263522 + sliding_eval [ 53.0%] 513632/969088 windows running_bpb=1.263625 + sliding_eval [ 53.2%] 515232/969088 windows running_bpb=1.263655 + sliding_eval [ 53.3%] 516832/969088 windows running_bpb=1.263817 + sliding_eval [ 53.5%] 518432/969088 windows running_bpb=1.263766 + sliding_eval [ 53.7%] 520032/969088 windows running_bpb=1.263836 + sliding_eval [ 53.8%] 521632/969088 windows running_bpb=1.263838 + sliding_eval [ 54.0%] 523232/969088 windows running_bpb=1.264082 + sliding_eval [ 54.2%] 524832/969088 windows running_bpb=1.264247 + sliding_eval [ 54.3%] 526432/969088 windows running_bpb=1.264312 + sliding_eval [ 54.5%] 528032/969088 windows running_bpb=1.264463 + sliding_eval [ 54.7%] 529632/969088 windows running_bpb=1.264536 + sliding_eval [ 54.8%] 531232/969088 windows running_bpb=1.264549 + sliding_eval [ 55.0%] 532832/969088 windows running_bpb=1.264798 + sliding_eval [ 55.1%] 534432/969088 windows running_bpb=1.264701 + sliding_eval [ 55.3%] 536032/969088 windows running_bpb=1.264638 + sliding_eval [ 55.5%] 537632/969088 windows running_bpb=1.264678 + sliding_eval [ 55.6%] 539232/969088 windows running_bpb=1.264807 + sliding_eval [ 55.8%] 540832/969088 windows running_bpb=1.264809 + sliding_eval [ 56.0%] 542432/969088 windows running_bpb=1.264753 + sliding_eval [ 56.1%] 544032/969088 windows running_bpb=1.264694 + sliding_eval [ 56.3%] 545632/969088 windows running_bpb=1.264905 + sliding_eval [ 56.5%] 547232/969088 windows running_bpb=1.265018 + sliding_eval [ 56.6%] 548832/969088 windows running_bpb=1.264920 + sliding_eval [ 56.8%] 550432/969088 windows running_bpb=1.264933 + sliding_eval [ 57.0%] 552032/969088 windows running_bpb=1.265006 + sliding_eval [ 57.1%] 553632/969088 windows running_bpb=1.264937 + sliding_eval [ 57.3%] 555232/969088 windows running_bpb=1.265157 + sliding_eval [ 57.5%] 556832/969088 windows running_bpb=1.265247 + sliding_eval [ 57.6%] 558432/969088 windows running_bpb=1.265219 + sliding_eval [ 57.8%] 560032/969088 windows running_bpb=1.265299 + sliding_eval [ 58.0%] 561632/969088 windows running_bpb=1.265433 + sliding_eval [ 58.1%] 563232/969088 windows running_bpb=1.265374 + sliding_eval [ 58.3%] 564832/969088 windows running_bpb=1.265277 + sliding_eval [ 58.5%] 566432/969088 windows running_bpb=1.265269 + sliding_eval [ 58.6%] 568032/969088 windows running_bpb=1.265165 + sliding_eval [ 58.8%] 569632/969088 windows running_bpb=1.265129 + sliding_eval [ 58.9%] 571232/969088 windows running_bpb=1.265112 + sliding_eval [ 59.1%] 572832/969088 windows running_bpb=1.264986 + sliding_eval [ 59.3%] 574432/969088 windows running_bpb=1.264766 + sliding_eval [ 59.4%] 576032/969088 windows running_bpb=1.264664 + sliding_eval [ 59.6%] 577632/969088 windows running_bpb=1.264721 + sliding_eval [ 59.8%] 579232/969088 windows running_bpb=1.264722 + sliding_eval [ 59.9%] 580832/969088 windows running_bpb=1.264482 + sliding_eval [ 60.1%] 582432/969088 windows running_bpb=1.264403 + sliding_eval [ 60.3%] 584032/969088 windows running_bpb=1.264419 + sliding_eval [ 60.4%] 585632/969088 windows running_bpb=1.264433 + sliding_eval [ 60.6%] 587232/969088 windows running_bpb=1.264417 + sliding_eval [ 60.8%] 588832/969088 windows running_bpb=1.264425 + sliding_eval [ 60.9%] 590432/969088 windows running_bpb=1.264371 + sliding_eval [ 61.1%] 592032/969088 windows running_bpb=1.264296 + sliding_eval [ 61.3%] 593632/969088 windows running_bpb=1.264346 + sliding_eval [ 61.4%] 595232/969088 windows running_bpb=1.264281 + sliding_eval [ 61.6%] 596832/969088 windows running_bpb=1.264313 + sliding_eval [ 61.8%] 598432/969088 windows running_bpb=1.264045 + sliding_eval [ 61.9%] 600032/969088 windows running_bpb=1.263984 + sliding_eval [ 62.1%] 601632/969088 windows running_bpb=1.263901 + sliding_eval [ 62.2%] 603232/969088 windows running_bpb=1.263755 + sliding_eval [ 62.4%] 604832/969088 windows running_bpb=1.263714 + sliding_eval [ 62.6%] 606432/969088 windows running_bpb=1.263713 + sliding_eval [ 62.7%] 608032/969088 windows running_bpb=1.263807 + sliding_eval [ 62.9%] 609632/969088 windows running_bpb=1.263825 + sliding_eval [ 63.1%] 611232/969088 windows running_bpb=1.264050 + sliding_eval [ 63.2%] 612832/969088 windows running_bpb=1.264034 + sliding_eval [ 63.4%] 614432/969088 windows running_bpb=1.264008 + sliding_eval [ 63.6%] 616032/969088 windows running_bpb=1.263978 + sliding_eval [ 63.7%] 617632/969088 windows running_bpb=1.263826 + sliding_eval [ 63.9%] 619232/969088 windows running_bpb=1.263568 + sliding_eval [ 64.1%] 620832/969088 windows running_bpb=1.263770 + sliding_eval [ 64.2%] 622432/969088 windows running_bpb=1.263931 + sliding_eval [ 64.4%] 624032/969088 windows running_bpb=1.264076 + sliding_eval [ 64.6%] 625632/969088 windows running_bpb=1.263926 + sliding_eval [ 64.7%] 627232/969088 windows running_bpb=1.263917 + sliding_eval [ 64.9%] 628832/969088 windows running_bpb=1.263871 + sliding_eval [ 65.1%] 630432/969088 windows running_bpb=1.263846 + sliding_eval [ 65.2%] 632032/969088 windows running_bpb=1.263641 + sliding_eval [ 65.4%] 633632/969088 windows running_bpb=1.263553 + sliding_eval [ 65.5%] 635232/969088 windows running_bpb=1.263489 + sliding_eval [ 65.7%] 636832/969088 windows running_bpb=1.263446 + sliding_eval [ 65.9%] 638432/969088 windows running_bpb=1.263295 + sliding_eval [ 66.0%] 640032/969088 windows running_bpb=1.263028 + sliding_eval [ 66.2%] 641632/969088 windows running_bpb=1.262839 + sliding_eval [ 66.4%] 643232/969088 windows running_bpb=1.262761 + sliding_eval [ 66.5%] 644832/969088 windows running_bpb=1.262727 + sliding_eval [ 66.7%] 646432/969088 windows running_bpb=1.262672 + sliding_eval [ 66.9%] 648032/969088 windows running_bpb=1.262638 + sliding_eval [ 67.0%] 649632/969088 windows running_bpb=1.262520 + sliding_eval [ 67.2%] 651232/969088 windows running_bpb=1.262336 + sliding_eval [ 67.4%] 652832/969088 windows running_bpb=1.262216 + sliding_eval [ 67.5%] 654432/969088 windows running_bpb=1.262046 + sliding_eval [ 67.7%] 656032/969088 windows running_bpb=1.262035 + sliding_eval [ 67.9%] 657632/969088 windows running_bpb=1.261947 + sliding_eval [ 68.0%] 659232/969088 windows running_bpb=1.261900 + sliding_eval [ 68.2%] 660832/969088 windows running_bpb=1.261762 + sliding_eval [ 68.4%] 662432/969088 windows running_bpb=1.261752 + sliding_eval [ 68.5%] 664032/969088 windows running_bpb=1.261836 + sliding_eval [ 68.7%] 665632/969088 windows running_bpb=1.261678 + sliding_eval [ 68.9%] 667232/969088 windows running_bpb=1.261578 + sliding_eval [ 69.0%] 668832/969088 windows running_bpb=1.261568 + sliding_eval [ 69.2%] 670432/969088 windows running_bpb=1.261394 + sliding_eval [ 69.3%] 672032/969088 windows running_bpb=1.261277 + sliding_eval [ 69.5%] 673632/969088 windows running_bpb=1.261238 + sliding_eval [ 69.7%] 675232/969088 windows running_bpb=1.261065 + sliding_eval [ 69.8%] 676832/969088 windows running_bpb=1.260937 + sliding_eval [ 70.0%] 678432/969088 windows running_bpb=1.260830 + sliding_eval [ 70.2%] 680032/969088 windows running_bpb=1.260760 + sliding_eval [ 70.3%] 681632/969088 windows running_bpb=1.260727 + sliding_eval [ 70.5%] 683232/969088 windows running_bpb=1.260661 + sliding_eval [ 70.7%] 684832/969088 windows running_bpb=1.260529 + sliding_eval [ 70.8%] 686432/969088 windows running_bpb=1.260546 + sliding_eval [ 71.0%] 688032/969088 windows running_bpb=1.260542 + sliding_eval [ 71.2%] 689632/969088 windows running_bpb=1.260448 + sliding_eval [ 71.3%] 691232/969088 windows running_bpb=1.260416 + sliding_eval [ 71.5%] 692832/969088 windows running_bpb=1.260424 + sliding_eval [ 71.7%] 694432/969088 windows running_bpb=1.260478 + sliding_eval [ 71.8%] 696032/969088 windows running_bpb=1.260560 + sliding_eval [ 72.0%] 697632/969088 windows running_bpb=1.260597 + sliding_eval [ 72.2%] 699232/969088 windows running_bpb=1.260789 + sliding_eval [ 72.3%] 700832/969088 windows running_bpb=1.260762 + sliding_eval [ 72.5%] 702432/969088 windows running_bpb=1.260844 + sliding_eval [ 72.6%] 704032/969088 windows running_bpb=1.260918 + sliding_eval [ 72.8%] 705632/969088 windows running_bpb=1.261072 + sliding_eval [ 73.0%] 707232/969088 windows running_bpb=1.261081 + sliding_eval [ 73.1%] 708832/969088 windows running_bpb=1.261198 + sliding_eval [ 73.3%] 710432/969088 windows running_bpb=1.261121 + sliding_eval [ 73.5%] 712032/969088 windows running_bpb=1.260857 + sliding_eval [ 73.6%] 713632/969088 windows running_bpb=1.260964 + sliding_eval [ 73.8%] 715232/969088 windows running_bpb=1.260791 + sliding_eval [ 74.0%] 716832/969088 windows running_bpb=1.260885 + sliding_eval [ 74.1%] 718432/969088 windows running_bpb=1.260894 + sliding_eval [ 74.3%] 720032/969088 windows running_bpb=1.260986 + sliding_eval [ 74.5%] 721632/969088 windows running_bpb=1.261096 + sliding_eval [ 74.6%] 723232/969088 windows running_bpb=1.261061 + sliding_eval [ 74.8%] 724832/969088 windows running_bpb=1.261124 + sliding_eval [ 75.0%] 726432/969088 windows running_bpb=1.261119 + sliding_eval [ 75.1%] 728032/969088 windows running_bpb=1.261122 + sliding_eval [ 75.3%] 729632/969088 windows running_bpb=1.261062 + sliding_eval [ 75.5%] 731232/969088 windows running_bpb=1.261026 + sliding_eval [ 75.6%] 732832/969088 windows running_bpb=1.261082 + sliding_eval [ 75.8%] 734432/969088 windows running_bpb=1.261188 + sliding_eval [ 76.0%] 736032/969088 windows running_bpb=1.261347 + sliding_eval [ 76.1%] 737632/969088 windows running_bpb=1.261573 + sliding_eval [ 76.3%] 739232/969088 windows running_bpb=1.261591 + sliding_eval [ 76.4%] 740832/969088 windows running_bpb=1.261546 + sliding_eval [ 76.6%] 742432/969088 windows running_bpb=1.261508 + sliding_eval [ 76.8%] 744032/969088 windows running_bpb=1.261483 + sliding_eval [ 76.9%] 745632/969088 windows running_bpb=1.261381 + sliding_eval [ 77.1%] 747232/969088 windows running_bpb=1.261416 + sliding_eval [ 77.3%] 748832/969088 windows running_bpb=1.261450 + sliding_eval [ 77.4%] 750432/969088 windows running_bpb=1.261508 + sliding_eval [ 77.6%] 752032/969088 windows running_bpb=1.261987 + sliding_eval [ 77.8%] 753632/969088 windows running_bpb=1.262073 + sliding_eval [ 77.9%] 755232/969088 windows running_bpb=1.262066 + sliding_eval [ 78.1%] 756832/969088 windows running_bpb=1.261993 + sliding_eval [ 78.3%] 758432/969088 windows running_bpb=1.261973 + sliding_eval [ 78.4%] 760032/969088 windows running_bpb=1.262267 + sliding_eval [ 78.6%] 761632/969088 windows running_bpb=1.262375 + sliding_eval [ 78.8%] 763232/969088 windows running_bpb=1.262359 + sliding_eval [ 78.9%] 764832/969088 windows running_bpb=1.262417 + sliding_eval [ 79.1%] 766432/969088 windows running_bpb=1.262378 + sliding_eval [ 79.3%] 768032/969088 windows running_bpb=1.262387 + sliding_eval [ 79.4%] 769632/969088 windows running_bpb=1.262423 + sliding_eval [ 79.6%] 771232/969088 windows running_bpb=1.262489 + sliding_eval [ 79.7%] 772832/969088 windows running_bpb=1.262454 + sliding_eval [ 79.9%] 774432/969088 windows running_bpb=1.262402 + sliding_eval [ 80.1%] 776032/969088 windows running_bpb=1.262430 + sliding_eval [ 80.2%] 777632/969088 windows running_bpb=1.262588 + sliding_eval [ 80.4%] 779232/969088 windows running_bpb=1.262601 + sliding_eval [ 80.6%] 780832/969088 windows running_bpb=1.262630 + sliding_eval [ 80.7%] 782432/969088 windows running_bpb=1.262837 + sliding_eval [ 80.9%] 784032/969088 windows running_bpb=1.262847 + sliding_eval [ 81.1%] 785632/969088 windows running_bpb=1.262763 + sliding_eval [ 81.2%] 787232/969088 windows running_bpb=1.262770 + sliding_eval [ 81.4%] 788832/969088 windows running_bpb=1.262904 + sliding_eval [ 81.6%] 790432/969088 windows running_bpb=1.262938 + sliding_eval [ 81.7%] 792032/969088 windows running_bpb=1.263061 + sliding_eval [ 81.9%] 793632/969088 windows running_bpb=1.263134 + sliding_eval [ 82.1%] 795232/969088 windows running_bpb=1.263151 + sliding_eval [ 82.2%] 796832/969088 windows running_bpb=1.263203 + sliding_eval [ 82.4%] 798432/969088 windows running_bpb=1.263219 + sliding_eval [ 82.6%] 800032/969088 windows running_bpb=1.263308 + sliding_eval [ 82.7%] 801632/969088 windows running_bpb=1.263352 + sliding_eval [ 82.9%] 803232/969088 windows running_bpb=1.263336 + sliding_eval [ 83.1%] 804832/969088 windows running_bpb=1.263400 + sliding_eval [ 83.2%] 806432/969088 windows running_bpb=1.263407 + sliding_eval [ 83.4%] 808032/969088 windows running_bpb=1.263518 + sliding_eval [ 83.5%] 809632/969088 windows running_bpb=1.263576 + sliding_eval [ 83.7%] 811232/969088 windows running_bpb=1.263671 + sliding_eval [ 83.9%] 812832/969088 windows running_bpb=1.263616 + sliding_eval [ 84.0%] 814432/969088 windows running_bpb=1.263598 + sliding_eval [ 84.2%] 816032/969088 windows running_bpb=1.263657 + sliding_eval [ 84.4%] 817632/969088 windows running_bpb=1.263743 + sliding_eval [ 84.5%] 819232/969088 windows running_bpb=1.263722 + sliding_eval [ 84.7%] 820832/969088 windows running_bpb=1.263733 + sliding_eval [ 84.9%] 822432/969088 windows running_bpb=1.263815 + sliding_eval [ 85.0%] 824032/969088 windows running_bpb=1.263872 + sliding_eval [ 85.2%] 825632/969088 windows running_bpb=1.263957 + sliding_eval [ 85.4%] 827232/969088 windows running_bpb=1.263982 + sliding_eval [ 85.5%] 828832/969088 windows running_bpb=1.263887 + sliding_eval [ 85.7%] 830432/969088 windows running_bpb=1.263761 + sliding_eval [ 85.9%] 832032/969088 windows running_bpb=1.263811 + sliding_eval [ 86.0%] 833632/969088 windows running_bpb=1.263864 + sliding_eval [ 86.2%] 835232/969088 windows running_bpb=1.263797 + sliding_eval [ 86.4%] 836832/969088 windows running_bpb=1.263736 + sliding_eval [ 86.5%] 838432/969088 windows running_bpb=1.263825 + sliding_eval [ 86.7%] 840032/969088 windows running_bpb=1.263822 + sliding_eval [ 86.8%] 841632/969088 windows running_bpb=1.263915 + sliding_eval [ 87.0%] 843232/969088 windows running_bpb=1.263953 + sliding_eval [ 87.2%] 844832/969088 windows running_bpb=1.263828 + sliding_eval [ 87.3%] 846432/969088 windows running_bpb=1.263993 + sliding_eval [ 87.5%] 848032/969088 windows running_bpb=1.264043 + sliding_eval [ 87.7%] 849632/969088 windows running_bpb=1.264029 + sliding_eval [ 87.8%] 851232/969088 windows running_bpb=1.264025 + sliding_eval [ 88.0%] 852832/969088 windows running_bpb=1.264106 + sliding_eval [ 88.2%] 854432/969088 windows running_bpb=1.264184 + sliding_eval [ 88.3%] 856032/969088 windows running_bpb=1.264189 + sliding_eval [ 88.5%] 857632/969088 windows running_bpb=1.264230 + sliding_eval [ 88.7%] 859232/969088 windows running_bpb=1.264197 + sliding_eval [ 88.8%] 860832/969088 windows running_bpb=1.264344 + sliding_eval [ 89.0%] 862432/969088 windows running_bpb=1.264327 + sliding_eval [ 89.2%] 864032/969088 windows running_bpb=1.264333 + sliding_eval [ 89.3%] 865632/969088 windows running_bpb=1.264425 + sliding_eval [ 89.5%] 867232/969088 windows running_bpb=1.264424 + sliding_eval [ 89.7%] 868832/969088 windows running_bpb=1.264363 + sliding_eval [ 89.8%] 870432/969088 windows running_bpb=1.264533 + sliding_eval [ 90.0%] 872032/969088 windows running_bpb=1.264541 + sliding_eval [ 90.1%] 873632/969088 windows running_bpb=1.264541 + sliding_eval [ 90.3%] 875232/969088 windows running_bpb=1.264564 + sliding_eval [ 90.5%] 876832/969088 windows running_bpb=1.264401 + sliding_eval [ 90.6%] 878432/969088 windows running_bpb=1.264392 + sliding_eval [ 90.8%] 880032/969088 windows running_bpb=1.264353 + sliding_eval [ 91.0%] 881632/969088 windows running_bpb=1.264358 + sliding_eval [ 91.1%] 883232/969088 windows running_bpb=1.264360 + sliding_eval [ 91.3%] 884832/969088 windows running_bpb=1.264395 + sliding_eval [ 91.5%] 886432/969088 windows running_bpb=1.264398 + sliding_eval [ 91.6%] 888032/969088 windows running_bpb=1.264343 + sliding_eval [ 91.8%] 889632/969088 windows running_bpb=1.264260 + sliding_eval [ 92.0%] 891232/969088 windows running_bpb=1.264132 + sliding_eval [ 92.1%] 892832/969088 windows running_bpb=1.264068 + sliding_eval [ 92.3%] 894432/969088 windows running_bpb=1.264037 + sliding_eval [ 92.5%] 896032/969088 windows running_bpb=1.263937 + sliding_eval [ 92.6%] 897632/969088 windows running_bpb=1.263984 + sliding_eval [ 92.8%] 899232/969088 windows running_bpb=1.264008 + sliding_eval [ 93.0%] 900832/969088 windows running_bpb=1.263964 + sliding_eval [ 93.1%] 902432/969088 windows running_bpb=1.263944 + sliding_eval [ 93.3%] 904032/969088 windows running_bpb=1.263930 + sliding_eval [ 93.5%] 905632/969088 windows running_bpb=1.263898 + sliding_eval [ 93.6%] 907232/969088 windows running_bpb=1.263938 + sliding_eval [ 93.8%] 908832/969088 windows running_bpb=1.263859 + sliding_eval [ 93.9%] 910432/969088 windows running_bpb=1.263817 + sliding_eval [ 94.1%] 912032/969088 windows running_bpb=1.263781 + sliding_eval [ 94.3%] 913632/969088 windows running_bpb=1.263612 + sliding_eval [ 94.4%] 915232/969088 windows running_bpb=1.263472 + sliding_eval [ 94.6%] 916832/969088 windows running_bpb=1.263443 + sliding_eval [ 94.8%] 918432/969088 windows running_bpb=1.263391 + sliding_eval [ 94.9%] 920032/969088 windows running_bpb=1.263375 + sliding_eval [ 95.1%] 921632/969088 windows running_bpb=1.263382 + sliding_eval [ 95.3%] 923232/969088 windows running_bpb=1.263347 + sliding_eval [ 95.4%] 924832/969088 windows running_bpb=1.263302 + sliding_eval [ 95.6%] 926432/969088 windows running_bpb=1.263281 + sliding_eval [ 95.8%] 928032/969088 windows running_bpb=1.263311 + sliding_eval [ 95.9%] 929632/969088 windows running_bpb=1.263380 + sliding_eval [ 96.1%] 931232/969088 windows running_bpb=1.263341 + sliding_eval [ 96.3%] 932832/969088 windows running_bpb=1.263254 + sliding_eval [ 96.4%] 934432/969088 windows running_bpb=1.263193 + sliding_eval [ 96.6%] 936032/969088 windows running_bpb=1.263131 + sliding_eval [ 96.8%] 937632/969088 windows running_bpb=1.263094 + sliding_eval [ 96.9%] 939232/969088 windows running_bpb=1.263299 + sliding_eval [ 97.1%] 940832/969088 windows running_bpb=1.263200 + sliding_eval [ 97.2%] 942432/969088 windows running_bpb=1.263132 + sliding_eval [ 97.4%] 944032/969088 windows running_bpb=1.263015 + sliding_eval [ 97.6%] 945632/969088 windows running_bpb=1.262919 + sliding_eval [ 97.7%] 947232/969088 windows running_bpb=1.262910 + sliding_eval [ 97.9%] 948832/969088 windows running_bpb=1.262871 + sliding_eval [ 98.1%] 950432/969088 windows running_bpb=1.262847 + sliding_eval [ 98.2%] 952032/969088 windows running_bpb=1.262792 + sliding_eval [ 98.4%] 953632/969088 windows running_bpb=1.262799 + sliding_eval [ 98.6%] 955232/969088 windows running_bpb=1.262786 + sliding_eval [ 98.7%] 956832/969088 windows running_bpb=1.262777 + sliding_eval [ 98.9%] 958432/969088 windows running_bpb=1.262755 + sliding_eval [ 99.1%] 960032/969088 windows running_bpb=1.262655 + sliding_eval [ 99.2%] 961632/969088 windows running_bpb=1.262523 + sliding_eval [ 99.4%] 963232/969088 windows running_bpb=1.262467 + sliding_eval [ 99.6%] 964832/969088 windows running_bpb=1.262467 + sliding_eval [ 99.7%] 966432/969088 windows running_bpb=1.262387 + sliding_eval [ 99.9%] 968032/969088 windows running_bpb=1.262462 +final_int8_zlib_roundtrip val_loss:2.1317 val_bpb:1.2625 eval_time:1346945ms +final_int8_zlib_roundtrip_exact val_loss:2.13166442 val_bpb:1.26249508 diff --git a/1x H100 SXM5 Logs/score_first_ttt.log b/1x H100 SXM5 Logs/score_first_ttt.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/sp8192_full_stack.log b/1x H100 SXM5 Logs/sp8192_full_stack.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/warmdown3500_20260401_062854.log b/1x H100 SXM5 Logs/warmdown3500_20260401_062854.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/1x H100 SXM5 Logs/xsa_ema.log b/1x H100 SXM5 Logs/xsa_ema.log new file mode 100644 index 0000000000..e69de29bb2 diff --git a/Graphs/PLAN_beat_SOTA.md b/Graphs/PLAN_beat_SOTA.md new file mode 100644 index 0000000000..b6f7bb8d18 --- /dev/null +++ b/Graphs/PLAN_beat_SOTA.md @@ -0,0 +1,156 @@ +# Plan: Beat SOTA (1.1428 bpb) + +**Date**: 2026-03-21 +**Current SOTA**: 1.1428 (thwu1, PR #180) +**Emerging**: 1.1303 (PR #254), 1.1307 (PR #265) — not yet on leaderboard +**Our target**: < 1.13 bpb + +--- + +## Strategy: Combine proven techniques nobody has stacked together yet + +The key insight from analyzing all PRs: **no single submission combines ALL the best techniques**. Each top entry uses a subset. We stack them all. + +--- + +## The Stack + +### Layer 1: Base Architecture (from thwu1 #180) +- 10-11 layers, dim=512, 8 heads, 4 KV heads (GQA) +- MLP 3x (hidden=1536), ReLU-squared +- U-Net skip connections +- Tied embeddings (FP16 passthrough) +- Logit softcap=30 + +### Layer 2: Quantization (from thwu1 #180) +- Int5 for MLP weights (saves ~1.86MB for extra layer/features) +- Int6 for attention weights +- zstd-22 compression +- 3% magnitude pruning post-training (better compression) +- WD=0.04 for quantization robustness + +### Layer 3: Input Augmentation (from thwu1 #180 + #265) +- BigramHash(10240) buckets, dim=128, projected to 512 +- SmearGate (proven compatible, +0.005-0.008) + +### Layer 4: Training Optimization (best of all PRs) +- Muon: lr=0.02, WD=0.04, momentum warmup 0.92→0.99 over 1500 steps (from #265) +- SWA: start_frac=0.4, every=50 steps (from thwu1) +- OrthoInit + muP scaling +- Warmdown=3000, warmup=20, grad_clip=0.3 +- Seq2048, batch=524K tokens (from #236 — more gradient updates) + +### Layer 5: Speed (from #265 + modded-nanogpt) +- FlashAttention 3 (Hopper native) — ~5% faster steps +- Fused Linear+ReLU^2 Triton kernel — ~10% MLP speedup +- torch.compile mode="max-autotune" + +### Layer 6: Eval-Time (from #265 + #267) +- Sliding window eval (stride=64) +- Partial XSA on last 3 layers (from #265, +0.002 bpb, only 2ms/step) +- Causal TTT: SGD on val chunks after scoring (from #267, +0.003 bpb) + +### Layer 7: Free Training Signal +- MTP auxiliary head (predict t+2, t+3) — discarded at save, zero artifact cost +- From PR #88 — provides gradient enrichment during training + +--- + +## Expected Impact Breakdown + +| Technique | bpb gain over baseline | Source | +|-----------|----------------------|--------| +| Int5/6 + MLP3x + 10L | ~0.08 | thwu1 baseline | +| BigramHash(10240) | ~0.01 | thwu1 | +| SmearGate | ~0.006 | PR #162 | +| SWA | ~0.005 | thwu1 | +| OrthoInit + muP | ~0.004 | PR #198 | +| Sliding Window | ~0.03 | All top PRs | +| Seq2048 | ~0.015 | PR #198 | +| Smaller batch (524K) | ~0.003 | PR #236 | +| FA3 + fused kernels (more steps) | ~0.005 | PR #265 | +| Partial XSA (last 3 layers) | ~0.002 | PR #265 | +| Causal TTT | ~0.003 | PR #267 | +| MTP auxiliary | ~0.002 | PR #88 | +| **Total from 1.2244 baseline** | **~0.165** | | +| **Projected bpb** | **~1.06-1.10** | | + +Conservative estimate: **1.10-1.12 bpb** (not everything stacks perfectly). + +--- + +## Implementation Phases + +### Phase 1: Fork SOTA code (~2 hours) +- Take thwu1's train_gpt.py from PR #180 as base +- Verify it reproduces 1.1428 on 8xH100 (10 min run, ~$3) +- This becomes our baseline to improve upon + +### Phase 2: Add proven extras (~3 hours) +- Add SmearGate (if not already in thwu1's code) +- Add Muon momentum warmup (0.92→0.99) +- Switch to batch=524K +- Add FlashAttention 3 +- Test on 1xH100 for quick validation + +### Phase 3: Add novel techniques (~4 hours) +- Implement Partial XSA on last 3 layers (from PR #265) +- Add MTP auxiliary head (from PR #88) +- Add fused Triton kernels (Linear+ReLU^2, softcapped CE) +- Test on 1xH100 + +### Phase 4: Eval-time optimization (~2 hours) +- Implement Causal TTT (SGD, 3 epochs per chunk) +- Tune TTT hyperparameters (lr, momentum, epochs) +- Test on 1xH100 + +### Phase 5: Record attempt (~$20) +- Full run on 8xH100, 10 min +- Submit to record track +- If < 1.13 → PR to openai/parameter-golf + +--- + +## Compute Budget + +| Phase | Hardware | Time | Cost | +|-------|----------|------|------| +| Phase 1 | 8xH100 | 15 min | ~$5 | +| Phase 2 | 1xH100 | 30 min | ~$2 | +| Phase 3 | 1xH100 | 1 hour | ~$4 | +| Phase 4 | 1xH100 | 30 min | ~$2 | +| Phase 5 | 8xH100 | 15 min | ~$5 | +| Buffer | — | — | ~$5 | +| **Total** | | | **~$23** | + +--- + +## What Makes This Novel + +Nobody has combined ALL of these: +1. Int5/Int6 mixed quant + 10-11L (thwu1) +2. + Partial XSA (PR #265, brand new technique) +3. + MTP auxiliary training (PR #88, free signal) +4. + Causal TTT (PR #267) +5. + FA3 + fused Triton kernels (modded-nanogpt) +6. + Optimized batch size (PR #236) + +Each top PR uses 3-4 of these. We use all 6+. + +--- + +## Risk Assessment + +| Risk | Mitigation | +|------|-----------| +| Techniques don't stack as expected | Phase-by-phase testing on 1xH100 | +| XSA + TTT conflict | Test independently first | +| Int5 fragile with new techniques | Fall back to Int6 if quant degrades | +| Compute budget overrun | 1xH100 validation before 8xH100 record | +| FA3 install issues on RunPod | FA3 may already be in the template; fall back to FA2 | + +--- + +## Immediate Next Step + +Pull thwu1's code from PR #180 and start Phase 1. diff --git a/Graphs/speed_optimizations.md b/Graphs/speed_optimizations.md new file mode 100644 index 0000000000..c58770bbec --- /dev/null +++ b/Graphs/speed_optimizations.md @@ -0,0 +1,99 @@ +# Speed Optimizations: Triton Kernels & Libraries + +**Goal**: More training steps in the same wallclock = better bpb + +--- + +## Priority 1: FlashAttention 3 (~5% step time reduction) + +**What**: H100-optimized attention using Hopper async TMA + warp specialization +**Speedup**: 1.5-2x over FA2 in attention forward, ~5% overall step time +**Integration**: Drop-in replacement +```python +from flash_attn_interface import flash_attn_func as flash_attn_3_func +``` +**Status**: Proven — PRs #198 and #164 use this. Only external library in top submissions. +**Install**: `pip install flash-attn --no-build-isolation` (from hopper branch) + +--- + +## Priority 2: Fused Linear+ReLU^2 Triton Kernel (~5-15% MLP speedup) + +**What**: Fuses CastedLinear + relu().square() into one Triton kernel +**Source**: modded-nanogpt `triton_kernels.FusedLinearReLUSquareFunction` +**Why it helps**: Eliminates intermediate tensor materialization in MLP (which is 3x expanded) +**Integration**: Copy Triton kernel, replace MLP forward pass +**Status**: Used in modded-nanogpt speedrun, not yet in any Parameter Golf PR + +--- + +## Priority 3: Fused Softcapped Cross-Entropy (~2-5% loss speedup) + +**What**: Fuses logit_softcap + cross_entropy into one Triton kernel +**Source**: modded-nanogpt `triton_kernels.FusedSoftcappedCrossEntropy` +**Why it helps**: Avoids materializing softcapped logits tensor +**Integration**: Copy Triton kernel, replace loss computation +**Note**: Only applies to non-MoS path (MoS uses nll_loss on log-probs) +**Status**: Used in modded-nanogpt speedrun, not yet in any Parameter Golf PR + +--- + +## Priority 4: torch.compile Tuning (0-5% overall) + +```python +# Current +torch.compile(model, dynamic=False, fullgraph=True) + +# Try +torch.compile(model, dynamic=False, fullgraph=True, mode="max-autotune") +``` + +Also set env var: +```bash +export PYTORCH_ALLOC_CONF="expandable_segments:True" +``` + +--- + +## Priority 5: Gradient Checkpointing (enables larger batch/seq) + +**What**: Recompute activations in backward pass instead of storing them +**Benefit**: 50-70% activation memory reduction, enables seq=2048 or larger batch on 1xH100 +**Cost**: ~20-33% more compute (5-10% wall-clock in practice) +**When to use**: If moving to seq=2048+ on 1xH100 + +--- + +## Priority 6: Custom Triton MoS Kernel (if MoS proves useful) + +**What**: Fuse log_softmax over K components + logsumexp mixture into one kernel +**Expected**: Reduce MoS overhead from ~5ms to ~2-3ms per step +**Effort**: ~50-100 lines of Triton, based on fused softmax tutorial +**Note**: The bigger bottleneck is the K einsum matmuls, not the softmax + +--- + +## NOT Worth It at Our Scale + +| Technique | Why Skip | +|-----------|----------| +| FP8 training (torchao) | dim=512 matrices too small, overhead > benefit | +| Fused RMSNorm | torch.compile already fuses it | +| Apex FusedAdam | Already using fused=True, marginal gain | +| Liger FusedCE | Logit tensor tiny at vocab=1024 | +| bitsandbytes 8-bit optimizer | Model too small to benefit | + +--- + +## Impact Estimate + +| Optimization | Step Time Reduction | Extra Steps in 10min | bpb Impact | +|-------------|--------------------|--------------------|------------| +| FA3 | ~5% | +1000 steps | ~0.005 bpb | +| Fused MLP | ~10% | +2000 steps | ~0.008 bpb | +| Fused CE | ~3% | +600 steps | ~0.002 bpb | +| max-autotune | ~2% | +400 steps | ~0.001 bpb | +| **Combined** | **~20%** | **+4000 steps** | **~0.015 bpb** | + +At current ~500ms/step, 20% reduction = 400ms/step = ~1500 steps in 10min → ~1875 steps. +On 8xH100 at ~27ms/step, 20% = ~22ms/step = ~27,300 steps vs ~22,200. diff --git a/data/train_tokenizer.py b/data/train_tokenizer.py new file mode 100644 index 0000000000..44f4b0b415 --- /dev/null +++ b/data/train_tokenizer.py @@ -0,0 +1,655 @@ +#!/usr/bin/env python3 +""" +Tokenizer Trainer for Parameter Golf Competition + +Trains custom SentencePiece tokenizers on FineWeb data, evaluates quality, +and optionally exports binary shards compatible with train_gpt.py. + +Supports both BPE and Unigram model types. Research suggests Unigram often +outperforms BPE at small vocab sizes (512-4096) due to its top-down pruning +strategy selecting globally-useful tokens vs BPE's greedy bottom-up merges. + +Usage: + python train_tokenizer.py --vocab-size 1024 --model-type bpe + python train_tokenizer.py --vocab-size 1024 --model-type unigram + python train_tokenizer.py --compare # Compare BPE vs Unigram across vocab sizes + python train_tokenizer.py --vocab-size 1024 --export-shards # Train + export .bin shards + python train_tokenizer.py --evaluate path/to/model.model # Evaluate existing tokenizer + +Integration: + The trained .model file plugs directly into train_gpt.py: + VOCAB_SIZE=1024 python train_gpt.py --tokenizer-path ./tokenizers/spm_bpe_1024.model +""" + +import argparse +import json +import math +import os +import sys +import time +from collections import Counter +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np + +# ============================================================================= +# CONFIGURATION +# ============================================================================= + +FINEWEB_DOCS_PATH = Path(os.environ.get( + "DOCS_JSONL_PATH", + "./data/docs_selected.jsonl" # Default for RunPod; override for local use +)) +OUTPUT_DIR = Path("./tokenizers") + +# Binary shard format (must match train_gpt.py / download_hf_docs_and_tokenize.py) +DATAFILE_MAGIC = 20240520 +DATAFILE_VERSION = 1 +SHARD_SIZE = 10**8 # 100M tokens per shard +NUM_VAL_DOCS = 50_000 +APPEND_EOS = False + +# Training data size recommendations by vocab size +# Larger vocabs need more data to see enough merge candidates +TRAIN_DOCS_COUNT = { + 512: 50_000, + 1024: 100_000, + 2048: 200_000, + 4096: 500_000, +} + +# Sample sentences for manual inspection of tokenization quality +SAMPLE_SENTENCES = [ + "The quick brown fox jumps over the lazy dog.", + "Machine learning models require careful hyperparameter tuning.", + "In 2024, researchers published 3,847 papers on language models.", + "def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)", + "The café served crème brûlée for €12.50 — absolutely délicieux!", + "HTTP/1.1 200 OK\nContent-Type: application/json\n{\"status\": \"success\"}", + "∀x ∈ ℝ, |sin(x)| ≤ 1", +] + + +# ============================================================================= +# DATA LOADING +# ============================================================================= + +def iter_docs_jsonl(docs_path: Path, max_docs: Optional[int] = None): + """Iterate over texts from FineWeb docs_selected.jsonl (streaming, low memory).""" + if not docs_path.exists(): + raise FileNotFoundError( + f"{docs_path} not found!\n" + "Download FineWeb data first:\n" + " cd ../parameter-golf && python data/cached_challenge_fineweb.py --variant sp1024" + ) + count = 0 + with open(docs_path) as f: + for line in f: + if max_docs is not None and count >= max_docs: + break + line = line.strip() + if line: + yield json.loads(line)["text"] + count += 1 + print(f"Loaded {count:,} documents from {docs_path}") + + +def iter_sentences_for_spm(docs_path: Path, max_docs: Optional[int] = None): + """Yield individual sentences for SentencePiece sentence_iterator. + + SentencePiece's sentence_iterator expects one sentence per yield. + Splitting on newlines gives cleaner training signal than full documents. + """ + for text in iter_docs_jsonl(docs_path, max_docs): + for line in text.split("\n"): + line = line.strip() + if line: + yield line + + +# ============================================================================= +# SENTENCEPIECE TRAINER +# ============================================================================= + +def train_sentencepiece( + docs_path: Path, + vocab_size: int, + model_type: str, + output_dir: Path, + max_docs: Optional[int] = None, + *, + character_coverage: float = 0.995, + max_sentencepiece_length: int = 16, + min_frequency: int = 2, + input_sentence_size: int = 0, +) -> Dict[str, Any]: + """Train a SentencePiece tokenizer (BPE or Unigram). + + Args: + docs_path: Path to docs_selected.jsonl + vocab_size: Target vocabulary size + model_type: 'bpe' or 'unigram' + output_dir: Directory to save model files + max_docs: Max documents to use for training + character_coverage: Unicode character coverage (0.995 recommended for small vocabs) + max_sentencepiece_length: Max token length in chars (prevents overly long tokens) + min_frequency: Minimum frequency for a token to be kept + input_sentence_size: Max sentences to use (0 = all). Set to 5M+ for large corpora. + """ + try: + import sentencepiece as spm + except ImportError: + print("ERROR: sentencepiece not installed. Run: pip install sentencepiece") + return {} + + output_dir.mkdir(parents=True, exist_ok=True) + model_prefix = str(output_dir / f"spm_{model_type}_{vocab_size}") + + print(f"\n{'=' * 60}") + print(f"Training SentencePiece {model_type.upper()} (vocab={vocab_size})") + print(f"{'=' * 60}") + + # Build training kwargs — uses sentence_iterator (streaming) instead of + # writing a temp file, matching the official data pipeline. + kwargs: Dict[str, Any] = { + "sentence_iterator": iter_sentences_for_spm(docs_path, max_docs), + "model_prefix": model_prefix, + "model_type": model_type, + "vocab_size": vocab_size, + # --- Coverage & fallback --- + # 0.995 saves vocab slots vs 0.9995; rare Unicode falls back to bytes. + "character_coverage": character_coverage, + # Critical for small vocab: reserves 256 byte tokens so any byte sequence + # is representable (no output). Costs 256 vocab slots. + "byte_fallback": True, + # --- Splitting rules --- + "split_digits": True, # Each digit 0-9 is its own token + "split_by_unicode_script": True, # Prevent cross-script merges + "split_by_number": True, # Prevent number-letter merges + # --- Normalization --- + # nmt_nfkc collapses Unicode variants (fullwidth chars, ligatures, etc.) + # saving vocab slots. Must match at inference time. + "normalization_rule_name": "nmt_nfkc", + # False matches the official data pipeline. Means "Hello" stays "Hello" + # (no leading ▁), but " Hello" gets ▁. + "add_dummy_prefix": False, + # --- Special token IDs (must match train_gpt.py) --- + "pad_id": 0, + "bos_id": 1, + "eos_id": 2, + "unk_id": 3, + # --- Vocab constraints --- + "hard_vocab_limit": False, + "max_sentencepiece_length": max_sentencepiece_length, + } + + if min_frequency > 0: + # SentencePiece: --min_frequency is only used in BPE mode (ignored for unigram) + kwargs["min_frequency"] = min_frequency + + if input_sentence_size > 0: + kwargs["input_sentence_size"] = input_sentence_size + kwargs["shuffle_input_sentence"] = True + + start_time = time.time() + spm.SentencePieceTrainer.train(**kwargs) + train_time = time.time() - start_time + + print(f"Training completed in {train_time:.1f}s") + + # Load and verify + sp = spm.SentencePieceProcessor(model_file=f"{model_prefix}.model") + + # Count token types + n_byte = sum(1 for i in range(sp.vocab_size()) if sp.is_byte(i)) + n_control = sum(1 for i in range(sp.vocab_size()) if sp.is_control(i)) + n_unknown = sum(1 for i in range(sp.vocab_size()) if sp.is_unknown(i)) + n_learned = sp.vocab_size() - n_byte - n_control - n_unknown + + print(f"Actual vocab size: {sp.vocab_size()}") + print(f" Learned subword tokens: {n_learned}") + print(f" Byte fallback tokens: {n_byte}") + print(f" Control tokens: {n_control}") + print(f" Unknown tokens: {n_unknown}") + + return { + "method": f"sentencepiece_{model_type}", + "model_type": model_type, + "vocab_size": vocab_size, + "actual_vocab": sp.vocab_size(), + "learned_tokens": n_learned, + "byte_tokens": n_byte, + "train_time_sec": train_time, + "model_path": f"{model_prefix}.model", + "vocab_path": f"{model_prefix}.vocab", + } + + +# ============================================================================= +# TOKENIZER EVALUATION +# ============================================================================= + +def evaluate_tokenizer(model_path: str, docs_path: Path, n_eval_docs: int = 5000) -> Dict[str, Any]: + """Comprehensive evaluation of a trained SentencePiece tokenizer. + + Computes: + - Compression ratio (bytes per token) — higher is better + - Fertility (tokens per word) — lower is better + - Token length distribution + - Coverage analysis + - Sample tokenizations for manual inspection + """ + import sentencepiece as spm + + sp = spm.SentencePieceProcessor(model_file=model_path) + + print(f"\n{'=' * 60}") + print(f"Evaluating: {model_path}") + print(f"Vocab size: {sp.vocab_size()}, Eval docs: {n_eval_docs:,}") + print(f"{'=' * 60}") + + total_bytes = 0 + total_tokens = 0 + total_words = 0 + token_lengths: List[int] = [] # UTF-8 byte length of each token's text + token_freq: Counter = Counter() + + start = time.time() + for text in iter_docs_jsonl(docs_path, max_docs=n_eval_docs): + text_bytes = len(text.encode("utf-8")) + ids = sp.encode(text, out_type=int) + pieces = sp.encode(text, out_type=str) + words = text.split() + + total_bytes += text_bytes + total_tokens += len(ids) + total_words += len(words) + + for piece in pieces: + clean = piece.lstrip("▁") + token_lengths.append(len(clean.encode("utf-8"))) + + token_freq.update(ids) + + eval_time = time.time() - start + + # Core metrics + bytes_per_token = total_bytes / total_tokens if total_tokens else 0 + fertility = total_tokens / total_words if total_words else 0 + bits_per_byte_ceiling = math.log2(sp.vocab_size()) # Theoretical worst case + + # Token length distribution + lengths_arr = np.array(token_lengths) + pct_single_byte = np.sum(lengths_arr <= 1) / len(lengths_arr) * 100 + pct_multi_char = np.sum(lengths_arr >= 3) / len(lengths_arr) * 100 + + # Vocab utilization: how many unique tokens actually appear + unique_used = len(token_freq) + vocab_utilization = unique_used / sp.vocab_size() * 100 + + # Top tokens + top_20 = token_freq.most_common(20) + + print(f"\n--- Compression ---") + print(f" Bytes per token: {bytes_per_token:.2f} (higher = better)") + print(f" Tokens per word: {fertility:.2f} (lower = better)") + print(f" Bits/byte ceiling: {bits_per_byte_ceiling:.2f} (log2(vocab_size))") + print(f" Effective BPB: ~{bits_per_byte_ceiling / bytes_per_token:.2f} (ceiling / compression)") + + print(f"\n--- Token Length Distribution ---") + print(f" Mean token length: {lengths_arr.mean():.1f} bytes") + print(f" Median: {np.median(lengths_arr):.0f} bytes") + print(f" Single-byte tokens: {pct_single_byte:.1f}%") + print(f" Multi-char (≥3B): {pct_multi_char:.1f}%") + + print(f"\n--- Vocab Utilization ---") + print(f" Unique tokens used: {unique_used:,} / {sp.vocab_size()} ({vocab_utilization:.1f}%)") + + print(f"\n--- Top 20 Tokens ---") + for token_id, count in top_20: + piece = sp.id_to_piece(token_id) + print(f" {token_id:5d} | {piece:20s} | {count:,}") + + # Sample tokenizations + print(f"\n--- Sample Tokenizations ---") + for sentence in SAMPLE_SENTENCES: + pieces = sp.encode(sentence, out_type=str) + ids = sp.encode(sentence, out_type=int) + print(f"\n Input: {sentence[:80]}") + print(f" Tokens: {len(ids)}") + print(f" Pieces: {' | '.join(pieces[:30])}") + + print(f"\n Eval time: {eval_time:.1f}s") + + return { + "model_path": model_path, + "vocab_size": sp.vocab_size(), + "bytes_per_token": bytes_per_token, + "fertility": fertility, + "bits_per_byte_ceiling": bits_per_byte_ceiling, + "mean_token_length_bytes": float(lengths_arr.mean()), + "pct_single_byte_tokens": pct_single_byte, + "pct_multi_char_tokens": pct_multi_char, + "vocab_utilization_pct": vocab_utilization, + "eval_docs": n_eval_docs, + "total_tokens": total_tokens, + "total_bytes": total_bytes, + } + + +# ============================================================================= +# BINARY SHARD EXPORT (compatible with train_gpt.py) +# ============================================================================= + +def write_datafile(path: Path, toks: np.ndarray) -> None: + """Write a binary shard file matching train_gpt.py format.""" + if len(toks) >= 2**31: + raise ValueError("token count too large") + header = np.zeros(256, dtype=" Dict[str, int]: + """Export tokenized binary shards for train_gpt.py. + + Format: [256-int32 header][uint16 token stream] + Each document: [BOS_ID] [encoded tokens] (no EOS by default) + First `num_val_docs` go to val shards, rest to train shards. + """ + import sentencepiece as spm + + sp = spm.SentencePieceProcessor(model_file=model_path) + vocab_size = sp.vocab_size() + + if vocab_size > 2**16: + raise ValueError(f"vocab_size={vocab_size} too large for uint16 shard storage") + + output_dir.mkdir(parents=True, exist_ok=True) + + # Clean stale shards + for pattern in ("fineweb_train_*.bin", "fineweb_val_*.bin"): + for stale in output_dir.glob(pattern): + stale.unlink() + + stats = {k: 0 for k in [ + "docs_total", "docs_val", "docs_train", + "files_total", "files_val", "files_train", + "tokens_total", "tokens_val", "tokens_train", + ]} + + buf = np.empty((shard_size,), dtype=np.uint16) + fill = 0 + split = "val" + shards = {"val": 0, "train": 0} + + def flush(): + nonlocal fill + if fill == 0: + return + path = output_dir / f"fineweb_{split}_{shards[split]:06d}.bin" + write_datafile(path, buf[:fill]) + stats["files_total"] += 1 + stats[f"files_{split}"] += 1 + shards[split] += 1 + fill = 0 + + bos_id = sp.bos_id() + + print(f"\nExporting binary shards to {output_dir}/") + print(f" Tokenizer: {model_path} (vocab={vocab_size})") + print(f" Val docs: {num_val_docs:,}, Shard size: {shard_size:,} tokens") + + for text in iter_docs_jsonl(docs_path): + doc_split = "val" if stats["docs_total"] < num_val_docs else "train" + if doc_split != split: + flush() + split = doc_split + + encoded = np.asarray(sp.encode(text, out_type=int), dtype=np.int32) + toks = np.empty((encoded.size + 1 + int(APPEND_EOS),), dtype=np.int32) + toks[0] = bos_id + toks[1: 1 + encoded.size] = encoded + if APPEND_EOS: + toks[-1] = sp.eos_id() + + if not ((0 <= toks).all() and (toks < vocab_size).all()): + bad = int(toks[(toks < 0) | (toks >= vocab_size)][0]) + raise ValueError(f"token id {bad} outside vocab_size={vocab_size}") + toks = toks.astype(" Dict[str, Any]: + """Calculate embedding size impact on the 16MB artifact budget.""" + bytes_fp16 = vocab_size * model_dim * 2 + artifact_budget = 16_000_000 + baseline_vocab = 1024 + baseline_bytes = baseline_vocab * model_dim * 2 + + return { + "vocab_size": vocab_size, + "model_dim": model_dim, + "embedding_bytes_fp16": bytes_fp16, + "embedding_mb_fp16": bytes_fp16 / 1_000_000, + "budget_remaining": artifact_budget - bytes_fp16, + "vs_baseline_1024": bytes_fp16 - baseline_bytes, + "pct_of_budget": bytes_fp16 / artifact_budget * 100, + } + + +# ============================================================================= +# COMPARISON +# ============================================================================= + +def run_comparison(docs_path: Path, vocab_sizes: List[int], model_types: List[str], + max_docs: Optional[int] = None, n_eval_docs: int = 5000): + """Train and evaluate all combinations, print comparison table.""" + results = [] + + for vocab_size in vocab_sizes: + n_docs = max_docs or TRAIN_DOCS_COUNT.get(vocab_size, 100_000) + for model_type in model_types: + info = train_sentencepiece(docs_path, vocab_size, model_type, OUTPUT_DIR, max_docs=n_docs) + if not info: + continue + eval_result = evaluate_tokenizer(info["model_path"], docs_path, n_eval_docs=n_eval_docs) + budget = calc_embedding_budget(vocab_size) + results.append({**info, **eval_result, "budget": budget}) + + # Print comparison table + print(f"\n{'=' * 100}") + print("COMPARISON TABLE") + print(f"{'=' * 100}") + print(f"{'Method':<20} {'Vocab':>6} {'Actual':>6} {'Learned':>7} {'B/Tok':>6} " + f"{'Fert':>6} {'Emb MB':>7} {'Budget%':>8} {'Train(s)':>9}") + print("-" * 100) + + for r in results: + print(f"{r['method']:<20} {r['vocab_size']:>6} {r['actual_vocab']:>6} " + f"{r['learned_tokens']:>7} {r['bytes_per_token']:>6.2f} " + f"{r['fertility']:>6.2f} {r['budget']['embedding_mb_fp16']:>7.2f} " + f"{r['budget']['pct_of_budget']:>7.1f}% {r['train_time_sec']:>8.1f}") + + # Save results + results_path = OUTPUT_DIR / "comparison_results.json" + with open(results_path, "w") as f: + json.dump(results, f, indent=2, default=str) + print(f"\nResults saved to {results_path}") + + # Recommendation + if results: + best = max(results, key=lambda r: r["bytes_per_token"]) + print(f"\nBest compression: {best['method']} vocab={best['vocab_size']} " + f"({best['bytes_per_token']:.2f} bytes/token)") + + return results + + +# ============================================================================= +# MAIN +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser( + description="Train custom SentencePiece tokenizers for Parameter Golf", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Train BPE tokenizer with vocab 1024 + python train_tokenizer.py --vocab-size 1024 --model-type bpe + + # Train Unigram tokenizer (often better for small vocabs) + python train_tokenizer.py --vocab-size 1024 --model-type unigram + + # Compare BPE vs Unigram across vocab sizes + python train_tokenizer.py --compare + + # Train + export binary shards for train_gpt.py + python train_tokenizer.py --vocab-size 1024 --model-type bpe --export-shards + + # Evaluate an existing tokenizer + python train_tokenizer.py --evaluate ./tokenizers/spm_bpe_1024.model + """, + ) + parser.add_argument("--vocab-size", type=int, default=1024, help="Target vocabulary size (default: 1024)") + parser.add_argument("--model-type", type=str, default="bpe", choices=["bpe", "unigram"], + help="SentencePiece model type (default: bpe)") + parser.add_argument("--compare", action="store_true", + help="Compare BPE vs Unigram across vocab sizes [512, 1024, 2048, 4096]") + parser.add_argument("--evaluate", type=str, default=None, metavar="MODEL_PATH", + help="Evaluate an existing .model file instead of training") + parser.add_argument("--export-shards", action="store_true", + help="Export binary shards after training (for train_gpt.py)") + parser.add_argument("--shard-output-dir", type=str, default=None, + help="Output directory for binary shards (default: alongside tokenizer)") + parser.add_argument("--docs-path", type=str, default=None, help="Path to docs_selected.jsonl") + parser.add_argument("--max-docs", type=int, default=None, help="Max docs for tokenizer training") + parser.add_argument("--eval-docs", type=int, default=5000, help="Docs for evaluation (default: 5000)") + parser.add_argument("--character-coverage", type=float, default=0.995, + help="Unicode character coverage (default: 0.995)") + parser.add_argument("--max-token-length", type=int, default=16, + help="Max SentencePiece token length in chars (default: 16)") + parser.add_argument("--input-sentence-size", type=int, default=0, + help="Max sentences for SPM training (0=all, set >0 for large corpora)") + args = parser.parse_args() + + docs_path = Path(args.docs_path) if args.docs_path else FINEWEB_DOCS_PATH + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + # --- Evaluate existing model --- + if args.evaluate: + if not Path(args.evaluate).exists(): + print(f"ERROR: {args.evaluate} not found!") + sys.exit(1) + evaluate_tokenizer(args.evaluate, docs_path, n_eval_docs=args.eval_docs) + return + + # --- Compare mode --- + if args.compare: + run_comparison( + docs_path, + vocab_sizes=[512, 1024, 2048, 4096], + model_types=["bpe", "unigram"], + max_docs=args.max_docs, + n_eval_docs=args.eval_docs, + ) + return + + # --- Train single tokenizer --- + max_docs = args.max_docs or TRAIN_DOCS_COUNT.get(args.vocab_size, 100_000) + + info = train_sentencepiece( + docs_path, + args.vocab_size, + args.model_type, + OUTPUT_DIR, + max_docs=max_docs, + character_coverage=args.character_coverage, + max_sentencepiece_length=args.max_token_length, + input_sentence_size=args.input_sentence_size, + ) + + if not info: + print("ERROR: Training failed!") + sys.exit(1) + + # Evaluate + eval_result = evaluate_tokenizer(info["model_path"], docs_path, n_eval_docs=args.eval_docs) + budget = calc_embedding_budget(args.vocab_size) + + # Print budget info + print(f"\n--- Embedding Budget Impact (d_model=512, FP16) ---") + print(f" Embedding size: {budget['embedding_mb_fp16']:.2f} MB") + print(f" % of 16MB budget: {budget['pct_of_budget']:.1f}%") + print(f" vs baseline (1024): {budget['vs_baseline_1024']:+,} bytes") + + # Save results + all_results = {**info, **eval_result, "budget": budget} + results_path = OUTPUT_DIR / f"result_{args.model_type}_{args.vocab_size}.json" + with open(results_path, "w") as f: + json.dump(all_results, f, indent=2, default=str) + print(f"\nResults saved to {results_path}") + + # Export shards if requested + if args.export_shards: + shard_dir = Path(args.shard_output_dir) if args.shard_output_dir else ( + OUTPUT_DIR / f"shards_{args.model_type}_{args.vocab_size}" + ) + export_binary_shards(info["model_path"], docs_path, shard_dir) + + # Print integration instructions + print(f"\n--- Integration ---") + print(f" To use with train_gpt.py:") + print(f" VOCAB_SIZE={info['actual_vocab']} python train_gpt.py \\") + print(f" --tokenizer-path {info['model_path']}") + if not args.export_shards: + print(f"\n To export binary shards:") + print(f" python train_tokenizer.py --vocab-size {args.vocab_size} " + f"--model-type {args.model_type} --export-shards") + + +if __name__ == "__main__": + main() diff --git a/quickstart_hyperbolic.sh b/quickstart_hyperbolic.sh new file mode 100644 index 0000000000..941f7cf6f7 --- /dev/null +++ b/quickstart_hyperbolic.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +# === Hyperbolic.ai Quick Start === +# Run from ~/runpod-testing directory + +set -euo pipefail +log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } +(while true; do sleep 60; nvidia-smi > /dev/null 2>&1; done) & +trap "kill $! 2>/dev/null" EXIT + +GPU_COUNT=$(nvidia-smi --list-gpus 2>/dev/null | wc -l | tr -d ' ') +log "Detected ${GPU_COUNT} GPUs" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "${SCRIPT_DIR}" + +# Clone parameter-golf if needed +if [ ! -d "$HOME/parameter-golf" ]; then + log "Cloning parameter-golf..." + git clone https://github.com/openai/parameter-golf.git "$HOME/parameter-golf" +fi + +# Build FA3 selectively (~5 min on H100) +if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then + log "Building Flash Attention 3 (selective, ~5 min)..." + + if [ ! -d "$HOME/flash-attention" ]; then + git clone https://github.com/Dao-AILab/flash-attention.git "$HOME/flash-attention" + fi + + cd "$HOME/flash-attention/hopper" + rm -rf build/ + mkdir -p flash_attn_3 + + # Only build bf16 hdim64 SM90 - skip everything else + export FLASH_ATTENTION_DISABLE_FP16=TRUE + export FLASH_ATTENTION_DISABLE_FP8=TRUE + export FLASH_ATTENTION_DISABLE_HDIM96=TRUE + export FLASH_ATTENTION_DISABLE_HDIM128=TRUE + export FLASH_ATTENTION_DISABLE_HDIM192=TRUE + export FLASH_ATTENTION_DISABLE_HDIM256=TRUE + export FLASH_ATTENTION_DISABLE_SM80=TRUE + export FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE + export FLASH_ATTENTION_DISABLE_APPENDKV=TRUE + export FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE + export FLASH_ATTENTION_DISABLE_PACKGQA=TRUE + export FLASH_ATTENTION_DISABLE_VARLEN=TRUE + export FLASH_ATTENTION_DISABLE_SPLIT=TRUE + export FLASH_ATTENTION_DISABLE_LOCAL=TRUE + export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF64=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF192=TRUE + + pip install --no-build-isolation --break-system-packages -e . + + # Symlink config to torch (fixes torch.compile backward crash) + SITE_PACKAGES=$(python3 -c "import site; print(site.getusersitepackages())") + TORCH_PATH=$(python3 -c "import torch; import os; print(os.path.dirname(torch.__file__))") + ln -sf "${SITE_PACKAGES}/flash_attn_3/flash_attn_config.py" "${TORCH_PATH}/flash_attn_config.py" 2>/dev/null || true + + python3 -c "from flash_attn_interface import flash_attn_func; print('FA3: OK')" +fi + +# Download dataset +cd "$HOME/parameter-golf" +log "Downloading FineWeb dataset (8B tokens)..." +python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 + +# Symlink data to runpod-testing +cd "${SCRIPT_DIR}" +mkdir -p data/datasets data/tokenizers +[ ! -L "data/datasets/fineweb10B_sp1024" ] && \ + ln -s "$HOME/parameter-golf/data/datasets/fineweb10B_sp1024" data/datasets/ +[ ! -L "data/tokenizers/fineweb_1024_bpe.model" ] && \ + ln -s "$HOME/parameter-golf/data/tokenizers/fineweb_1024_bpe.model" data/tokenizers/ + +log "" +log "=== Setup Complete ===" +log "GPUs: ${GPU_COUNT}" +log "FA3: $(python3 -c 'from flash_attn_interface import flash_attn_func; print("OK")' 2>/dev/null || echo 'FAILED')" +log "Dataset: $(ls -1 data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l) train shards" +log "" +log "Ready! Run:" +log " MODE=mos bash run_mos_sota.sh" \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/README.md b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/README.md new file mode 100644 index 0000000000..8a62b6780d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/README.md @@ -0,0 +1,46 @@ +First pilot run of Mixture of Softmax (MoS) on 1x H100 SXM, 10-minute wallclock. + +Configuration: +- Track: `non-record`, 1x H100 SXM, 10 min wallclock +- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` +- MoS: `USE_MOS=1 MOS_K=2 MOS_RANK=64` (low-rank factorization, ~99K extra params) +- Tied embeddings, seed=42 + +Command: +```bash +RUN_ID=mos_k2_r64_pilot \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 SEED=42 \ +USE_MOS=1 MOS_K=2 MOS_RANK=64 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 TRAIN_LOG_EVERY=100 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py +``` + +Key metrics: +- Stopped at step 1113/20000 (wallclock cap) +- Pre-quant: `val_loss:2.3505 val_bpb:1.3921` +- Post-quant (int8+zlib): `val_loss:2.3523 val_bpb:1.3932` +- Quantization degradation: +0.0011 bpb (minimal) +- Model params: 17,159,240 +- Artifact: 12,764,492 bytes int8+zlib (12.8MB, 3.2MB under 16MB cap) +- Code: 63,345 bytes +- Total: 12,827,837 bytes +- Peak memory: 11,012 MiB allocated +- Step avg: 539ms/step on 1x H100 + +Training curve: +| Step | Train Loss | Val BPB | Time | +|------|-----------|---------|------| +| 0 | 6.93 | 4.11 | 0s | +| 100 | 3.27 | — | 54s | +| 500 | 2.58 | 1.52 | 271s | +| 1000 | 2.40 | 1.40 | 542s | +| 1113 | — | 1.39 | 600s | + +Notes: +- Loss still dropping at wallclock stop — model had more to learn +- No TTT/LoRA eval was run (only int8 roundtrip) +- No same-conditions baseline for direct comparison (8xH100 baseline: ~1.2244 bpb at 20K steps) +- 1x H100 = ~1/8 throughput → only 1113 steps vs ~20K on 8xH100 diff --git a/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/submission.json b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/submission.json new file mode 100644 index 0000000000..4112a7586f --- /dev/null +++ b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/submission.json @@ -0,0 +1,31 @@ +{ + "author": "billyendson", + "github_id": "User123331", + "name": "MoS K=2 Rank=64 Pilot (1xH100, 10min)", + "blurb": "First pilot of Mixture of Softmax (K=2, low-rank=64) on 1xH100 SXM for 10 minutes. Tests softmax bottleneck breaking with minimal parameter overhead (~99K params, 97KB). Artifact 12.8MB, well under 16MB cap. No TTT eval. Loss still dropping at wallclock stop.", + "date": "2026-03-21T19:48:40Z", + "track": "non-record-unlimited-compute-16mb", + "val_loss": 2.35234121, + "val_bpb": 1.39318897, + "pre_quant_val_loss": 2.3505, + "pre_quant_val_bpb": 1.3921, + "step_stop": 1113, + "wallclock_seconds": 600.423, + "bytes_total": 12827837, + "bytes_model_int8_zlib": 12764492, + "bytes_code": 63345, + "gpu": "1xH100_SXM", + "config": { + "USE_MOS": 1, + "MOS_K": 2, + "MOS_RANK": 64, + "VOCAB_SIZE": 1024, + "NUM_LAYERS": 9, + "MODEL_DIM": 512, + "NUM_HEADS": 8, + "NUM_KV_HEADS": 4, + "MLP_MULT": 2, + "SEED": 42, + "TRAIN_SEQ_LEN": 1024 + } +} diff --git a/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train.log b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train.log new file mode 100644 index 0000000000..4904cb52fb --- /dev/null +++ b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train.log @@ -0,0 +1,93 @@ +logs/mos_k2_r64_pilot.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17159240 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning: +Online softmax is disabled on the fly since Inductor decides to +split the reduction. Cut an issue to PyTorch if this is an +important use case and you want to speed it up with online +softmax. + + warnings.warn( +/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning: +Online softmax is disabled on the fly since Inductor decides to +split the reduction. Cut an issue to PyTorch if this is an +important use case and you want to speed it up with online +softmax. + + warnings.warn( +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning: +Online softmax is disabled on the fly since Inductor decides to +split the reduction. Cut an issue to PyTorch if this is an +important use case and you want to speed it up with online +softmax. + + warnings.warn( +/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:7242: UserWarning: +Online softmax is disabled on the fly since Inductor decides to +split the reduction. Cut an issue to PyTorch if this is an +important use case and you want to speed it up with online +softmax. + + warnings.warn( +step:0/20000 val_loss:6.9314 val_bpb:4.1052 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9314 train_time:581ms step_avg:581.11ms +step:2/20000 train_loss:6.8515 train_time:1295ms step_avg:647.47ms +step:3/20000 train_loss:5.8655 train_time:1996ms step_avg:665.19ms +step:4/20000 train_loss:5.4250 train_time:2793ms step_avg:698.33ms +step:5/20000 train_loss:5.0728 train_time:3413ms step_avg:682.51ms +step:6/20000 train_loss:4.9797 train_time:4016ms step_avg:669.27ms +step:7/20000 train_loss:4.8555 train_time:4676ms step_avg:668.03ms +step:8/20000 train_loss:4.7612 train_time:5341ms step_avg:667.67ms +step:9/20000 train_loss:4.6900 train_time:5990ms step_avg:665.54ms +step:10/20000 train_loss:4.7029 train_time:6682ms step_avg:668.23ms +step:100/20000 train_loss:3.2746 train_time:54475ms step_avg:544.75ms +step:200/20000 train_loss:2.8511 train_time:108479ms step_avg:542.40ms +step:300/20000 train_loss:2.7046 train_time:162973ms step_avg:543.24ms +step:400/20000 train_loss:2.4804 train_time:217390ms step_avg:543.47ms +step:500/20000 train_loss:2.5755 train_time:271183ms step_avg:542.37ms +step:500/20000 val_loss:2.5703 val_bpb:1.5223 train_time:271193ms step_avg:542.39ms +step:600/20000 train_loss:2.5630 train_time:324786ms step_avg:541.31ms +step:700/20000 train_loss:2.5112 train_time:378359ms step_avg:540.51ms +step:800/20000 train_loss:2.3957 train_time:432963ms step_avg:541.20ms +step:900/20000 train_loss:2.4135 train_time:487589ms step_avg:541.77ms +step:1000/20000 train_loss:2.4031 train_time:542181ms step_avg:542.18ms +step:1000/20000 val_loss:2.3696 val_bpb:1.4034 train_time:542248ms step_avg:542.25ms +step:1100/20000 train_loss:2.3186 train_time:594115ms step_avg:540.10ms +step:1113/20000 val_loss:2.3505 val_bpb:1.3921 train_time:600423ms step_avg:539.46ms +stopping_early: wallclock_cap train_time:600423ms step:1113/20000 +peak memory allocated: 11012 MiB reserved: 11320 MiB +Serialized model: 67623386 bytes +Code size: 63345 bytes +Total submission size: 67686731 bytes +Serialized model int8+zlib: 12764492 bytes (payload:17377568 raw_torch:17423635 payload_ratio:3.89x) +Total submission size int8+zlib: 12827837 bytes +final_int8_zlib_roundtrip val_loss:2.3523 val_bpb:1.3932 eval_time:11887ms +final_int8_zlib_roundtrip_exact val_loss:2.35234121 val_bpb:1.39318897 diff --git a/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train_gpt.py b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train_gpt.py new file mode 100644 index 0000000000..40db9e1f6d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-21_MoS_K2_R64_1xH100_10min/train_gpt.py @@ -0,0 +1,1567 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + + # Mixture of Softmax (MoS) output layer - breaks softmax bottleneck. + # At vocab=1024, dim=512, standard softmax has rank ≤ 513 (binding constraint). + # MoS with K=2 lifts this to rank ≤ 1026, enabling richer output distributions. + use_mos = bool(int(os.environ.get("USE_MOS", "0"))) + mos_k = int(os.environ.get("MOS_K", 2)) + mos_rank = int(os.environ.get("MOS_RANK", 64)) # 0 = full-rank, >0 = low-rank factorization + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class MixtureOfSoftmax(nn.Module): + """Mixture of Softmax output layer for breaking the softmax bottleneck. + + At vocab=1024, dim=512, the standard softmax has rank ≤ 513. + MoS with K=2 lifts this to rank ≤ 1026, enabling richer output distributions. + + When mos_rank > 0, uses low-rank factorization to save parameters: + instead of dim -> K*dim projection, uses dim -> rank -> K*dim. + + Paper: Yang et al. (2018), "Breaking the Softmax Bottleneck", ICLR 2018. + """ + + def __init__(self, model_dim: int, vocab_size: int, n_mixtures: int = 2, rank: int = 0): + super().__init__() + self.n_mixtures = n_mixtures + self.model_dim = model_dim + self.vocab_size = vocab_size + self.rank = rank + + if rank > 0: + # Low-rank factorization: dim -> rank -> K*dim + self.proj_down = CastedLinear(model_dim, rank, bias=False) + self.proj_up = CastedLinear(rank, n_mixtures * model_dim, bias=False) + nn.init.normal_(self.proj_down.weight, mean=0.0, std=0.02) + nn.init.normal_(self.proj_up.weight, mean=0.0, std=0.02) + else: + # Full-rank: dim -> K*dim + self.projections = CastedLinear(model_dim, n_mixtures * model_dim, bias=False) + nn.init.normal_(self.projections.weight, mean=0.0, std=0.02) + + # Mixing weight predictor + self.gate = CastedLinear(model_dim, n_mixtures, bias=False) + nn.init.normal_(self.gate.weight, mean=0.0, std=0.02) + + def forward(self, hidden: Tensor, weight_matrix: Tensor) -> Tensor: + """Compute mixed softmax distribution. + + Args: + hidden: (bsz, seq_len, dim) - final hidden states + weight_matrix: (vocab_size, dim) - tied embedding weights + + Returns: + log_probs: (bsz, seq_len, vocab_size) - mixed log probabilities + """ + bsz, seq_len, dim = hidden.shape + K = self.n_mixtures + + # Compute mixing weights: (bsz, seq, K) + pi = F.softmax(self.gate(hidden), dim=-1) + + # Project to K different spaces: (bsz, seq, K * dim) -> (bsz, seq, K, dim) + if self.rank > 0: + projected = self.proj_up(self.proj_down(hidden)).view(bsz, seq_len, K, dim) + else: + projected = self.projections(hidden).view(bsz, seq_len, K, dim) + + # Compute K different logit vectors: (bsz, seq, K, vocab) + logits = torch.einsum('bskd,vd->bskv', projected, weight_matrix) + + # Mix softmax distributions using log-space for numerical stability + log_probs = F.log_softmax(logits, dim=-1) # (bsz, seq, K, vocab) + log_pi = torch.log(pi.unsqueeze(-1) + 1e-10) # (bsz, seq, K, 1) + mixed_log_probs = torch.logsumexp(log_probs + log_pi, dim=2) # (bsz, seq, vocab) + + return mixed_log_probs + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + use_mos: bool = False, + mos_k: int = 2, + mos_rank: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.use_mos = use_mos + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + # MoS output layer (optional) - breaks softmax bottleneck + if use_mos: + self.mos = MixtureOfSoftmax(model_dim, vocab_size, n_mixtures=mos_k, rank=mos_rank) + else: + self.mos = None + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x = self.final_norm(x) + # Output layer + if self.mos is not None and self.tie_embeddings: + # MoS: returns log-probs (already log-softmaxed), use NLL loss directly + log_probs = self.mos(x, self.tok_emb.weight) + if lora: + # LoRA correction breaks normalization; re-normalize via log_softmax + log_probs = F.log_softmax(log_probs + lora.lm_head_lora(x), dim=-1) + bsz, sl, V = log_probs.shape + return F.nll_loss( + log_probs.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) + return F.nll_loss(log_probs.float().reshape(-1, log_probs.size(-1)), target_ids.reshape(-1), reduction="mean") + elif self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- +# +# At evaluation time, we adapt per-document low-rank adapters on the validation data. +# Each document gets its own adapter, so there is no inter-document dependency. + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element. + Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) # kaiming-uniform + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: # Fresh state. + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document, identified by BOS boundaries. + + If include_next_bos is True, include next document's BOS (to match continuous-stream + eval token count exactly). + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + """Add one doc-chunk's contribution to the running BPB accumulators.""" + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" + # Load validation tokens and find document boundaries + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + # Each rank takes a contiguous slice of documents + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] # (chunk_offset, chunk_len) per doc + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + # Forward pass (keep grad graph alive only when we need to train) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + # Score: accumulate loss and byte counts for BPB (before training on chunk) + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + # Train: one Adam step on the LoRA params using this chunk's loss + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + use_mos=args.use_mos, + mos_k=args.mos_k, + mos_rank=args.mos_rank, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + # MoS parameters: 2D projection weights go to Muon, gate goes to scalar optimizer + if base_model.mos is not None: + if base_model.mos.rank > 0: + matrix_params.append(base_model.mos.proj_down.weight) + matrix_params.append(base_model.mos.proj_up.weight) + else: + matrix_params.append(base_model.mos.projections.weight) + scalar_params.append(base_model.mos.gate.weight) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # LoRA test-time training evaluation (the competition score) + torch._dynamo.reset() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Fri Mar 20 19:48:46 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 27C P0 96W / 700W | 1185MiB / 81559MiB | 10% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 1166 C /usr/local/bin/python 1176MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17159240 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9314 val_bpb:4.1052 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9314 train_time:581ms step_avg:581.11ms +step:2/20000 train_loss:6.8515 train_time:1295ms step_avg:647.47ms +step:3/20000 train_loss:5.8655 train_time:1996ms step_avg:665.19ms +step:4/20000 train_loss:5.4250 train_time:2793ms step_avg:698.33ms +step:5/20000 train_loss:5.0728 train_time:3413ms step_avg:682.51ms +step:6/20000 train_loss:4.9797 train_time:4016ms step_avg:669.27ms +step:7/20000 train_loss:4.8555 train_time:4676ms step_avg:668.03ms +step:8/20000 train_loss:4.7612 train_time:5341ms step_avg:667.67ms +step:9/20000 train_loss:4.6900 train_time:5990ms step_avg:665.54ms +step:10/20000 train_loss:4.7029 train_time:6682ms step_avg:668.23ms +step:100/20000 train_loss:3.2746 train_time:54475ms step_avg:544.75ms +step:200/20000 train_loss:2.8511 train_time:108479ms step_avg:542.40ms +step:300/20000 train_loss:2.7046 train_time:162973ms step_avg:543.24ms +step:400/20000 train_loss:2.4804 train_time:217390ms step_avg:543.47ms +step:500/20000 train_loss:2.5755 train_time:271183ms step_avg:542.37ms +step:500/20000 val_loss:2.5703 val_bpb:1.5223 train_time:271193ms step_avg:542.39ms +step:600/20000 train_loss:2.5630 train_time:324786ms step_avg:541.31ms +step:700/20000 train_loss:2.5112 train_time:378359ms step_avg:540.51ms +step:800/20000 train_loss:2.3957 train_time:432963ms step_avg:541.20ms +step:900/20000 train_loss:2.4135 train_time:487589ms step_avg:541.77ms +step:1000/20000 train_loss:2.4031 train_time:542181ms step_avg:542.18ms +step:1000/20000 val_loss:2.3696 val_bpb:1.4034 train_time:542248ms step_avg:542.25ms +step:1100/20000 train_loss:2.3186 train_time:594115ms step_avg:540.10ms +step:1113/20000 val_loss:2.3505 val_bpb:1.3921 train_time:600423ms step_avg:539.46ms +stopping_early: wallclock_cap train_time:600423ms step:1113/20000 +peak memory allocated: 11012 MiB reserved: 11320 MiB +Serialized model: 67623386 bytes +Code size: 63345 bytes +Total submission size: 67686731 bytes +Serialized model int8+zlib: 12764492 bytes (payload:17377568 raw_torch:17423635 payload_ratio:3.89x) +Total submission size int8+zlib: 12827837 bytes +final_int8_zlib_roundtrip val_loss:2.3523 val_bpb:1.3932 eval_time:11887ms +final_int8_zlib_roundtrip_exact val_loss:2.35234121 val_bpb:1.39318897 diff --git a/run_baseline_10min.sh b/run_baseline_10min.sh new file mode 100644 index 0000000000..d81a4526e2 --- /dev/null +++ b/run_baseline_10min.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Vanilla baseline, 10 min, 1x H100. Survives terminal disconnect. +# Monitor: tail -f /workspace/baseline_10min_log.txt + +cd /workspace/parameter-golf + +nohup bash -c ' +RUN_ID=baseline_10min \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +SEED=42 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py +' > /workspace/baseline_10min_log.txt 2>&1 & + +echo "PID: $!" +echo "Monitor: tail -f /workspace/baseline_10min_log.txt" diff --git a/run_competitive.sh b/run_competitive.sh new file mode 100755 index 0000000000..b36d9ba575 --- /dev/null +++ b/run_competitive.sh @@ -0,0 +1,92 @@ +#!/bin/bash +# === Parameter Golf: Competitive Entry (SOTA stack) === +# Based on thwu1's #1 submission (1.1428 bpb) +# Techniques: 10L + BigramHash(10240) + SmearGate + Int5/Int6 mixed quant +# + 3x MLP + OrthoInit + SWA(0.4) + WD=0.04 + sliding eval +# + zstd-22 + magnitude pruning +# +# Requirements: 8x H100 SXM (or adjust WORLD_SIZE) +# Expected: ~1.14 bpb in 10 minutes +# +# Usage: +# # 8x H100 (full competitive run) +# bash run_competitive.sh +# +# # 1x H100 (pilot test) +# bash run_competitive.sh --pilot + +set -e + +PILOT=0 +if [[ "$1" == "--pilot" ]]; then + PILOT=1 +fi + +cd /workspace/parameter-golf + +# Install zstandard for better compression +pip install zstandard 2>/dev/null || true + +# Download dataset if not present +if [ ! -f data/datasets/fineweb10B_sp1024/fineweb_train_000000.bin ]; then + echo "=== Downloading dataset ===" + export HF_TOKEN="${HF_TOKEN:-}" + python3 data/cached_challenge_fineweb.py --variant sp1024 +fi + +echo "Train shards: $(ls data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l)" +echo "Val shards: $(ls data/datasets/fineweb10B_sp1024/fineweb_val_*.bin 2>/dev/null | wc -l)" + +if [ "$PILOT" -eq 1 ]; then + echo "" + echo "=== PILOT RUN: 1x H100, 10 min ===" + echo "Start: $(date)" + NPROC=1 +else + echo "" + echo "=== COMPETITIVE RUN: 8x H100, 10 min ===" + echo "Start: $(date)" + NPROC=$(nvidia-smi --list-gpus | wc -l) + echo "Detected GPUs: $NPROC" +fi + +RUN_ID="competitive_$(date +%Y%m%d_%H%M%S)" \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +SEED=42 \ +NUM_LAYERS=10 \ +MODEL_DIM=512 \ +NUM_HEADS=8 \ +NUM_KV_HEADS=4 \ +MLP_MULT=3.0 \ +TIE_EMBEDDINGS=1 \ +TRAIN_SEQ_LEN=2048 \ +TRAIN_BATCH_TOKENS=786432 \ +WARMDOWN_ITERS=3000 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +WEIGHT_DECAY=0.04 \ +MATRIX_LR=0.02 \ +SCALAR_LR=0.02 \ +TIED_EMBED_LR=0.03 \ +MUON_MOMENTUM=0.99 \ +MUON_MOMENTUM_WARMUP_START=0.92 \ +MUON_MOMENTUM_WARMUP_STEPS=1500 \ +GRAD_CLIP_NORM=0.3 \ +BIGRAM_VOCAB_SIZE=10240 \ +BIGRAM_DIM=128 \ +SWA_ENABLED=1 \ +SWA_START_FRAC=0.4 \ +SWA_EVERY=50 \ +EVAL_STRIDE=64 \ +EVAL_BATCH_SEQS=32 \ +torchrun --standalone --nproc_per_node=$NPROC train_gpt.py 2>&1 | tee /workspace/competitive_log.txt + +echo "" +echo "=== RESULTS ===" +grep -E 'val_bpb|final_int8|submission|model_params|swa:' /workspace/competitive_log.txt | tail -20 +echo "" +echo "Target: val_bpb < 1.1428 (current SOTA)" +echo "Done: $(date)" diff --git a/run_competitive_custom_tok.sh b/run_competitive_custom_tok.sh new file mode 100755 index 0000000000..1e8b6a4dfa --- /dev/null +++ b/run_competitive_custom_tok.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# === Parameter Golf: Competitive Entry with Custom Tokenizer === +# Same SOTA stack as run_competitive.sh but uses a custom-trained tokenizer. +# +# WORKFLOW: +# 1. Train tokenizer: python3 ../trainer-tokenizer/train_tokenizer.py --vocab-size 1024 --model-type unigram +# 2. Export shards: python3 ../trainer-tokenizer/train_tokenizer.py --vocab-size 1024 --model-type unigram --export-shards +# 3. Run this script with the paths +# +# Usage: +# CUSTOM_TOKENIZER=./tokenizers/spm_unigram_1024.model \ +# CUSTOM_DATA=./tokenizers/shards_unigram_1024 \ +# CUSTOM_VOCAB=1024 \ +# bash run_competitive_custom_tok.sh [--pilot] + +set -e + +PILOT=0 +if [[ "$1" == "--pilot" ]]; then + PILOT=1 +fi + +# Custom tokenizer paths (must be set) +CUSTOM_TOKENIZER="${CUSTOM_TOKENIZER:?Set CUSTOM_TOKENIZER to your .model file}" +CUSTOM_DATA="${CUSTOM_DATA:?Set CUSTOM_DATA to your shard directory}" +CUSTOM_VOCAB="${CUSTOM_VOCAB:-1024}" + +cd /workspace/parameter-golf + +pip install zstandard 2>/dev/null || true + +echo "=== Custom Tokenizer Run ===" +echo "Tokenizer: $CUSTOM_TOKENIZER" +echo "Data: $CUSTOM_DATA" +echo "Vocab: $CUSTOM_VOCAB" +echo "Train shards: $(ls ${CUSTOM_DATA}/fineweb_train_*.bin 2>/dev/null | wc -l)" +echo "Val shards: $(ls ${CUSTOM_DATA}/fineweb_val_*.bin 2>/dev/null | wc -l)" + +if [ "$PILOT" -eq 1 ]; then + NPROC=1 + echo "Mode: PILOT (1x GPU)" +else + NPROC=$(nvidia-smi --list-gpus | wc -l) + echo "Mode: COMPETITIVE ($NPROC GPUs)" +fi + +echo "Start: $(date)" + +RUN_ID="custom_tok_$(date +%Y%m%d_%H%M%S)" \ +DATA_PATH="$CUSTOM_DATA" \ +TOKENIZER_PATH="$CUSTOM_TOKENIZER" \ +VOCAB_SIZE="$CUSTOM_VOCAB" \ +SEED=42 \ +NUM_LAYERS=10 \ +MODEL_DIM=512 \ +NUM_HEADS=8 \ +NUM_KV_HEADS=4 \ +MLP_MULT=3.0 \ +TIE_EMBEDDINGS=1 \ +TRAIN_SEQ_LEN=2048 \ +TRAIN_BATCH_TOKENS=786432 \ +WARMDOWN_ITERS=3000 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +WEIGHT_DECAY=0.04 \ +MATRIX_LR=0.02 \ +SCALAR_LR=0.02 \ +TIED_EMBED_LR=0.03 \ +MUON_MOMENTUM=0.99 \ +MUON_MOMENTUM_WARMUP_START=0.92 \ +MUON_MOMENTUM_WARMUP_STEPS=1500 \ +GRAD_CLIP_NORM=0.3 \ +BIGRAM_VOCAB_SIZE=10240 \ +BIGRAM_DIM=128 \ +SWA_ENABLED=1 \ +SWA_START_FRAC=0.4 \ +SWA_EVERY=50 \ +EVAL_STRIDE=64 \ +EVAL_BATCH_SEQS=32 \ +torchrun --standalone --nproc_per_node=$NPROC train_gpt.py 2>&1 | tee /workspace/custom_tok_log.txt + +echo "" +echo "=== RESULTS ===" +grep -E 'val_bpb|final_int8|submission|model_params|swa:' /workspace/custom_tok_log.txt | tail -20 +echo "" +echo "Done: $(date)" diff --git a/run_custom_tokenizer_pipeline.sh b/run_custom_tokenizer_pipeline.sh new file mode 100755 index 0000000000..f3da04c45b --- /dev/null +++ b/run_custom_tokenizer_pipeline.sh @@ -0,0 +1,149 @@ +#!/bin/bash +# ============================================================================= +# Parameter Golf: Custom Tokenizer + Competitive Run (All-in-One) +# ============================================================================= +# +# Steps: +# 1. Download docs_selected.jsonl (~45GB, 10-30 min) +# 2. Train unigram tokenizer (5-10 min) +# 3. Export binary shards (30-60 min) +# 4. Run SOTA training with custom tokenizer (10 min) +# +# Usage (paste into RunPod terminal): +# +# git clone https://github.com/User123331/parameter-golf.git +# cd parameter-golf +# git pull +# bash run_custom_tokenizer_pipeline.sh +# +# ============================================================================= + +set -e + +VOCAB_SIZE=1024 +MODEL_TYPE=unigram +MAX_TRAIN_DOCS=200000 +EVAL_DOCS=10000 +HF_TOKEN="${HF_TOKEN:-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}" + +DATA_DIR="./data/datasets" +TOKENIZER_DIR="./data/tokenizers_custom" +DOCS_JSONL="${DATA_DIR}/docs_selected.jsonl" +CUSTOM_SHARDS="${DATA_DIR}/fineweb10B_custom_${MODEL_TYPE}${VOCAB_SIZE}" +CUSTOM_MODEL="${TOKENIZER_DIR}/spm_${MODEL_TYPE}_${VOCAB_SIZE}.model" + +GREEN='\033[0;32m' +NC='\033[0m' +log() { echo -e "${GREEN}[$(date +%H:%M:%S)]${NC} $*"; } + +# ============================================================================= +# Step 1: Download docs_selected.jsonl +# ============================================================================= +log "Step 1: Downloading docs_selected.jsonl (~45GB)..." + +mkdir -p "${DATA_DIR}" + +if [ ! -f "${DOCS_JSONL}" ]; then + pip install --quiet huggingface_hub + + python3 -c " +from huggingface_hub import hf_hub_download +import shutil, os +cached = hf_hub_download( + repo_id='willdepueoai/parameter-golf', + filename='docs_selected.jsonl', + subfolder='datasets', + repo_type='dataset', +) +src = os.path.realpath(cached) +dst = '${DOCS_JSONL}' +print(f'Copying to {dst}') +try: + os.link(src, dst) +except OSError: + shutil.copy2(src, dst) +" +fi + +log "Docs ready: $(du -h "${DOCS_JSONL}" 2>/dev/null | cut -f1)" + +# ============================================================================= +# Step 2: Train custom tokenizer +# ============================================================================= +log "Step 2: Training ${MODEL_TYPE} tokenizer..." + +mkdir -p "${TOKENIZER_DIR}" +pip install --quiet sentencepiece numpy + +python3 data/train_tokenizer.py \ + --vocab-size ${VOCAB_SIZE} \ + --model-type ${MODEL_TYPE} \ + --docs-path "${DOCS_JSONL}" \ + --max-docs ${MAX_TRAIN_DOCS} \ + --eval-docs ${EVAL_DOCS} \ + --character-coverage 0.995 + +log "Tokenizer ready: ${CUSTOM_MODEL}" + +# ============================================================================= +# Step 3: Export binary shards +# ============================================================================= +log "Step 3: Exporting binary shards (30-60 min)..." + +python3 data/train_tokenizer.py \ + --vocab-size ${VOCAB_SIZE} \ + --model-type ${MODEL_TYPE} \ + --docs-path "${DOCS_JSONL}" \ + --export-shards \ + --shard-output-dir "${CUSTOM_SHARDS}" + +log "Shards ready: ${CUSTOM_SHARDS}" +log "Train shards: $(ls ${CUSTOM_SHARDS}/fineweb_train_*.bin 2>/dev/null | wc -l | tr -d ' ')" +log "Val shards: $(ls ${CUSTOM_SHARDS}/fineweb_val_*.bin 2>/dev/null | wc -l | tr -d ' ')" + +# ============================================================================= +# Step 4: Run training with custom tokenizer +# ============================================================================= +log "Step 4: Running SOTA training..." + +pip install --quiet zstandard + +NPROC=$(nvidia-smi --list-gpus 2>/dev/null | wc -l | tr -d ' ') +[ -z "$NPROC" ] || [ "$NPROC" -lt 1 ] && NPROC=1 + +RUN_ID="custom_${MODEL_TYPE}${VOCAB_SIZE}_$(date +%Y%m%d_%H%M%S)" \ +DATA_PATH="${CUSTOM_SHARDS}" \ +TOKENIZER_PATH="${CUSTOM_MODEL}" \ +VOCAB_SIZE=${VOCAB_SIZE} \ +SEED=42 \ +NUM_LAYERS=10 \ +MODEL_DIM=512 \ +NUM_HEADS=8 \ +NUM_KV_HEADS=4 \ +MLP_MULT=3.0 \ +TIE_EMBEDDINGS=1 \ +TRAIN_SEQ_LEN=2048 \ +TRAIN_BATCH_TOKENS=786432 \ +WARMDOWN_ITERS=3000 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +WEIGHT_DECAY=0.04 \ +MATRIX_LR=0.02 \ +SCALAR_LR=0.02 \ +TIED_EMBED_LR=0.03 \ +MUON_MOMENTUM=0.99 \ +MUON_MOMENTUM_WARMUP_START=0.92 \ +MUON_MOMENTUM_WARMUP_STEPS=1500 \ +GRAD_CLIP_NORM=0.3 \ +BIGRAM_VOCAB_SIZE=10240 \ +BIGRAM_DIM=128 \ +SWA_ENABLED=1 \ +SWA_START_FRAC=0.4 \ +SWA_EVERY=50 \ +EVAL_STRIDE=64 \ +EVAL_BATCH_SEQS=32 \ +torchrun --standalone --nproc_per_node=$NPROC train_gpt.py 2>&1 | tee /workspace/custom_tok_train.log + +log "Done!" +grep -E 'val_bpb|final_int8' /workspace/custom_tok_train.log | tail -5 \ No newline at end of file diff --git a/run_mos_sota.sh b/run_mos_sota.sh new file mode 100755 index 0000000000..3c04d53c27 --- /dev/null +++ b/run_mos_sota.sh @@ -0,0 +1,166 @@ +#!/usr/bin/env bash +# === Parameter Golf: MoS + SOTA Techniques on 1x/8x H100 (RunPod) === +# Tests Mixture of Softmax (K=2) with full SOTA technique stack. +# +# Usage on RunPod: +# git clone https://github.com/User123331/runpod-testing.git +# cd runpod-testing +# bash run_mos_sota.sh +# +# Modes: +# MODE=baseline bash run_mos_sota.sh # SOTA stack without MoS (control) +# MODE=mos bash run_mos_sota.sh # SOTA stack + MoS K=2 (experiment) +# MODE=smoke bash run_mos_sota.sh # Quick 300s smoke test with MoS + +set -euo pipefail + +log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } + +# Keep-alive heartbeat: prevents RunPod from killing pod during long builds +(while true; do sleep 60; nvidia-smi > /dev/null 2>&1; done) & +KEEPALIVE_PID=$! +trap "kill ${KEEPALIVE_PID} 2>/dev/null" EXIT + +MODE="${MODE:-mos}" +SEED="${SEED:-1337}" +MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +HF_TOKEN="${HF_TOKEN:-${HUGGINGFACE_TOKEN:-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}}" + +case "${MODE}" in + baseline) + USE_MOS=0 + MOS_K=2 + BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-2048}" + RUN_TAG="sota_baseline" + ;; + mos) + USE_MOS=1 + MOS_K="${MOS_K:-2}" + BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-1024}" # reduced to fit MoS in 16MB + RUN_TAG="sota_mos_k${MOS_K}" + ;; + smoke) + USE_MOS=1 + MOS_K="${MOS_K:-2}" + BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-1024}" + MAX_WALLCLOCK_SECONDS=300 + RUN_TAG="sota_mos_smoke" + ;; + *) + echo "Unknown MODE=${MODE}. Use: baseline, mos, smoke" >&2 + exit 1 + ;; +esac + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TRAIN_SCRIPT="${SCRIPT_DIR}/train_gpt_mos_sota.py" +RUN_ID="${RUN_TAG}_seed${SEED}_$(date +%Y%m%d_%H%M%S)" +LOG_DIR="${LOG_DIR:-${SCRIPT_DIR}/logs}" +mkdir -p "${LOG_DIR}" +LOG_PATH="${LOG_DIR}/${RUN_ID}.log" + +[ -f "${TRAIN_SCRIPT}" ] || { echo "ERROR: ${TRAIN_SCRIPT} not found"; exit 1; } + +# Ensure deps +python3 -c "import huggingface_hub, zstandard, sentencepiece, numpy" 2>/dev/null || \ + pip install --quiet huggingface_hub zstandard sentencepiece numpy --break-system-packages + +# Build FA3 (selective, ~5 min) if not already installed +if ! python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then + log "FA3 not found. Building selectively (~5 min)..." + FA3_DIR="${HOME}/flash-attention" + if [ ! -d "${FA3_DIR}" ]; then + git clone https://github.com/Dao-AILab/flash-attention.git "${FA3_DIR}" + fi + cd "${FA3_DIR}/hopper" + rm -rf build/ # clear any stale full-build artifacts + mkdir -p flash_attn_3 # pip copies .so here; dir must exist + # Only build bf16 hdim64 SM90 causal — skip everything else + export FLASH_ATTENTION_DISABLE_FP16=TRUE + export FLASH_ATTENTION_DISABLE_FP8=TRUE + export FLASH_ATTENTION_DISABLE_HDIM96=TRUE + export FLASH_ATTENTION_DISABLE_HDIM128=TRUE + export FLASH_ATTENTION_DISABLE_HDIM192=TRUE + export FLASH_ATTENTION_DISABLE_HDIM256=TRUE + export FLASH_ATTENTION_DISABLE_SM80=TRUE + export FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE + export FLASH_ATTENTION_DISABLE_APPENDKV=TRUE + export FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE + export FLASH_ATTENTION_DISABLE_PACKGQA=TRUE + export FLASH_ATTENTION_DISABLE_VARLEN=TRUE + export FLASH_ATTENTION_DISABLE_SPLIT=TRUE + export FLASH_ATTENTION_DISABLE_LOCAL=TRUE + export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF64=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF192=TRUE + pip install --no-build-isolation --break-system-packages -e . + cd "${SCRIPT_DIR}" + log "FA3 build complete." +else + log "FA3 already installed." +fi + +# Download dataset if needed +DATA_DIR="data/datasets/fineweb10B_sp1024" +TOK_PATH="data/tokenizers/fineweb_1024_bpe.model" +if [ ! -f "${DATA_DIR}/fineweb_train_000000.bin" ] || [ ! -f "${TOK_PATH}" ]; then + log "Downloading FineWeb dataset..." + if [ -n "${HF_TOKEN}" ]; then export HF_TOKEN; fi + python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 +fi + +GPU_COUNT="$(nvidia-smi --list-gpus 2>/dev/null | wc -l | tr -d ' ')" +log "Detected ${GPU_COUNT} GPU(s). Mode: ${MODE}" +log "MoS: USE_MOS=${USE_MOS} MOS_K=${MOS_K} BIGRAM_VOCAB_SIZE=${BIGRAM_VOCAB_SIZE}" +log "Run ID: ${RUN_ID}" + +export PYTHONUNBUFFERED=1 +export RUN_ID +export DATA_PATH="./${DATA_DIR}" +export TOKENIZER_PATH="./${TOK_PATH}" +export VOCAB_SIZE=1024 +export NUM_LAYERS=11 +export MODEL_DIM=512 +export NUM_HEADS=8 +export NUM_KV_HEADS=4 +export MLP_MULT=3.0 +export TIE_EMBEDDINGS=1 +export TRAIN_BATCH_TOKENS=786432 +export TRAIN_SEQ_LEN=2048 +export BIGRAM_VOCAB_SIZE +export BIGRAM_DIM=128 +export MATRIX_LR=0.025 +export SCALAR_LR=0.025 +export TIED_EMBED_LR=0.035 +export MUON_MOMENTUM=0.99 +export MUON_MOMENTUM_WARMUP_START=0.92 +export MUON_MOMENTUM_WARMUP_STEPS=1500 +export WARMDOWN_ITERS=3000 +export ITERATIONS=9000 +export MAX_WALLCLOCK_SECONDS +export EVAL_STRIDE=64 +export SWA_ENABLED=1 +export SWA_EVERY=50 +export MUON_WD=0.04 +export ADAM_WD=0.04 +export XSA_LAST_N=4 +export ROPE_DIMS=16 +export LN_SCALE=1 +export LATE_QAT_THRESHOLD=0.1 +export VE_ENABLED=1 +export VE_DIM=128 +export VE_LAYERS="9,10" +export USE_MOS +export MOS_K +export SEED +export DISABLE_COMPILE="${DISABLE_COMPILE:-1}" # Disable torch.compile by default (fixes inductor issues) + +log "Starting training..." +log "Log file: ${LOG_PATH}" +torchrun --standalone --nproc_per_node="${GPU_COUNT}" "${TRAIN_SCRIPT}" 2>&1 | tee "${LOG_PATH}" +TRAIN_EXIT=${PIPESTATUS[0]} + +log "Training finished (exit code: ${TRAIN_EXIT}). Key metrics:" +grep -E 'val_bpb|model_params|mos_params|final_int|submission|Serialized|artifact|swa:' "${LOG_PATH}" | tail -20 || true + +log "Done. Log: ${LOG_PATH}" diff --git a/run_pilot.sh b/run_pilot.sh new file mode 100755 index 0000000000..cb12b317b3 --- /dev/null +++ b/run_pilot.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Quick start script for 1x H100 MoS pilot +# Run from the parameter-golf repo root directory + +set -e + +echo "=== Parameter Golf MoS Pilot ===" +echo "Date: $(date)" +echo "GPU: 1x H100 SXM" +echo "" + +# Configuration (all via env vars — train_gpt.py has no argparse) +ITERATIONS=2000 +SEED=42 +MOS_K=2 +MOS_RANK=64 # Low-rank to fit in 16MB budget (~100KB vs ~500KB full-rank) + +echo "Configuration:" +echo " Iterations: $ITERATIONS" +echo " Seed: $SEED" +echo " MoS K: $MOS_K" +echo " MoS Rank: $MOS_RANK (0=full-rank)" +echo "" + +# Baseline run +echo "=== Running Baseline ===" +ITERATIONS=$ITERATIONS SEED=$SEED MAX_WALLCLOCK_SECONDS=99999 \ + python3 train_gpt.py 2>&1 | tee baseline_log.txt + +echo "" +echo "=== Running MoS K=$MOS_K rank=$MOS_RANK ===" +ITERATIONS=$ITERATIONS SEED=$SEED MAX_WALLCLOCK_SECONDS=99999 \ + USE_MOS=1 MOS_K=$MOS_K MOS_RANK=$MOS_RANK \ + python3 train_gpt.py 2>&1 | tee mos_k${MOS_K}_r${MOS_RANK}_log.txt + +echo "" +echo "=== Done ===" +echo "Compare results:" +echo " grep 'val_bpb' baseline_log.txt" +echo " grep 'val_bpb' mos_k${MOS_K}_r${MOS_RANK}_log.txt" +echo " grep 'bytes' baseline_log.txt" +echo " grep 'bytes' mos_k${MOS_K}_r${MOS_RANK}_log.txt" \ No newline at end of file diff --git a/setup_and_run.sh b/setup_and_run.sh new file mode 100755 index 0000000000..af74eca5ad --- /dev/null +++ b/setup_and_run.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# === Parameter Golf: MoS Pilot on 1x H100 === +# Paste this into your RunPod terminal. +# Total time: ~18 min download + 10 min MoS run = ~28 min +# Baseline already known: ~1.2244 bpb (10min/8xH100) or ~1.2074 (4hr/8xH100) + +set -e + +echo "=== Step 1: Download dataset ===" +# Run from the repo root (already cloned) +cd /workspace/parameter-golf + +# HF token for faster downloads (avoids rate limiting) +export HF_TOKEN="${HF_TOKEN:-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}" + +# Download full dataset (~18 min) +python3 data/cached_challenge_fineweb.py --variant sp1024 + +# Verify dataset +echo "Train shards: $(ls data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l)" +echo "Val shards: $(ls data/datasets/fineweb10B_sp1024/fineweb_val_*.bin 2>/dev/null | wc -l)" + +echo "" +echo "=== Step 2: Run MoS K=2 rank=64 (10 min, 1x H100) ===" +echo "Start time: $(date)" + +RUN_ID=mos_k2_r64_pilot \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +SEED=42 \ +USE_MOS=1 \ +MOS_K=2 \ +MOS_RANK=64 \ +MAX_WALLCLOCK_SECONDS=600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | tee /workspace/mos_log.txt + +echo "" +echo "=== RESULTS ===" +echo "" +grep -E 'val_bpb|val_loss|bytes|param|model_params' /workspace/mos_log.txt | tail -15 +echo "" +echo "Known baseline: val_bpb ~1.2244 (10min/8xH100)" +echo "Done at: $(date)" diff --git a/setup_and_run_1h.sh b/setup_and_run_1h.sh new file mode 100644 index 0000000000..1efe040661 --- /dev/null +++ b/setup_and_run_1h.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# === Parameter Golf: MoS 1-Hour Validation on 1x H100 === +# Usage: bash setup_and_run_1h.sh +# The script runs training inside nohup so it survives terminal disconnects. +# Log is written to /workspace/mos_1h_log.txt — check with: tail -f /workspace/mos_1h_log.txt + +set -e + +echo "=== Step 1: Download dataset ===" +cd /workspace/parameter-golf + +# HF token for faster downloads +export HF_TOKEN="${HF_TOKEN:-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}" + +# Download full dataset (~18 min, skips if already present) +python3 data/cached_challenge_fineweb.py --variant sp1024 + +# Verify dataset +TRAIN_COUNT=$(ls data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l) +VAL_COUNT=$(ls data/datasets/fineweb10B_sp1024/fineweb_val_*.bin 2>/dev/null | wc -l) +echo "Train shards: $TRAIN_COUNT Val shards: $VAL_COUNT" +if [ "$TRAIN_COUNT" -lt 1 ]; then + echo "ERROR: No training shards found. Dataset download failed." + exit 1 +fi + +echo "" +echo "=== Step 2: Run MoS K=2 R=64 (1 HOUR, 1x H100) ===" +echo "Start time: $(date)" +echo "" +echo "Training will run in the background via nohup." +echo "Monitor with: tail -f /workspace/mos_1h_log.txt" +echo "Check GPU with: nvidia-smi" +echo "Safe to close terminal — training will continue." +echo "" + +nohup bash -c ' +RUN_ID=mos_k2_r64_1h \ +DATA_PATH=./data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +SEED=42 \ +USE_MOS=1 \ +MOS_K=2 \ +MOS_RANK=64 \ +WARMDOWN_ITERS=100 \ +MAX_WALLCLOCK_SECONDS=3600 \ +VAL_LOSS_EVERY=500 \ +TRAIN_LOG_EVERY=100 \ +torchrun --standalone --nproc_per_node=1 train_gpt.py +' > /workspace/mos_1h_log.txt 2>&1 & + +TRAIN_PID=$! +echo "Training PID: $TRAIN_PID" +echo "PID saved to /workspace/train.pid" +echo "$TRAIN_PID" > /workspace/train.pid + +# Wait a few seconds and confirm it started +sleep 5 +if kill -0 $TRAIN_PID 2>/dev/null; then + echo "Training is running. You can safely close this terminal." + echo "" + echo "=== Quick commands ===" + echo " Monitor: tail -f /workspace/mos_1h_log.txt" + echo " Status: nvidia-smi" + echo " Kill: kill \$(cat /workspace/train.pid)" +else + echo "ERROR: Training process died. Check /workspace/mos_1h_log.txt" + tail -20 /workspace/mos_1h_log.txt + exit 1 +fi diff --git a/setup_and_run_sota_1xh100.sh b/setup_and_run_sota_1xh100.sh new file mode 100755 index 0000000000..3fe4b71b25 --- /dev/null +++ b/setup_and_run_sota_1xh100.sh @@ -0,0 +1,262 @@ +#!/usr/bin/env bash +# === Parameter Golf: SOTA comparison on 1x H100 (RunPod) === +# Default target: PR #198 (1.1318 bpb on 8xH100, current best open result in local notes) +# +# Usage on RunPod: +# git clone https://github.com/User123331/runpod-testing.git +# cd runpod-testing +# HF_TOKEN=hf_xxx bash setup_and_run_sota_1xh100.sh +# +# Optional: +# TARGET_PR=180 bash setup_and_run_sota_1xh100.sh # thwu1 merged record +# SEED=42 bash setup_and_run_sota_1xh100.sh +# TRAIN_SHARDS=1 bash setup_and_run_sota_1xh100.sh # smoke download only + +set -euo pipefail + +log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } +warn() { printf '[%s] WARNING: %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*" >&2; } +die() { printf '[%s] ERROR: %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*" >&2; exit 1; } + +require_cmd() { + command -v "$1" >/dev/null 2>&1 || die "Required command not found: $1" +} + +require_clean_checkout() { + if ! git diff --quiet || ! git diff --cached --quiet; then + die "Existing checkout at ${SRC_DIR} has uncommitted changes. Use a fresh SRC_DIR." + fi +} + +discover_legacy_hf_token() { + python3 - "$@" <<'PY' +from pathlib import Path +import re +import sys + +pattern = re.compile(r'export\s+HF_TOKEN="\$\{HF_TOKEN:-([^"}]+)\}"') + +for raw_path in sys.argv[1:]: + path = Path(raw_path) + if not path.is_file(): + continue + text = path.read_text(encoding="utf-8", errors="ignore") + match = pattern.search(text) + if match: + print(f"{match.group(1)}\t{path}") + sys.exit(0) + +sys.exit(1) +PY +} + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +TARGET_PR="${TARGET_PR:-198}" +WORKSPACE_DIR="${WORKSPACE_DIR:-/workspace}" +SRC_DIR="${SRC_DIR:-${WORKSPACE_DIR}/parameter-golf-pr${TARGET_PR}}" +LOG_DIR="${LOG_DIR:-${WORKSPACE_DIR}/logs}" +TRAIN_SHARDS="${TRAIN_SHARDS:-80}" +MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}" +TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-786432}" +TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}" +HF_TOKEN="${HF_TOKEN:-${HUGGINGFACE_TOKEN:-}}" +HF_TOKEN_SOURCE="" + +case "${TARGET_PR}" in + 198) + TARGET_SHA="${TARGET_SHA:-372bddea57f465c7217c5e26af2252a803221518}" + TRAIN_SCRIPT_REL="records/track_10min_16mb/2026-03-20_11L_Int6_MLP3x_WD04_SmearBigram2k_1.1318/train_gpt.py" + NUM_LAYERS="${NUM_LAYERS:-11}" + BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-2048}" + MATRIX_LR="${MATRIX_LR:-0.025}" + SCALAR_LR="${SCALAR_LR:-0.025}" + TIED_EMBED_LR="${TIED_EMBED_LR:-0.035}" + MUON_MOMENTUM="${MUON_MOMENTUM:-0.99}" + MUON_MOMENTUM_WARMUP_START="${MUON_MOMENTUM_WARMUP_START:-0.92}" + MUON_MOMENTUM_WARMUP_STEPS="${MUON_MOMENTUM_WARMUP_STEPS:-1500}" + WARMDOWN_ITERS="${WARMDOWN_ITERS:-3000}" + ITERATIONS="${ITERATIONS:-9000}" + EVAL_STRIDE="${EVAL_STRIDE:-64}" + VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-1000}" + TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-200}" + SWA_EVERY="${SWA_EVERY:-200}" + MUON_WD="${MUON_WD:-0.04}" + ADAM_WD="${ADAM_WD:-0.04}" + SEED="${SEED:-1337}" + REQUIRED_PY_MODULES="flash_attn_interface" + ;; + 180) + TARGET_SHA="${TARGET_SHA:-1a8be36c17e20b1fb53dbf4975e1d67f5b8a63e9}" + TRAIN_SCRIPT_REL="records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/train_gpt.py" + NUM_LAYERS="${NUM_LAYERS:-10}" + BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-10240}" + MATRIX_LR="${MATRIX_LR:-0.02}" + SCALAR_LR="${SCALAR_LR:-0.02}" + TIED_EMBED_LR="${TIED_EMBED_LR:-0.03}" + MUON_MOMENTUM="${MUON_MOMENTUM:-0.99}" + MUON_MOMENTUM_WARMUP_START="${MUON_MOMENTUM_WARMUP_START:-0.92}" + MUON_MOMENTUM_WARMUP_STEPS="${MUON_MOMENTUM_WARMUP_STEPS:-1500}" + WARMDOWN_ITERS="${WARMDOWN_ITERS:-3000}" + ITERATIONS="${ITERATIONS:-9000}" + EVAL_STRIDE="${EVAL_STRIDE:-64}" + VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-500}" + TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-100}" + SWA_START_FRAC="${SWA_START_FRAC:-0.4}" + SWA_EVERY="${SWA_EVERY:-50}" + WEIGHT_DECAY="${WEIGHT_DECAY:-0.04}" + SEED="${SEED:-42}" + REQUIRED_PY_MODULES="" + ;; + *) + die "Unsupported TARGET_PR=${TARGET_PR}. Use 198 (default) or 180." + ;; +esac + +RUN_ID="${RUN_ID:-pr${TARGET_PR}_1xh100_seed${SEED}_$(date +%Y%m%d_%H%M%S)}" +LOG_PATH="${LOG_DIR}/${RUN_ID}.log" + +require_cmd git +require_cmd python3 +require_cmd torchrun +require_cmd nvidia-smi + +GPU_COUNT="$(nvidia-smi --list-gpus | wc -l | tr -d ' ')" +[ "${GPU_COUNT}" -ge 1 ] || die "No GPUs detected." + +log "Detected ${GPU_COUNT} GPU(s). This script uses exactly 1 GPU." +if [ -z "${HF_TOKEN}" ]; then + if TOKEN_RECORD="$( + discover_legacy_hf_token \ + "${SCRIPT_DIR}/setup_and_run.sh" \ + "${SCRIPT_DIR}/setup_and_run_1h.sh" \ + "${SCRIPT_DIR}/run_custom_tokenizer_pipeline.sh" \ + "${SCRIPT_DIR}/../parameter-golf/setup_and_run.sh" \ + "${SCRIPT_DIR}/../parameter-golf/setup_and_run_1h.sh" \ + "${SCRIPT_DIR}/../trainer-tokenizer/setup_runpod.sh" + )"; then + IFS=$'\t' read -r HF_TOKEN HF_TOKEN_SOURCE <<< "${TOKEN_RECORD}" + fi +fi + +if [ -n "${HF_TOKEN}" ]; then + export HF_TOKEN + if [ -n "${HF_TOKEN_SOURCE}" ]; then + warn "Using HF token found in existing local script: ${HF_TOKEN_SOURCE}" + else + log "HF token detected in environment; authenticated downloads enabled." + fi +else + warn "HF_TOKEN/HUGGINGFACE_TOKEN not set. Public downloads may still work, but auth is recommended." +fi + +python3 - <<'PY' +import importlib.util +import subprocess +import sys + +missing = [pkg for pkg in ("huggingface_hub", "zstandard") if importlib.util.find_spec(pkg) is None] +if missing: + subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", *missing]) +PY + +mkdir -p "${LOG_DIR}" + +if [ ! -d "${SRC_DIR}/.git" ]; then + log "Cloning openai/parameter-golf into ${SRC_DIR}" + git clone https://github.com/openai/parameter-golf.git "${SRC_DIR}" +fi + +cd "${SRC_DIR}" +require_clean_checkout + +log "Fetching PR #${TARGET_PR}" +git fetch origin "pull/${TARGET_PR}/head:runpod-pr-${TARGET_PR}" --force +git checkout --detach "${TARGET_SHA}" + +CURRENT_SHA="$(git rev-parse HEAD)" +[ "${CURRENT_SHA}" = "${TARGET_SHA}" ] || die "Checked out ${CURRENT_SHA}, expected ${TARGET_SHA}" +log "Checked out PR #${TARGET_PR} commit ${CURRENT_SHA}" +[ -f "${TRAIN_SCRIPT_REL}" ] || die "Target training script not found: ${TRAIN_SCRIPT_REL}" + +python3 - "${TRAIN_SCRIPT_REL}" "${REQUIRED_PY_MODULES}" <<'PY' +from pathlib import Path +import importlib.util +import sys + +train_script = Path(sys.argv[1]) +required_modules = [m for m in sys.argv[2].split(",") if m] + +source = train_script.read_text(encoding="utf-8") +compile(source, str(train_script), "exec") + +missing = [m for m in required_modules if importlib.util.find_spec(m) is None] +if missing: + raise SystemExit(f"Missing required Python modules for {train_script}: {', '.join(missing)}") +PY +log "Preflight compile check passed for ${TRAIN_SCRIPT_REL}" + +DATASET_DIR="data/datasets/fineweb10B_sp1024" +TOKENIZER_PATH="data/tokenizers/fineweb_1024_bpe.model" + +if [ ! -f "${DATASET_DIR}/fineweb_train_000000.bin" ] || [ ! -f "${TOKENIZER_PATH}" ]; then + log "Downloading FineWeb cached dataset/tokenizer (train_shards=${TRAIN_SHARDS})" + python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards "${TRAIN_SHARDS}" +else + log "Dataset/tokenizer already present; skipping download." +fi + +TRAIN_COUNT="$(find "${DATASET_DIR}" -maxdepth 1 -name 'fineweb_train_*.bin' 2>/dev/null | wc -l | tr -d ' ')" +VAL_COUNT="$(find "${DATASET_DIR}" -maxdepth 1 -name 'fineweb_val_*.bin' 2>/dev/null | wc -l | tr -d ' ')" +log "Dataset ready: train_shards=${TRAIN_COUNT} val_shards=${VAL_COUNT}" + +export PYTHONUNBUFFERED=1 +export RUN_ID +export DATA_PATH="./${DATASET_DIR}" +export TOKENIZER_PATH="./${TOKENIZER_PATH}" +export VOCAB_SIZE="${VOCAB_SIZE:-1024}" +export NUM_LAYERS +export MODEL_DIM="${MODEL_DIM:-512}" +export NUM_HEADS="${NUM_HEADS:-8}" +export NUM_KV_HEADS="${NUM_KV_HEADS:-4}" +export MLP_MULT="${MLP_MULT:-3.0}" +export TIE_EMBEDDINGS="${TIE_EMBEDDINGS:-1}" +export TRAIN_BATCH_TOKENS +export TRAIN_SEQ_LEN +export BIGRAM_VOCAB_SIZE +export BIGRAM_DIM="${BIGRAM_DIM:-128}" +export MATRIX_LR +export SCALAR_LR +export TIED_EMBED_LR +export MUON_MOMENTUM +export MUON_MOMENTUM_WARMUP_START +export MUON_MOMENTUM_WARMUP_STEPS +export SWA_ENABLED="${SWA_ENABLED:-1}" +export EVAL_STRIDE +export EVAL_BATCH_SEQS="${EVAL_BATCH_SEQS:-32}" +export ITERATIONS +export WARMDOWN_ITERS +export MAX_WALLCLOCK_SECONDS +export VAL_LOSS_EVERY +export TRAIN_LOG_EVERY +export SEED + +if [ "${TARGET_PR}" = "198" ]; then + export MUON_WD + export ADAM_WD + export SWA_EVERY +else + export WEIGHT_DECAY + export SWA_START_FRAC + export SWA_EVERY +fi + +log "Starting 1xH100 run for PR #${TARGET_PR}" +log "Run ID: ${RUN_ID}" +log "Log file: ${LOG_PATH}" + +torchrun --standalone --nproc_per_node=1 "${TRAIN_SCRIPT_REL}" 2>&1 | tee "${LOG_PATH}" + +log "Run completed. Final metrics:" +grep -E 'val_bpb|val_loss|artifact|bytes|final_int|submission|model_params|swa:' "${LOG_PATH}" | tail -20 || true + +log "Done." diff --git a/setup_hyperbolic.sh b/setup_hyperbolic.sh new file mode 100644 index 0000000000..532e277ff1 --- /dev/null +++ b/setup_hyperbolic.sh @@ -0,0 +1,100 @@ +#!/usr/bin/env bash +# === Hyperbolic.ai 8x H100 Setup Script === +# Run this after SSHing into your instance +# +# Usage: +# wget https://raw.githubusercontent.com/User123331/runpod-testing/main/setup_hyperbolic.sh +# chmod +x setup_hyperbolic.sh +# ./setup_hyperbolic.sh +# +# Or paste directly from clipboard + +set -euo pipefail + +log() { printf '[%s] %s\n' "$(date '+%Y-%m-%d %H:%M:%S')" "$*"; } + +# Check GPU count +GPU_COUNT=$(nvidia-smi --list-gpus 2>/dev/null | wc -l | tr -d ' ') +log "Detected ${GPU_COUNT} GPU(s)" + +if [ "${GPU_COUNT}" -lt 8 ]; then + log "WARNING: Expected 8 GPUs, found ${GPU_COUNT}" +fi + +# Keep-alive to prevent timeout during long builds +(while true; do sleep 60; nvidia-smi > /dev/null 2>&1; done) & +KEEPALIVE_PID=$! +trap "kill ${KEEPALIVE_PID} 2>/dev/null" EXIT + +# 1. Clone the competition repo (already in image, but verify) +if [ ! -d "/workspace/parameter-golf" ]; then + log "Cloning parameter-golf repo..." + cd /workspace + git clone https://github.com/openai/parameter-golf.git + cd parameter-golf +else + log "parameter-golf repo already exists" + cd /workspace/parameter-golf +fi + +# 2. Clone our MoS-enhanced training scripts +log "Cloning runpod-testing repo with MoS implementation..." +if [ ! -d "/workspace/runpod-testing" ]; then + cd /workspace + git clone https://github.com/User123331/runpod-testing.git +else + cd /workspace/runpod-testing + git pull || true +fi + +# 3. Build Flash Attention 3 (selective, ~5 min) +log "Building Flash Attention 3 (selective kernels only)..." +if python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null; then + log "FA3 already installed" +else + FA3_DIR="/workspace/flash-attention" + if [ ! -d "${FA3_DIR}" ]; then + git clone https://github.com/Dao-AILab/flash-attention.git "${FA3_DIR}" + fi + cd "${FA3_DIR}/hopper" + rm -rf build/ + mkdir -p flash_attn_3 + + # Only build bf16 hdim64 SM90 causal — skip everything else + export FLASH_ATTENTION_DISABLE_FP16=TRUE + export FLASH_ATTENTION_DISABLE_FP8=TRUE + export FLASH_ATTENTION_DISABLE_HDIM96=TRUE + export FLASH_ATTENTION_DISABLE_HDIM128=TRUE + export FLASH_ATTENTION_DISABLE_HDIM192=TRUE + export FLASH_ATTENTION_DISABLE_HDIM256=TRUE + export FLASH_ATTENTION_DISABLE_SM80=TRUE + export FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE + export FLASH_ATTENTION_DISABLE_APPENDKV=TRUE + export FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE + export FLASH_ATTENTION_DISABLE_PACKGQA=TRUE + export FLASH_ATTENTION_DISABLE_VARLEN=TRUE + export FLASH_ATTENTION_DISABLE_SPLIT=TRUE + export FLASH_ATTENTION_DISABLE_LOCAL=TRUE + export FLASH_ATTENTION_DISABLE_CLUSTER=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF64=TRUE + export FLASH_ATTENTION_DISABLE_HDIMDIFF192=TRUE + + log "Starting FA3 selective build (~5 min)..." + pip install --no-build-isolation -e . + log "FA3 build complete" +fi + +# 4. Download dataset (80 train shards = 8B tokens) +cd /workspace/parameter-golf +log "Downloading FineWeb dataset (8B tokens)..." +HF_TOKEN="${HF_TOKEN:-${HUGGINGFACE_TOKEN:-}}" python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80 + +# 5. Quick sanity check +log "" +log "=== Setup Complete ===" +log "GPU Count: ${GPU_COUNT}" +log "FA3 Status: $(python3 -c 'from flash_attn_interface import flash_attn_func; print("OK")' 2>/dev/null || echo 'FAILED')" +log "Dataset: $(ls -1 data/datasets/fineweb10B_sp1024/fineweb_train_*.bin 2>/dev/null | wc -l) train shards" +log "" +log "Ready to run experiments!" +log "See: /workspace/runpod-testing/run_mos_sota.sh" \ No newline at end of file diff --git a/test_mos.py b/test_mos.py new file mode 100644 index 0000000000..6732b343b0 --- /dev/null +++ b/test_mos.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""Quick test to verify MoS implementation works correctly.""" + +import torch +import torch.nn.functional as F + +# Mock CastedLinear for testing +class CastedLinear(torch.nn.Linear): + def forward(self, x): + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +class MixtureOfSoftmax(torch.nn.Module): + """Mixture of Softmax output layer for breaking the softmax bottleneck.""" + + def __init__(self, model_dim: int, vocab_size: int, n_mixtures: int = 2, rank: int = 0): + super().__init__() + self.n_mixtures = n_mixtures + self.model_dim = model_dim + self.vocab_size = vocab_size + self.rank = rank + + if rank > 0: + self.proj_down = CastedLinear(model_dim, rank, bias=False) + self.proj_up = CastedLinear(rank, n_mixtures * model_dim, bias=False) + torch.nn.init.normal_(self.proj_down.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(self.proj_up.weight, mean=0.0, std=0.02) + else: + self.projections = CastedLinear(model_dim, n_mixtures * model_dim, bias=False) + torch.nn.init.normal_(self.projections.weight, mean=0.0, std=0.02) + + self.gate = CastedLinear(model_dim, n_mixtures, bias=False) + torch.nn.init.normal_(self.gate.weight, mean=0.0, std=0.02) + + def forward(self, hidden: torch.Tensor, weight_matrix: torch.Tensor) -> torch.Tensor: + bsz, seq_len, dim = hidden.shape + K = self.n_mixtures + + pi = F.softmax(self.gate(hidden), dim=-1) + + if self.rank > 0: + projected = self.proj_up(self.proj_down(hidden)).view(bsz, seq_len, K, dim) + else: + projected = self.projections(hidden).view(bsz, seq_len, K, dim) + + logits = torch.einsum('bskd,vd->bskv', projected, weight_matrix) + + log_probs = F.log_softmax(logits, dim=-1) + log_pi = torch.log(pi.unsqueeze(-1) + 1e-10) + mixed_log_probs = torch.logsumexp(log_probs + log_pi, dim=2) + + return mixed_log_probs + + +def test_mos(): + """Test MoS forward pass.""" + print("Testing MoS implementation...") + + vocab_size = 1024 + model_dim = 512 + batch_size = 2 + seq_len = 16 + + hidden = torch.randn(batch_size, seq_len, model_dim) + weight_matrix = torch.randn(vocab_size, model_dim) + + for K in [1, 2, 3]: + for rank in [0, 32, 64]: + mos = MixtureOfSoftmax(model_dim, vocab_size, n_mixtures=K, rank=rank) + output = mos(hidden, weight_matrix) + + assert output.shape == (batch_size, seq_len, vocab_size), f"Wrong shape: {output.shape}" + + # Verify output is valid log probabilities + probs = torch.exp(output) + prob_sum = probs.sum(dim=-1) + assert torch.allclose(prob_sum, torch.ones_like(prob_sum), atol=1e-4), \ + f"K={K} rank={rank}: probs don't sum to 1: {prob_sum.mean():.6f}" + + # Count parameters + params = sum(p.numel() for p in mos.parameters()) + size_kb = params / 1024 + print(f" K={K} rank={rank:>3d}: {params:>10,} params ({size_kb:>7.1f} KB at int8)") + + # Verify NLL loss works correctly with MoS output + mos = MixtureOfSoftmax(model_dim, vocab_size, n_mixtures=2, rank=64) + output = mos(hidden, weight_matrix) + targets = torch.randint(0, vocab_size, (batch_size, seq_len)) + loss = F.nll_loss(output.reshape(-1, vocab_size), targets.reshape(-1)) + assert loss.isfinite(), f"NLL loss is not finite: {loss}" + print(f"\n NLL loss test: {loss.item():.4f} (should be ~6.93 for random)") + + print("\nAll tests passed!") + + +if __name__ == "__main__": + test_mos() diff --git a/train_gpt.py b/train_gpt.py index 85e2cc463a..bbe5ab2943 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -19,6 +19,12 @@ import zlib from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + import numpy as np import sentencepiece as spm import torch @@ -30,79 +36,68 @@ # ----------------------------- # HYPERPARAMETERS # ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap class Hyperparameters: - # Data paths are shard globs produced by the existing preprocessing pipeline. data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") train_files = os.path.join(data_path, "fineweb_train_*.bin") val_files = os.path.join(data_path, "fineweb_val_*.bin") tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) - seed = int(os.environ.get("SEED", 1337)) + seed = int(os.environ.get("SEED", 42)) - # Validation cadence and batch size. Validation always uses the full fineweb_val split. val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) - val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) - train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) - # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) - train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) - # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) model_dim = int(os.environ.get("MODEL_DIM", 512)) num_heads = int(os.environ.get("NUM_HEADS", 8)) - mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) - # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) - muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) - muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) - muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - # Test-time training (LoRA) hyperparameters. - ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) - ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) - ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) - ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) - ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # ----------------------------- -# MUON OPTIMIZER +# MUON OPTIMIZER # ----------------------------- -# -# As borrowed from modded-nanogpt -# Background on Muon: https://kellerjordan.github.io/posts/muon/ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: - # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. - # Muon uses this to normalize matrix-shaped gradients before applying them. a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() X /= X.norm() + eps @@ -117,10 +112,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) - class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): super().__init__( params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), ) @torch.no_grad() @@ -129,7 +124,6 @@ def step(self, closure=None): if closure is not None: with torch.enable_grad(): loss = closure() - distributed = dist.is_available() and dist.is_initialized() world_size = dist.get_world_size() if distributed else 1 rank = dist.get_rank() if distributed else 0 @@ -158,7 +152,6 @@ def step(self, closure=None): if nesterov: g = g.add(buf, alpha=momentum) g = zeropower_via_newtonschulz5(g, steps=backend_steps) - # Scale correction from Muon reference implementations. g *= max(1, g.size(0) / g.size(1)) ** 0.5 updates_flat[curr : curr + p.numel()] = g.reshape(-1) curr += p.numel() @@ -166,23 +159,20 @@ def step(self, closure=None): if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) curr = 0 for p in params: g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) p.add_(g, alpha=-lr) curr += p.numel() - return loss # ----------------------------- -# TOKENIZER-AGNOSTIC EVALUATION SETUP +# TOKENIZER-AGNOSTIC EVALUATION # ----------------------------- -# -# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. -# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. -# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. -# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. def build_sentencepiece_luts( sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device @@ -200,7 +190,7 @@ def build_sentencepiece_luts( base_bytes_np[token_id] = 1 continue piece = sp.id_to_piece(token_id) - if piece.startswith("▁"): + if piece.startswith("\u2581"): has_leading_space_np[token_id] = True piece = piece[1:] base_bytes_np[token_id] = len(piece.encode("utf-8")) @@ -215,7 +205,6 @@ def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: files = [Path(p) for p in sorted(glob.glob(pattern))] if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") - # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() usable = ((tokens.numel() - 1) // seq_len) * seq_len if usable <= 0: @@ -235,9 +224,6 @@ def eval_val( has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, ) -> tuple[float, float]: - # Validation computes two metrics: - # - val_loss: token cross-entropy (natural log) - # - val_bpb: tokenizer-agnostic compression metric used by the challenge local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) if local_batch_tokens < args.train_seq_len: raise ValueError( @@ -252,7 +238,6 @@ def eval_val( val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) val_token_count = torch.zeros((), device=device, dtype=torch.float64) val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - model.eval() with torch.inference_mode(): for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): @@ -272,34 +257,34 @@ def eval_val( token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - val_loss = val_loss_sum / val_token_count bits_per_token = val_loss.item() / math.log(2.0) tokens_per_byte = val_token_count.item() / val_byte_count.item() model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + # ----------------------------- -# POST-TRAINING QUANTIZATION +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) # ----------------------------- -# -# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", ).split(",") if pattern ) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( @@ -317,19 +302,9 @@ def eval_val( def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) -def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: - if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): - return t.float().contiguous() - if t.dtype in {torch.float32, torch.bfloat16}: - passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") - return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() - return t - def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: t32 = t.float() if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. clip_abs = ( torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() @@ -339,105 +314,95 @@ def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - - # Vectors / scalars use a simpler per-tensor scale. clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() return q, scale -def quantize_state_dict_int8(state_dict: dict[str, Tensor]): - # Single supported clean-script export format: - # - per-row int8 for 2D float tensors - # - per-tensor int8 for other float tensors - # - exact passthrough for non-floats - # - passthrough for small float tensors, stored as fp16 to save bytes - quantized: dict[str, Tensor] = {} - scales: dict[str, Tensor] = {} - dtypes: dict[str, str] = {} - passthrough: dict[str, Tensor] = {} - passthrough_orig_dtypes: dict[str, str] = {} - qmeta: dict[str, dict[str, object]] = {} - stats = dict.fromkeys( - ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), - 0, - ) +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} for name, tensor in state_dict.items(): - t = tensor.detach().to("cpu").contiguous() - stats["param_count"] += int(t.numel()) - stats["num_tensors"] += 1 - stats["baseline_tensor_bytes"] += tensor_nbytes(t) - - if not t.is_floating_point(): - stats["num_nonfloat_tensors"] += 1 - passthrough[name] = t - stats["int8_payload_bytes"] += tensor_nbytes(t) + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" continue - - # Small float tensors are cheap enough to keep directly. We still downcast - # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. - if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: - kept = keep_float_tensor(name, t, passthrough_orig_dtypes) - passthrough[name] = kept - stats["int8_payload_bytes"] += tensor_nbytes(kept) + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" continue - - stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) - if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} - quantized[name] = q - scales[name] = s - dtypes[name] = str(t.dtype).removeprefix("torch.") - stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) - - obj: dict[str, object] = { - "__quant_format__": "int8_clean_per_row_v1", - "quantized": quantized, - "scales": scales, - "dtypes": dtypes, - "passthrough": passthrough, - } - if qmeta: - obj["qmeta"] = qmeta - if passthrough_orig_dtypes: - obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes - return obj, stats - -def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: out: dict[str, Tensor] = {} - qmeta = obj.get("qmeta", {}) - passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) - for name, q in obj["quantized"].items(): - dtype = getattr(torch, obj["dtypes"][name]) - s = obj["scales"][name] - if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: - s = s.to(dtype=torch.float32) - # Broadcast the saved row scale back across trailing dimensions. - out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) else: - scale = float(s.item()) - out[name] = (q.float() * scale).to(dtype=dtype).contiguous() - for name, t in obj["passthrough"].items(): - # Restore small tensors, undoing the temporary fp16 storage cast if needed. - out_t = t.detach().to("cpu").contiguous() - orig_dtype = passthrough_orig_dtypes.get(name) - if isinstance(orig_dtype, str): - out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() - out[name] = out_t + out[name] = (q.float() * float(s.item())).to(orig_dtype) return out # ----------------------------- -# DATA LOADING +# DATA LOADING # ----------------------------- def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" Tensor: class TokenStream: - # Reads shards sequentially and wraps around forever. The training loop therefore - # has deterministic, simple streaming behavior with no sampling or workers. def __init__(self, pattern: str): self.files = [Path(p) for p in sorted(glob.glob(pattern))] if not self.files: @@ -482,8 +445,6 @@ def take(self, n: int) -> Tensor: class DistributedTokenLoader: - # Each call consumes a contiguous chunk from the shared token stream, then slices out - # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): self.rank = rank self.world_size = world_size @@ -500,6 +461,7 @@ def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + # ----------------------------- # TRANSFORMER MODULES # ----------------------------- @@ -514,14 +476,13 @@ def forward(self, x: Tensor) -> Tensor: class CastedLinear(nn.Linear): - # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) bias = self.bias.to(x.dtype) if self.bias is not None else None - return F.linear(x, self.weight.to(x.dtype), bias) + return F.linear(x, w, bias) def restore_low_dim_params_to_fp32(module: nn.Module) -> None: - # Keep small/control parameters in fp32 even when the model body runs in bf16. with torch.no_grad(): for name, param in module.named_parameters(): if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: @@ -529,7 +490,6 @@ def restore_low_dim_params_to_fp32(module: nn.Module) -> None: class Rotary(nn.Module): - # Caches cos/sin tables per sequence length on the current device. def __init__(self, dim: int, base: float = 10000.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) @@ -560,14 +520,7 @@ def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: class CausalSelfAttention(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - ): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): super().__init__() if dim % num_heads != 0: raise ValueError("model_dim must be divisible by num_heads") @@ -587,14 +540,11 @@ def __init__( self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, base=rope_base) - def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + def forward(self, x: Tensor) -> Tensor: bsz, seqlen, dim = x.shape - q = self.c_q(x) + (q_delta if q_delta is not None else 0) - k = self.c_k(x) - v = self.c_v(x) + (v_delta if v_delta is not None else 0) - q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) - k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) - v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) q = F.rms_norm(q, (q.size(-1),)) k = F.rms_norm(k, (k.size(-1),)) cos, sin = self.rotary(seqlen, x.device, q.dtype) @@ -602,11 +552,7 @@ def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: k = apply_rotary_emb(k, cos, sin) q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] y = F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=None, - is_causal=True, + q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads), ) y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) @@ -614,10 +560,9 @@ def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup - def __init__(self, dim: int, mlp_mult: int): + def __init__(self, dim: int, mlp_mult: float): super().__init__() - hidden = mlp_mult * dim + hidden = int(mlp_mult * dim) self.fc = CastedLinear(dim, hidden, bias=False) self.proj = CastedLinear(hidden, dim, bias=False) self.proj._zero_init = True @@ -627,16 +572,47 @@ def forward(self, x: Tensor) -> Tensor: return self.proj(x.square()) +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + class Block(nn.Module): - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - mlp_mult: int, - rope_base: float, - qk_gain_init: float, - ): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): super().__init__() self.attn_norm = RMSNorm() self.mlp_norm = RMSNorm() @@ -646,13 +622,10 @@ def __init__( self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + def forward(self, x: Tensor, x0: Tensor) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - n = self.attn_norm(x) - qd = q_delta_fn(n) if q_delta_fn is not None else None - vd = v_delta_fn(n) if v_delta_fn is not None else None - attn_out = self.attn(n, qd, vd) + attn_out = self.attn(self.attn_norm(x)) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x @@ -666,12 +639,14 @@ def __init__( model_dim: int, num_heads: int, num_kv_heads: int, - mlp_mult: int, + mlp_mult: float, tie_embeddings: bool, tied_embed_init_std: float, logit_softcap: float, rope_base: float, qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, ): super().__init__() if logit_softcap <= 0.0: @@ -680,21 +655,16 @@ def __init__( self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) self.blocks = nn.ModuleList( [ - Block( - model_dim, - num_heads, - num_kv_heads, - mlp_mult, - rope_base, - qk_gain_init, - ) - for i in range(num_layers) + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_layers) ] ) self.final_norm = RMSNorm() @@ -706,253 +676,145 @@ def __init__( def _init_weights(self) -> None: if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) - for module in self.modules(): - if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): - nn.init.zeros_(module.weight) - - def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) x0 = x skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") - # First half stores skips; second half reuses them in reverse order. + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] for i in range(self.num_encoder_layers): - qd = lora.q_loras[i] if lora else None - vd = lora.v_loras[i] if lora else None - x = self.blocks[i](x, x0, qd, vd) + x = self.blocks[i](x, x0) skips.append(x) for i in range(self.num_decoder_layers): - bi = self.num_encoder_layers + i if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - qd = lora.q_loras[bi] if lora else None - vd = lora.v_loras[bi] if lora else None - x = self.blocks[bi](x, x0, qd, vd) + x = self.blocks[self.num_encoder_layers + i](x, x0) x = self.final_norm(x) if self.tie_embeddings: - logits = F.linear(x, self.tok_emb.weight) + logits_proj = F.linear(x, self.tok_emb.weight) else: - logits = self.lm_head(x) - logits = logits + (lora.lm_head_lora(x) if lora else 0) - logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) - if lora: - bsz, sl, V = logits.shape - return F.cross_entropy( - logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) - return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") - - -# ----------------------------- -# TEST-TIME TRAINING (LoRA) -# ----------------------------- -# -# At evaluation time, we adapt per-document low-rank adapters on the validation data. -# Each document gets its own adapter, so there is no inter-document dependency. - -BOS_ID = 1 + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) -class BatchedLinearLoRA(nn.Module): - """LoRA for a linear layer, with independent weights per batch element. - Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" - def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): - super().__init__() - self.in_features = in_features - self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection - self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection - self.reset() - - def forward(self, x: Tensor) -> Tensor: - return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) - - def reset(self) -> None: - bound = 1.0 / math.sqrt(self.in_features) - with torch.no_grad(): - self.A.uniform_(-bound, bound) # kaiming-uniform - self.B.zero_() -class BatchedTTTLoRA(nn.Module): - """All LoRA adapters for one batch: LM head and Q/V per block.""" - def __init__(self, bsz: int, model: GPT, rank: int): - super().__init__() - dim = model.tok_emb.embedding_dim - vocab = model.tok_emb.num_embeddings - self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) - self.q_loras = nn.ModuleList() - self.v_loras = nn.ModuleList() - for block in model.blocks: - self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) - self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) - - def reset(self) -> None: - for m in self.modules(): - if isinstance(m, BatchedLinearLoRA): - m.reset() - -def _reset_ttt_optimizer(opt): - for group in opt.param_groups: - for p in group['params']: - s = opt.state.get(p) - if not s: # Fresh state. - continue - s['exp_avg'].zero_() - s['exp_avg_sq'].zero_() - s['step'].fill_(0) - -def _build_ttt_optimizer(lora, args: Hyperparameters): - return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) - -def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: - """Return (start_offset, length) for each document, identified by BOS boundaries. - - If include_next_bos is True, include next document's BOS (to match continuous-stream - eval token count exactly). - """ - bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() - docs = [] - for i in range(len(bos_positions)): - start = int(bos_positions[i]) - end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() - if include_next_bos and i + 1 < len(bos_positions): - end += 1 - assert end - start >= 2 - docs.append((start, end - start)) - return docs - -def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): - """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" - chunk_start = ci * chunk_size - chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size - win_start = max(0, chunk_end - eval_seq_len) - win_len = chunk_end - win_start - chunk_offset = chunk_start - win_start - chunk_len = chunk_end - chunk_start - return win_start, win_len, chunk_offset, chunk_len - -def _accumulate_bpb( - ptl: Tensor, x: Tensor, y: Tensor, - batch_i: int, chunk_offset: int, chunk_len: int, - base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, - loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, -): - """Add one doc-chunk's contribution to the running BPB accumulators.""" - lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) - prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] - tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] - tok_bytes = base_bytes_lut[tgt].to(torch.float64) - tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] - loss_sum += lbl.sum() - byte_sum += tok_bytes.sum() - token_count += chunk_len - -def eval_val_ttt_lora( +def eval_val_sliding( args: Hyperparameters, - base_model: GPT, + base_model: nn.Module, rank: int, world_size: int, device: torch.device, + val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, ) -> tuple[float, float]: - """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" - # Load validation tokens and find document boundaries - files = sorted(glob.glob(args.val_files)) - all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) - docs = _find_docs(all_tokens) - - # Each rank takes a contiguous slice of documents - rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] - chunk_size = args.ttt_chunk_size - eval_seq_len = args.ttt_eval_seq_len - batch_size = args.ttt_batch_size - lora_rank = args.ttt_lora_rank - - rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) - - base_model.eval() - for p in base_model.parameters(): - p.requires_grad_(False) - - lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) - opt = _build_ttt_optimizer(lora, args) + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] loss_sum = torch.zeros((), device=device, dtype=torch.float64) - byte_sum = torch.zeros((), device=device, dtype=torch.float64) token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) - for bi in range(0, len(rank_docs), batch_size): - batch = rank_docs[bi:bi + batch_size] - bsz = len(batch) - - if bsz == batch_size: - cur_lora, cur_opt = lora, opt - cur_lora.reset() - _reset_ttt_optimizer(cur_opt) - else: - cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) - cur_opt = _build_ttt_optimizer(cur_lora, args) - - pred_lens = [doc_len - 1 for _, doc_len in batch] - num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] - max_nc = max(num_chunks) - - for ci in range(max_nc): - chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) - context_size, chunk_offset = chunk_stats[1], chunk_stats[2] - - active = [ci < nc for nc in num_chunks] - needs_train = any(ci < nc - 1 for nc in num_chunks) - - x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) - doc_info = [] # (chunk_offset, chunk_len) per doc - for b in range(bsz): - if not active[b]: - doc_info.append((0, 0)) - continue - ds, dl = batch[b] - ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) - chunk = all_tokens[ds + ws: ds + ws + wl + 1] - toks = chunk.to(dtype=torch.int64, device=device) - x[b, :wl] = toks[:-1] - y[b, :wl] = toks[1:] - doc_info.append((co, cl)) - - # Forward pass (keep grad graph alive only when we need to train) - if needs_train: - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - else: - with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): - ptl = base_model(x, y, lora=cur_lora) - - # Score: accumulate loss and byte counts for BPB (before training on chunk) - with torch.no_grad(): - for b in range(bsz): - if not active[b]: - continue - co, cl = doc_info[b] - _accumulate_bpb( - ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, - is_boundary_token_lut, loss_sum, byte_sum, token_count) - - # Train: one Adam step on the LoRA params using this chunk's loss - if needs_train: - mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) - per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) - cur_opt.zero_grad() - (per_doc * mask).sum().backward() - cur_opt.step() + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) if dist.is_available() and dist.is_initialized(): dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte - val_loss = float(loss_sum.item() / token_count.item()) - val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) - return val_loss, val_bpb # ----------------------------- # TRAINING @@ -965,10 +827,6 @@ def main() -> None: args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -988,11 +846,9 @@ def main() -> None: dist.barrier() master_process = rank == 0 - # Fast math knobs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp - enable_cudnn_sdp(False) enable_flash_sdp(True) enable_mem_efficient_sdp(False) @@ -1023,10 +879,6 @@ def log0(msg: str, console: bool = True) -> None: ) log0("=" * 100, console=False) - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -1049,10 +901,7 @@ def log0(msg: str, console: bool = True) -> None: log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - # ----------------------------- # MODEL + OPTIMIZER SETUP - # ----------------------------- - base_model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, @@ -1065,39 +914,43 @@ def log0(msg: str, console: bool = True) -> None: logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): module.float() - if isinstance(module, Rotary): - module.inv_freq.data = module.inv_freq.data.float() restore_low_dim_params_to_fp32(base_model) compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model - # Optimizer split: - # - token embedding (Adam) uses EMBED_LR - # - untied lm_head (Adam) uses HEAD_LR - # - matrix params in transformer blocks use MATRIX_LR via Muon - # - vectors/scalars use SCALAR_LR via Adam block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ - p - for name, p in block_named_params + p for name, p in block_named_params if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] scalar_params = [ - p - for name, p in block_named_params + p for name, p in block_named_params if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr - optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.weight_decay, fused=True, ) optimizer_muon = Muon( @@ -1105,13 +958,15 @@ def log0(msg: str, console: bool = True) -> None: lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, + weight_decay=0.04, ) for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr - optimizer_scalar = torch.optim.Adam( + optimizer_scalar = torch.optim.AdamW( [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.weight_decay, fused=True, ) optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] @@ -1127,11 +982,9 @@ def log0(msg: str, console: bool = True) -> None: n_params = sum(p.numel() for p in base_model.parameters()) log0(f"model_params:{n_params}") log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") - log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") log0( f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " - f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" ) log0( @@ -1141,10 +994,7 @@ def log0(msg: str, console: bool = True) -> None: ) log0(f"seed:{args.seed}") - # ----------------------------- # DATA LOADER & MODEL WARMUP - # ----------------------------- - train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) def zero_grad_all() -> None: @@ -1164,8 +1014,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - # Warmup primes the compiled forward/backward/optimizer paths, then we restore the - # initial weights/optimizer state so measured training starts from the true init. if args.warmup_steps > 0: initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] @@ -1192,12 +1040,11 @@ def lr_mul(step: int, elapsed_ms: float) -> float: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - # ----------------------------- # MAIN TRAINING LOOP - # ----------------------------- - training_time_ms = 0.0 stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 torch.cuda.synchronize() t0 = time.perf_counter() @@ -1210,16 +1057,8 @@ def lr_mul(step: int, elapsed_ms: float) -> float: torch.cuda.synchronize() training_time_ms += 1000.0 * (time.perf_counter() - t0) val_loss, val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, ) log0( f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " @@ -1267,6 +1106,18 @@ def lr_mul(step: int, elapsed_ms: float) -> float: step += 1 approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) @@ -1277,7 +1128,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" ) - # Needed to sync whether we've reached the wallclock cap. reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms if distributed and max_wallclock_ms is not None: reached_cap_tensor = torch.tensor(int(reached_cap), device=device) @@ -1291,12 +1141,17 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + # SERIALIZATION + ROUNDTRIP VALIDATION if master_process: torch.save(base_model.state_dict(), "final_model.pt") model_bytes = os.path.getsize("final_model.pt") @@ -1305,44 +1160,60 @@ def lr_mul(step: int, elapsed_ms: float) -> float: log0(f"Code size: {code_bytes} bytes") log0(f"Total submission size: {model_bytes + code_bytes} bytes") - quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) quant_buf = io.BytesIO() - torch.save(quant_obj, quant_buf) + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) - quant_raw_bytes = len(quant_raw) + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) if master_process: with open("final_model.int8.ptz", "wb") as f: f.write(quant_blob) quant_file_bytes = os.path.getsize("final_model.int8.ptz") code_bytes = len(code.encode("utf-8")) - ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) - log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " - f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" - ) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") if distributed: dist.barrier() with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights torch.cuda.synchronize() t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) torch.cuda.synchronize() log0( f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " @@ -1350,23 +1221,11 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") - # LoRA test-time training evaluation (the competition score) - torch._dynamo.reset() - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( - args, base_model, rank, world_size, device, - base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" - ) - if distributed: dist.destroy_process_group() if __name__ == "__main__": main() +# fixes applied +# tuned diff --git a/train_gpt_mos_sota.py b/train_gpt_mos_sota.py new file mode 100644 index 0000000000..c3df5b8ab2 --- /dev/null +++ b/train_gpt_mos_sota.py @@ -0,0 +1,1755 @@ +""" +MoS + SOTA technique stack for Parameter Golf. +Mixture of Softmax (K=2) output layer with 11L Int6 + XSA4 + Partial RoPE + LN Scale + +Tight SWA + Shared VE128 + U-Net skips + Late QAT + SmearGate + BigramHash + FA3. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch._dynamo +torch._dynamo.config.optimize_ddp = False # Required for FA3 + torch.compile backward pass +# Disable torch.compile if environment variable is set (fixes inductor issues on some systems) +_DISABLE_COMPILE = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) +if _DISABLE_COMPILE: + torch._dynamo.config.suppress_errors = True +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _FA3_AVAILABLE = True +except ImportError: + _FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # Efficient partial XSA: apply to last N layers only (deep layers have highest self-attention bias) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.1)) + + # Value Embeddings: 1 shared table, per-layer scales (saves 50% VE params) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + + # Mixture of Softmax output layer + use_mos = bool(int(os.environ.get("USE_MOS", "0"))) + mos_k = int(os.environ.get("MOS_K", 2)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # NTK-aware RoPE: auto-scales base frequency when seq_len exceeds train_seq_len. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + # Reshape y into KV head groups — free view, no memory alloc + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + # Project out self-value component per KV head group + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + # Value embedding: add token identity directly to values before reshape + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _FA3_AVAILABLE: + y = flash_attn_3_func(q, k, v, causal=True) + else: + # SDPA fallback: FA3 uses [B, T, H, D], SDPA uses [B, H, T, D] + q_t = q.transpose(1, 2) + k_t = k.transpose(1, 2) + v_t = v.transpose(1, 2) + y = F.scaled_dot_product_attention( + q_t, k_t, v_t, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ).transpose(1, 2).contiguous() + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class MixtureOfSoftmax(nn.Module): + """MoS output layer: K separate softmax distributions mixed by learned gate. + Returns log-probabilities, not logits. Use F.nll_loss, not F.cross_entropy.""" + + def __init__(self, model_dim: int, n_mixtures: int = 2): + super().__init__() + self.n_mixtures = n_mixtures + self.model_dim = model_dim + self.projections = CastedLinear(model_dim, n_mixtures * model_dim, bias=False) + self.gate = CastedLinear(model_dim, n_mixtures, bias=False) + + def forward(self, hidden: Tensor, weight_matrix: Tensor) -> Tensor: + # hidden: [B*T, D], weight_matrix: [V, D] + K = self.n_mixtures + D = self.model_dim + pi = F.softmax(self.gate(hidden), dim=-1) # [B*T, K] + projected = self.projections(hidden).view(-1, K, D) # [B*T, K, D] + logits = projected @ weight_matrix.T # [B*T, K, V] + log_probs = F.log_softmax(logits, dim=-1) # [B*T, K, V] + log_pi = torch.log(pi.unsqueeze(-1) + 1e-10) # [B*T, K, 1] + return torch.logsumexp(log_probs + log_pi, dim=1) # [B*T, V] + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + use_mos: bool = False, + mos_k: int = 2, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + # Set partial RoPE on all attention blocks + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + # Value embeddings: 1 SHARED table, per-layer learned scales + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + # Enable efficient XSA on the deepest layers (highest self-attention bias) + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + # Mixture of Softmax output layer + self.use_mos = use_mos + self.returns_log_probs = use_mos + self.mos = MixtureOfSoftmax(model_dim, n_mixtures=mos_k) if use_mos else None + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + # Cache the shared VE computation (same for all layers, different scale) + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.mos is not None: + log_probs = self.mos(x_flat, self.tok_emb.weight) + main_loss = F.nll_loss(log_probs.float(), targets, reduction="mean") + else: + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + + return main_loss + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits or log-probs (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + if self.mos is not None: + bsz, seqlen, dim = x.shape + x_flat = x.reshape(-1, dim) + log_probs = self.mos(x_flat, self.tok_emb.weight) + return log_probs.reshape(bsz, seqlen, -1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# SLIDING WINDOW EVALUATION +# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + compiled_logits = base_model.forward_logits if _DISABLE_COMPILE else torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + _use_nll = getattr(base_model, 'returns_log_probs', False) + + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + + output_flat = logits.reshape(-1, logits.size(-1)).float() + targets_flat = y_batch.reshape(-1) + if _use_nll: + nll = F.nll_loss(output_flat, targets_flat, reduction="none").reshape(bsz, seq_len) + else: + nll = F.cross_entropy(output_flat, targets_flat, reduction="none").reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# INT6 MIXED QUANTIZATION (transplanted from working diagnostic scripts) +# ----------------------------- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "mos." in name: + return "mos" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(amax / 31.0 if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # tok_emb.weight falls through to int8 via "embed" category + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if not _DISABLE_COMPILE: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = args.qat_enabled + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + use_mos=args.use_mos, + mos_k=args.mos_k, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = base_model if _DISABLE_COMPILE else torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + # gnp_scale removed in v22 + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + # MoS optimizer assignment + if base_model.mos is not None: + matrix_params.append(base_model.mos.projections.weight) # 512x1024, 2D → Muon + scalar_params.append(base_model.mos.gate.weight) # 512x2, too narrow for NS5 → AdamW + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + mos_params = sum(p.numel() for p in base_model.mos.parameters()) if base_model.mos is not None else 0 + log0(f"model_params:{n_params}") + if mos_params > 0: + log0(f"mos_params:{mos_params} mos_k:{args.mos_k}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + avg_state = {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) + for name, t in swa_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + # Diagnostic eval: measure quality after SWA, before quantization + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_swa val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "mos"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + + # Roundtrip: decompress + dequantize into fresh model + eval + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + use_mos=args.use_mos, mos_k=args.mos_k, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = eval_model if _DISABLE_COMPILE else torch.compile(eval_model, dynamic=False, fullgraph=True) + + # Standard non-overlapping eval (sanity check) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval (submission score) + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Second sliding window eval at stride=64 for submission comparison + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/train_tokenizer_only.sh b/train_tokenizer_only.sh new file mode 100755 index 0000000000..26971f6eee --- /dev/null +++ b/train_tokenizer_only.sh @@ -0,0 +1,122 @@ +#!/bin/bash +# ============================================================================= +# Parameter Golf: Train Custom Tokenizer Only +# ============================================================================= +# +# This script: +# 1. Downloads docs_selected.jsonl (~45GB) +# 2. Trains unigram tokenizer +# 3. Evaluates against baseline +# +# Output: data/tokenizers_custom/spm_unigram_1024.model (~1MB) +# +# Usage on RunPod: +# git clone https://github.com/User123331/parameter-golf.git +# cd parameter-golf +# bash train_tokenizer_only.sh +# ============================================================================= + +set -e + +VOCAB_SIZE=1024 +MODEL_TYPE=unigram +MAX_TRAIN_DOCS=200000 +EVAL_DOCS=10000 + +DATA_DIR="./data/datasets" +TOKENIZER_DIR="./data/tokenizers_custom" +DOCS_JSONL="${DATA_DIR}/docs_selected.jsonl" + +GREEN='\033[0;32m' +CYAN='\033[0;36m' +NC='\033[0m' +log() { echo -e "${GREEN}[$(date +%H:%M:%S)]${NC} $*"; } + +echo "" +echo "============================================================" +echo " Custom Tokenizer Training" +echo " Vocab: ${VOCAB_SIZE}, Type: ${MODEL_TYPE}" +echo "============================================================" +echo "" + +# Check disk +AVAIL_GB=$(df -BG . 2>/dev/null | tail -1 | awk '{print $4}' | tr -d 'G' || echo "?") +log "Available disk: ${AVAIL_GB} GB (need ~50GB)" + +# ============================================================================= +# Step 1: Download docs_selected.jsonl +# ============================================================================= +if [ -f "${DOCS_JSONL}" ]; then + log "Docs already exist: $(du -h "${DOCS_JSONL}" | cut -f1)" +else + log "Downloading docs_selected.jsonl (~45GB, 10-30 min)..." + + mkdir -p "${DATA_DIR}" + pip install --quiet huggingface_hub + + python3 -c " +from huggingface_hub import hf_hub_download +import shutil, os + +cached = hf_hub_download( + repo_id='willdepueoai/parameter-golf', + filename='docs_selected.jsonl', + subfolder='datasets', + repo_type='dataset', +) +src = os.path.realpath(cached) +dst = '${DOCS_JSONL}' +print(f'Copying {src} -> {dst}') +try: + os.link(src, dst) + print('Hard-linked (no extra disk)') +except OSError: + shutil.copy2(src, dst) + print('Copied') +" + log "Download complete: $(du -h "${DOCS_JSONL}" | cut -f1)" +fi + +# ============================================================================= +# Step 2: Train unigram tokenizer +# ============================================================================= +log "Training ${MODEL_TYPE} tokenizer (vocab=${VOCAB_SIZE})..." + +mkdir -p "${TOKENIZER_DIR}" +pip install --quiet sentencepiece numpy + +python3 data/train_tokenizer.py \ + --vocab-size ${VOCAB_SIZE} \ + --model-type ${MODEL_TYPE} \ + --docs-path "${DOCS_JSONL}" \ + --max-docs ${MAX_TRAIN_DOCS} \ + --eval-docs ${EVAL_DOCS} \ + --character-coverage 0.995 + +# ============================================================================= +# Step 3: Evaluate baseline for comparison +# ============================================================================= +log "" +log "Evaluating baseline tokenizer for comparison..." + +BASELINE_MODEL="./data/tokenizers/fineweb_1024_bpe.model" +if [ -f "${BASELINE_MODEL}" ]; then + python3 data/train_tokenizer.py \ + --evaluate "${BASELINE_MODEL}" \ + --docs-path "${DOCS_JSONL}" \ + --eval-docs ${EVAL_DOCS} +else + log "Baseline not found at ${BASELINE_MODEL}" + log "Download it with: python data/cached_challenge_fineweb.py --variant sp1024" +fi + +# ============================================================================= +# Summary +# ============================================================================= +echo "" +echo "============================================================" +log "TOKENIZER TRAINED: ${TOKENIZER_DIR}/spm_${MODEL_TYPE}_${VOCAB_SIZE}.model" +log "" +log "To download from RunPod:" +log " scp root@:/workspace/parameter-golf/${TOKENIZER_DIR}/spm_${MODEL_TYPE}_${VOCAB_SIZE}.model ." +echo "============================================================" \ No newline at end of file