|
| 1 | +# Legal Score-First TTT + Parallel Muon + Parameter Banking |
| 2 | + |
| 3 | +**val_bpb: 1.1218** (legal TTT, 2-seed mean; 3rd seed in progress) | **~15.8 MB** | 8×H100 SXM, 600s training + 400s TTT eval |
| 4 | + |
| 5 | +## Results (8×H100 80GB SXM, PyTorch 2.9.1+cu128) |
| 6 | + |
| 7 | +| Seed | step_avg | steps | Pre-TTT bpb | **Post-TTT bpb** | TTT gain | Artifact | |
| 8 | +|------|----------|-------|-------------|-----------------|----------|----------| |
| 9 | +| 1337 | 82.3ms | 7,278 | 1.1234 | **1.1213** | -0.0021 | 15,841,722 | |
| 10 | +| 42 | 82.4ms | 7,265 | 1.1242 | **1.1222** | -0.0020 | pending | |
| 11 | +| 2025 | running | | | | | | |
| 12 | +| **Mean** | **82.3ms** | **~7,272** | **1.1238** | **1.1218** | **-0.0021** | **~15.8 MB** | |
| 13 | + |
| 14 | +## Legal TTT Protocol (from PR #461) |
| 15 | + |
| 16 | +Every validation token is **scored BEFORE any weight update** that could use it: |
| 17 | + |
| 18 | +``` |
| 19 | +for each 32K-token chunk of val data: |
| 20 | + Phase 1 — SCORE: sliding window eval under torch.inference_mode() |
| 21 | + Record per-token NLL. This is the official score. |
| 22 | + Phase 2 — TRAIN: SGD(lr=0.002, momentum=0.9) for 3 epochs |
| 23 | + Freeze first 2 blocks. Grad clip 1.0. Cosine LR decay. |
| 24 | + Model adapts, improving predictions for FUTURE chunks only. |
| 25 | +``` |
| 26 | + |
| 27 | +Scoring under `inference_mode()` guarantees no gradient computation or weight mutation during scoring. The chunk ordering ensures strict causal legality. |
| 28 | + |
| 29 | +### TTT Hyperparameters |
| 30 | + |
| 31 | +| Parameter | Value | |
| 32 | +|-----------|-------| |
| 33 | +| Chunk size | 32,768 tokens | |
| 34 | +| Optimizer | SGD + momentum(0.9) | |
| 35 | +| Learning rate | 0.002 (cosine decay across chunks) | |
| 36 | +| Epochs per chunk | 3 | |
| 37 | +| Frozen blocks | First 2 of 11 | |
| 38 | +| Gradient clip | 1.0 | |
| 39 | + |
| 40 | +## Training Architecture |
| 41 | + |
| 42 | +Built on PR #414's stack with Parameter Banking + Parallel Muon optimizer: |
| 43 | + |
| 44 | +- 11L, 512d, 8H/4KV, MLP 3× (relu²) |
| 45 | +- XSA on last 4 layers, Partial RoPE (16/64 dims), LN Scale |
| 46 | +- SmearGate, BigramHash(2048), VE128 on layers 9-10 |
| 47 | +- EMA(0.997) + Tight SWA(every 50) |
| 48 | +- GPTQ-lite int6 quantization + lzma compression |
| 49 | +- **Parameter Banking**: 4 contiguous 3D banks replace 66 nn.Linear weights |
| 50 | +- **Parallel Muon**: No DDP for banks. Post-backward reduce-scatter → local NS → all-gather |
| 51 | + |
| 52 | +## Run Command |
| 53 | + |
| 54 | +```bash |
| 55 | +NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=2048 XSA_LAST_N=4 \ |
| 56 | +EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=1 SWA_EVERY=50 \ |
| 57 | +ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 LATE_QAT_THRESHOLD=0.15 \ |
| 58 | +VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 \ |
| 59 | +TTT_ENABLED=1 TTT_LR=0.002 TTT_EPOCHS=3 TTT_CHUNK_TOKENS=32768 \ |
| 60 | +TTT_FREEZE_BLOCKS=2 TTT_MOMENTUM=0.9 TTT_BATCH_SEQS=32 TTT_GRAD_CLIP=1.0 \ |
| 61 | +MUON_WD=0.04 ADAM_WD=0.04 \ |
| 62 | +MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \ |
| 63 | +MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \ |
| 64 | +MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3500 \ |
| 65 | +ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \ |
| 66 | +SEED=1337 \ |
| 67 | +torchrun --standalone --nproc_per_node=8 train_gpt.py |
| 68 | +``` |
| 69 | + |
| 70 | +## Credits |
| 71 | + |
| 72 | +- **TTT recipe**: PR #461 by @anantdgoel — legal score-first TTT with SGD+momentum, selective freezing |
| 73 | +- **Base model**: PR #414 by @signalrush — GPTQ-lite, VE128, Tight SWA, warmdown=3500 |
| 74 | +- **Optimizer**: Parameter Banking + Parallel Muon (arXiv:2511.07464) |
0 commit comments