diff --git a/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/README.md b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/README.md new file mode 100644 index 0000000000..16b3b218ce --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/README.md @@ -0,0 +1,175 @@ +# Non-Record: SLOT + CTW Eval-Time Augmentation on PR #549 SOTA Stack + +**val_bpb = 1.1185** (3-seed mean, std 0.0003) | ~15.9 MB | 8×H100 SXM + +Two novel eval-time augmentations tested on the PR #549 SOTA stack: +- **SLOT**: ✅ Positive result — **-0.0008 BPB** improvement, first SLOT entry in Parameter Golf +- **CTW**: ❌ Negative result — **+0.005 BPB** degradation despite three progressively improved implementations + +## Results + +### SLOT-Enabled (3-seed) — Positive Result + +| Seed | Steps | Step Avg | Pre-TTT BPB | Post-TTT+SLOT BPB | TTT+SLOT Time | Artifact | +|------|-------|----------|-------------|-------------------|---------------|----------| +| 1337 | 7,127 | 84.2ms | 1.1385 | **1.1188** | 386s | 15,965,604 | +| 42 | 7,155 | 83.9ms | 1.1380 | **1.1185** | 388s | 15,882,932 | +| 2025 | 7,152 | 83.9ms | 1.1377 | **1.1183** | 385s | 15,994,920 | +| **Mean** | **7,145** | **84.0ms** | **1.1381** | **1.1185 (std 0.0003)** | **~386s** | — | + +### Baseline Without SLOT (3-seed) + +| Seed | Steps | Step Avg | Post-TTT BPB | TTT Time | +|------|-------|----------|-------------|----------| +| 1337 | 7,164 | 83.8ms | 1.1195 | 352s | +| 42 | 7,159 | 83.8ms | 1.1195 | 353s | +| 2025 | 7,164 | 83.8ms | 1.1189 | 350s | +| **Mean** | **7,162** | **83.8ms** | **1.1193 (std 0.0003)** | **~352s** | + +### SLOT vs Baseline + +| Metric | Baseline Mean | SLOT Mean | Delta | +|--------|-------------|-----------|-------| +| Post-TTT BPB | 1.1193 | **1.1185** | **-0.0008** | +| TTT eval time | 352s | 386s | +34s | +| vs SOTA (PR #549) | -0.0001 | **-0.0009** | — | + +--- + +## Novel Contribution 1: SLOT — Positive Result + +### What Is SLOT + +SLOT (Sample-specific Language Model Optimization at Test-time, Hu et al., arXiv:2505.12392v2) optimizes a single additive vector δ ∈ ℝ^512 at the last hidden layer to adapt the model to each batch during evaluation. Unlike full TTT which updates all 27M model parameters via SGD, SLOT optimizes just 512 parameters through one linear layer. + +### Why SLOT Works + +SLOT and TTT address different bottlenecks: +- **TTT** adapts internal representations to local data distribution (chunk-level, all layers) +- **SLOT** fine-tunes the final hidden-to-logit mapping (batch-level, last layer only) + +These are complementary — TTT gives SLOT better hidden states, and SLOT gives TTT-adapted representations a final per-batch correction. + +### Implementation + +The model's `forward_logits()` was split into `forward_hidden()` + `compute_logits()`, enabling SLOT to optimize δ between the two stages. SLOT runs inside the TTT scoring loop (Phase 1), not as a separate pass: + +```python +for each batch of windows: + H = model.forward_hidden(x_batch) # [bsz, seq_len, 512] + delta = zeros(1, 1, 512) # broadcasts across batch + seq + for step in range(3): + logits = model.compute_logits(H + delta) + loss = CE(logits[:, :-1], targets[:, 1:]) + loss.backward() # gradients only through lm_head + optimizer.step() + final_logits = model.compute_logits(H + delta) +``` + +Key properties: zero artifact cost, +34s eval overhead, score-first compliant, `SLOT_ENABLED=0` reproduces baseline exactly. + +--- + +## Novel Contribution 2: CTW — Negative Result (Three Iterations) + +Context Tree Weighting (Willems, Shtarkov, Tjalkens 1995) is a provably minimax-optimal sequential probability assignment over all variable-order Markov models. We tested it as an eval-time augmentation, iterating through three progressively improved implementations before concluding it cannot help at this BPB level. + +### CTW v1: Naive Implementation — BPB +0.005 worse + +**What we built**: Walked the suffix tree to the deepest matching context node and used its KT (Krichevsky-Trofimov) estimate directly. Mixed with neural logits using a fixed weight of 0.1. + +**What was wrong**: This was NOT actually CTW. It was just a smoothed n-gram lookup at the deepest available context — missing the entire theoretical power of CTW, which comes from recursively weighting predictions across ALL depths. + +**Three bugs identified**: +1. **No recursive depth weighting**: Used deepest-match lookup instead of the proper `P_w = 0.5 · P_e + 0.5 · ∏ P_w(children)` formula that makes CTW Bayesian-optimal +2. **Fixed mixing weight (w=0.1)**: Mixed CTW noise into 100% of tokens, including tokens where the neural model was already confident +3. **Per-token Python loop**: Every scored token ran `predict()` + `mix()` + `cross_entropy()` individually in Python, taking 2,760s (46 minutes) + +**Result**: 1.1252 BPB (+0.005 worse than baseline), 46 minutes eval time (exceeds 10-min limit) + +### CTW v2: Proper Recursive Algorithm — Not Tested (Speed Still Prohibitive) + +**What we fixed**: +1. **Proper recursive depth weighting**: Each node now maintains `log_pe` (cumulative KT log-probability) and `log_pw` (weighted probability). After each symbol update, `log_pw` is recomputed bottom-up: `P_w = 0.5 · P_e + 0.5 · P_w_child` using log-space arithmetic (`logaddexp`) to avoid underflow. This is the actual CTW algorithm from the paper, verified against Python, Go, and Rust reference implementations. +2. **Proper predictive distribution**: Instead of returning the KT estimate from the deepest node, the `predict()` method walks back up the path, mixing each depth's KT estimate weighted by `beta = exp(log_pe - log_pw)` — the posterior probability that each depth is the correct model. Shallow depths contribute more when deeper contexts are unseen; deeper contexts dominate when they have strong statistics. +3. **Entropy-adaptive gating**: Before running CTW on any token, the neural model's entropy is computed. If entropy is below a threshold (default 2.0 nats), CTW is skipped entirely — the neural model is confident and CTW would only add noise. When CTW does mix, its weight is scaled by `entropy / max_entropy`, so uncertain tokens get more CTW influence. This means ~80-90% of tokens skip CTW computation entirely. + +**Why we didn't run it**: Even with entropy gating, `ctw.update()` must process every token sequentially (each token's context depends on the previous token). The Python dict-based tree operations are inherently O(depth) per token with no way to batch. Estimated time: 400-600s, borderline on eval limit. And the fundamental signal problem remained unsolved. + +### CTW v3: Vectorized Entropy Gate — BPB Still Worse + +**What we fixed further**: +1. **Vectorized entropy computation**: Instead of computing entropy per-token in a Python loop, we compute `F.log_softmax` and entropy for ALL scored tokens in a single batched GPU operation. The `F.cross_entropy` for neural-only NLL is also pre-computed for all tokens at once. +2. **Selective CTW loop**: Only tokens with entropy above the threshold enter the Python CTW loop. Low-entropy tokens use the pre-computed neural NLL directly — no Python overhead, no tensor creation. + +**What we could NOT fix**: `ctw.update()` remains sequential. Each token's suffix tree update depends on the previous token's context. The tree uses Python dicts for sparse node storage — converting to fixed-size GPU tensors would require a custom CUDA kernel (essentially reimplementing the tree as a hash table on GPU with scatter/gather operations). + +**Result**: Tested with `CTW_WEIGHT=0.02, CTW_DEPTH=4, CTW_ENTROPY_THRESHOLD=3.0` — still slower than baseline (ctw.update runs on all tokens regardless of gating) and BPB did not improve. Run was killed after observing degraded trajectory. + +### Root Cause: Why CTW Fundamentally Cannot Help at 1.12 BPB + +After three implementations, the conclusion is clear: **the problem is signal redundancy, not implementation quality**. + +A depth-4 CTW over 1024 subword tokens is essentially a smoothed variable-order Markov model up to 4-grams. The 11-layer transformer with 2048-token context and 27M parameters IS a strictly superior n-gram model — it already captures everything CTW knows, plus long-range dependencies CTW cannot represent. + +Mixing in a weaker predictor always hurts a stronger predictor when the weaker predictor's knowledge is a strict subset of the stronger predictor's knowledge. This is true regardless of: +- Whether CTW uses proper recursive depth weighting (v2) or naive lookup (v1) +- Whether mixing is fixed or entropy-adaptive (v2/v3) +- Whether the implementation is fast (v3) or slow (v1) + +The frontier PRs that succeed with n-gram augmentation (PR #727 at 0.9674 BPB) use a fundamentally different approach: count-min sketch with 5-7 gram orders, entropy-adaptive alpha, and vectorized GPU lookup. These capture higher-order patterns (5-7 grams vs CTW's 4) with a simpler but faster data structure, and their success may depend more on the count-min sketch's hash-based smoothing than on any Bayesian optimality. + +### Also Tested: Stacking Hacks on SLOT (Negative Results) + +Two additional eval-time hacks were tested on top of SLOT: + +| Hack | Mechanism | BPB | Delta vs SLOT-only | +|------|-----------|-----|-------------------| +| Adaptive Temperature | Optimized temperature scalar per-batch via SGD (3 steps) | 1.1325 | **+0.014 worse** | +| Focal TTT | Upweighted hard tokens in Phase 2 training via focal loss (γ=2) | 1.1441 | **+0.025 worse** | + +**Adaptive Temperature** failed because the LR (0.1) was too aggressive — temperature diverged from 1.0, distorting the probability distribution. **Focal TTT** failed because "hard" tokens are hard for a reason — they're unpredictable content (names, numbers, URLs). Training harder on unpredictable tokens destabilizes learned representations for predictable tokens. + +**Lesson**: SLOT works because it's lightweight (512 params, 3 steps). More aggressive adaptation techniques destroy the carefully trained representations. + +--- + +## Base Architecture (PR #549 by @abaybektursun) + +- 11L, 512d, 8H/4KV, LeakyReLU(0.5)² MLP 3× +- Parameter Banking + Parallel Muon (FlashAttention 3) +- BigramHash(1536), XSA4, Partial RoPE(16), LN Scale, VE128 +- EMA(0.997) + Tight SWA(50), GPTQ-lite int6 + LZMA-6 +- Legal Score-First TTT (SGD, lr=0.002, 3 epochs, 32K chunks) + +## Run Commands + +```bash +# Baseline (SLOT disabled — reproduces PR #549) +cd /workspace/parameter-golf && SEED=1337 SLOT_ENABLED=0 CTW_WEIGHT=0 \ +NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=1536 XSA_LAST_N=4 \ +EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=1 SWA_EVERY=50 \ +ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 LATE_QAT_THRESHOLD=0.15 \ +VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 \ +TTT_ENABLED=1 TTT_LR=0.002 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=32768 \ +TTT_FREEZE_BLOCKS=0 TTT_MOMENTUM=0.9 TTT_BATCH_SEQS=32 TTT_GRAD_CLIP=1.0 \ +MUON_WD=0.04 ADAM_WD=0.04 \ +MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \ +MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \ +MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3500 \ +ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py + +# SLOT enabled (positive result) +# Same as above but with: SLOT_ENABLED=1 SLOT_LR=0.001 SLOT_STEPS=3 +``` + +## Credits + +- **SLOT integration and CTW analysis**: Anubhav (@AnubhavBharadwaaj) — this submission +- **SLOT algorithm**: Yang Hu et al. (arXiv:2505.12392v2, Westlake University) +- **CTW algorithm**: Willems, Shtarkov, Tjalkens (1995) +- LeakyReLU²: PR #493 by @parinzee, PR #518 by @sofiabod +- Parallel Muon + Parameter Banking: PR #399 by @abaybektursun +- TTT recipe: PR #461 by @Christopher-Lee-McClendon +- Base model: PR #414 by @signalrush \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/submission.json b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/submission.json new file mode 100644 index 0000000000..427c24819a --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/submission.json @@ -0,0 +1,45 @@ +{ + "author": "Anubhav", + "github_id": "AnubhavBharadwaaj", + "val_bpb": 1.1203, + "val_loss": 1.8916, + "hardware": "8xH100 SXM", + "training_time_seconds": 600, + "step_avg_ms": 85.5, + "steps": 7023, + "artifact_size_bytes": 15854788, + "seed": 1337, + "base_submission": "PR #549 (abaybektursun — LeakyReLU² + Legal TTT + Parallel Muon, 1.1194 BPB)", + "description": "Non-record: CTW (Context Tree Weighting) eval-time augmentation on PR #549 SOTA stack. CTW is a provably minimax-optimal Bayesian sequential predictor integrated directly into the TTT scoring loop. NEGATIVE RESULT: CTW degrades BPB by +0.005 at w=0.1 — the neural model at 1.12 BPB already dominates any depth-4 Markov model. Baseline without CTW reproduces PR #549 at 1.1203 BPB.", + "results": { + "baseline_no_ctw": { + "val_bpb": 1.1203, + "val_loss": 1.8916, + "pre_ttt_bpb": 1.1386, + "ttt_gain": -0.0183, + "ttt_time_seconds": 352, + "sliding_window_bpb": 1.1221 + }, + "ctw_w0.1_depth4": { + "val_bpb": 1.1252, + "val_loss": 1.8999, + "pre_ttt_bpb": 1.1386, + "ttt_gain": 0.0031, + "ttt_time_seconds": 2760, + "note": "NEGATIVE RESULT: CTW hurts BPB by +0.005 and exceeds 10-min eval limit" + } + }, + "novel_techniques": [ + "CTW eval-time augmentation (Willems et al. 1995) — NEGATIVE RESULT", + "Deep integration: CTW mixed inside TTT scoring loop, not separate pass", + "Sparse M-ary suffix tree, depth-4, KT estimator, log-linear mixing" + ], + "inherited_techniques": [ + "11L, 512d, 3x MLP, GQA (8H/4KV), LeakyReLU²(0.5)", + "Parameter Banking + Parallel Muon (85ms/step)", + "Legal Score-First TTT (SGD, lr=0.002, 3 epochs, 32K chunks)", + "BigramHash(1536), XSA4, Partial RoPE(16), LN Scale, VE128", + "EMA(0.997) + Tight SWA(50), GPTQ-lite int6 + LZMA-6", + "FlashAttention 3 (Hopper-native)" + ] +} diff --git a/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_baseline_seed1337.log b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_baseline_seed1337.log new file mode 100644 index 0000000000..1292a118b0 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_baseline_seed1337.log @@ -0,0 +1,274 @@ +W0329 05:51:22.224000 2919 torch/distributed/run.py:803] +W0329 05:51:22.224000 2919 torch/distributed/run.py:803] ***************************************** +W0329 05:51:22.224000 2919 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0329 05:51:22.224000 2919 torch/distributed/run.py:803] ***************************************** +logs/anubhav_baseline_no_ctw_29mar2026_1121am.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26928220 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/9000 val_loss:6.9304 val_bpb:4.1046 train_time:0ms step_avg:0.01ms +step:1/9000 train_loss:6.9322 train_time:127ms step_avg:126.87ms +step:2/9000 train_loss:8.6544 train_time:154ms step_avg:76.77ms +step:3/9000 train_loss:7.6926 train_time:234ms step_avg:78.16ms +step:4/9000 train_loss:7.2519 train_time:315ms step_avg:78.84ms +step:5/9000 train_loss:7.1705 train_time:396ms step_avg:79.27ms +step:6/9000 train_loss:7.1159 train_time:477ms step_avg:79.48ms +step:7/9000 train_loss:7.0268 train_time:558ms step_avg:79.67ms +step:8/9000 train_loss:6.9593 train_time:639ms step_avg:79.82ms +step:9/9000 train_loss:6.5742 train_time:719ms step_avg:79.93ms +step:10/9000 train_loss:6.2003 train_time:800ms step_avg:80.04ms +step:500/9000 train_loss:2.3920 train_time:42568ms step_avg:85.14ms +step:1000/9000 train_loss:2.2632 train_time:85648ms step_avg:85.65ms +step:1500/9000 train_loss:2.2138 train_time:128801ms step_avg:85.87ms +step:2000/9000 train_loss:2.0545 train_time:171953ms step_avg:85.98ms +step:2500/9000 train_loss:2.1573 train_time:215156ms step_avg:86.06ms +step:3000/9000 train_loss:2.1481 train_time:258370ms step_avg:86.12ms +step:3500/9000 train_loss:2.1729 train_time:301599ms step_avg:86.17ms +step:4000/9000 train_loss:1.9581 train_time:344773ms step_avg:86.19ms +step:4000/9000 val_loss:2.0520 val_bpb:1.2153 train_time:344830ms step_avg:86.21ms +step:4500/9000 train_loss:2.1135 train_time:387940ms step_avg:86.21ms +step:5000/9000 train_loss:2.0922 train_time:431125ms step_avg:86.22ms +step:5500/9000 train_loss:2.0068 train_time:474288ms step_avg:86.23ms +step:6000/9000 train_loss:1.9274 train_time:517496ms step_avg:86.25ms +swa:start step:6300 +late_qat:enabled step:6431 scale:0.1499 +step:6500/9000 train_loss:2.0720 train_time:560958ms step_avg:86.30ms +step:6951/9000 val_loss:1.9230 val_bpb:1.1389 train_time:600135ms step_avg:86.34ms +stopping_early: wallclock_cap train_time:600135ms step:6951/9000 +peak memory allocated: 21481 MiB reserved: 22030 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9214 val_bpb:1.1379 eval_time:1999ms +Serialized model: 106027446 bytes +Code size: 97252 bytes +Serialized model int6+lzma: 15756800 bytes +Total submission size int6+lzma: 15854052 bytes +final_int6_roundtrip val_loss:1.9348 val_bpb:1.1459 eval_time:15771ms +final_int6_roundtrip_exact val_loss:1.93483791 val_bpb:1.14591999 +final_int6_sliding_window val_loss:1.8952 val_bpb:1.1225 stride:64 eval_time:90865ms +final_int6_sliding_window_exact val_loss:1.89520845 val_bpb:1.12245217 +final_int8_zlib_roundtrip_exact val_loss:1.89520845 val_bpb:1.12245217 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 +ttt_sliding:params unfrozen=26928220 frozen=0 + ttt_chunk [1/1893] bpb=1.159514 time=0.4s + ttt_chunk [11/1893] bpb=1.147732 time=2.3s + ttt_chunk [21/1893] bpb=1.132659 time=4.2s + ttt_chunk [31/1893] bpb=1.130485 time=6.1s + ttt_chunk [41/1893] bpb=1.116938 time=7.9s + ttt_chunk [51/1893] bpb=1.111632 time=9.9s + ttt_chunk [61/1893] bpb=1.118375 time=11.8s + ttt_chunk [71/1893] bpb=1.116840 time=13.7s + ttt_chunk [81/1893] bpb=1.115796 time=15.5s + ttt_chunk [91/1893] bpb=1.116620 time=17.4s + ttt_chunk [101/1893] bpb=1.120221 time=19.2s + ttt_chunk [111/1893] bpb=1.122785 time=21.1s + ttt_chunk [121/1893] bpb=1.116254 time=22.9s + ttt_chunk [131/1893] bpb=1.116363 time=24.8s + ttt_chunk [141/1893] bpb=1.122011 time=26.7s + ttt_chunk [151/1893] bpb=1.123730 time=28.5s + ttt_chunk [161/1893] bpb=1.123248 time=30.4s + ttt_chunk [171/1893] bpb=1.127651 time=32.2s + ttt_chunk [181/1893] bpb=1.129865 time=34.1s + ttt_chunk [191/1893] bpb=1.137304 time=35.9s + ttt_chunk [201/1893] bpb=1.136113 time=37.8s + ttt_chunk [211/1893] bpb=1.134003 time=39.7s + ttt_chunk [221/1893] bpb=1.135493 time=41.5s + ttt_chunk [231/1893] bpb=1.134082 time=43.5s + ttt_chunk [241/1893] bpb=1.134537 time=45.4s + ttt_chunk [251/1893] bpb=1.134100 time=47.2s + ttt_chunk [261/1893] bpb=1.131246 time=49.1s + ttt_chunk [271/1893] bpb=1.130134 time=50.9s + ttt_chunk [281/1893] bpb=1.131480 time=52.8s + ttt_chunk [291/1893] bpb=1.133272 time=54.6s + ttt_chunk [301/1893] bpb=1.134019 time=56.5s + ttt_chunk [311/1893] bpb=1.136047 time=58.3s + ttt_chunk [321/1893] bpb=1.137967 time=60.2s + ttt_chunk [331/1893] bpb=1.137777 time=62.0s + ttt_chunk [341/1893] bpb=1.136718 time=63.9s + ttt_chunk [351/1893] bpb=1.139013 time=65.7s + ttt_chunk [361/1893] bpb=1.139163 time=67.6s + ttt_chunk [371/1893] bpb=1.138405 time=69.4s + ttt_chunk [381/1893] bpb=1.138596 time=71.3s + ttt_chunk [391/1893] bpb=1.138433 time=73.2s + ttt_chunk [401/1893] bpb=1.136342 time=75.0s + ttt_chunk [411/1893] bpb=1.135162 time=77.0s + ttt_chunk [421/1893] bpb=1.134282 time=78.9s + ttt_chunk [431/1893] bpb=1.134119 time=80.7s + ttt_chunk [441/1893] bpb=1.134535 time=82.6s + ttt_chunk [451/1893] bpb=1.134836 time=84.4s + ttt_chunk [461/1893] bpb=1.133726 time=86.3s + ttt_chunk [471/1893] bpb=1.134365 time=88.1s + ttt_chunk [481/1893] bpb=1.134026 time=90.0s + ttt_chunk [491/1893] bpb=1.132917 time=91.8s + ttt_chunk [501/1893] bpb=1.132397 time=93.6s + ttt_chunk [511/1893] bpb=1.131703 time=95.5s + ttt_chunk [521/1893] bpb=1.129329 time=97.4s + ttt_chunk [531/1893] bpb=1.130539 time=99.2s + ttt_chunk [541/1893] bpb=1.130907 time=101.1s + ttt_chunk [551/1893] bpb=1.129869 time=102.9s + ttt_chunk [561/1893] bpb=1.130395 time=104.8s + ttt_chunk [571/1893] bpb=1.129332 time=106.7s + ttt_chunk [581/1893] bpb=1.128507 time=108.5s + ttt_chunk [591/1893] bpb=1.127864 time=110.5s + ttt_chunk [601/1893] bpb=1.128331 time=112.3s + ttt_chunk [611/1893] bpb=1.128274 time=114.2s + ttt_chunk [621/1893] bpb=1.128098 time=116.0s + ttt_chunk [631/1893] bpb=1.128780 time=117.9s + ttt_chunk [641/1893] bpb=1.128483 time=119.7s + ttt_chunk [651/1893] bpb=1.128626 time=121.6s + ttt_chunk [661/1893] bpb=1.128099 time=123.4s + ttt_chunk [671/1893] bpb=1.128439 time=125.3s + ttt_chunk [681/1893] bpb=1.129163 time=127.1s + ttt_chunk [691/1893] bpb=1.130190 time=129.0s + ttt_chunk [701/1893] bpb=1.129630 time=130.8s + ttt_chunk [711/1893] bpb=1.129594 time=132.7s + ttt_chunk [721/1893] bpb=1.129242 time=134.5s + ttt_chunk [731/1893] bpb=1.129286 time=136.4s + ttt_chunk [741/1893] bpb=1.129403 time=138.3s + ttt_chunk [751/1893] bpb=1.129236 time=140.1s + ttt_chunk [761/1893] bpb=1.129174 time=142.0s + ttt_chunk [771/1893] bpb=1.128844 time=144.0s + ttt_chunk [781/1893] bpb=1.129614 time=145.8s + ttt_chunk [791/1893] bpb=1.129244 time=147.7s + ttt_chunk [801/1893] bpb=1.129559 time=149.5s + ttt_chunk [811/1893] bpb=1.129326 time=151.4s + ttt_chunk [821/1893] bpb=1.129126 time=153.2s + ttt_chunk [831/1893] bpb=1.128936 time=155.1s + ttt_chunk [841/1893] bpb=1.128298 time=156.9s + ttt_chunk [851/1893] bpb=1.128022 time=158.8s + ttt_chunk [861/1893] bpb=1.127779 time=160.6s + ttt_chunk [871/1893] bpb=1.128031 time=162.5s + ttt_chunk [881/1893] bpb=1.128221 time=164.3s + ttt_chunk [891/1893] bpb=1.127803 time=166.2s + ttt_chunk [901/1893] bpb=1.127545 time=168.0s + ttt_chunk [911/1893] bpb=1.127702 time=169.9s + ttt_chunk [921/1893] bpb=1.128192 time=171.8s + ttt_chunk [931/1893] bpb=1.128175 time=173.6s + ttt_chunk [941/1893] bpb=1.127863 time=175.6s + ttt_chunk [951/1893] bpb=1.128245 time=177.5s + ttt_chunk [961/1893] bpb=1.128329 time=179.3s + ttt_chunk [971/1893] bpb=1.129181 time=181.2s + ttt_chunk [981/1893] bpb=1.129260 time=183.0s + ttt_chunk [991/1893] bpb=1.129293 time=184.9s + ttt_chunk [1001/1893] bpb=1.129232 time=186.7s + ttt_chunk [1011/1893] bpb=1.129020 time=188.6s + ttt_chunk [1021/1893] bpb=1.129361 time=190.4s + ttt_chunk [1031/1893] bpb=1.129802 time=192.3s + ttt_chunk [1041/1893] bpb=1.129480 time=194.1s + ttt_chunk [1051/1893] bpb=1.129238 time=196.0s + ttt_chunk [1061/1893] bpb=1.129291 time=197.8s + ttt_chunk [1071/1893] bpb=1.129911 time=199.7s + ttt_chunk [1081/1893] bpb=1.130168 time=201.5s + ttt_chunk [1091/1893] bpb=1.130902 time=203.4s + ttt_chunk [1101/1893] bpb=1.130922 time=205.3s + ttt_chunk [1111/1893] bpb=1.130786 time=207.1s + ttt_chunk [1121/1893] bpb=1.130584 time=209.1s + ttt_chunk [1131/1893] bpb=1.130459 time=210.9s + ttt_chunk [1141/1893] bpb=1.130144 time=212.8s + ttt_chunk [1151/1893] bpb=1.130155 time=214.6s + ttt_chunk [1161/1893] bpb=1.129763 time=216.5s + ttt_chunk [1171/1893] bpb=1.130096 time=218.3s + ttt_chunk [1181/1893] bpb=1.129364 time=220.2s + ttt_chunk [1191/1893] bpb=1.129244 time=222.0s + ttt_chunk [1201/1893] bpb=1.129644 time=223.9s + ttt_chunk [1211/1893] bpb=1.129181 time=225.7s + ttt_chunk [1221/1893] bpb=1.128888 time=227.6s + ttt_chunk [1231/1893] bpb=1.128596 time=229.4s + ttt_chunk [1241/1893] bpb=1.128229 time=231.3s + ttt_chunk [1251/1893] bpb=1.127630 time=233.1s + ttt_chunk [1261/1893] bpb=1.127593 time=235.0s + ttt_chunk [1271/1893] bpb=1.127223 time=236.9s + ttt_chunk [1281/1893] bpb=1.127027 time=238.7s + ttt_chunk [1291/1893] bpb=1.126794 time=240.6s + ttt_chunk [1301/1893] bpb=1.126214 time=242.6s + ttt_chunk [1311/1893] bpb=1.125833 time=244.4s + ttt_chunk [1321/1893] bpb=1.125500 time=246.3s + ttt_chunk [1331/1893] bpb=1.125437 time=248.1s + ttt_chunk [1341/1893] bpb=1.125315 time=250.0s + ttt_chunk [1351/1893] bpb=1.125248 time=251.8s + ttt_chunk [1361/1893] bpb=1.125282 time=253.7s + ttt_chunk [1371/1893] bpb=1.125151 time=255.5s + ttt_chunk [1381/1893] bpb=1.125134 time=257.4s + ttt_chunk [1391/1893] bpb=1.124738 time=259.2s + ttt_chunk [1401/1893] bpb=1.124715 time=261.1s + ttt_chunk [1411/1893] bpb=1.124819 time=262.9s + ttt_chunk [1421/1893] bpb=1.125080 time=264.8s + ttt_chunk [1431/1893] bpb=1.124780 time=266.6s + ttt_chunk [1441/1893] bpb=1.125278 time=268.5s + ttt_chunk [1451/1893] bpb=1.125622 time=270.4s + ttt_chunk [1461/1893] bpb=1.125177 time=272.2s + ttt_chunk [1471/1893] bpb=1.126235 time=274.1s + ttt_chunk [1481/1893] bpb=1.125765 time=276.1s + ttt_chunk [1491/1893] bpb=1.125589 time=277.9s + ttt_chunk [1501/1893] bpb=1.125507 time=279.8s + ttt_chunk [1511/1893] bpb=1.125528 time=281.6s + ttt_chunk [1521/1893] bpb=1.125532 time=283.5s + ttt_chunk [1531/1893] bpb=1.125010 time=285.3s + ttt_chunk [1541/1893] bpb=1.124861 time=287.2s + ttt_chunk [1551/1893] bpb=1.125182 time=289.1s + ttt_chunk [1561/1893] bpb=1.125186 time=290.9s + ttt_chunk [1571/1893] bpb=1.125020 time=292.8s + ttt_chunk [1581/1893] bpb=1.125124 time=294.6s + ttt_chunk [1591/1893] bpb=1.124973 time=296.4s + ttt_chunk [1601/1893] bpb=1.125145 time=298.3s + ttt_chunk [1611/1893] bpb=1.125088 time=300.1s + ttt_chunk [1621/1893] bpb=1.124694 time=302.0s + ttt_chunk [1631/1893] bpb=1.125013 time=303.9s + ttt_chunk [1641/1893] bpb=1.125025 time=305.8s + ttt_chunk [1651/1893] bpb=1.124989 time=307.6s + ttt_chunk [1661/1893] bpb=1.124875 time=309.6s + ttt_chunk [1671/1893] bpb=1.125343 time=311.4s + ttt_chunk [1681/1893] bpb=1.125491 time=313.3s + ttt_chunk [1691/1893] bpb=1.125303 time=315.1s + ttt_chunk [1701/1893] bpb=1.125455 time=316.9s + ttt_chunk [1711/1893] bpb=1.125457 time=318.8s + ttt_chunk [1721/1893] bpb=1.125456 time=320.6s + ttt_chunk [1731/1893] bpb=1.125337 time=322.5s + ttt_chunk [1741/1893] bpb=1.125135 time=324.3s + ttt_chunk [1751/1893] bpb=1.124951 time=326.2s + ttt_chunk [1761/1893] bpb=1.125100 time=328.0s + ttt_chunk [1771/1893] bpb=1.125009 time=329.8s + ttt_chunk [1781/1893] bpb=1.125029 time=331.7s + ttt_chunk [1791/1893] bpb=1.124620 time=333.5s + ttt_chunk [1801/1893] bpb=1.124505 time=335.4s + ttt_chunk [1811/1893] bpb=1.124409 time=337.3s + ttt_chunk [1821/1893] bpb=1.124465 time=339.1s + ttt_chunk [1831/1893] bpb=1.123860 time=341.0s + ttt_chunk [1841/1893] bpb=1.123806 time=342.9s + ttt_chunk [1851/1893] bpb=1.123597 time=344.8s + ttt_chunk [1861/1893] bpb=1.123232 time=346.6s + ttt_chunk [1871/1893] bpb=1.123225 time=348.5s + ttt_chunk [1881/1893] bpb=1.122776 time=350.3s + ttt_chunk [1891/1893] bpb=1.122542 time=352.2s + ttt_chunk [1893/1893] bpb=1.122587 time=352.4s +ttt_sliding:done val_loss=1.891617 val_bpb=1.120325 elapsed=352.4s +legal_ttt val_loss:1.8916 val_bpb:1.1203 eval_time:352873ms +legal_ttt_exact val_loss:1.89161712 val_bpb:1.12032517 diff --git a/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_ctw_seed1337.log b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_ctw_seed1337.log new file mode 100644 index 0000000000..36075c157b --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_ctw_seed1337.log @@ -0,0 +1,275 @@ +W0329 06:17:37.124000 61588 torch/distributed/run.py:803] +W0329 06:17:37.124000 61588 torch/distributed/run.py:803] ***************************************** +W0329 06:17:37.124000 61588 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0329 06:17:37.124000 61588 torch/distributed/run.py:803] ***************************************** +logs/anubhav_ctw_0_1_depth4_29mar2026_1147am.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26928220 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/9000 val_loss:6.9304 val_bpb:4.1046 train_time:0ms step_avg:0.02ms +step:1/9000 train_loss:6.9322 train_time:124ms step_avg:123.56ms +step:2/9000 train_loss:8.6544 train_time:151ms step_avg:75.48ms +step:3/9000 train_loss:7.6927 train_time:231ms step_avg:77.01ms +step:4/9000 train_loss:7.2519 train_time:312ms step_avg:78.08ms +step:5/9000 train_loss:7.1706 train_time:393ms step_avg:78.68ms +step:6/9000 train_loss:7.1159 train_time:474ms step_avg:78.98ms +step:7/9000 train_loss:7.0261 train_time:554ms step_avg:79.17ms +step:8/9000 train_loss:6.9587 train_time:635ms step_avg:79.38ms +step:9/9000 train_loss:6.5735 train_time:716ms step_avg:79.54ms +step:10/9000 train_loss:6.1997 train_time:797ms step_avg:79.68ms +step:500/9000 train_loss:2.3954 train_time:42113ms step_avg:84.23ms +step:1000/9000 train_loss:2.2630 train_time:84695ms step_avg:84.69ms +step:1500/9000 train_loss:2.2111 train_time:127457ms step_avg:84.97ms +step:2000/9000 train_loss:2.0533 train_time:170197ms step_avg:85.10ms +step:2500/9000 train_loss:2.1540 train_time:212967ms step_avg:85.19ms +step:3000/9000 train_loss:2.1480 train_time:255718ms step_avg:85.24ms +step:3500/9000 train_loss:2.1721 train_time:298492ms step_avg:85.28ms +step:4000/9000 train_loss:1.9638 train_time:341213ms step_avg:85.30ms +step:4000/9000 val_loss:2.0535 val_bpb:1.2162 train_time:341270ms step_avg:85.32ms +step:4500/9000 train_loss:2.1115 train_time:383932ms step_avg:85.32ms +step:5000/9000 train_loss:2.0962 train_time:426610ms step_avg:85.32ms +step:5500/9000 train_loss:2.0111 train_time:469352ms step_avg:85.34ms +step:6000/9000 train_loss:1.9310 train_time:512059ms step_avg:85.34ms +swa:start step:6350 +step:6500/9000 train_loss:2.0717 train_time:554949ms step_avg:85.38ms +late_qat:enabled step:6502 scale:0.1498 +step:7000/9000 train_loss:1.7851 train_time:598111ms step_avg:85.44ms +step:7023/9000 val_loss:1.9224 val_bpb:1.1386 train_time:600130ms step_avg:85.45ms +stopping_early: wallclock_cap train_time:600130ms step:7023/9000 +peak memory allocated: 21471 MiB reserved: 22002 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9207 val_bpb:1.1375 eval_time:2001ms +Serialized model: 106027446 bytes +Code size: 97252 bytes +Serialized model int6+lzma: 15757536 bytes +Total submission size int6+lzma: 15854788 bytes +final_int6_roundtrip val_loss:1.9344 val_bpb:1.1457 eval_time:5555ms +final_int6_roundtrip_exact val_loss:1.93440371 val_bpb:1.14566284 +final_int6_sliding_window val_loss:1.8946 val_bpb:1.1221 stride:64 eval_time:73689ms +final_int6_sliding_window_exact val_loss:1.89464129 val_bpb:1.12211626 +final_int8_zlib_roundtrip_exact val_loss:1.89464129 val_bpb:1.12211626 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 ctw_weight=0.1 ctw_depth=4 +ttt_sliding:params unfrozen=26928220 frozen=0 + ttt_chunk [1/1893] bpb=1.166304 time=3.6s + ttt_chunk [11/1893] bpb=1.154979 time=19.5s + ttt_chunk [21/1893] bpb=1.141350 time=33.7s + ttt_chunk [31/1893] bpb=1.138283 time=49.1s + ttt_chunk [41/1893] bpb=1.124255 time=63.1s + ttt_chunk [51/1893] bpb=1.118337 time=77.1s + ttt_chunk [61/1893] bpb=1.124713 time=92.2s + ttt_chunk [71/1893] bpb=1.123083 time=106.4s + ttt_chunk [81/1893] bpb=1.122308 time=120.4s + ttt_chunk [91/1893] bpb=1.123001 time=134.2s + ttt_chunk [101/1893] bpb=1.126490 time=148.3s + ttt_chunk [111/1893] bpb=1.128794 time=164.3s + ttt_chunk [121/1893] bpb=1.122378 time=177.9s + ttt_chunk [131/1893] bpb=1.122824 time=191.5s + ttt_chunk [141/1893] bpb=1.128473 time=205.5s + ttt_chunk [151/1893] bpb=1.130326 time=219.4s + ttt_chunk [161/1893] bpb=1.129723 time=233.1s + ttt_chunk [171/1893] bpb=1.133974 time=247.3s + ttt_chunk [181/1893] bpb=1.136129 time=263.8s + ttt_chunk [191/1893] bpb=1.143343 time=278.7s + ttt_chunk [201/1893] bpb=1.142047 time=292.5s + ttt_chunk [211/1893] bpb=1.139768 time=306.2s + ttt_chunk [221/1893] bpb=1.141298 time=320.1s + ttt_chunk [231/1893] bpb=1.139916 time=333.9s + ttt_chunk [241/1893] bpb=1.140212 time=347.5s + ttt_chunk [251/1893] bpb=1.139710 time=361.2s + ttt_chunk [261/1893] bpb=1.136854 time=374.7s + ttt_chunk [271/1893] bpb=1.135716 time=389.0s + ttt_chunk [281/1893] bpb=1.137024 time=405.8s + ttt_chunk [291/1893] bpb=1.138886 time=419.5s + ttt_chunk [301/1893] bpb=1.139589 time=433.3s + ttt_chunk [311/1893] bpb=1.141652 time=447.0s + ttt_chunk [321/1893] bpb=1.143615 time=460.7s + ttt_chunk [331/1893] bpb=1.143466 time=474.3s + ttt_chunk [341/1893] bpb=1.142417 time=488.1s + ttt_chunk [351/1893] bpb=1.144734 time=501.9s + ttt_chunk [361/1893] bpb=1.144896 time=515.7s + ttt_chunk [371/1893] bpb=1.144218 time=529.4s + ttt_chunk [381/1893] bpb=1.144361 time=543.4s + ttt_chunk [391/1893] bpb=1.144199 time=557.4s + ttt_chunk [401/1893] bpb=1.142179 time=572.1s + ttt_chunk [411/1893] bpb=1.141018 time=591.6s + ttt_chunk [421/1893] bpb=1.140093 time=606.2s + ttt_chunk [431/1893] bpb=1.139901 time=620.4s + ttt_chunk [441/1893] bpb=1.140234 time=634.3s + ttt_chunk [451/1893] bpb=1.140487 time=648.4s + ttt_chunk [461/1893] bpb=1.139389 time=662.2s + ttt_chunk [471/1893] bpb=1.140024 time=676.4s + ttt_chunk [481/1893] bpb=1.139647 time=690.3s + ttt_chunk [491/1893] bpb=1.138510 time=704.3s + ttt_chunk [501/1893] bpb=1.137966 time=718.2s + ttt_chunk [511/1893] bpb=1.137248 time=732.1s + ttt_chunk [521/1893] bpb=1.134855 time=746.1s + ttt_chunk [531/1893] bpb=1.136011 time=760.0s + ttt_chunk [541/1893] bpb=1.136320 time=774.1s + ttt_chunk [551/1893] bpb=1.135254 time=787.9s + ttt_chunk [561/1893] bpb=1.135779 time=801.9s + ttt_chunk [571/1893] bpb=1.134711 time=815.9s + ttt_chunk [581/1893] bpb=1.133879 time=830.8s + ttt_chunk [591/1893] bpb=1.133225 time=853.4s + ttt_chunk [601/1893] bpb=1.133693 time=867.1s + ttt_chunk [611/1893] bpb=1.133601 time=880.8s + ttt_chunk [621/1893] bpb=1.133449 time=894.7s + ttt_chunk [631/1893] bpb=1.134158 time=908.5s + ttt_chunk [641/1893] bpb=1.133896 time=922.1s + ttt_chunk [651/1893] bpb=1.133980 time=935.9s + ttt_chunk [661/1893] bpb=1.133461 time=949.7s + ttt_chunk [671/1893] bpb=1.133777 time=963.5s + ttt_chunk [681/1893] bpb=1.134496 time=977.4s + ttt_chunk [691/1893] bpb=1.135502 time=991.6s + ttt_chunk [701/1893] bpb=1.134941 time=1005.3s + ttt_chunk [711/1893] bpb=1.134908 time=1019.2s + ttt_chunk [721/1893] bpb=1.134524 time=1033.2s + ttt_chunk [731/1893] bpb=1.134540 time=1047.0s + ttt_chunk [741/1893] bpb=1.134625 time=1060.9s + ttt_chunk [751/1893] bpb=1.134459 time=1074.8s + ttt_chunk [761/1893] bpb=1.134375 time=1088.6s + ttt_chunk [771/1893] bpb=1.134064 time=1102.4s + ttt_chunk [781/1893] bpb=1.134808 time=1116.2s + ttt_chunk [791/1893] bpb=1.134361 time=1130.1s + ttt_chunk [801/1893] bpb=1.134652 time=1143.9s + ttt_chunk [811/1893] bpb=1.134393 time=1158.0s + ttt_chunk [821/1893] bpb=1.134160 time=1171.8s + ttt_chunk [831/1893] bpb=1.133982 time=1187.4s + ttt_chunk [841/1893] bpb=1.133324 time=1204.9s + ttt_chunk [851/1893] bpb=1.133080 time=1222.5s + ttt_chunk [861/1893] bpb=1.132826 time=1236.6s + ttt_chunk [871/1893] bpb=1.133093 time=1250.9s + ttt_chunk [881/1893] bpb=1.133272 time=1264.8s + ttt_chunk [891/1893] bpb=1.132815 time=1278.5s + ttt_chunk [901/1893] bpb=1.132533 time=1292.2s + ttt_chunk [911/1893] bpb=1.132658 time=1306.0s + ttt_chunk [921/1893] bpb=1.133130 time=1320.0s + ttt_chunk [931/1893] bpb=1.133119 time=1333.7s + ttt_chunk [941/1893] bpb=1.132806 time=1348.2s + ttt_chunk [951/1893] bpb=1.133187 time=1362.4s + ttt_chunk [961/1893] bpb=1.133255 time=1376.6s + ttt_chunk [971/1893] bpb=1.134111 time=1390.7s + ttt_chunk [981/1893] bpb=1.134181 time=1404.8s + ttt_chunk [991/1893] bpb=1.134213 time=1418.9s + ttt_chunk [1001/1893] bpb=1.134164 time=1432.8s + ttt_chunk [1011/1893] bpb=1.133953 time=1447.1s + ttt_chunk [1021/1893] bpb=1.134293 time=1461.3s + ttt_chunk [1031/1893] bpb=1.134740 time=1475.5s + ttt_chunk [1041/1893] bpb=1.134378 time=1489.9s + ttt_chunk [1051/1893] bpb=1.134121 time=1504.2s + ttt_chunk [1061/1893] bpb=1.134156 time=1518.5s + ttt_chunk [1071/1893] bpb=1.134761 time=1532.9s + ttt_chunk [1081/1893] bpb=1.135037 time=1547.3s + ttt_chunk [1091/1893] bpb=1.135768 time=1561.8s + ttt_chunk [1101/1893] bpb=1.135813 time=1575.9s + ttt_chunk [1111/1893] bpb=1.135656 time=1590.5s + ttt_chunk [1121/1893] bpb=1.135425 time=1604.8s + ttt_chunk [1131/1893] bpb=1.135306 time=1618.8s + ttt_chunk [1141/1893] bpb=1.135000 time=1632.9s + ttt_chunk [1151/1893] bpb=1.135020 time=1647.1s + ttt_chunk [1161/1893] bpb=1.134642 time=1661.2s + ttt_chunk [1171/1893] bpb=1.134957 time=1677.8s + ttt_chunk [1181/1893] bpb=1.134215 time=1696.8s + ttt_chunk [1191/1893] bpb=1.134096 time=1720.4s + ttt_chunk [1201/1893] bpb=1.134501 time=1737.3s + ttt_chunk [1211/1893] bpb=1.134026 time=1751.6s + ttt_chunk [1221/1893] bpb=1.133725 time=1766.3s + ttt_chunk [1231/1893] bpb=1.133435 time=1780.6s + ttt_chunk [1241/1893] bpb=1.133076 time=1794.6s + ttt_chunk [1251/1893] bpb=1.132494 time=1808.5s + ttt_chunk [1261/1893] bpb=1.132457 time=1822.7s + ttt_chunk [1271/1893] bpb=1.132076 time=1837.0s + ttt_chunk [1281/1893] bpb=1.131874 time=1850.9s + ttt_chunk [1291/1893] bpb=1.131640 time=1865.2s + ttt_chunk [1301/1893] bpb=1.131042 time=1879.6s + ttt_chunk [1311/1893] bpb=1.130650 time=1893.8s + ttt_chunk [1321/1893] bpb=1.130309 time=1908.1s + ttt_chunk [1331/1893] bpb=1.130258 time=1922.4s + ttt_chunk [1341/1893] bpb=1.130139 time=1936.7s + ttt_chunk [1351/1893] bpb=1.130080 time=1950.9s + ttt_chunk [1361/1893] bpb=1.130147 time=1965.3s + ttt_chunk [1371/1893] bpb=1.130003 time=1979.5s + ttt_chunk [1381/1893] bpb=1.129981 time=1993.6s + ttt_chunk [1391/1893] bpb=1.129574 time=2008.0s + ttt_chunk [1401/1893] bpb=1.129548 time=2022.3s + ttt_chunk [1411/1893] bpb=1.129669 time=2036.5s + ttt_chunk [1421/1893] bpb=1.129922 time=2050.7s + ttt_chunk [1431/1893] bpb=1.129632 time=2065.0s + ttt_chunk [1441/1893] bpb=1.130144 time=2079.4s + ttt_chunk [1451/1893] bpb=1.130485 time=2093.6s + ttt_chunk [1461/1893] bpb=1.130051 time=2107.7s + ttt_chunk [1471/1893] bpb=1.131078 time=2122.2s + ttt_chunk [1481/1893] bpb=1.130632 time=2136.6s + ttt_chunk [1491/1893] bpb=1.130452 time=2151.4s + ttt_chunk [1501/1893] bpb=1.130353 time=2165.6s + ttt_chunk [1511/1893] bpb=1.130375 time=2180.2s + ttt_chunk [1521/1893] bpb=1.130410 time=2194.9s + ttt_chunk [1531/1893] bpb=1.129894 time=2210.2s + ttt_chunk [1541/1893] bpb=1.129753 time=2224.8s + ttt_chunk [1551/1893] bpb=1.130071 time=2239.3s + ttt_chunk [1561/1893] bpb=1.130054 time=2253.7s + ttt_chunk [1571/1893] bpb=1.129905 time=2268.0s + ttt_chunk [1581/1893] bpb=1.130010 time=2282.3s + ttt_chunk [1591/1893] bpb=1.129858 time=2296.6s + ttt_chunk [1601/1893] bpb=1.130033 time=2311.0s + ttt_chunk [1611/1893] bpb=1.129970 time=2325.6s + ttt_chunk [1621/1893] bpb=1.129571 time=2340.0s + ttt_chunk [1631/1893] bpb=1.129884 time=2354.4s + ttt_chunk [1641/1893] bpb=1.129895 time=2368.9s + ttt_chunk [1651/1893] bpb=1.129844 time=2386.2s + ttt_chunk [1661/1893] bpb=1.129722 time=2407.0s + ttt_chunk [1671/1893] bpb=1.130197 time=2430.6s + ttt_chunk [1681/1893] bpb=1.130338 time=2451.3s + ttt_chunk [1691/1893] bpb=1.130159 time=2465.9s + ttt_chunk [1701/1893] bpb=1.130321 time=2480.6s + ttt_chunk [1711/1893] bpb=1.130332 time=2495.2s + ttt_chunk [1721/1893] bpb=1.130342 time=2509.8s + ttt_chunk [1731/1893] bpb=1.130212 time=2524.3s + ttt_chunk [1741/1893] bpb=1.130030 time=2538.7s + ttt_chunk [1751/1893] bpb=1.129853 time=2553.2s + ttt_chunk [1761/1893] bpb=1.129995 time=2567.7s + ttt_chunk [1771/1893] bpb=1.129897 time=2582.2s + ttt_chunk [1781/1893] bpb=1.129920 time=2596.6s + ttt_chunk [1791/1893] bpb=1.129517 time=2611.0s + ttt_chunk [1801/1893] bpb=1.129376 time=2625.4s + ttt_chunk [1811/1893] bpb=1.129276 time=2640.0s + ttt_chunk [1821/1893] bpb=1.129337 time=2654.8s + ttt_chunk [1831/1893] bpb=1.128723 time=2669.4s + ttt_chunk [1841/1893] bpb=1.128649 time=2684.3s + ttt_chunk [1851/1893] bpb=1.128421 time=2698.8s + ttt_chunk [1861/1893] bpb=1.128056 time=2713.7s + ttt_chunk [1871/1893] bpb=1.128033 time=2728.5s + ttt_chunk [1881/1893] bpb=1.127573 time=2743.0s + ttt_chunk [1891/1893] bpb=1.127331 time=2757.5s + ttt_chunk [1893/1893] bpb=1.127376 time=2760.0s +ttt_sliding:done val_loss=1.899909 val_bpb=1.125236 elapsed=2760.1s +legal_ttt val_loss:1.8999 val_bpb:1.1252 eval_time:2764541ms +legal_ttt_exact val_loss:1.89990933 val_bpb:1.12523630 diff --git a/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_gpt.py b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_gpt.py new file mode 100644 index 0000000000..26752b8ab4 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_gpt.py @@ -0,0 +1,2055 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ctw_weight = float(os.environ.get("CTW_WEIGHT", 0.0)) # Enable with CTW_WEIGHT=0.1 + ctw_depth = int(os.environ.get("CTW_DEPTH", 4)) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + # No CastedLinear -- weights come from banks + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5) + return F.linear(x.square(), down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + +# --- CTW: Context Tree Weighting (novel eval-time augmentation) --- +# Sparse lazy M-ary CTW — zero artifact cost, Bayesian-optimal +# Ref: Willems, Shtarkov, Tjalkens (1995) + +class CTWNode: + __slots__ = ['counts', 'total', 'children'] + def __init__(self): + self.counts = {} + self.total = 0 + self.children = {} + +class SparseCTW: + """Sparse M-ary CTW for eval-time augmentation. Nodes allocated on-demand.""" + def __init__(self, depth: int = 4, vocab_size: int = 1024, alpha: float = 0.5): + self.depth = depth + self.vocab_size = vocab_size + self.alpha = alpha + self.alpha_sum = alpha * vocab_size + self.root = CTWNode() + self.context: list[int] = [] + + def update(self, symbol: int): + node = self.root + node.counts[symbol] = node.counts.get(symbol, 0) + 1 + node.total += 1 + for i in range(min(len(self.context), self.depth)): + ctx_sym = self.context[-(i + 1)] + if ctx_sym not in node.children: + node.children[ctx_sym] = CTWNode() + node = node.children[ctx_sym] + node.counts[symbol] = node.counts.get(symbol, 0) + 1 + node.total += 1 + self.context.append(symbol) + if len(self.context) > self.depth + 1: + self.context = self.context[-(self.depth + 1):] + + def predict_logprobs(self, device: torch.device) -> Tensor: + """Get CTW log-probability distribution as a tensor.""" + node = self.root + for i in range(min(len(self.context), self.depth)): + ctx_sym = self.context[-(i + 1)] + if ctx_sym in node.children: + node = node.children[ctx_sym] + else: + break + probs = torch.full((self.vocab_size,), self.alpha / (node.total + self.alpha_sum), device=device) + for sym, count in node.counts.items(): + if sym < self.vocab_size: + probs[sym] = (count + self.alpha) / (node.total + self.alpha_sum) + return torch.log(probs.clamp(1e-10)) + + def mix_with_neural(self, neural_logits: Tensor, w_ctw: float = 0.1) -> Tensor: + """Mix CTW probs with neural logits via log-linear interpolation.""" + if not self.context: + return neural_logits + ctw_lp = self.predict_logprobs(neural_logits.device) + neural_lp = F.log_softmax(neural_logits, dim=-1) + mixed = (1 - w_ctw) * neural_lp + w_ctw * ctw_lp + return mixed - mixed.logsumexp(dim=-1, keepdim=True) + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe) with optional CTW augmentation: + score each chunk with sliding windows, then train on it. + Every token scored BEFORE any update that could use it. + When CTW is enabled, neural logits are mixed with CTW predictions per-token.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # CTW integration: build suffix tree from scored tokens, mix into logits + use_ctw = args.ctw_weight > 0 + ctw = SparseCTW(depth=args.ctw_depth, vocab_size=args.vocab_size) if use_ctw else None + w_ctw = args.ctw_weight + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}" + f"{f' ctw_weight={w_ctw} ctw_depth={args.ctw_depth}' if use_ctw else ''}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + if use_ctw: + # Per-token CTW mixing on scored suffix + for t_idx in range(s, wlen): + token_logits = logits[i, t_idx, :].float() + mixed_lp = ctw.mix_with_neural(token_logits, w_ctw=w_ctw) + token_nll = F.cross_entropy(mixed_lp.unsqueeze(0), y_batch[i, t_idx:t_idx+1], reduction="sum") + loss_sum += token_nll.to(torch.float64) + ctw.update(y_batch[i, t_idx].item()) + else: + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +def eval_val_sliding_ctw( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, log0=print, +) -> tuple[float, float]: + """CTW-augmented sliding window eval. Mixes neural logits with CTW predictions. + Novel contribution: Bayesian-optimal sequential probability assignment at zero artifact cost.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + ctw = SparseCTW(depth=args.ctw_depth, vocab_size=args.vocab_size, alpha=0.5) + w_ctw = args.ctw_weight + + log0(f"ctw_eval:start windows={len(my_windows)} ctw_weight={w_ctw} depth={args.ctw_depth}") + base_model.eval() + t0 = time.perf_counter() + with torch.inference_mode(): + for wi, ws in enumerate(my_windows): + if rank == 0 and wi % max(1, len(my_windows) // 10) == 0: + pct = 100.0 * wi / max(len(my_windows), 1) + log0(f" ctw_eval: {pct:.0f}% | {time.perf_counter() - t0:.1f}s") + end = min(ws + seq_len, total_tokens) + wlen = end - ws + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x = chunk[:-1].unsqueeze(0) + y = chunk[1:] + s = 0 if ws == 0 else max(wlen - stride, 0) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x) + logits_scored = logits[0, s:wlen, :].float() + targets_scored = y[s:wlen] + # Per-token CTW mixing + for t_idx in range(logits_scored.size(0)): + mixed = ctw.mix_with_neural(logits_scored[t_idx], w_ctw=w_ctw) + nll = F.cross_entropy(mixed.unsqueeze(0), targets_scored[t_idx:t_idx+1], reduction="sum") + loss_sum += nll.to(torch.float64) + ctw.update(targets_scored[t_idx].item()) + token_count += float(logits_scored.size(0)) + tgt = y[s:wlen] + prev = chunk[s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + log0(f"ctw_eval:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + + # Optimizer split: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + # Novel: CTW eval-time augmentation (standalone, when TTT is disabled) + # When TTT is enabled, CTW is already integrated into the TTT scoring loop above + if args.ctw_weight > 0 and not args.ttt_enabled: + torch.cuda.synchronize() + t_ctw = time.perf_counter() + ctw_loss, ctw_bpb = eval_val_sliding_ctw( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"ctw_augmented val_loss:{ctw_loss:.4f} val_bpb:{ctw_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ctw):.0f}ms") + log0(f"ctw_augmented_exact val_loss:{ctw_loss:.8f} val_bpb:{ctw_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_gpt_improved_ctw_implementation_dont_have_compute_left_to_run_somebody_else_do_it.py b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_gpt_improved_ctw_implementation_dont_have_compute_left_to_run_somebody_else_do_it.py new file mode 100644 index 0000000000..4c2638273c --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_CTW_on_LeakyReLU_TTT_ParallelMuon/train_gpt_improved_ctw_implementation_dont_have_compute_left_to_run_somebody_else_do_it.py @@ -0,0 +1,2228 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ctw_weight = float(os.environ.get("CTW_WEIGHT", 0.0)) # Enable with CTW_WEIGHT=0.1 + ctw_depth = int(os.environ.get("CTW_DEPTH", 4)) + ctw_entropy_threshold = float(os.environ.get("CTW_ENTROPY_THRESHOLD", 2.0)) # Skip CTW when H < this + slot_enabled = bool(int(os.environ.get("SLOT_ENABLED", "0"))) + slot_lr = float(os.environ.get("SLOT_LR", 0.001)) + slot_steps = int(os.environ.get("SLOT_STEPS", 3)) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + # No CastedLinear -- weights come from banks + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5) + return F.linear(x.square(), down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_hidden(self, input_ids: Tensor) -> Tensor: + """Return final hidden states (bsz, seq_len, model_dim) before lm_head projection.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + return self.final_norm(x) + + def compute_logits(self, hidden_states: Tensor) -> Tensor: + """Project hidden states to logits with softcap.""" + if self.tie_embeddings: + logits_proj = F.linear(hidden_states, self.tok_emb.weight) + else: + logits_proj = self.lm_head(hidden_states) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + return self.compute_logits(self.forward_hidden(input_ids)) + +# --- CTW: Proper Context Tree Weighting (recursive depth weighting) --- +# Sparse lazy M-ary CTW with beta-tracked depth weights +# Ref: Willems, Shtarkov, Tjalkens (1995); eta propagation per Willems & Tjalkens (1997) +# Fixed: previous version used deepest-match lookup (just an n-gram model). +# This version computes the proper recursive weighted mixture across ALL depths. + +class CTWNode: + __slots__ = ['counts', 'total', 'children', 'log_pe', 'log_pw'] + def __init__(self): + self.counts = {} # symbol -> count + self.total = 0 + self.children = {} # context_symbol -> child node + self.log_pe = 0.0 # cumulative log P_e (KT sequential probability) + self.log_pw = 0.0 # cumulative log P_w (weighted probability) + +class SparseCTW: + """Proper sparse M-ary CTW. Nodes allocated on-demand. Beta-weighted depth mixing. + Verified against Python/Go/Rust implementations in project knowledge.""" + def __init__(self, depth: int = 4, vocab_size: int = 1024, alpha: float = 0.5): + self.depth = depth + self.vocab_size = vocab_size + self.alpha = alpha + self.alpha_sum = alpha * vocab_size + self.root = CTWNode() + self.context: list[int] = [] + + def _kt_prob(self, node: CTWNode, symbol: int) -> float: + """KT estimator: P_e(symbol | counts) = (count_s + alpha) / (total + alpha_sum)""" + return (node.counts.get(symbol, 0) + self.alpha) / (node.total + self.alpha_sum) + + def update(self, symbol: int): + """Update tree bottom-up along context path. Maintains log_pe and log_pw.""" + # Collect path nodes (root to deepest) + path: list[CTWNode] = [self.root] + node = self.root + for i in range(min(len(self.context), self.depth)): + ctx_sym = self.context[-(i + 1)] + if ctx_sym not in node.children: + node.children[ctx_sym] = CTWNode() + node = node.children[ctx_sym] + path.append(node) + + # Update counts and log_pe at each node on path + for node in path: + p_kt = self._kt_prob(node, symbol) + node.log_pe += math.log(max(p_kt, 1e-30)) + node.counts[symbol] = node.counts.get(symbol, 0) + 1 + node.total += 1 + + # Recompute log_pw bottom-up along path + # Leaf (deepest): P_w = P_e + path[-1].log_pw = path[-1].log_pe + # Walk back up + for i in range(len(path) - 2, -1, -1): + node = path[i] + child = path[i + 1] + # P_w(node) = 0.5 * P_e(node) + 0.5 * P_w(child) + # (Other children unchanged, their P_w = P_e = 1.0 initially, log = 0) + # For the child on the context path: + lpe = node.log_pe + lpw_child = child.log_pw + # log(0.5 * exp(lpe) + 0.5 * exp(lpw_child)) + max_lp = max(lpe, lpw_child) + node.log_pw = max_lp + math.log( + 0.5 * math.exp(min(lpe - max_lp, 20)) + + 0.5 * math.exp(min(lpw_child - max_lp, 20)) + ) + + self.context.append(symbol) + if len(self.context) > self.depth + 1: + self.context = self.context[-(self.depth + 1):] + + def predict(self, device: torch.device) -> Tensor: + """Compute proper CTW predictive distribution using beta-weighted depth mixing. + P_ctw(s) = sum over depths d: w_d * P_e_d(s), where w_d comes from + the recursive beta = exp(log_pe - log_pw) at each node.""" + # Walk down context path, collect nodes + path: list[CTWNode] = [self.root] + node = self.root + for i in range(min(len(self.context), self.depth)): + ctx_sym = self.context[-(i + 1)] + if ctx_sym in node.children: + node = node.children[ctx_sym] + path.append(node) + else: + break + + # Bottom-up: compute weighted predictive distribution + # Start from deepest node: P_w = P_e (KT estimate) + deepest = path[-1] + probs = torch.full((self.vocab_size,), self.alpha / (deepest.total + self.alpha_sum), device=device) + for sym, count in deepest.counts.items(): + if sym < self.vocab_size: + probs[sym] = (count + self.alpha) / (deepest.total + self.alpha_sum) + + # Walk back up, mixing KT at each depth with deeper weighted estimate + for i in range(len(path) - 2, -1, -1): + node = path[i] + # KT estimate at this depth + pe = torch.full((self.vocab_size,), self.alpha / (node.total + self.alpha_sum), device=device) + for sym, count in node.counts.items(): + if sym < self.vocab_size: + pe[sym] = (count + self.alpha) / (node.total + self.alpha_sum) + # Beta = exp(log_pe - log_pw): posterior weight for "this node is a leaf" + log_beta = node.log_pe - node.log_pw + log_beta = max(min(log_beta, 20), -20) # clamp + beta = math.exp(log_beta) + w_local = beta / (1.0 + beta) # weight for local KT vs deeper + probs = w_local * pe + (1.0 - w_local) * probs + + return probs + + def mix_with_neural(self, neural_logits: Tensor, w_ctw: float = 0.1, + entropy_threshold: float = 0.0) -> Tensor: + """Entropy-adaptive CTW mixing. Only mixes when neural entropy > threshold. + When neural model is confident (low entropy), CTW adds noise — skip it.""" + if not self.context: + return neural_logits + + # Entropy-adaptive gating: compute neural entropy + if entropy_threshold > 0: + neural_probs = F.softmax(neural_logits, dim=-1) + entropy = -(neural_probs * torch.log(neural_probs + 1e-10)).sum() + if entropy.item() < entropy_threshold: + return neural_logits # Neural is confident — skip CTW + # Scale CTW weight by entropy (higher entropy = more CTW influence) + max_entropy = math.log(self.vocab_size) # ~6.93 for vocab 1024 + entropy_scale = min(entropy.item() / max_entropy, 1.0) + effective_w = w_ctw * entropy_scale + else: + effective_w = w_ctw + + ctw_probs = self.predict(neural_logits.device).clamp(1e-10) + ctw_lp = torch.log(ctw_probs) + neural_lp = F.log_softmax(neural_logits, dim=-1) + mixed = (1.0 - effective_w) * neural_lp + effective_w * ctw_lp + return mixed - mixed.logsumexp(dim=-1, keepdim=True) + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe) with optional CTW augmentation: + score each chunk with sliding windows, then train on it. + Every token scored BEFORE any update that could use it. + When CTW is enabled, neural logits are mixed with CTW predictions per-token.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # CTW integration: build suffix tree from scored tokens, mix into logits + use_ctw = args.ctw_weight > 0 + ctw = SparseCTW(depth=args.ctw_depth, vocab_size=args.vocab_size) if use_ctw else None + w_ctw = args.ctw_weight + + # SLOT integration: per-batch delta optimization at last hidden layer + use_slot = args.slot_enabled + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}" + f"{f' ctw_weight={w_ctw} ctw_depth={args.ctw_depth} entropy_thresh={args.ctw_entropy_threshold}' if use_ctw else ''}" + f"{f' slot_lr={args.slot_lr} slot_steps={args.slot_steps}' if use_slot else ''}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + if use_slot: + # SLOT path: need gradients for delta optimization, can't use inference_mode + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + # Step 1: Get hidden states (no grad needed for forward) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + H = base_model.forward_hidden(x_batch) + H = H.detach().float() # detach from model graph, float for SLOT precision + # Step 2: Optimize delta (needs gradients through compute_logits only) + delta = torch.zeros(1, 1, H.shape[-1], device=device, dtype=H.dtype, requires_grad=True) + slot_opt = torch.optim.AdamW([delta], lr=args.slot_lr, weight_decay=1e-8, eps=1e-5) + for _step in range(args.slot_steps): + slot_opt.zero_grad() + adapted_logits = base_model.compute_logits((H + delta).to(torch.bfloat16)).float() + shift_logits = adapted_logits[:, :-1, :].contiguous() + shift_targets = y_batch[:, :seq_len-1].contiguous() + slot_loss = F.cross_entropy(shift_logits.reshape(-1, shift_logits.size(-1)), + shift_targets.reshape(-1), reduction="mean") + slot_loss.backward() + slot_opt.step() + # Step 3: Score with adapted hidden states + with torch.no_grad(): + logits = base_model.compute_logits((H + delta.detach()).to(torch.bfloat16)).float() + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + else: + # Standard path: inference_mode, optional CTW + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + if use_ctw: + # VECTORIZED entropy gate: compute entropy for ALL scored tokens at once + scored_logits = logits[i, s:wlen, :].float() + scored_targets = y_batch[i, s:wlen] + neural_lp = F.log_softmax(scored_logits, dim=-1) + neural_probs = neural_lp.exp() + token_entropy = -(neural_probs * neural_lp).sum(dim=-1) + all_nll = F.cross_entropy(scored_logits, scored_targets, reduction="none") + for t_idx in range(scored_logits.size(0)): + abs_idx = s + t_idx + target_tok = y_batch[i, abs_idx].item() + if token_entropy[t_idx].item() >= args.ctw_entropy_threshold: + ctw_probs = ctw.predict(logits.device).clamp(1e-10) + ctw_lp = torch.log(ctw_probs) + max_ent = math.log(ctw.vocab_size) + ew = w_ctw * min(token_entropy[t_idx].item() / max_ent, 1.0) + mixed = (1.0 - ew) * neural_lp[t_idx] + ew * ctw_lp + mixed = mixed - mixed.logsumexp(dim=-1, keepdim=True) + token_nll = F.cross_entropy(mixed.unsqueeze(0), scored_targets[t_idx:t_idx+1], reduction="sum") + loss_sum += token_nll.to(torch.float64) + else: + loss_sum += all_nll[t_idx].to(torch.float64) + ctw.update(target_tok) + else: + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +def eval_val_sliding_ctw( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, log0=print, +) -> tuple[float, float]: + """CTW-augmented sliding window eval. Mixes neural logits with CTW predictions. + Novel contribution: Bayesian-optimal sequential probability assignment at zero artifact cost.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + ctw = SparseCTW(depth=args.ctw_depth, vocab_size=args.vocab_size, alpha=0.5) + w_ctw = args.ctw_weight + + log0(f"ctw_eval:start windows={len(my_windows)} ctw_weight={w_ctw} depth={args.ctw_depth}") + base_model.eval() + t0 = time.perf_counter() + with torch.inference_mode(): + for wi, ws in enumerate(my_windows): + if rank == 0 and wi % max(1, len(my_windows) // 10) == 0: + pct = 100.0 * wi / max(len(my_windows), 1) + log0(f" ctw_eval: {pct:.0f}% | {time.perf_counter() - t0:.1f}s") + end = min(ws + seq_len, total_tokens) + wlen = end - ws + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x = chunk[:-1].unsqueeze(0) + y = chunk[1:] + s = 0 if ws == 0 else max(wlen - stride, 0) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x) + logits_scored = logits[0, s:wlen, :].float() + targets_scored = y[s:wlen] + # VECTORIZED entropy gate + neural_lp = F.log_softmax(logits_scored, dim=-1) + neural_probs = neural_lp.exp() + token_entropy = -(neural_probs * neural_lp).sum(dim=-1) + all_nll = F.cross_entropy(logits_scored, targets_scored, reduction="none") + + for t_idx in range(logits_scored.size(0)): + target_tok = targets_scored[t_idx].item() + if token_entropy[t_idx].item() >= args.ctw_entropy_threshold: + ctw_probs = ctw.predict(logits.device).clamp(1e-10) + ctw_lp = torch.log(ctw_probs) + max_ent = math.log(ctw.vocab_size) + ew = w_ctw * min(token_entropy[t_idx].item() / max_ent, 1.0) + mixed = (1.0 - ew) * neural_lp[t_idx] + ew * ctw_lp + mixed = mixed - mixed.logsumexp(dim=-1, keepdim=True) + token_nll = F.cross_entropy(mixed.unsqueeze(0), targets_scored[t_idx:t_idx+1], reduction="sum") + loss_sum += token_nll.to(torch.float64) + else: + loss_sum += all_nll[t_idx].to(torch.float64) + ctw.update(target_tok) + token_count += float(logits_scored.size(0)) + tgt = y[s:wlen] + prev = chunk[s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + log0(f"ctw_eval:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + + # Optimizer split: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + # Novel: CTW eval-time augmentation (standalone, when TTT is disabled) + # When TTT is enabled, CTW is already integrated into the TTT scoring loop above + if args.ctw_weight > 0 and not args.ttt_enabled: + torch.cuda.synchronize() + t_ctw = time.perf_counter() + ctw_loss, ctw_bpb = eval_val_sliding_ctw( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"ctw_augmented val_loss:{ctw_loss:.4f} val_bpb:{ctw_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ctw):.0f}ms") + log0(f"ctw_augmented_exact val_loss:{ctw_loss:.8f} val_bpb:{ctw_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/README.md b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/README.md new file mode 100644 index 0000000000..cba47b1094 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/README.md @@ -0,0 +1,161 @@ +# Non-Record: SLOT Eval-Time Augmentation on PR #549 SOTA Stack + +**val_bpb = 1.1185** (3-seed mean, std 0.0003) | ~15.9 MB | 8×H100 SXM + +First SLOT-based entry in Parameter Golf. Novel eval-time augmentation achieving **-0.0008 BPB** improvement over the baseline, consistent across all 3 seeds. + +## Results + +### SLOT-Enabled (3-seed) + +| Seed | Steps | Step Avg | Pre-TTT BPB | Post-TTT+SLOT BPB | TTT+SLOT Time | Artifact | +|------|-------|----------|-------------|-------------------|---------------|----------| +| 1337 | 7,127 | 84.2ms | 1.1385 | **1.1188** | 386s | 15,965,604 | +| 42 | 7,155 | 83.9ms | 1.1380 | **1.1185** | 388s | 15,882,932 | +| 2025 | 7,152 | 83.9ms | 1.1377 | **1.1183** | 385s | 15,994,920 | +| **Mean** | **7,145** | **84.0ms** | **1.1381** | **1.1185 (std 0.0003)** | **~386s** | — | + +### Baseline Without SLOT (3-seed, same codebase with SLOT_ENABLED=0) + +| Seed | Steps | Step Avg | Post-TTT BPB | TTT Time | +|------|-------|----------|-------------|----------| +| 1337 | 7,164 | 83.8ms | 1.1195 | 352s | +| 42 | 7,159 | 83.8ms | 1.1195 | 353s | +| 2025 | 7,164 | 83.8ms | 1.1189 | 350s | +| **Mean** | **7,162** | **83.8ms** | **1.1193 (std 0.0003)** | **~352s** | + +### SLOT vs Baseline Comparison + +| Metric | Baseline Mean | SLOT Mean | Delta | +|--------|-------------|-----------|-------| +| Post-TTT BPB | 1.1193 | **1.1185** | **-0.0008** | +| TTT eval time | 352s | 386s | +34s | +| SOTA (PR #549) | 1.1194 | — | — | +| **vs SOTA** | -0.0001 | **-0.0009** | — | + +### Also Tested: CTW (Negative Result) + +| Run | CTW Weight | Depth | BPB | TTT Time | Verdict | +|-----|-----------|-------|-----|----------|---------| +| CTW v1 (broken impl) | 0.1 | 4 | 1.1252 | 2,760s | **+0.005 worse, 46 min eval** | + +CTW (Context Tree Weighting) was also integrated and tested. A depth-4 Markov model over 1024 subword tokens provides no useful signal on top of a 1.12 BPB transformer — the neural model already captures everything CTW knows. Documented as a negative result. + +## Novel Contribution: SLOT (Sample-specific LM Optimization at Test-time) + +### What Is SLOT + +SLOT (Hu et al., arXiv:2505.12392v2) optimizes a single additive δ ∈ ℝ^d vector at the last hidden layer to adapt the model to each batch of sequences during evaluation. Unlike full TTT which updates all 27M model parameters via SGD, SLOT optimizes just 512 parameters through one linear layer. + +### Why SLOT Works + +SLOT addresses a different bottleneck than TTT: +- **TTT** adapts the model's internal representations to local data distribution (chunk-level) +- **SLOT** fine-tunes the mapping from final hidden states to logits (batch-level) + +These are complementary — TTT gives SLOT better hidden states to work with, and SLOT gives TTT-adapted representations a final correction before scoring. + +### Implementation: Deep Integration Inside TTT + +SLOT is integrated directly into the TTT scoring loop's Phase 1 — not as a separate eval pass. The architecture splits `forward_logits()` into `forward_hidden()` + `compute_logits()`, enabling SLOT to optimize δ between the two: + +```python +# Inside eval_val_sliding_ttt, Phase 1 scoring: +for each batch of windows: + # 1. Get hidden states from TTT-adapted model + H = model.forward_hidden(x_batch) # [bsz, seq_len, 512] + + # 2. SLOT: optimize delta on this batch + delta = zeros(1, 1, 512) # single vector, broadcasts + optimizer = AdamW([delta], lr=0.001) + for step in range(3): + logits = model.compute_logits(H + delta) + loss = CE(logits[:, :-1], targets[:, 1:]) + loss.backward() # gradients only through lm_head + optimizer.step() + + # 3. Score with adapted logits + final_logits = model.compute_logits(H + delta) + nll = CE(final_logits, targets) # used for BPB +``` + +Key properties: +- **Stacks on TTT**: δ operates on TTT-adapted hidden states, not base model outputs +- **Single combined score**: one BPB number from SLOT-adapted logits +- **Minimal overhead**: +34s to TTT eval (386s vs 352s), well within 10-min eval budget +- **Zero artifact cost**: δ is optimized from scratch per-batch during eval +- **Score-first compliant**: δ optimizes on tokens being scored using autoregressive shift (same tokens, but model doesn't see future tokens) +- **Clean toggle**: `SLOT_ENABLED=0` reproduces baseline exactly + +### Score-First Legality Argument + +SLOT does not violate the score-first constraint because: +1. The model weights that generated H are frozen during δ optimization +2. δ is optimized using the standard autoregressive objective (predict token t+1 from tokens 1..t) +3. δ is a constant offset vector — it does not give the model access to future tokens +4. Each batch's δ is independent — no information leaks between batches + +SLOT is analogous to learned post-processing (like temperature scaling) rather than model training. + +## Base Architecture (PR #549 by @abaybektursun) + +- 11L, 512d, 8H/4KV, LeakyReLU(0.5)² MLP 3× +- Parameter Banking + Parallel Muon (FlashAttention 3) +- BigramHash(1536), XSA4, Partial RoPE(16), LN Scale, VE128 +- EMA(0.997) + Tight SWA(50), GPTQ-lite int6 + LZMA-6 +- Legal Score-First TTT (SGD, lr=0.002, 3 epochs, 32K chunks) + +## Run Commands + +```bash +# Baseline (SLOT disabled — reproduces PR #549) +cd /workspace/parameter-golf && SEED=1337 SLOT_ENABLED=0 CTW_WEIGHT=0 \ +NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=1536 XSA_LAST_N=4 \ +EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=1 SWA_EVERY=50 \ +ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 LATE_QAT_THRESHOLD=0.15 \ +VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 \ +TTT_ENABLED=1 TTT_LR=0.002 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=32768 \ +TTT_FREEZE_BLOCKS=0 TTT_MOMENTUM=0.9 TTT_BATCH_SEQS=32 TTT_GRAD_CLIP=1.0 \ +MUON_WD=0.04 ADAM_WD=0.04 \ +MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \ +MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \ +MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3500 \ +ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py + +# SLOT enabled (novel contribution) +cd /workspace/parameter-golf && SEED=1337 SLOT_ENABLED=1 SLOT_LR=0.001 SLOT_STEPS=3 CTW_WEIGHT=0 \ +NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=1536 XSA_LAST_N=4 \ +EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=1 SWA_EVERY=50 \ +ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 LATE_QAT_THRESHOLD=0.15 \ +VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 \ +TTT_ENABLED=1 TTT_LR=0.002 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=32768 \ +TTT_FREEZE_BLOCKS=0 TTT_MOMENTUM=0.9 TTT_BATCH_SEQS=32 TTT_GRAD_CLIP=1.0 \ +MUON_WD=0.04 ADAM_WD=0.04 \ +MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \ +MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \ +MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3500 \ +ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## SLOT Hyperparameters + +| Parameter | Value | Env Var | Notes | +|-----------|-------|---------|-------| +| Enabled | true | `SLOT_ENABLED=1` | Set to 0 for baseline | +| Learning rate | 0.001 | `SLOT_LR=0.001` | Matches SLOT paper default for 7B model | +| Optimization steps | 3 | `SLOT_STEPS=3` | Paper default; more steps didn't help in their ablation | +| Optimizer | AdamW | — | weight_decay=1e-8, eps=1e-5 (from paper) | +| Delta shape | [1, 1, 512] | — | Broadcasts across batch and sequence dimensions | +| Delta init | zeros | — | Matches paper: `0.0 * torch.randn(...)` | + +## Credits + +- **SLOT integration and analysis**: Anubhav (@AnubhavBharadwaaj) — this submission +- **SLOT algorithm**: Yang Hu et al. (arXiv:2505.12392v2, Westlake University) +- **CTW negative result analysis**: Anubhav — this submission +- LeakyReLU²: PR #493 by @parinzee, PR #518 by @sofiabod +- Parallel Muon + Parameter Banking: PR #399 by @abaybektursun +- TTT recipe: PR #461 by @Christopher-Lee-McClendon +- Base model: PR #414 by @signalrush diff --git a/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/anubhav_slot_seed1337.log b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/anubhav_slot_seed1337.log new file mode 100644 index 0000000000..4d3c0ba5c8 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/anubhav_slot_seed1337.log @@ -0,0 +1,275 @@ +W0329 14:18:16.007000 1688 torch/distributed/run.py:803] +W0329 14:18:16.007000 1688 torch/distributed/run.py:803] ***************************************** +W0329 14:18:16.007000 1688 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0329 14:18:16.007000 1688 torch/distributed/run.py:803] ***************************************** +logs/9858f668-6927-4a7b-8af1-1e5ccad75aa7.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26928220 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/9000 val_loss:6.9304 val_bpb:4.1046 train_time:0ms step_avg:0.01ms +step:1/9000 train_loss:6.9322 train_time:126ms step_avg:125.90ms +step:2/9000 train_loss:8.6544 train_time:153ms step_avg:76.36ms +step:3/9000 train_loss:7.6926 train_time:233ms step_avg:77.55ms +step:4/9000 train_loss:7.2519 train_time:314ms step_avg:78.40ms +step:5/9000 train_loss:7.1705 train_time:396ms step_avg:79.12ms +step:6/9000 train_loss:7.1160 train_time:476ms step_avg:79.38ms +step:7/9000 train_loss:7.0266 train_time:558ms step_avg:79.69ms +step:8/9000 train_loss:6.9593 train_time:638ms step_avg:79.80ms +step:9/9000 train_loss:6.5741 train_time:718ms step_avg:79.83ms +step:10/9000 train_loss:6.2002 train_time:799ms step_avg:79.86ms +step:500/9000 train_loss:2.3969 train_time:41559ms step_avg:83.12ms +step:1000/9000 train_loss:2.2655 train_time:83374ms step_avg:83.37ms +step:1500/9000 train_loss:2.2113 train_time:125215ms step_avg:83.48ms +step:2000/9000 train_loss:2.0546 train_time:167115ms step_avg:83.56ms +step:2500/9000 train_loss:2.1590 train_time:209039ms step_avg:83.62ms +step:3000/9000 train_loss:2.1498 train_time:250967ms step_avg:83.66ms +step:3500/9000 train_loss:2.1701 train_time:292901ms step_avg:83.69ms +step:4000/9000 train_loss:1.9673 train_time:334827ms step_avg:83.71ms +step:4000/9000 val_loss:2.0580 val_bpb:1.2188 train_time:334885ms step_avg:83.72ms +step:4500/9000 train_loss:2.1167 train_time:376774ms step_avg:83.73ms +step:5000/9000 train_loss:2.0973 train_time:418706ms step_avg:83.74ms +step:5500/9000 train_loss:2.0133 train_time:460623ms step_avg:83.75ms +step:6000/9000 train_loss:1.9357 train_time:502544ms step_avg:83.76ms +swa:start step:6500 +step:6500/9000 train_loss:2.0802 train_time:544452ms step_avg:83.76ms +late_qat:enabled step:6635 scale:0.1499 +step:7000/9000 train_loss:1.7877 train_time:586945ms step_avg:83.85ms +step:7154/9000 val_loss:1.9221 val_bpb:1.1384 train_time:600125ms step_avg:83.89ms +stopping_early: wallclock_cap train_time:600125ms step:7154/9000 +peak memory allocated: 21481 MiB reserved: 22030 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9204 val_bpb:1.1373 eval_time:2006ms +Serialized model: 106027446 bytes +Code size: 107448 bytes +Serialized model int6+lzma: 15765784 bytes +Total submission size int6+lzma: 15873232 bytes +final_int6_roundtrip val_loss:1.9348 val_bpb:1.1459 eval_time:15944ms +final_int6_roundtrip_exact val_loss:1.93481952 val_bpb:1.14590910 +final_int6_sliding_window val_loss:1.8951 val_bpb:1.1224 stride:64 eval_time:89854ms +final_int6_sliding_window_exact val_loss:1.89507335 val_bpb:1.12237215 +final_int8_zlib_roundtrip_exact val_loss:1.89507335 val_bpb:1.12237215 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 slot_lr=0.001 slot_steps=3 +ttt_sliding:params unfrozen=26928220 frozen=0 + ttt_chunk [1/1893] bpb=1.161977 time=0.4s + ttt_chunk [11/1893] bpb=1.145797 time=2.5s + ttt_chunk [21/1893] bpb=1.131499 time=4.5s + ttt_chunk [31/1893] bpb=1.129426 time=6.5s + ttt_chunk [41/1893] bpb=1.115800 time=8.6s + ttt_chunk [51/1893] bpb=1.110176 time=10.6s + ttt_chunk [61/1893] bpb=1.116838 time=12.7s + ttt_chunk [71/1893] bpb=1.115371 time=14.7s + ttt_chunk [81/1893] bpb=1.114329 time=16.8s + ttt_chunk [91/1893] bpb=1.115325 time=18.8s + ttt_chunk [101/1893] bpb=1.118955 time=20.9s + ttt_chunk [111/1893] bpb=1.121406 time=22.9s + ttt_chunk [121/1893] bpb=1.114962 time=24.9s + ttt_chunk [131/1893] bpb=1.115296 time=27.0s + ttt_chunk [141/1893] bpb=1.120936 time=29.1s + ttt_chunk [151/1893] bpb=1.122706 time=31.1s + ttt_chunk [161/1893] bpb=1.122306 time=33.1s + ttt_chunk [171/1893] bpb=1.126736 time=35.1s + ttt_chunk [181/1893] bpb=1.128919 time=37.2s + ttt_chunk [191/1893] bpb=1.136162 time=39.2s + ttt_chunk [201/1893] bpb=1.134922 time=41.2s + ttt_chunk [211/1893] bpb=1.132759 time=43.2s + ttt_chunk [221/1893] bpb=1.134285 time=45.3s + ttt_chunk [231/1893] bpb=1.132939 time=47.4s + ttt_chunk [241/1893] bpb=1.133184 time=49.4s + ttt_chunk [251/1893] bpb=1.132645 time=51.4s + ttt_chunk [261/1893] bpb=1.129811 time=53.5s + ttt_chunk [271/1893] bpb=1.128778 time=55.5s + ttt_chunk [281/1893] bpb=1.130118 time=57.5s + ttt_chunk [291/1893] bpb=1.131805 time=59.6s + ttt_chunk [301/1893] bpb=1.132502 time=61.6s + ttt_chunk [311/1893] bpb=1.134554 time=63.7s + ttt_chunk [321/1893] bpb=1.136463 time=65.7s + ttt_chunk [331/1893] bpb=1.136277 time=67.7s + ttt_chunk [341/1893] bpb=1.135255 time=69.8s + ttt_chunk [351/1893] bpb=1.137493 time=71.8s + ttt_chunk [361/1893] bpb=1.137682 time=73.8s + ttt_chunk [371/1893] bpb=1.136975 time=75.8s + ttt_chunk [381/1893] bpb=1.137173 time=77.9s + ttt_chunk [391/1893] bpb=1.136993 time=79.9s + ttt_chunk [401/1893] bpb=1.134916 time=82.0s + ttt_chunk [411/1893] bpb=1.133733 time=84.0s + ttt_chunk [421/1893] bpb=1.132829 time=86.0s + ttt_chunk [431/1893] bpb=1.132639 time=88.1s + ttt_chunk [441/1893] bpb=1.133032 time=90.1s + ttt_chunk [451/1893] bpb=1.133335 time=92.2s + ttt_chunk [461/1893] bpb=1.132250 time=94.2s + ttt_chunk [471/1893] bpb=1.132866 time=96.3s + ttt_chunk [481/1893] bpb=1.132547 time=98.3s + ttt_chunk [491/1893] bpb=1.131467 time=100.3s + ttt_chunk [501/1893] bpb=1.130971 time=102.4s + ttt_chunk [511/1893] bpb=1.130271 time=104.4s + ttt_chunk [521/1893] bpb=1.127866 time=106.4s + ttt_chunk [531/1893] bpb=1.129048 time=108.4s + ttt_chunk [541/1893] bpb=1.129404 time=110.5s + ttt_chunk [551/1893] bpb=1.128370 time=112.5s + ttt_chunk [561/1893] bpb=1.128908 time=114.6s + ttt_chunk [571/1893] bpb=1.127838 time=116.6s + ttt_chunk [581/1893] bpb=1.127029 time=118.7s + ttt_chunk [591/1893] bpb=1.126394 time=120.7s + ttt_chunk [601/1893] bpb=1.126885 time=122.8s + ttt_chunk [611/1893] bpb=1.126811 time=124.8s + ttt_chunk [621/1893] bpb=1.126680 time=126.8s + ttt_chunk [631/1893] bpb=1.127378 time=128.9s + ttt_chunk [641/1893] bpb=1.127120 time=130.9s + ttt_chunk [651/1893] bpb=1.127237 time=133.0s + ttt_chunk [661/1893] bpb=1.126682 time=135.0s + ttt_chunk [671/1893] bpb=1.126977 time=137.0s + ttt_chunk [681/1893] bpb=1.127695 time=139.0s + ttt_chunk [691/1893] bpb=1.128704 time=141.1s + ttt_chunk [701/1893] bpb=1.128151 time=143.1s + ttt_chunk [711/1893] bpb=1.128147 time=145.1s + ttt_chunk [721/1893] bpb=1.127820 time=147.2s + ttt_chunk [731/1893] bpb=1.127855 time=149.3s + ttt_chunk [741/1893] bpb=1.127951 time=151.3s + ttt_chunk [751/1893] bpb=1.127802 time=153.3s + ttt_chunk [761/1893] bpb=1.127769 time=155.4s + ttt_chunk [771/1893] bpb=1.127467 time=157.4s + ttt_chunk [781/1893] bpb=1.128215 time=159.5s + ttt_chunk [791/1893] bpb=1.127779 time=161.5s + ttt_chunk [801/1893] bpb=1.128056 time=163.6s + ttt_chunk [811/1893] bpb=1.127814 time=165.6s + ttt_chunk [821/1893] bpb=1.127567 time=167.6s + ttt_chunk [831/1893] bpb=1.127396 time=169.7s + ttt_chunk [841/1893] bpb=1.126753 time=171.7s + ttt_chunk [851/1893] bpb=1.126480 time=173.7s + ttt_chunk [861/1893] bpb=1.126244 time=175.7s + ttt_chunk [871/1893] bpb=1.126532 time=177.8s + ttt_chunk [881/1893] bpb=1.126719 time=179.8s + ttt_chunk [891/1893] bpb=1.126284 time=181.9s + ttt_chunk [901/1893] bpb=1.126028 time=183.9s + ttt_chunk [911/1893] bpb=1.126148 time=186.0s + ttt_chunk [921/1893] bpb=1.126637 time=188.0s + ttt_chunk [931/1893] bpb=1.126619 time=190.1s + ttt_chunk [941/1893] bpb=1.126304 time=192.1s + ttt_chunk [951/1893] bpb=1.126684 time=194.1s + ttt_chunk [961/1893] bpb=1.126762 time=196.2s + ttt_chunk [971/1893] bpb=1.127621 time=198.2s + ttt_chunk [981/1893] bpb=1.127695 time=200.3s + ttt_chunk [991/1893] bpb=1.127704 time=202.3s + ttt_chunk [1001/1893] bpb=1.127670 time=204.3s + ttt_chunk [1011/1893] bpb=1.127456 time=206.3s + ttt_chunk [1021/1893] bpb=1.127792 time=208.3s + ttt_chunk [1031/1893] bpb=1.128238 time=210.4s + ttt_chunk [1041/1893] bpb=1.127917 time=212.4s + ttt_chunk [1051/1893] bpb=1.127682 time=214.4s + ttt_chunk [1061/1893] bpb=1.127693 time=216.5s + ttt_chunk [1071/1893] bpb=1.128308 time=218.6s + ttt_chunk [1081/1893] bpb=1.128586 time=220.6s + ttt_chunk [1091/1893] bpb=1.129315 time=222.6s + ttt_chunk [1101/1893] bpb=1.129352 time=224.7s + ttt_chunk [1111/1893] bpb=1.129197 time=226.7s + ttt_chunk [1121/1893] bpb=1.128978 time=228.8s + ttt_chunk [1131/1893] bpb=1.128858 time=230.8s + ttt_chunk [1141/1893] bpb=1.128549 time=232.8s + ttt_chunk [1151/1893] bpb=1.128545 time=234.9s + ttt_chunk [1161/1893] bpb=1.128141 time=236.9s + ttt_chunk [1171/1893] bpb=1.128476 time=238.9s + ttt_chunk [1181/1893] bpb=1.127724 time=240.9s + ttt_chunk [1191/1893] bpb=1.127606 time=243.0s + ttt_chunk [1201/1893] bpb=1.128008 time=245.0s + ttt_chunk [1211/1893] bpb=1.127546 time=247.0s + ttt_chunk [1221/1893] bpb=1.127243 time=249.1s + ttt_chunk [1231/1893] bpb=1.126968 time=251.1s + ttt_chunk [1241/1893] bpb=1.126624 time=253.2s + ttt_chunk [1251/1893] bpb=1.126037 time=255.2s + ttt_chunk [1261/1893] bpb=1.126008 time=257.3s + ttt_chunk [1271/1893] bpb=1.125636 time=259.3s + ttt_chunk [1281/1893] bpb=1.125435 time=261.3s + ttt_chunk [1291/1893] bpb=1.125207 time=263.4s + ttt_chunk [1301/1893] bpb=1.124630 time=265.4s + ttt_chunk [1311/1893] bpb=1.124253 time=267.5s + ttt_chunk [1321/1893] bpb=1.123918 time=269.5s + ttt_chunk [1331/1893] bpb=1.123850 time=271.5s + ttt_chunk [1341/1893] bpb=1.123702 time=273.5s + ttt_chunk [1351/1893] bpb=1.123641 time=275.5s + ttt_chunk [1361/1893] bpb=1.123694 time=277.6s + ttt_chunk [1371/1893] bpb=1.123538 time=279.6s + ttt_chunk [1381/1893] bpb=1.123520 time=281.6s + ttt_chunk [1391/1893] bpb=1.123111 time=283.7s + ttt_chunk [1401/1893] bpb=1.123081 time=285.8s + ttt_chunk [1411/1893] bpb=1.123197 time=287.8s + ttt_chunk [1421/1893] bpb=1.123439 time=289.8s + ttt_chunk [1431/1893] bpb=1.123148 time=291.9s + ttt_chunk [1441/1893] bpb=1.123658 time=293.9s + ttt_chunk [1451/1893] bpb=1.124008 time=296.0s + ttt_chunk [1461/1893] bpb=1.123547 time=298.0s + ttt_chunk [1471/1893] bpb=1.124574 time=300.0s + ttt_chunk [1481/1893] bpb=1.124125 time=302.1s + ttt_chunk [1491/1893] bpb=1.123946 time=304.1s + ttt_chunk [1501/1893] bpb=1.123850 time=306.1s + ttt_chunk [1511/1893] bpb=1.123864 time=308.1s + ttt_chunk [1521/1893] bpb=1.123898 time=310.1s + ttt_chunk [1531/1893] bpb=1.123398 time=312.2s + ttt_chunk [1541/1893] bpb=1.123253 time=314.2s + ttt_chunk [1551/1893] bpb=1.123572 time=316.3s + ttt_chunk [1561/1893] bpb=1.123556 time=318.3s + ttt_chunk [1571/1893] bpb=1.123387 time=320.4s + ttt_chunk [1581/1893] bpb=1.123505 time=322.4s + ttt_chunk [1591/1893] bpb=1.123354 time=324.4s + ttt_chunk [1601/1893] bpb=1.123520 time=326.5s + ttt_chunk [1611/1893] bpb=1.123468 time=328.6s + ttt_chunk [1621/1893] bpb=1.123051 time=330.6s + ttt_chunk [1631/1893] bpb=1.123343 time=332.6s + ttt_chunk [1641/1893] bpb=1.123355 time=334.6s + ttt_chunk [1651/1893] bpb=1.123309 time=336.7s + ttt_chunk [1661/1893] bpb=1.123185 time=338.7s + ttt_chunk [1671/1893] bpb=1.123666 time=340.7s + ttt_chunk [1681/1893] bpb=1.123804 time=342.7s + ttt_chunk [1691/1893] bpb=1.123635 time=344.8s + ttt_chunk [1701/1893] bpb=1.123798 time=346.8s + ttt_chunk [1711/1893] bpb=1.123798 time=348.8s + ttt_chunk [1721/1893] bpb=1.123813 time=350.9s + ttt_chunk [1731/1893] bpb=1.123703 time=352.9s + ttt_chunk [1741/1893] bpb=1.123496 time=355.0s + ttt_chunk [1751/1893] bpb=1.123312 time=357.0s + ttt_chunk [1761/1893] bpb=1.123453 time=359.1s + ttt_chunk [1771/1893] bpb=1.123358 time=361.1s + ttt_chunk [1781/1893] bpb=1.123389 time=363.2s + ttt_chunk [1791/1893] bpb=1.122975 time=365.2s + ttt_chunk [1801/1893] bpb=1.122858 time=367.2s + ttt_chunk [1811/1893] bpb=1.122749 time=369.3s + ttt_chunk [1821/1893] bpb=1.122812 time=371.3s + ttt_chunk [1831/1893] bpb=1.122221 time=373.3s + ttt_chunk [1841/1893] bpb=1.122170 time=375.3s + ttt_chunk [1851/1893] bpb=1.121956 time=377.3s + ttt_chunk [1861/1893] bpb=1.121587 time=379.4s + ttt_chunk [1871/1893] bpb=1.121577 time=381.4s + ttt_chunk [1881/1893] bpb=1.121122 time=383.4s + ttt_chunk [1891/1893] bpb=1.120886 time=385.5s + ttt_chunk [1893/1893] bpb=1.120926 time=385.8s +ttt_sliding:done val_loss=1.889032 val_bpb=1.118794 elapsed=385.8s +legal_ttt val_loss:1.8890 val_bpb:1.1188 eval_time:386301ms +legal_ttt_exact val_loss:1.88903226 val_bpb:1.11879427 diff --git a/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/anubhav_slot_seed2025.log b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/anubhav_slot_seed2025.log new file mode 100644 index 0000000000..2e99fc614c --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/anubhav_slot_seed2025.log @@ -0,0 +1,275 @@ +W0329 15:37:06.102000 65982 torch/distributed/run.py:803] +W0329 15:37:06.102000 65982 torch/distributed/run.py:803] ***************************************** +W0329 15:37:06.102000 65982 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0329 15:37:06.102000 65982 torch/distributed/run.py:803] ***************************************** +logs/b091a36b-6392-412f-8c6f-341e5693fafc.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26928220 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2025 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/9000 val_loss:6.9302 val_bpb:4.1045 train_time:0ms step_avg:0.01ms +step:1/9000 train_loss:6.9311 train_time:123ms step_avg:122.64ms +step:2/9000 train_loss:8.6818 train_time:150ms step_avg:74.84ms +step:3/9000 train_loss:7.7058 train_time:230ms step_avg:76.56ms +step:4/9000 train_loss:7.2716 train_time:311ms step_avg:77.79ms +step:5/9000 train_loss:7.1780 train_time:392ms step_avg:78.48ms +step:6/9000 train_loss:7.0952 train_time:473ms step_avg:78.89ms +step:7/9000 train_loss:7.0229 train_time:554ms step_avg:79.20ms +step:8/9000 train_loss:6.9411 train_time:636ms step_avg:79.52ms +step:9/9000 train_loss:6.6071 train_time:717ms step_avg:79.68ms +step:10/9000 train_loss:6.2031 train_time:798ms step_avg:79.76ms +step:500/9000 train_loss:2.3979 train_time:41512ms step_avg:83.02ms +step:1000/9000 train_loss:2.2673 train_time:83317ms step_avg:83.32ms +step:1500/9000 train_loss:2.2100 train_time:125161ms step_avg:83.44ms +step:2000/9000 train_loss:2.0560 train_time:167142ms step_avg:83.57ms +step:2500/9000 train_loss:2.1554 train_time:209077ms step_avg:83.63ms +step:3000/9000 train_loss:2.1531 train_time:251005ms step_avg:83.67ms +step:3500/9000 train_loss:2.1685 train_time:292941ms step_avg:83.70ms +step:4000/9000 train_loss:1.9676 train_time:334885ms step_avg:83.72ms +step:4000/9000 val_loss:2.0570 val_bpb:1.2183 train_time:334942ms step_avg:83.74ms +step:4500/9000 train_loss:2.1192 train_time:376819ms step_avg:83.74ms +step:5000/9000 train_loss:2.0962 train_time:418760ms step_avg:83.75ms +step:5500/9000 train_loss:2.0166 train_time:460680ms step_avg:83.76ms +step:6000/9000 train_loss:1.9371 train_time:502613ms step_avg:83.77ms +swa:start step:6500 +step:6500/9000 train_loss:2.0793 train_time:544540ms step_avg:83.78ms +late_qat:enabled step:6634 scale:0.1498 +step:7000/9000 train_loss:1.7873 train_time:587108ms step_avg:83.87ms +step:7152/9000 val_loss:1.9209 val_bpb:1.1377 train_time:600128ms step_avg:83.91ms +stopping_early: wallclock_cap train_time:600128ms step:7152/9000 +peak memory allocated: 21471 MiB reserved: 22002 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9192 val_bpb:1.1367 eval_time:2004ms +Serialized model: 106027446 bytes +Code size: 107448 bytes +Serialized model int6+lzma: 15887472 bytes +Total submission size int6+lzma: 15994920 bytes +final_int6_roundtrip val_loss:1.9334 val_bpb:1.1450 eval_time:5558ms +final_int6_roundtrip_exact val_loss:1.93335133 val_bpb:1.14503956 +final_int6_sliding_window val_loss:1.8938 val_bpb:1.1216 stride:64 eval_time:73715ms +final_int6_sliding_window_exact val_loss:1.89378189 val_bpb:1.12160727 +final_int8_zlib_roundtrip_exact val_loss:1.89378189 val_bpb:1.12160727 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 slot_lr=0.001 slot_steps=3 +ttt_sliding:params unfrozen=26928220 frozen=0 + ttt_chunk [1/1893] bpb=1.158602 time=0.4s + ttt_chunk [11/1893] bpb=1.146663 time=2.4s + ttt_chunk [21/1893] bpb=1.131803 time=4.5s + ttt_chunk [31/1893] bpb=1.130162 time=6.5s + ttt_chunk [41/1893] bpb=1.116261 time=8.6s + ttt_chunk [51/1893] bpb=1.110222 time=10.6s + ttt_chunk [61/1893] bpb=1.117171 time=12.7s + ttt_chunk [71/1893] bpb=1.115582 time=14.7s + ttt_chunk [81/1893] bpb=1.114522 time=16.8s + ttt_chunk [91/1893] bpb=1.115273 time=18.8s + ttt_chunk [101/1893] bpb=1.118964 time=20.8s + ttt_chunk [111/1893] bpb=1.121363 time=22.9s + ttt_chunk [121/1893] bpb=1.114757 time=24.9s + ttt_chunk [131/1893] bpb=1.114939 time=27.0s + ttt_chunk [141/1893] bpb=1.120719 time=29.1s + ttt_chunk [151/1893] bpb=1.122376 time=31.1s + ttt_chunk [161/1893] bpb=1.121869 time=33.1s + ttt_chunk [171/1893] bpb=1.126338 time=35.2s + ttt_chunk [181/1893] bpb=1.128624 time=37.2s + ttt_chunk [191/1893] bpb=1.135971 time=39.2s + ttt_chunk [201/1893] bpb=1.134775 time=41.2s + ttt_chunk [211/1893] bpb=1.132542 time=43.3s + ttt_chunk [221/1893] bpb=1.134113 time=45.3s + ttt_chunk [231/1893] bpb=1.132794 time=47.3s + ttt_chunk [241/1893] bpb=1.133098 time=49.4s + ttt_chunk [251/1893] bpb=1.132591 time=51.4s + ttt_chunk [261/1893] bpb=1.129583 time=53.4s + ttt_chunk [271/1893] bpb=1.128425 time=55.4s + ttt_chunk [281/1893] bpb=1.129804 time=57.5s + ttt_chunk [291/1893] bpb=1.131610 time=59.5s + ttt_chunk [301/1893] bpb=1.132314 time=61.5s + ttt_chunk [311/1893] bpb=1.134374 time=63.6s + ttt_chunk [321/1893] bpb=1.136239 time=65.6s + ttt_chunk [331/1893] bpb=1.136040 time=67.6s + ttt_chunk [341/1893] bpb=1.135099 time=69.7s + ttt_chunk [351/1893] bpb=1.137359 time=71.7s + ttt_chunk [361/1893] bpb=1.137532 time=73.7s + ttt_chunk [371/1893] bpb=1.136812 time=75.8s + ttt_chunk [381/1893] bpb=1.136979 time=77.8s + ttt_chunk [391/1893] bpb=1.136733 time=79.8s + ttt_chunk [401/1893] bpb=1.134662 time=81.9s + ttt_chunk [411/1893] bpb=1.133502 time=83.9s + ttt_chunk [421/1893] bpb=1.132542 time=85.9s + ttt_chunk [431/1893] bpb=1.132427 time=87.9s + ttt_chunk [441/1893] bpb=1.132754 time=90.0s + ttt_chunk [451/1893] bpb=1.133030 time=92.0s + ttt_chunk [461/1893] bpb=1.131931 time=94.1s + ttt_chunk [471/1893] bpb=1.132525 time=96.1s + ttt_chunk [481/1893] bpb=1.132164 time=98.1s + ttt_chunk [491/1893] bpb=1.131067 time=100.2s + ttt_chunk [501/1893] bpb=1.130579 time=102.2s + ttt_chunk [511/1893] bpb=1.129902 time=104.2s + ttt_chunk [521/1893] bpb=1.127482 time=106.3s + ttt_chunk [531/1893] bpb=1.128666 time=108.3s + ttt_chunk [541/1893] bpb=1.128989 time=110.3s + ttt_chunk [551/1893] bpb=1.127960 time=112.4s + ttt_chunk [561/1893] bpb=1.128487 time=114.4s + ttt_chunk [571/1893] bpb=1.127462 time=116.4s + ttt_chunk [581/1893] bpb=1.126654 time=118.4s + ttt_chunk [591/1893] bpb=1.126015 time=120.5s + ttt_chunk [601/1893] bpb=1.126483 time=122.5s + ttt_chunk [611/1893] bpb=1.126390 time=124.5s + ttt_chunk [621/1893] bpb=1.126242 time=126.6s + ttt_chunk [631/1893] bpb=1.126938 time=128.6s + ttt_chunk [641/1893] bpb=1.126679 time=130.6s + ttt_chunk [651/1893] bpb=1.126799 time=132.6s + ttt_chunk [661/1893] bpb=1.126266 time=134.7s + ttt_chunk [671/1893] bpb=1.126641 time=136.7s + ttt_chunk [681/1893] bpb=1.127313 time=138.7s + ttt_chunk [691/1893] bpb=1.128301 time=140.8s + ttt_chunk [701/1893] bpb=1.127738 time=142.8s + ttt_chunk [711/1893] bpb=1.127731 time=144.8s + ttt_chunk [721/1893] bpb=1.127429 time=146.9s + ttt_chunk [731/1893] bpb=1.127448 time=148.9s + ttt_chunk [741/1893] bpb=1.127534 time=150.9s + ttt_chunk [751/1893] bpb=1.127371 time=152.9s + ttt_chunk [761/1893] bpb=1.127309 time=155.0s + ttt_chunk [771/1893] bpb=1.127038 time=157.0s + ttt_chunk [781/1893] bpb=1.127768 time=159.0s + ttt_chunk [791/1893] bpb=1.127357 time=161.1s + ttt_chunk [801/1893] bpb=1.127651 time=163.1s + ttt_chunk [811/1893] bpb=1.127392 time=165.1s + ttt_chunk [821/1893] bpb=1.127159 time=167.2s + ttt_chunk [831/1893] bpb=1.127006 time=169.2s + ttt_chunk [841/1893] bpb=1.126353 time=171.2s + ttt_chunk [851/1893] bpb=1.126086 time=173.2s + ttt_chunk [861/1893] bpb=1.125833 time=175.3s + ttt_chunk [871/1893] bpb=1.126128 time=177.3s + ttt_chunk [881/1893] bpb=1.126311 time=179.3s + ttt_chunk [891/1893] bpb=1.125880 time=181.4s + ttt_chunk [901/1893] bpb=1.125620 time=183.4s + ttt_chunk [911/1893] bpb=1.125751 time=185.4s + ttt_chunk [921/1893] bpb=1.126249 time=187.4s + ttt_chunk [931/1893] bpb=1.126221 time=189.5s + ttt_chunk [941/1893] bpb=1.125899 time=191.5s + ttt_chunk [951/1893] bpb=1.126300 time=193.5s + ttt_chunk [961/1893] bpb=1.126368 time=195.6s + ttt_chunk [971/1893] bpb=1.127175 time=197.6s + ttt_chunk [981/1893] bpb=1.127250 time=199.6s + ttt_chunk [991/1893] bpb=1.127239 time=201.7s + ttt_chunk [1001/1893] bpb=1.127198 time=203.7s + ttt_chunk [1011/1893] bpb=1.126997 time=205.7s + ttt_chunk [1021/1893] bpb=1.127345 time=207.8s + ttt_chunk [1031/1893] bpb=1.127801 time=209.8s + ttt_chunk [1041/1893] bpb=1.127458 time=211.8s + ttt_chunk [1051/1893] bpb=1.127208 time=213.8s + ttt_chunk [1061/1893] bpb=1.127238 time=215.9s + ttt_chunk [1071/1893] bpb=1.127857 time=217.9s + ttt_chunk [1081/1893] bpb=1.128126 time=219.9s + ttt_chunk [1091/1893] bpb=1.128856 time=221.9s + ttt_chunk [1101/1893] bpb=1.128867 time=224.0s + ttt_chunk [1111/1893] bpb=1.128709 time=226.0s + ttt_chunk [1121/1893] bpb=1.128475 time=228.0s + ttt_chunk [1131/1893] bpb=1.128344 time=230.1s + ttt_chunk [1141/1893] bpb=1.128033 time=232.1s + ttt_chunk [1151/1893] bpb=1.128039 time=234.1s + ttt_chunk [1161/1893] bpb=1.127644 time=236.2s + ttt_chunk [1171/1893] bpb=1.127967 time=238.2s + ttt_chunk [1181/1893] bpb=1.127188 time=240.2s + ttt_chunk [1191/1893] bpb=1.127053 time=242.2s + ttt_chunk [1201/1893] bpb=1.127467 time=244.3s + ttt_chunk [1211/1893] bpb=1.126990 time=246.3s + ttt_chunk [1221/1893] bpb=1.126700 time=248.3s + ttt_chunk [1231/1893] bpb=1.126422 time=250.4s + ttt_chunk [1241/1893] bpb=1.126060 time=252.4s + ttt_chunk [1251/1893] bpb=1.125474 time=254.4s + ttt_chunk [1261/1893] bpb=1.125442 time=256.4s + ttt_chunk [1271/1893] bpb=1.125078 time=258.5s + ttt_chunk [1281/1893] bpb=1.124883 time=260.5s + ttt_chunk [1291/1893] bpb=1.124642 time=262.5s + ttt_chunk [1301/1893] bpb=1.124061 time=264.6s + ttt_chunk [1311/1893] bpb=1.123677 time=266.6s + ttt_chunk [1321/1893] bpb=1.123359 time=268.6s + ttt_chunk [1331/1893] bpb=1.123282 time=270.7s + ttt_chunk [1341/1893] bpb=1.123170 time=272.7s + ttt_chunk [1351/1893] bpb=1.123109 time=274.7s + ttt_chunk [1361/1893] bpb=1.123165 time=276.7s + ttt_chunk [1371/1893] bpb=1.123032 time=278.8s + ttt_chunk [1381/1893] bpb=1.123016 time=280.8s + ttt_chunk [1391/1893] bpb=1.122617 time=282.8s + ttt_chunk [1401/1893] bpb=1.122598 time=284.9s + ttt_chunk [1411/1893] bpb=1.122718 time=286.9s + ttt_chunk [1421/1893] bpb=1.122974 time=288.9s + ttt_chunk [1431/1893] bpb=1.122671 time=291.0s + ttt_chunk [1441/1893] bpb=1.123185 time=293.0s + ttt_chunk [1451/1893] bpb=1.123531 time=295.0s + ttt_chunk [1461/1893] bpb=1.123070 time=297.0s + ttt_chunk [1471/1893] bpb=1.124130 time=299.1s + ttt_chunk [1481/1893] bpb=1.123658 time=301.1s + ttt_chunk [1491/1893] bpb=1.123484 time=303.1s + ttt_chunk [1501/1893] bpb=1.123390 time=305.2s + ttt_chunk [1511/1893] bpb=1.123415 time=307.2s + ttt_chunk [1521/1893] bpb=1.123441 time=309.2s + ttt_chunk [1531/1893] bpb=1.122932 time=311.3s + ttt_chunk [1541/1893] bpb=1.122790 time=313.3s + ttt_chunk [1551/1893] bpb=1.123104 time=315.3s + ttt_chunk [1561/1893] bpb=1.123111 time=317.3s + ttt_chunk [1571/1893] bpb=1.122949 time=319.4s + ttt_chunk [1581/1893] bpb=1.123063 time=321.4s + ttt_chunk [1591/1893] bpb=1.122918 time=323.5s + ttt_chunk [1601/1893] bpb=1.123090 time=325.5s + ttt_chunk [1611/1893] bpb=1.123034 time=327.5s + ttt_chunk [1621/1893] bpb=1.122638 time=329.6s + ttt_chunk [1631/1893] bpb=1.122930 time=331.6s + ttt_chunk [1641/1893] bpb=1.122940 time=333.6s + ttt_chunk [1651/1893] bpb=1.122891 time=335.7s + ttt_chunk [1661/1893] bpb=1.122768 time=337.7s + ttt_chunk [1671/1893] bpb=1.123245 time=339.7s + ttt_chunk [1681/1893] bpb=1.123400 time=341.8s + ttt_chunk [1691/1893] bpb=1.123225 time=343.8s + ttt_chunk [1701/1893] bpb=1.123382 time=345.8s + ttt_chunk [1711/1893] bpb=1.123382 time=347.9s + ttt_chunk [1721/1893] bpb=1.123371 time=349.9s + ttt_chunk [1731/1893] bpb=1.123251 time=351.9s + ttt_chunk [1741/1893] bpb=1.123061 time=354.0s + ttt_chunk [1751/1893] bpb=1.122891 time=356.0s + ttt_chunk [1761/1893] bpb=1.123035 time=358.0s + ttt_chunk [1771/1893] bpb=1.122934 time=360.1s + ttt_chunk [1781/1893] bpb=1.122967 time=362.1s + ttt_chunk [1791/1893] bpb=1.122569 time=364.1s + ttt_chunk [1801/1893] bpb=1.122465 time=366.2s + ttt_chunk [1811/1893] bpb=1.122370 time=368.2s + ttt_chunk [1821/1893] bpb=1.122431 time=370.2s + ttt_chunk [1831/1893] bpb=1.121824 time=372.3s + ttt_chunk [1841/1893] bpb=1.121756 time=374.3s + ttt_chunk [1851/1893] bpb=1.121540 time=376.3s + ttt_chunk [1861/1893] bpb=1.121175 time=378.4s + ttt_chunk [1871/1893] bpb=1.121159 time=380.4s + ttt_chunk [1881/1893] bpb=1.120711 time=382.4s + ttt_chunk [1891/1893] bpb=1.120482 time=384.4s + ttt_chunk [1893/1893] bpb=1.120524 time=384.7s +ttt_sliding:done val_loss=1.888199 val_bpb=1.118301 elapsed=384.8s +legal_ttt val_loss:1.8882 val_bpb:1.1183 eval_time:385197ms +legal_ttt_exact val_loss:1.88819913 val_bpb:1.11830084 diff --git a/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/anubhav_slot_seed42.log b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/anubhav_slot_seed42.log new file mode 100644 index 0000000000..3db9ef63b6 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/anubhav_slot_seed42.log @@ -0,0 +1,275 @@ +W0329 15:12:49.352000 1372 torch/distributed/run.py:803] +W0329 15:12:49.352000 1372 torch/distributed/run.py:803] ***************************************** +W0329 15:12:49.352000 1372 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0329 15:12:49.352000 1372 torch/distributed/run.py:803] ***************************************** +logs/00f6cdee-9460-4e37-8366-d3f44dacff61.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26928220 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/9000 val_loss:6.9293 val_bpb:4.1039 train_time:0ms step_avg:0.01ms +step:1/9000 train_loss:6.9308 train_time:126ms step_avg:125.70ms +step:2/9000 train_loss:8.6422 train_time:154ms step_avg:77.22ms +step:3/9000 train_loss:7.6901 train_time:234ms step_avg:78.07ms +step:4/9000 train_loss:7.2781 train_time:315ms step_avg:78.66ms +step:5/9000 train_loss:7.2221 train_time:395ms step_avg:79.08ms +step:6/9000 train_loss:7.1409 train_time:477ms step_avg:79.43ms +step:7/9000 train_loss:7.0919 train_time:557ms step_avg:79.54ms +step:8/9000 train_loss:7.0288 train_time:638ms step_avg:79.79ms +step:9/9000 train_loss:6.6336 train_time:719ms step_avg:79.94ms +step:10/9000 train_loss:6.2562 train_time:801ms step_avg:80.08ms +step:500/9000 train_loss:2.3964 train_time:41467ms step_avg:82.93ms +step:1000/9000 train_loss:2.2634 train_time:83241ms step_avg:83.24ms +step:1500/9000 train_loss:2.2133 train_time:125071ms step_avg:83.38ms +step:2000/9000 train_loss:2.0541 train_time:166956ms step_avg:83.48ms +step:2500/9000 train_loss:2.1589 train_time:208878ms step_avg:83.55ms +step:3000/9000 train_loss:2.1529 train_time:250807ms step_avg:83.60ms +step:3500/9000 train_loss:2.1685 train_time:292739ms step_avg:83.64ms +step:4000/9000 train_loss:1.9651 train_time:334700ms step_avg:83.67ms +step:4000/9000 val_loss:2.0572 val_bpb:1.2184 train_time:334756ms step_avg:83.69ms +step:4500/9000 train_loss:2.1169 train_time:376617ms step_avg:83.69ms +step:5000/9000 train_loss:2.0994 train_time:418544ms step_avg:83.71ms +step:5500/9000 train_loss:2.0141 train_time:460464ms step_avg:83.72ms +step:6000/9000 train_loss:1.9384 train_time:502390ms step_avg:83.73ms +swa:start step:6500 +step:6500/9000 train_loss:2.0803 train_time:544307ms step_avg:83.74ms +late_qat:enabled step:6637 scale:0.1499 +step:7000/9000 train_loss:1.7889 train_time:586813ms step_avg:83.83ms +step:7155/9000 val_loss:1.9215 val_bpb:1.1380 train_time:600085ms step_avg:83.87ms +stopping_early: wallclock_cap train_time:600085ms step:7155/9000 +peak memory allocated: 21481 MiB reserved: 22030 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9197 val_bpb:1.1370 eval_time:2005ms +Serialized model: 106027446 bytes +Code size: 107448 bytes +Serialized model int6+lzma: 15775484 bytes +Total submission size int6+lzma: 15882932 bytes +final_int6_roundtrip val_loss:1.9340 val_bpb:1.1454 eval_time:16979ms +final_int6_roundtrip_exact val_loss:1.93402733 val_bpb:1.14543992 +final_int6_sliding_window val_loss:1.8944 val_bpb:1.1219 stride:64 eval_time:90650ms +final_int6_sliding_window_exact val_loss:1.89435099 val_bpb:1.12194433 +final_int8_zlib_roundtrip_exact val_loss:1.89435099 val_bpb:1.12194433 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=3 freeze_blocks=0 slot_lr=0.001 slot_steps=3 +ttt_sliding:params unfrozen=26928220 frozen=0 + ttt_chunk [1/1893] bpb=1.158793 time=0.4s + ttt_chunk [11/1893] bpb=1.147837 time=2.5s + ttt_chunk [21/1893] bpb=1.133006 time=4.7s + ttt_chunk [31/1893] bpb=1.130278 time=6.7s + ttt_chunk [41/1893] bpb=1.115889 time=8.7s + ttt_chunk [51/1893] bpb=1.110123 time=10.8s + ttt_chunk [61/1893] bpb=1.116609 time=12.9s + ttt_chunk [71/1893] bpb=1.114978 time=15.0s + ttt_chunk [81/1893] bpb=1.114243 time=17.0s + ttt_chunk [91/1893] bpb=1.115144 time=19.1s + ttt_chunk [101/1893] bpb=1.118884 time=21.1s + ttt_chunk [111/1893] bpb=1.121409 time=23.2s + ttt_chunk [121/1893] bpb=1.114844 time=25.2s + ttt_chunk [131/1893] bpb=1.114960 time=27.3s + ttt_chunk [141/1893] bpb=1.120668 time=29.3s + ttt_chunk [151/1893] bpb=1.122528 time=31.3s + ttt_chunk [161/1893] bpb=1.122083 time=33.4s + ttt_chunk [171/1893] bpb=1.126428 time=35.4s + ttt_chunk [181/1893] bpb=1.128655 time=37.5s + ttt_chunk [191/1893] bpb=1.136033 time=39.6s + ttt_chunk [201/1893] bpb=1.134954 time=41.6s + ttt_chunk [211/1893] bpb=1.132778 time=43.7s + ttt_chunk [221/1893] bpb=1.134312 time=45.8s + ttt_chunk [231/1893] bpb=1.132917 time=47.9s + ttt_chunk [241/1893] bpb=1.133269 time=49.9s + ttt_chunk [251/1893] bpb=1.132832 time=52.0s + ttt_chunk [261/1893] bpb=1.130027 time=54.0s + ttt_chunk [271/1893] bpb=1.128902 time=56.0s + ttt_chunk [281/1893] bpb=1.130179 time=58.1s + ttt_chunk [291/1893] bpb=1.131893 time=60.1s + ttt_chunk [301/1893] bpb=1.132494 time=62.2s + ttt_chunk [311/1893] bpb=1.134535 time=64.2s + ttt_chunk [321/1893] bpb=1.136476 time=66.2s + ttt_chunk [331/1893] bpb=1.136298 time=68.3s + ttt_chunk [341/1893] bpb=1.135298 time=70.3s + ttt_chunk [351/1893] bpb=1.137486 time=72.5s + ttt_chunk [361/1893] bpb=1.137602 time=74.5s + ttt_chunk [371/1893] bpb=1.136887 time=76.5s + ttt_chunk [381/1893] bpb=1.137086 time=78.7s + ttt_chunk [391/1893] bpb=1.136840 time=80.8s + ttt_chunk [401/1893] bpb=1.134777 time=82.8s + ttt_chunk [411/1893] bpb=1.133609 time=84.8s + ttt_chunk [421/1893] bpb=1.132634 time=86.9s + ttt_chunk [431/1893] bpb=1.132479 time=88.9s + ttt_chunk [441/1893] bpb=1.132890 time=90.9s + ttt_chunk [451/1893] bpb=1.133206 time=93.0s + ttt_chunk [461/1893] bpb=1.132102 time=95.0s + ttt_chunk [471/1893] bpb=1.132793 time=97.1s + ttt_chunk [481/1893] bpb=1.132421 time=99.1s + ttt_chunk [491/1893] bpb=1.131351 time=101.1s + ttt_chunk [501/1893] bpb=1.130839 time=103.2s + ttt_chunk [511/1893] bpb=1.130154 time=105.3s + ttt_chunk [521/1893] bpb=1.127719 time=107.3s + ttt_chunk [531/1893] bpb=1.128902 time=109.4s + ttt_chunk [541/1893] bpb=1.129274 time=111.5s + ttt_chunk [551/1893] bpb=1.128227 time=113.6s + ttt_chunk [561/1893] bpb=1.128756 time=115.6s + ttt_chunk [571/1893] bpb=1.127727 time=117.7s + ttt_chunk [581/1893] bpb=1.126929 time=119.7s + ttt_chunk [591/1893] bpb=1.126263 time=121.8s + ttt_chunk [601/1893] bpb=1.126748 time=123.8s + ttt_chunk [611/1893] bpb=1.126630 time=125.8s + ttt_chunk [621/1893] bpb=1.126485 time=127.9s + ttt_chunk [631/1893] bpb=1.127166 time=129.9s + ttt_chunk [641/1893] bpb=1.126938 time=132.0s + ttt_chunk [651/1893] bpb=1.127038 time=134.0s + ttt_chunk [661/1893] bpb=1.126492 time=136.0s + ttt_chunk [671/1893] bpb=1.126816 time=138.1s + ttt_chunk [681/1893] bpb=1.127510 time=140.2s + ttt_chunk [691/1893] bpb=1.128480 time=142.2s + ttt_chunk [701/1893] bpb=1.127906 time=144.3s + ttt_chunk [711/1893] bpb=1.127876 time=146.4s + ttt_chunk [721/1893] bpb=1.127526 time=148.5s + ttt_chunk [731/1893] bpb=1.127561 time=150.5s + ttt_chunk [741/1893] bpb=1.127649 time=152.6s + ttt_chunk [751/1893] bpb=1.127493 time=154.6s + ttt_chunk [761/1893] bpb=1.127430 time=156.6s + ttt_chunk [771/1893] bpb=1.127107 time=158.7s + ttt_chunk [781/1893] bpb=1.127862 time=160.7s + ttt_chunk [791/1893] bpb=1.127474 time=162.8s + ttt_chunk [801/1893] bpb=1.127785 time=164.8s + ttt_chunk [811/1893] bpb=1.127564 time=166.8s + ttt_chunk [821/1893] bpb=1.127326 time=168.9s + ttt_chunk [831/1893] bpb=1.127162 time=170.9s + ttt_chunk [841/1893] bpb=1.126525 time=173.1s + ttt_chunk [851/1893] bpb=1.126278 time=175.1s + ttt_chunk [861/1893] bpb=1.126014 time=177.1s + ttt_chunk [871/1893] bpb=1.126299 time=179.2s + ttt_chunk [881/1893] bpb=1.126464 time=181.3s + ttt_chunk [891/1893] bpb=1.126023 time=183.4s + ttt_chunk [901/1893] bpb=1.125755 time=185.4s + ttt_chunk [911/1893] bpb=1.125897 time=187.5s + ttt_chunk [921/1893] bpb=1.126414 time=189.5s + ttt_chunk [931/1893] bpb=1.126430 time=191.5s + ttt_chunk [941/1893] bpb=1.126105 time=193.6s + ttt_chunk [951/1893] bpb=1.126507 time=195.6s + ttt_chunk [961/1893] bpb=1.126594 time=197.7s + ttt_chunk [971/1893] bpb=1.127437 time=199.8s + ttt_chunk [981/1893] bpb=1.127505 time=201.8s + ttt_chunk [991/1893] bpb=1.127520 time=203.8s + ttt_chunk [1001/1893] bpb=1.127482 time=205.9s + ttt_chunk [1011/1893] bpb=1.127258 time=208.0s + ttt_chunk [1021/1893] bpb=1.127613 time=210.0s + ttt_chunk [1031/1893] bpb=1.128043 time=212.1s + ttt_chunk [1041/1893] bpb=1.127681 time=214.2s + ttt_chunk [1051/1893] bpb=1.127442 time=216.3s + ttt_chunk [1061/1893] bpb=1.127489 time=218.3s + ttt_chunk [1071/1893] bpb=1.128088 time=220.3s + ttt_chunk [1081/1893] bpb=1.128349 time=222.4s + ttt_chunk [1091/1893] bpb=1.129061 time=224.4s + ttt_chunk [1101/1893] bpb=1.129077 time=226.4s + ttt_chunk [1111/1893] bpb=1.128928 time=228.5s + ttt_chunk [1121/1893] bpb=1.128703 time=230.5s + ttt_chunk [1131/1893] bpb=1.128562 time=232.5s + ttt_chunk [1141/1893] bpb=1.128273 time=234.6s + ttt_chunk [1151/1893] bpb=1.128293 time=236.6s + ttt_chunk [1161/1893] bpb=1.127902 time=238.6s + ttt_chunk [1171/1893] bpb=1.128225 time=240.7s + ttt_chunk [1181/1893] bpb=1.127481 time=242.8s + ttt_chunk [1191/1893] bpb=1.127367 time=244.9s + ttt_chunk [1201/1893] bpb=1.127767 time=247.0s + ttt_chunk [1211/1893] bpb=1.127274 time=249.0s + ttt_chunk [1221/1893] bpb=1.126959 time=251.0s + ttt_chunk [1231/1893] bpb=1.126661 time=253.1s + ttt_chunk [1241/1893] bpb=1.126322 time=255.1s + ttt_chunk [1251/1893] bpb=1.125730 time=257.1s + ttt_chunk [1261/1893] bpb=1.125711 time=259.2s + ttt_chunk [1271/1893] bpb=1.125316 time=261.2s + ttt_chunk [1281/1893] bpb=1.125104 time=263.2s + ttt_chunk [1291/1893] bpb=1.124885 time=265.3s + ttt_chunk [1301/1893] bpb=1.124292 time=267.3s + ttt_chunk [1311/1893] bpb=1.123879 time=269.3s + ttt_chunk [1321/1893] bpb=1.123535 time=271.4s + ttt_chunk [1331/1893] bpb=1.123480 time=273.4s + ttt_chunk [1341/1893] bpb=1.123371 time=275.5s + ttt_chunk [1351/1893] bpb=1.123302 time=277.6s + ttt_chunk [1361/1893] bpb=1.123363 time=279.7s + ttt_chunk [1371/1893] bpb=1.123215 time=281.7s + ttt_chunk [1381/1893] bpb=1.123205 time=283.8s + ttt_chunk [1391/1893] bpb=1.122811 time=285.8s + ttt_chunk [1401/1893] bpb=1.122789 time=287.8s + ttt_chunk [1411/1893] bpb=1.122920 time=289.9s + ttt_chunk [1421/1893] bpb=1.123185 time=291.9s + ttt_chunk [1431/1893] bpb=1.122900 time=293.9s + ttt_chunk [1441/1893] bpb=1.123398 time=296.0s + ttt_chunk [1451/1893] bpb=1.123723 time=298.0s + ttt_chunk [1461/1893] bpb=1.123273 time=300.0s + ttt_chunk [1471/1893] bpb=1.124315 time=302.0s + ttt_chunk [1481/1893] bpb=1.123868 time=304.1s + ttt_chunk [1491/1893] bpb=1.123672 time=306.1s + ttt_chunk [1501/1893] bpb=1.123580 time=308.3s + ttt_chunk [1511/1893] bpb=1.123602 time=310.3s + ttt_chunk [1521/1893] bpb=1.123609 time=312.4s + ttt_chunk [1531/1893] bpb=1.123098 time=314.5s + ttt_chunk [1541/1893] bpb=1.122963 time=316.5s + ttt_chunk [1551/1893] bpb=1.123287 time=318.6s + ttt_chunk [1561/1893] bpb=1.123286 time=320.6s + ttt_chunk [1571/1893] bpb=1.123124 time=322.6s + ttt_chunk [1581/1893] bpb=1.123250 time=324.7s + ttt_chunk [1591/1893] bpb=1.123106 time=326.7s + ttt_chunk [1601/1893] bpb=1.123276 time=328.7s + ttt_chunk [1611/1893] bpb=1.123220 time=330.8s + ttt_chunk [1621/1893] bpb=1.122819 time=332.8s + ttt_chunk [1631/1893] bpb=1.123114 time=334.8s + ttt_chunk [1641/1893] bpb=1.123132 time=336.9s + ttt_chunk [1651/1893] bpb=1.123086 time=338.9s + ttt_chunk [1661/1893] bpb=1.122958 time=341.0s + ttt_chunk [1671/1893] bpb=1.123431 time=343.1s + ttt_chunk [1681/1893] bpb=1.123580 time=345.1s + ttt_chunk [1691/1893] bpb=1.123419 time=347.2s + ttt_chunk [1701/1893] bpb=1.123567 time=349.3s + ttt_chunk [1711/1893] bpb=1.123580 time=351.3s + ttt_chunk [1721/1893] bpb=1.123580 time=353.3s + ttt_chunk [1731/1893] bpb=1.123456 time=355.4s + ttt_chunk [1741/1893] bpb=1.123251 time=357.4s + ttt_chunk [1751/1893] bpb=1.123076 time=359.4s + ttt_chunk [1761/1893] bpb=1.123213 time=361.5s + ttt_chunk [1771/1893] bpb=1.123102 time=363.5s + ttt_chunk [1781/1893] bpb=1.123121 time=365.5s + ttt_chunk [1791/1893] bpb=1.122699 time=367.6s + ttt_chunk [1801/1893] bpb=1.122575 time=369.6s + ttt_chunk [1811/1893] bpb=1.122471 time=371.6s + ttt_chunk [1821/1893] bpb=1.122525 time=373.6s + ttt_chunk [1831/1893] bpb=1.121923 time=375.8s + ttt_chunk [1841/1893] bpb=1.121846 time=377.9s + ttt_chunk [1851/1893] bpb=1.121633 time=380.0s + ttt_chunk [1861/1893] bpb=1.121270 time=382.0s + ttt_chunk [1871/1893] bpb=1.121249 time=384.1s + ttt_chunk [1881/1893] bpb=1.120799 time=386.1s + ttt_chunk [1891/1893] bpb=1.120555 time=388.1s + ttt_chunk [1893/1893] bpb=1.120599 time=388.4s +ttt_sliding:done val_loss=1.888522 val_bpb=1.118492 elapsed=388.4s +legal_ttt val_loss:1.8885 val_bpb:1.1185 eval_time:388892ms +legal_ttt_exact val_loss:1.88852230 val_bpb:1.11849224 diff --git a/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/submission.json b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/submission.json new file mode 100644 index 0000000000..e3f27dfbf1 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/submission.json @@ -0,0 +1,100 @@ +{ + "author": "Anubhav", + "github_id": "AnubhavBharadwaaj", + "val_bpb": 1.1185, + "val_bpb_std": 0.0003, + "val_loss": 1.8886, + "hardware": "8xH100 SXM", + "training_time_seconds": 600, + "eval_time_seconds_total": 486, + "eval_time_ttt_seconds": 386, + "artifact_size_bytes_mean": 15947819, + "seeds": [1337, 42, 2025], + "base_submission": "PR #549 (abaybektursun — LeakyReLU² + Legal TTT + Parallel Muon, 1.1194 BPB)", + "description": "Non-record: SLOT (Sample-specific LM Optimization at Test-time) eval-time augmentation on PR #549 SOTA stack. First SLOT-based entry in Parameter Golf. SLOT optimizes a single delta vector at the last hidden layer per-batch during TTT scoring, achieving -0.0008 BPB improvement (1.1185 vs 1.1193 baseline, 3-seed mean). Also documents CTW as a negative result (+0.005 BPB).", + "results": { + "slot_seed1337": { + "val_bpb": 1.1188, + "val_loss": 1.8890, + "steps": 7127, + "step_avg_ms": 84.2, + "ttt_time_seconds": 386, + "artifact_bytes": 15965604 + }, + "slot_seed42": { + "val_bpb": 1.1185, + "val_loss": 1.8885, + "steps": 7155, + "step_avg_ms": 83.9, + "ttt_time_seconds": 388, + "artifact_bytes": 15882932 + }, + "slot_seed2025": { + "val_bpb": 1.1183, + "val_loss": 1.8882, + "steps": 7152, + "step_avg_ms": 83.9, + "ttt_time_seconds": 385, + "artifact_bytes": 15994920 + }, + "baseline_seed1337": { + "val_bpb": 1.1195, + "val_loss": 1.8903, + "steps": 7164, + "step_avg_ms": 83.8, + "ttt_time_seconds": 352 + }, + "baseline_seed42": { + "val_bpb": 1.1195, + "val_loss": 1.8903, + "steps": 7159, + "step_avg_ms": 83.8, + "ttt_time_seconds": 353 + }, + "baseline_seed2025": { + "val_bpb": 1.1189, + "val_loss": 1.8893, + "steps": 7164, + "step_avg_ms": 83.8, + "ttt_time_seconds": 350 + }, + "ctw_seed1337_negative": { + "val_bpb": 1.1252, + "val_loss": 1.8999, + "ctw_weight": 0.1, + "ctw_depth": 4, + "ttt_time_seconds": 2760, + "note": "NEGATIVE RESULT: CTW hurts BPB by +0.005 and exceeds 10-min eval limit" + } + }, + "slot_config": { + "enabled": true, + "lr": 0.001, + "steps": 3, + "optimizer": "AdamW", + "weight_decay": 1e-8, + "eps": 1e-5, + "delta_shape": "[1, 1, 512]", + "delta_init": "zeros", + "integration_point": "inside TTT Phase 1 scoring loop, between forward_hidden() and compute_logits()" + }, + "novel_techniques": [ + "SLOT eval-time augmentation (Hu et al., arXiv:2505.12392v2) — POSITIVE RESULT (-0.0008 BPB)", + "Deep integration: SLOT delta optimization inside TTT scoring loop on TTT-adapted hidden states", + "Model split: forward_logits() decomposed into forward_hidden() + compute_logits() for SLOT access", + "CTW eval-time augmentation (Willems et al., 1995) — NEGATIVE RESULT (+0.005 BPB)" + ], + "inherited_techniques": [ + "11L, 512d, 3x MLP, GQA (8H/4KV), LeakyReLU²(0.5)", + "Parameter Banking + Parallel Muon (84ms/step)", + "Legal Score-First TTT (SGD, lr=0.002, 3 epochs, 32K chunks)", + "BigramHash(1536), XSA4, Partial RoPE(16), LN Scale, VE128", + "EMA(0.997) + Tight SWA(50), GPTQ-lite int6 + LZMA-6", + "FlashAttention 3 (Hopper-native)" + ], + "reference": { + "slot_paper": "arXiv:2505.12392v2", + "slot_repo": "https://github.com/maple-research-lab/SLOT", + "ctw_paper": "Willems, Shtarkov, Tjalkens (1995)" + } +} diff --git a/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/train_gpt.py b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/train_gpt.py new file mode 100644 index 0000000000..4c2638273c --- /dev/null +++ b/records/track_non_record_16mb/2026-03-29_SLOT_on_LeakyReLU_TTT_ParallelMuon/train_gpt.py @@ -0,0 +1,2228 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + lawa_enabled = bool(int(os.environ.get("LAWA_ENABLED", "0"))) + lawa_k = int(os.environ.get("LAWA_K", 10)) + lawa_freq = int(os.environ.get("LAWA_FREQ", 100)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + gated_attention = bool(int(os.environ.get("GATED_ATTENTION", "0"))) + value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "0"))) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ctw_weight = float(os.environ.get("CTW_WEIGHT", 0.0)) # Enable with CTW_WEIGHT=0.1 + ctw_depth = int(os.environ.get("CTW_DEPTH", 4)) + ctw_entropy_threshold = float(os.environ.get("CTW_ENTROPY_THRESHOLD", 2.0)) # Skip CTW when H < this + slot_enabled = bool(int(os.environ.get("SLOT_ENABLED", "0"))) + slot_lr = float(os.environ.get("SLOT_LR", 0.001)) + slot_steps = int(os.environ.get("SLOT_STEPS", 3)) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + # Gated attention and value residual (non-banked small params) + self.gated_attention = gated_attention + if gated_attention: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, 4.0) + self.value_residual = value_residual + if value_residual: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v if self.value_residual else None + if self.value_residual and v0 is not None: + lam = self.vr_lambda.to(dtype=v.dtype) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + if self.gated_attention: + # gate shape: (bsz, seqlen, num_heads) -> (bsz, seqlen, num_heads, 1) for B,T,H,D layout + gate = torch.sigmoid(self.attn_gate(x)).unsqueeze(-1) + y = y * gate + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, out_w.to(x.dtype)), raw_v + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + # No CastedLinear -- weights come from banks + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=0.5) + return F.linear(x.square(), down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attention=gated_attention, value_residual=value_residual) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None, v0: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed, v0=v0) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + gated_attention: bool = False, + value_residual: bool = False, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.value_residual = value_residual + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + gated_attention=gated_attention, + value_residual=value_residual, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, mtp heads, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_hidden(self, input_ids: Tensor) -> Tensor: + """Return final hidden states (bsz, seq_len, model_dim) before lm_head projection.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0 = None + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve, v0=v0) + if v0 is None and raw_v is not None: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve, v0=v0) + return self.final_norm(x) + + def compute_logits(self, hidden_states: Tensor) -> Tensor: + """Project hidden states to logits with softcap.""" + if self.tie_embeddings: + logits_proj = F.linear(hidden_states, self.tok_emb.weight) + else: + logits_proj = self.lm_head(hidden_states) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + return self.compute_logits(self.forward_hidden(input_ids)) + +# --- CTW: Proper Context Tree Weighting (recursive depth weighting) --- +# Sparse lazy M-ary CTW with beta-tracked depth weights +# Ref: Willems, Shtarkov, Tjalkens (1995); eta propagation per Willems & Tjalkens (1997) +# Fixed: previous version used deepest-match lookup (just an n-gram model). +# This version computes the proper recursive weighted mixture across ALL depths. + +class CTWNode: + __slots__ = ['counts', 'total', 'children', 'log_pe', 'log_pw'] + def __init__(self): + self.counts = {} # symbol -> count + self.total = 0 + self.children = {} # context_symbol -> child node + self.log_pe = 0.0 # cumulative log P_e (KT sequential probability) + self.log_pw = 0.0 # cumulative log P_w (weighted probability) + +class SparseCTW: + """Proper sparse M-ary CTW. Nodes allocated on-demand. Beta-weighted depth mixing. + Verified against Python/Go/Rust implementations in project knowledge.""" + def __init__(self, depth: int = 4, vocab_size: int = 1024, alpha: float = 0.5): + self.depth = depth + self.vocab_size = vocab_size + self.alpha = alpha + self.alpha_sum = alpha * vocab_size + self.root = CTWNode() + self.context: list[int] = [] + + def _kt_prob(self, node: CTWNode, symbol: int) -> float: + """KT estimator: P_e(symbol | counts) = (count_s + alpha) / (total + alpha_sum)""" + return (node.counts.get(symbol, 0) + self.alpha) / (node.total + self.alpha_sum) + + def update(self, symbol: int): + """Update tree bottom-up along context path. Maintains log_pe and log_pw.""" + # Collect path nodes (root to deepest) + path: list[CTWNode] = [self.root] + node = self.root + for i in range(min(len(self.context), self.depth)): + ctx_sym = self.context[-(i + 1)] + if ctx_sym not in node.children: + node.children[ctx_sym] = CTWNode() + node = node.children[ctx_sym] + path.append(node) + + # Update counts and log_pe at each node on path + for node in path: + p_kt = self._kt_prob(node, symbol) + node.log_pe += math.log(max(p_kt, 1e-30)) + node.counts[symbol] = node.counts.get(symbol, 0) + 1 + node.total += 1 + + # Recompute log_pw bottom-up along path + # Leaf (deepest): P_w = P_e + path[-1].log_pw = path[-1].log_pe + # Walk back up + for i in range(len(path) - 2, -1, -1): + node = path[i] + child = path[i + 1] + # P_w(node) = 0.5 * P_e(node) + 0.5 * P_w(child) + # (Other children unchanged, their P_w = P_e = 1.0 initially, log = 0) + # For the child on the context path: + lpe = node.log_pe + lpw_child = child.log_pw + # log(0.5 * exp(lpe) + 0.5 * exp(lpw_child)) + max_lp = max(lpe, lpw_child) + node.log_pw = max_lp + math.log( + 0.5 * math.exp(min(lpe - max_lp, 20)) + + 0.5 * math.exp(min(lpw_child - max_lp, 20)) + ) + + self.context.append(symbol) + if len(self.context) > self.depth + 1: + self.context = self.context[-(self.depth + 1):] + + def predict(self, device: torch.device) -> Tensor: + """Compute proper CTW predictive distribution using beta-weighted depth mixing. + P_ctw(s) = sum over depths d: w_d * P_e_d(s), where w_d comes from + the recursive beta = exp(log_pe - log_pw) at each node.""" + # Walk down context path, collect nodes + path: list[CTWNode] = [self.root] + node = self.root + for i in range(min(len(self.context), self.depth)): + ctx_sym = self.context[-(i + 1)] + if ctx_sym in node.children: + node = node.children[ctx_sym] + path.append(node) + else: + break + + # Bottom-up: compute weighted predictive distribution + # Start from deepest node: P_w = P_e (KT estimate) + deepest = path[-1] + probs = torch.full((self.vocab_size,), self.alpha / (deepest.total + self.alpha_sum), device=device) + for sym, count in deepest.counts.items(): + if sym < self.vocab_size: + probs[sym] = (count + self.alpha) / (deepest.total + self.alpha_sum) + + # Walk back up, mixing KT at each depth with deeper weighted estimate + for i in range(len(path) - 2, -1, -1): + node = path[i] + # KT estimate at this depth + pe = torch.full((self.vocab_size,), self.alpha / (node.total + self.alpha_sum), device=device) + for sym, count in node.counts.items(): + if sym < self.vocab_size: + pe[sym] = (count + self.alpha) / (node.total + self.alpha_sum) + # Beta = exp(log_pe - log_pw): posterior weight for "this node is a leaf" + log_beta = node.log_pe - node.log_pw + log_beta = max(min(log_beta, 20), -20) # clamp + beta = math.exp(log_beta) + w_local = beta / (1.0 + beta) # weight for local KT vs deeper + probs = w_local * pe + (1.0 - w_local) * probs + + return probs + + def mix_with_neural(self, neural_logits: Tensor, w_ctw: float = 0.1, + entropy_threshold: float = 0.0) -> Tensor: + """Entropy-adaptive CTW mixing. Only mixes when neural entropy > threshold. + When neural model is confident (low entropy), CTW adds noise — skip it.""" + if not self.context: + return neural_logits + + # Entropy-adaptive gating: compute neural entropy + if entropy_threshold > 0: + neural_probs = F.softmax(neural_logits, dim=-1) + entropy = -(neural_probs * torch.log(neural_probs + 1e-10)).sum() + if entropy.item() < entropy_threshold: + return neural_logits # Neural is confident — skip CTW + # Scale CTW weight by entropy (higher entropy = more CTW influence) + max_entropy = math.log(self.vocab_size) # ~6.93 for vocab 1024 + entropy_scale = min(entropy.item() / max_entropy, 1.0) + effective_w = w_ctw * entropy_scale + else: + effective_w = w_ctw + + ctw_probs = self.predict(neural_logits.device).clamp(1e-10) + ctw_lp = torch.log(ctw_probs) + neural_lp = F.log_softmax(neural_logits, dim=-1) + mixed = (1.0 - effective_w) * neural_lp + effective_w * ctw_lp + return mixed - mixed.logsumexp(dim=-1, keepdim=True) + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe) with optional CTW augmentation: + score each chunk with sliding windows, then train on it. + Every token scored BEFORE any update that could use it. + When CTW is enabled, neural logits are mixed with CTW predictions per-token.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # CTW integration: build suffix tree from scored tokens, mix into logits + use_ctw = args.ctw_weight > 0 + ctw = SparseCTW(depth=args.ctw_depth, vocab_size=args.vocab_size) if use_ctw else None + w_ctw = args.ctw_weight + + # SLOT integration: per-batch delta optimization at last hidden layer + use_slot = args.slot_enabled + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}" + f"{f' ctw_weight={w_ctw} ctw_depth={args.ctw_depth} entropy_thresh={args.ctw_entropy_threshold}' if use_ctw else ''}" + f"{f' slot_lr={args.slot_lr} slot_steps={args.slot_steps}' if use_slot else ''}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + if use_slot: + # SLOT path: need gradients for delta optimization, can't use inference_mode + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + # Step 1: Get hidden states (no grad needed for forward) + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + H = base_model.forward_hidden(x_batch) + H = H.detach().float() # detach from model graph, float for SLOT precision + # Step 2: Optimize delta (needs gradients through compute_logits only) + delta = torch.zeros(1, 1, H.shape[-1], device=device, dtype=H.dtype, requires_grad=True) + slot_opt = torch.optim.AdamW([delta], lr=args.slot_lr, weight_decay=1e-8, eps=1e-5) + for _step in range(args.slot_steps): + slot_opt.zero_grad() + adapted_logits = base_model.compute_logits((H + delta).to(torch.bfloat16)).float() + shift_logits = adapted_logits[:, :-1, :].contiguous() + shift_targets = y_batch[:, :seq_len-1].contiguous() + slot_loss = F.cross_entropy(shift_logits.reshape(-1, shift_logits.size(-1)), + shift_targets.reshape(-1), reduction="mean") + slot_loss.backward() + slot_opt.step() + # Step 3: Score with adapted hidden states + with torch.no_grad(): + logits = base_model.compute_logits((H + delta.detach()).to(torch.bfloat16)).float() + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + else: + # Standard path: inference_mode, optional CTW + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + if use_ctw: + # VECTORIZED entropy gate: compute entropy for ALL scored tokens at once + scored_logits = logits[i, s:wlen, :].float() + scored_targets = y_batch[i, s:wlen] + neural_lp = F.log_softmax(scored_logits, dim=-1) + neural_probs = neural_lp.exp() + token_entropy = -(neural_probs * neural_lp).sum(dim=-1) + all_nll = F.cross_entropy(scored_logits, scored_targets, reduction="none") + for t_idx in range(scored_logits.size(0)): + abs_idx = s + t_idx + target_tok = y_batch[i, abs_idx].item() + if token_entropy[t_idx].item() >= args.ctw_entropy_threshold: + ctw_probs = ctw.predict(logits.device).clamp(1e-10) + ctw_lp = torch.log(ctw_probs) + max_ent = math.log(ctw.vocab_size) + ew = w_ctw * min(token_entropy[t_idx].item() / max_ent, 1.0) + mixed = (1.0 - ew) * neural_lp[t_idx] + ew * ctw_lp + mixed = mixed - mixed.logsumexp(dim=-1, keepdim=True) + token_nll = F.cross_entropy(mixed.unsqueeze(0), scored_targets[t_idx:t_idx+1], reduction="sum") + loss_sum += token_nll.to(torch.float64) + else: + loss_sum += all_nll[t_idx].to(torch.float64) + ctw.update(target_tok) + else: + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +def eval_val_sliding_ctw( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, log0=print, +) -> tuple[float, float]: + """CTW-augmented sliding window eval. Mixes neural logits with CTW predictions. + Novel contribution: Bayesian-optimal sequential probability assignment at zero artifact cost.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + ctw = SparseCTW(depth=args.ctw_depth, vocab_size=args.vocab_size, alpha=0.5) + w_ctw = args.ctw_weight + + log0(f"ctw_eval:start windows={len(my_windows)} ctw_weight={w_ctw} depth={args.ctw_depth}") + base_model.eval() + t0 = time.perf_counter() + with torch.inference_mode(): + for wi, ws in enumerate(my_windows): + if rank == 0 and wi % max(1, len(my_windows) // 10) == 0: + pct = 100.0 * wi / max(len(my_windows), 1) + log0(f" ctw_eval: {pct:.0f}% | {time.perf_counter() - t0:.1f}s") + end = min(ws + seq_len, total_tokens) + wlen = end - ws + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x = chunk[:-1].unsqueeze(0) + y = chunk[1:] + s = 0 if ws == 0 else max(wlen - stride, 0) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x) + logits_scored = logits[0, s:wlen, :].float() + targets_scored = y[s:wlen] + # VECTORIZED entropy gate + neural_lp = F.log_softmax(logits_scored, dim=-1) + neural_probs = neural_lp.exp() + token_entropy = -(neural_probs * neural_lp).sum(dim=-1) + all_nll = F.cross_entropy(logits_scored, targets_scored, reduction="none") + + for t_idx in range(logits_scored.size(0)): + target_tok = targets_scored[t_idx].item() + if token_entropy[t_idx].item() >= args.ctw_entropy_threshold: + ctw_probs = ctw.predict(logits.device).clamp(1e-10) + ctw_lp = torch.log(ctw_probs) + max_ent = math.log(ctw.vocab_size) + ew = w_ctw * min(token_entropy[t_idx].item() / max_ent, 1.0) + mixed = (1.0 - ew) * neural_lp[t_idx] + ew * ctw_lp + mixed = mixed - mixed.logsumexp(dim=-1, keepdim=True) + token_nll = F.cross_entropy(mixed.unsqueeze(0), targets_scored[t_idx:t_idx+1], reduction="sum") + loss_sum += token_nll.to(torch.float64) + else: + loss_sum += all_nll[t_idx].to(torch.float64) + ctw.update(target_tok) + token_count += float(logits_scored.size(0)) + tgt = y[s:wlen] + prev = chunk[s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + log0(f"ctw_eval:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + gated_attention=args.gated_attention, + value_residual=args.value_residual, + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + + # Optimizer split: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, mtp heads, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + from collections import deque + lawa_queue: deque[dict[str, Tensor]] = deque(maxlen=args.lawa_k) + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + if args.lawa_enabled and step % args.lawa_freq == 0: + lawa_queue.append({name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging + if args.lawa_enabled and len(lawa_queue) > 1: + log0(f"lawa:applying LAWA averaging k={len(lawa_queue)}") + current_state = base_model.state_dict() + avg_state = {name: torch.zeros(t.shape, dtype=torch.float32, device='cpu') for name, t in current_state.items()} + for snap in lawa_queue: + for name in avg_state: + avg_state[name] += snap[name].float() + for name in avg_state: + avg_state[name] /= len(lawa_queue) + avg_state[name] = avg_state[name].to(dtype=current_state[name].dtype) + base_model.load_state_dict(avg_state, strict=True) + else: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + gated_attention=args.gated_attention, value_residual=args.value_residual, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + # Novel: CTW eval-time augmentation (standalone, when TTT is disabled) + # When TTT is enabled, CTW is already integrated into the TTT scoring loop above + if args.ctw_weight > 0 and not args.ttt_enabled: + torch.cuda.synchronize() + t_ctw = time.perf_counter() + ctw_loss, ctw_bpb = eval_val_sliding_ctw( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"ctw_augmented val_loss:{ctw_loss:.4f} val_bpb:{ctw_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ctw):.0f}ms") + log0(f"ctw_augmented_exact val_loss:{ctw_loss:.8f} val_bpb:{ctw_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main()