diff --git a/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/README.md b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/README.md new file mode 100644 index 000000000..9819a6827 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/README.md @@ -0,0 +1,89 @@ +# Cosine TTT Scheduling with Per-Layer Learning Rates + +Mean val_bpb = 1.0970 (3 seeds, std=0.0010) | 8×H100 SXM | 600s train + 465s TTT + 187s eval + +## Results + +| Seed | Steps | Pre-TTT | Post-TTT | Artifact | +|------|-------|---------|----------|----------| +| 1337 | 7,101 | 1.1577 | 1.0959 | 15.4 MB | +| 42 | 6,700 | 1.1588 | 1.0971 | 15.5 MB | +| 7 | 6,987 | 1.1580 | 1.0979 | 15.8 MB | + +## Background + +Starting from the community stack (PRs #162, #180, #315, #398), we spent several days exploring ways to improve compression and eval-time adaptation. Many of these did not improve the result but informed the direction that eventually worked. + +### Compression research (did not improve score) + +We analyzed trained checkpoints to evaluate alternative quantization and compression approaches: + +- **Learned codebook quantization** (K-means, K=256): 87% lower reconstruction MSE than uniform int6, but 25% larger compressed artifact under zstd-22. Codebook indices have higher byte entropy than clamped int6 values. +- **Symmetry-transport** (Procrustes alignment across layers): Layers share 91-93% rotational structure, but storing the rotation matrices costs more than storing the weights directly. Low-rank approximation of the rotation delta (rank-128) captured only 16.6% of variance. +- **Embedding low-rank factorization** (SVD): Rank-64 explains 41.9% of variance on tok_emb (1024×512). Not viable at this vocabulary size. +- **Magnitude pruning**: Non-monotonic interaction with zstd-22. 3% pruning increased artifact size by 728KB on our checkpoint. + +These results indicated that int6+zstd is close to optimal for this model architecture and that compression was not the path to further improvement. + +### Architectural exploration (did not improve score) + +- **Progressive layer dropping**: Randomly skipping layers during training for regularization. Caused 0.06 BPB regression at step 1000 when combined with head dropout. The DDP implementation also introduced higher-order ops incompatible with torch.compile + DDPOptimizer. +- **Depth recurrence** (Huginn-style, 3 shared blocks × 3 loops): Blocks learned position-specific functions rather than general refiners. Eval at 2× trained depth produced val_bpb 4.34. Not viable below ~100M params per unique layer. +- **Neural cache** (cross-window KV caching at eval): Implemented but not validated on hardware. The original proposal (PR #318) was blocked by a torch.compile issue. + +### TTT analysis (led to the finding) + +Analyzing our trained checkpoint, we observed: + +1. **Quantization error is uniformly distributed** — the top 1% of weights by error magnitude account for only 3.9% of total reconstruction error. This confirmed that outlier protection approaches would not be effective. +2. **Quantization damage varies 3.4× across layer types** — MLP output projections (512×1536) have systematically higher relative error than input projections (1536×512). +3. **TTT improvement exceeds quantization repair** — the TTT contribution (~0.06 BPB on our model) is roughly 2.4× larger than the quantization gap (~0.008), indicating TTT performs distribution adaptation beyond repairing quantization damage. + +These observations motivated exploring the TTT schedule rather than the training architecture or compression scheme. + +## TTT schedule + +Two modifications to AdamW TTT (PR #442): + +**Cosine lr decay** over 30 epochs instead of flat lr over 10 epochs. Quantization introduces both large-scale damage (outlier weight rounding) and distributed noise (small perturbations across all weights). A flat lr must compromise between these two regimes. Cosine decay applies full lr early to address large damage, then progressively reduces to refine without overshooting. + +**Per-layer lr groups** based on the quantization damage measurements above. MLP output projections receive 3× base lr, input projections 0.5×, all other parameters 1×. This allocates more adaptation capacity to more damaged layers. The ratios are specific to our model — other architectures may show different damage profiles. + +We tested 34 TTT configurations across optimizers (AdamW, Adam, SGD), learning rates (1e-4 to 2e-3), epoch counts (3 to 30), schedules (flat, cosine, warmup+cosine), per-layer groupings, freeze strategies, and loss functions (cross-entropy, focal loss γ=1-3, KL divergence from pre-quant model). + +Focal loss did not improve over cross-entropy — hard tokens appear to be unpredictable rather than undertrained. KL divergence from the pre-quant model was less effective than cross-entropy — the pre-quant and post-quant models are similar enough that the KL signal is weak relative to the cross-entropy signal from the validation data. + +## TTT config + +``` +TTT_OPTIMIZER=adamw TTT_LR=0.0005 TTT_EPOCHS=30 +TTT_COSINE=1 TTT_PERLAYER=1 TTT_FREEZE_BLOCKS=0 +TTT_BATCH_SEQS=64 (per GPU, 512 total with DDP sharding) +``` + +Each GPU processes a contiguous 1/8 shard of the validation tokens with gradient all_reduce (ReduceOp.AVG). 30 epochs at ~15.5s/epoch = ~465s total. + +## Training config + +Standard community stack. 11L, 512d, 8H/4KV (GQA), 3x MLP (relu-squared), U-Net skips, SmearGate, BigramHash(2048), OrthoInit, Partial RoPE (16/64 dims), LN Scale, EMA(0.997), tied embeddings. XSA disabled. Int6 per-row + zstd-22. + +## Notes + +- All runs used FA2. FA3 Hopper would improve pre-TTT quality through faster training steps. The TTT schedule is independent of the attention kernel. +- The cosine + per-layer schedule adds no artifact cost and minimal code complexity over flat-lr TTT. +- See PR #212 for a non-record submission documenting 25+ additional experiments. + +## Reproduction + +```bash +git clone https://github.com/mrdavtan/parameter-golf.git +cd parameter-golf && git checkout next-gen +pip install flash-attn --no-cache-dir --no-build-isolation +pip install zstandard sentencepiece huggingface_hub +python3 data/cached_challenge_fineweb.py --variant sp1024 +bash run_competition.sh 1337 +``` + +Hardware: 8×H100 SXM (RunPod), PyTorch 2.9.1+cu128, Flash Attention 2 + +Builds on PRs #162, #180, #77, #398, #442, #417, #315, and modded-nanogpt. diff --git a/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/submission.json b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/submission.json new file mode 100644 index 000000000..91bdea376 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/submission.json @@ -0,0 +1,48 @@ +{ + "author": "mrdavtan", + "github_id": "mrdavtan", + "name": "Cosine TTT scheduling with per-layer lr (mean val_bpb=1.0970, 3 seeds)", + "blurb": "AdamW TTT with cosine lr decay and per-layer lr groups. 30 epochs, 3x lr for MLP output projections, 0.5x for input projections.", + "date": "2026-03-22", + "val_loss": 1.8504, + "val_bpb": 1.0959, + "mean_val_bpb": 1.0970, + "std_val_bpb": 0.0010, + "seed": 1337, + "num_seeds": 3, + "seed_results": { + "1337": 1.0959, + "42": 1.0971, + "7": 1.0979 + }, + "step_stop": 7101, + "wallclock_seconds": 600.0, + "ttt_time_seconds": 465.4, + "eval_time_seconds": 186.5, + "bytes_total": 15362557, + "bytes_model_int8_zstd": 15258143, + "bytes_code": 104414, + "hardware": "8xH100 SXM (RunPod), PyTorch 2.9.1+cu128, FA2", + "track": "track_10min_16mb", + "model": { + "num_layers": 11, + "model_dim": 512, + "num_heads": 8, + "num_kv_heads": 4, + "mlp_mult": 3, + "vocab_size": 1024, + "tie_embeddings": true, + "total_params": 26829913 + }, + "ttt_config": { + "optimizer": "adamw", + "lr": 0.0005, + "epochs": 30, + "cosine": true, + "perlayer": true, + "perlayer_proj_mult": 3.0, + "perlayer_fc_mult": 0.5, + "freeze_blocks": 0, + "batch_seqs_per_gpu": 64 + } +} diff --git a/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_gpt.py b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_gpt.py new file mode 100644 index 000000000..77e2691c0 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_gpt.py @@ -0,0 +1,2263 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn import flash_attn_func as _flash_attn_3_func + _HAS_FA3 = True # FA2 fallback — same input format, slightly slower than FA3 + except ImportError: + _HAS_FA3 = False +_use_fa3: bool = False # set at runtime after args are parsed + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) # eval window; NTK-RoPE scales if > train_seq_len + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # ── Tier-2 proxy mode ──────────────────────────────────────────────────── + # TIER2_MODE=1 overrides key settings for a fast 3-minute proxy run. + # Use this to validate architectural features before committing a full run. + # Schedule-dependent features (EMA, TTT, SWA) are disabled — they can only + # be evaluated meaningfully at full training duration. + # Compare val_bpb at step ~2000 against a baseline TIER2_MODE=1 run. + _tier2 = bool(int(os.environ.get("TIER2_MODE", "0"))) + if _tier2: + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 180.0)) + iterations = int(os.environ.get("ITERATIONS", 3000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 500)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) # faster final eval + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) # 0 = use mlp_mult * model_dim + fp16_embed_export = bool(int(os.environ.get("FP16_EMBED_EXPORT", "1"))) # keep tok_emb in fp16 at export + num_loops = int(os.environ.get("NUM_LOOPS", 1)) + lora_rank = int(os.environ.get("LORA_RANK", 0)) + qat = bool(int(os.environ.get("QAT", "1"))) + qat_min_seconds = float(os.environ.get("QAT_MIN_SECONDS", 120.0)) # guarantee QAT runs for at least this long + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + doc_isolated_eval = bool(int(os.environ.get("DOC_ISOLATED_EVAL", "1"))) # eval per-document, no cross-doc context + smear_gate = bool(int(os.environ.get("SMEAR_GATE", "1"))) # cheap bigram context at embedding layer + bigram_hash = bool(int(os.environ.get("BIGRAM_HASH", "1"))) # hash-based bigram embedding + bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 2048)) + bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128)) + swa = bool(int(os.environ.get("SWA", "0"))) # stochastic weight averaging (disabled: EMA preferred) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last N layers (0=disabled) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) # exponential moving average (replaces SWA) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) # EMA decay per step + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) # test-time training on val data + ttt_lr = float(os.environ.get("TTT_LR", 0.0005)) # AdamW lr (PR #442: AdamW beats SGD by 0.019 BPB) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 30)) # 30ep cosine beats 10ep flat by 16% + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) # only used if TTT_OPTIMIZER=sgd + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) # 0 = all blocks unfrozen (PR #398) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adamw") # "adamw" or "sgd" + ttt_max_steps = int(os.environ.get("TTT_MAX_STEPS", 300)) # cap steps per epoch (~10s per epoch) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 64)) # seqs per GPU per step (64*8=512 total) + ttt_cosine = bool(int(os.environ.get("TTT_COSINE", "1"))) # cosine lr decay during TTT (+16% over flat) + ttt_warmup_frac = float(os.environ.get("TTT_WARMUP_FRAC", 0.0)) # linear warmup fraction (0.1 = 10%) + ttt_perlayer = bool(int(os.environ.get("TTT_PERLAYER", "1"))) # per-layer lr (3x proj, 0.5x fc) — +23.5% combined with cosine + # Two-phase TTT (matches PR #415/#417 approach) + ttt_two_phase = bool(int(os.environ.get("TTT_TWO_PHASE", "0"))) # enable two-phase TTT + ttt_p1_epochs = int(os.environ.get("TTT_P1_EPOCHS", 50)) # phase 1: norm-only recalibration + ttt_p1_lr = float(os.environ.get("TTT_P1_LR", 0.01)) # phase 1: Adam lr + ttt_p2_epochs = int(os.environ.get("TTT_P2_EPOCHS", 10)) # phase 2: selective block adaptation + ttt_p2_lr = float(os.environ.get("TTT_P2_LR", 0.005)) # phase 2: SGD lr + ttt_p2_unfreeze_blocks = int(os.environ.get("TTT_P2_UNFREEZE_BLOCKS", 3)) # phase 2: unfreeze last N blocks + use_zstd = bool(int(os.environ.get("USE_ZSTD", "1"))) # use zstd instead of zlib for compression + curriculum = bool(int(os.environ.get("CURRICULUM", "0"))) # sort training shards by doc length (easy first) + quant_bits = int(os.environ.get("QUANT_BITS", 6)) # 8=int8, 6=int6 (int6 fits ~3x more params in 16MB) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 50000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + lora_lr = float(os.environ.get("LORA_LR", 0.01)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + ortho_init = bool(int(os.environ.get("ORTHO_INIT", "1"))) + late_k_fp16 = bool(int(os.environ.get("LATE_K_FP16", "1"))) + use_fa3 = bool(int(os.environ.get("USE_FA3", "1"))) + rope_dims = int(os.environ.get("ROPE_DIMS", 0)) # 0 = full RoPE; >0 = apply RoPE to only first N dims per head + ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) # scale block output by 1/sqrt(layer_idx+1) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "0"))) # U-Net style skip connections + prune_pct = float(os.environ.get("PRUNE_PCT", 0.0)) # magnitude pruning: zero smallest N% of weights before compression + gptq_lite = bool(int(os.environ.get("GPTQ_LITE", "1"))) # per-row optimal clip search at quantization (default ON, zero training cost) + reptile_enabled = bool(int(os.environ.get("REPTILE_TTT", "0"))) # Reptile meta-TTT before standard TTT + reptile_budget_s = float(os.environ.get("REPTILE_BUDGET_S", 60.0)) # seconds for Reptile meta-learning + reptile_inner_steps = int(os.environ.get("REPTILE_INNER_STEPS", 3)) # SGD steps per inner loop + reptile_inner_lr = float(os.environ.get("REPTILE_INNER_LR", 0.1)) # inner SGD learning rate + reptile_outer_lr = float(os.environ.get("REPTILE_OUTER_LR", 0.01)) # outer interpolation rate + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "0"))) # shared value embedding + ve_dim = int(os.environ.get("VE_DIM", 128)) # value embedding dimension + ve_layers = os.environ.get("VE_LAYERS", "9,10") # comma-separated layer indices + + # Disable schedule-dependent features in TIER2_MODE unless explicitly overridden + if _tier2: + qat = bool(int(os.environ.get("QAT", "0"))) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + swa = bool(int(os.environ.get("SWA", "0"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + # Cache flat buffer to avoid per-step allocation + if "_updates_flat" not in group: + group["_total_params"] = sum(int(p.numel()) for p in params) + group["_updates_flat"] = torch.zeros(group["_total_params"], device=params[0].device, dtype=torch.bfloat16) + total_params = group["_total_params"] + updates_flat = group["_updates_flat"] + updates_flat.zero_() + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + if wd > 0.0: + p.mul_(1.0 - wd * lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_2d_with_clip(t32: Tensor, clip_abs: Tensor, max_val: int) -> tuple[Tensor, Tensor, Tensor]: + """Quantize a 2D tensor with given per-row clip values. Returns (q, scale, reconstruction_error).""" + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8) + # Reconstruction error: MSE between original and dequantized + deq = q.float() * scale[:, None] + err = (t32 - deq).square().mean(dim=1) + return q, scale, err + + +# GPTQ-lite clip ratios: search these percentiles per weight matrix to find optimal clipping. +_GPTQ_CLIP_RATIOS = [0.9, 0.95, 0.99, 0.999, 0.99999] +_gptq_lite: bool = False # set at runtime from GPTQ_LITE env var + + +def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]: + max_val = 127 if bits == 8 else (2 ** (bits - 1)) - 1 # int6: 31, int8: 127 + t32 = t.float() + if t32.ndim == 2: + if _gptq_lite and t32.numel() > 0: + # GPTQ-lite: try multiple clip percentiles per row, pick lowest reconstruction error. + abs_vals = t32.abs() + best_q: Tensor | None = None + best_scale: Tensor | None = None + best_err: Tensor | None = None + for ratio in _GPTQ_CLIP_RATIOS: + clip_abs = torch.quantile(abs_vals, ratio, dim=1) + q, scale, err = _quantize_2d_with_clip(t32, clip_abs, max_val) + if best_err is None: + best_q, best_scale, best_err = q, scale, err + else: + # Per-row: keep whichever clip ratio gave lower error for each row + better = err < best_err + best_q[better] = q[better] + best_scale[better] = scale[better] + best_err[better] = err[better] + return best_q.contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Standard: single clip percentile + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(max_val) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous() + return q, scale + +# Names of tensors to keep in fp16 at export instead of quantizing to int8. +# Populated at runtime when fp16_embed_export=True. +_FP16_EXPORT_NAMES: set[str] = set() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor], bits: int = 8): + # Clean-script export format supporting int8 or int6: + # - per-row quantization for 2D float tensors (int8: [-127,127], int6: [-31,31]) + # - per-tensor quantization for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + # - fp16 passthrough for tensors in _FP16_EXPORT_NAMES (e.g. tied embeddings) + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # FP16 export bypass: keep specified tensors (e.g. tied embeddings) in fp16 + # instead of quantizing to int8. Avoids compounding int8 errors through both + # the input embedding and output projection paths. + if name in _FP16_EXPORT_NAMES: + kept = t.to(dtype=torch.float16).contiguous() + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t, bits=bits) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" list[Path]: + """Sort shards by average document length (shorter docs = easier = first). + + Reads only the first 100K tokens of each shard to estimate avg doc length. + Shards with shorter average documents contain simpler, more repetitive text + that helps the model learn basic patterns before encountering harder material. + """ + difficulties: list[tuple[float, Path]] = [] + for f in files: + header = np.fromfile(f, dtype=" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device, + sorted_files: list[Path] | None = None): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern, sorted_files=sorted_files) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +def fake_quantize_int8_per_row(w: Tensor) -> Tensor: + """Simulate per-row int6 quantization with straight-through estimator. + + Uses absmax per-row scale (no torch.quantile — O(n) instead of O(n log n)). + Matches the int6 export range [-31, 31]. + Backward: gradients pass through as if no quantization happened (STE). + """ + w32 = w.float() + scale = (w32.abs().amax(dim=1) / 31.0).clamp_min(1.0 / 31.0) + w_q = torch.clamp(torch.round(w32 / scale[:, None]), -31, 31) + w_deq = w_q * scale[:, None] + return w + (w_deq.to(w.dtype) - w).detach() + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + _qat: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self._qat and self.training: + w = fake_quantize_int8_per_row(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +class AttentionLoRA(nn.Module): + """Per-iteration LoRA adapters for attention Q, K, V, and output projections. + + Initialized so that the LoRA contribution is zero at the start of training + (B matrices are zeros). During training, the optimizer learns per-iteration + specialization while the base attention weights remain shared across loops. + """ + def __init__(self, dim: int, kv_dim: int, rank: int): + super().__init__() + self.q_A = nn.Parameter(torch.empty(dim, rank)) + self.q_B = nn.Parameter(torch.zeros(rank, dim)) + self.k_A = nn.Parameter(torch.empty(dim, rank)) + self.k_B = nn.Parameter(torch.zeros(rank, kv_dim)) + self.v_A = nn.Parameter(torch.empty(dim, rank)) + self.v_B = nn.Parameter(torch.zeros(rank, kv_dim)) + self.proj_A = nn.Parameter(torch.empty(dim, rank)) + self.proj_B = nn.Parameter(torch.zeros(rank, dim)) + self._init_lora() + + def _init_lora(self) -> None: + for name in ("q_A", "k_A", "v_A", "proj_A"): + nn.init.kaiming_uniform_(getattr(self, name), a=math.sqrt(5)) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + # ntk_base_seq_len: if > 0, apply NTK-aware RoPE scaling when seq_len > ntk_base_seq_len + # (lets a model trained at seq_len=1024 generalise to seq_len=2048 at eval with no quality loss) + def __init__(self, dim: int, base: float = 10000.0, ntk_base_seq_len: int = 0): + super().__init__() + self._dim = dim + self._base = base + self._ntk_base_seq_len = ntk_base_seq_len + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + if self._ntk_base_seq_len > 0 and seq_len > self._ntk_base_seq_len: + # NTK-aware scaling: extend context without fine-tuning + scale = seq_len / self._ntk_base_seq_len + ntk_base = self._base * (scale ** (self._dim / (self._dim - 2))) + inv_freq = 1.0 / (ntk_base ** (torch.arange(0, self._dim, 2, dtype=torch.float32, device=device) / self._dim)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + # rope_dims: if >0 apply RoPE only to first rope_dims dims of each head; rest are position-free + self.rope_dims = rope_dims if rope_dims > 0 else self.head_dim + self.rotary = Rotary(self.rope_dims, base=rope_base, ntk_base_seq_len=ntk_base_seq_len) + self.use_xsa = False # enabled on last N layers by GPT.__init__ + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Exclusive Self Attention: subtract self-value projection (arXiv:2603.09078). + GQA-aware reshape avoids repeat_interleave — zero extra allocation.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) # (B, T, Hkv, 1, D) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, lora: AttentionLoRA | None = None, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + if lora is not None: + # LoRA delta: (bsz, seqlen, dim) @ (dim, rank) @ (rank, out_dim) + # autocast handles fp32->bf16 cast of LoRA params automatically + q = q + (x @ lora.q_A) @ lora.q_B + k = k + (x @ lora.k_A) @ lora.k_B + v = v + (x @ lora.v_A) @ lora.v_B + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dims < self.head_dim: + # Partial RoPE: apply rotation to first rope_dims dims, leave remaining dims untouched + q = torch.cat([apply_rotary_emb(q[..., :self.rope_dims], cos, sin), q[..., self.rope_dims:]], dim=-1) + k = torch.cat([apply_rotary_emb(k[..., :self.rope_dims], cos, sin), k[..., self.rope_dims:]], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if _HAS_FA3 and _use_fa3: + # FA3 expects [bsz, seqlen, heads, head_dim] + q_fa = q.transpose(1, 2).contiguous() + k_fa = k.transpose(1, 2).contiguous() + v_fa = v.transpose(1, 2).contiguous() + y = _flash_attn_3_func(q_fa, k_fa, v_fa, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v_fa) + y = y.reshape(bsz, seqlen, dim) + else: + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + y = self._xsa_efficient(y.transpose(1, 2), v.transpose(1, 2)) + else: + y = y.transpose(1, 2) + y = y.contiguous().reshape(bsz, seqlen, dim) + out = self.proj(y) + if lora is not None: + out = out + (y @ lora.proj_A) @ lora.proj_B + return out + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, ntk_base_seq_len=ntk_base_seq_len, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, lora: AttentionLoRA | None = None, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x) * s, lora=lora, v_embed=v_embed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + + +# ── Shared Value Embedding (VE) ────────────────────────────────────────────── +class ValueEmbedding(nn.Module): + """Learned embedding added to attention values in selected layers. + One shared table across layers, with per-layer learned scales.""" + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, kv_dim, bias=False) if ve_dim != kv_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +# ── SmearGate: cheap bigram context at embedding layer ────────────────────── +class SmearGate(nn.Module): + """Blends each token's embedding with the previous token's via per-channel sigmoid gates. + 512 independent channel gates (vs scalar gate) give the model richer bigram context.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + # x: (bsz, seq_len, dim) + prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] # (1, 1, dim) + return torch.lerp(x, prev, g) + + +# ── BigramHash: hash-based bigram embedding ───────────────────────────────── +class BigramHashEmbedding(nn.Module): + """Maps consecutive token pairs to embeddings via a hash table. + Injects explicit bigram statistics into the residual stream.""" + def __init__(self, num_buckets: int, hash_dim: int, model_dim: int): + super().__init__() + self.num_buckets = num_buckets + self.embed = nn.Embedding(num_buckets, hash_dim) + self.proj = nn.Linear(hash_dim, model_dim, bias=False) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + nn.init.normal_(self.embed.weight, std=0.01) + nn.init.zeros_(self.proj.weight) + + def _hash_pair(self, prev_ids: Tensor, cur_ids: Tensor) -> Tensor: + # XOR hash with large primes for better distribution + return (torch.bitwise_xor(36313 * prev_ids.long(), 27191 * cur_ids.long()) % max(self.num_buckets - 1, 1)).to(prev_ids.device) + + def forward(self, input_ids: Tensor) -> Tensor: + # input_ids: (bsz, seq_len) + prev_ids = torch.cat([torch.zeros_like(input_ids[:, :1]), input_ids[:, :-1]], dim=1) + bucket_ids = self._hash_pair(prev_ids, input_ids) + return self.scale.to(dtype=input_ids.dtype if input_ids.is_floating_point() else torch.bfloat16) * self.proj(self.embed(bucket_ids)) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_loops: int = 1, + lora_rank: int = 0, + mlp_hidden: int = 0, + smear_gate: bool = False, + bigram_hash: bool = False, + bigram_hash_buckets: int = 4096, + bigram_hash_dim: int = 128, + ortho_init: bool = True, + xsa_last_n: int = 0, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + unet_skips: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.ortho_init = ortho_init + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_unique_layers = num_layers + self.num_loops = num_loops + effective_depth = num_layers * num_loops + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.smear_gate = SmearGate(model_dim) if smear_gate else None + self.bigram_hash = BigramHashEmbedding(bigram_hash_buckets, bigram_hash_dim, model_dim) if bigram_hash else None + # Shared Value Embedding: one table, added to V in selected layers + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = num_kv_heads * (model_dim // num_heads) + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) if ve_enabled else None + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) if ve_enabled else None + self.unet_skips = unet_skips + self.num_encoder_layers = effective_depth // 2 + self.num_decoder_layers = effective_depth - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) if unet_skips else None + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + mlp_hidden=mlp_hidden, + ntk_base_seq_len=ntk_base_seq_len, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + ) + for i in range(num_layers) + ] + ) + # Per-(loop, block) LoRA adapters for attention projections. + # Only created when num_loops > 1 and lora_rank > 0. + kv_dim = num_kv_heads * (model_dim // num_heads) + if lora_rank > 0 and num_loops > 1: + self.lora_adapters = nn.ModuleList( + [ + nn.ModuleList( + [AttentionLoRA(model_dim, kv_dim, lora_rank) for _ in range(num_layers)] + ) + for _ in range(num_loops) + ] + ) + else: + self.lora_adapters = None + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif self.ortho_init and isinstance(module, (nn.Linear, CastedLinear)) \ + and not getattr(module, "_zero_init", False) \ + and module.weight.ndim >= 2: + nn.init.orthogonal_(module.weight, gain=1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_gate is not None: + x = self.smear_gate(x) + x0 = x + num_enc = self.num_encoder_layers + num_dec = self.num_decoder_layers + + # Compute shared VE once (cached across layers) + ve_base = self.ve_shared(input_ids) if self.ve_shared is not None else None + + def _get_ve(layer_idx: int) -> Tensor | None: + if ve_base is None or layer_idx not in self.ve_layer_indices: + return None + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + if self.unet_skips: + skips: list[Tensor] = [] + for i in range(num_enc): + x = self.blocks[i](x, x0, v_embed=_get_ve(i)) + skips.append(x) + for i in range(num_dec): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[num_enc + i](x, x0, v_embed=_get_ve(num_enc + i)) + else: + eff_idx = 0 + for loop_idx in range(self.num_loops): + for block_idx in range(self.num_unique_layers): + lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None + x = self.blocks[block_idx](x, x0, lora=lora, v_embed=_get_ve(eff_idx)) + eff_idx += 1 + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_gate is not None: + x = self.smear_gate(x) + x0 = x + num_enc = self.num_encoder_layers + num_dec = self.num_decoder_layers + + if self.unet_skips: + skips: list[Tensor] = [] + for i in range(num_enc): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(num_dec): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[num_enc + i](x, x0) + else: + eff_idx = 0 + for loop_idx in range(self.num_loops): + for block_idx in range(self.num_unique_layers): + lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None + x = self.blocks[block_idx](x, x0, lora=lora) + eff_idx += 1 + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +BOS_ID = 1 # SentencePiece BOS token ID + +def _find_docs(all_tokens: Tensor) -> list[tuple[int, int]]: + """Return (start, length) for each document, identified by BOS boundaries. + + Each document starts at a BOS token and extends to just before the next BOS. + The last document extends to the end of the token stream. + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].tolist() + if not bos_positions: + return [(0, all_tokens.numel())] + docs = [] + for i in range(len(bos_positions)): + start = bos_positions[i] + end = bos_positions[i + 1] if i + 1 < len(bos_positions) else all_tokens.numel() + if end - start >= 2: # need at least 2 tokens for (x, y) pair + docs.append((start, end - start)) + return docs + + +def _build_sliding_windows( + total_tokens: int, seq_len: int, stride: int +) -> list[tuple[int, int]]: + """Return (window_start, score_start) pairs covering every token exactly once. + + Every token in [0, total_tokens) is scored by exactly one window. + Full windows score their last `stride` positions (first window scores all seq_len). + One tail-aligned window covers any tokens beyond the last full window's end. + """ + if total_tokens <= 0: + return [] + if total_tokens <= seq_len: + return [(0, 0)] + + windows: list[tuple[int, int]] = [] + last_full_end = 0 + ws = 0 + while ws + seq_len <= total_tokens: + s = 0 if ws == 0 else seq_len - stride + windows.append((ws, s)) + last_full_end = ws + seq_len + ws += stride + + # One tail window ending exactly at total_tokens covers any remainder. + if last_full_end < total_tokens: + tail_ws = total_tokens - seq_len + tail_s = last_full_end - tail_ws # skip already-scored prefix + windows.append((tail_ws, tail_s)) + + return windows + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + + Windows of eval_seq_len advance by `stride`. Only newly covered tokens + contribute to the score (first window scores all seq_len; non-first full + windows score the last `stride` tokens; one tail window covers any remainder). + Every validation token is counted exactly once. + """ + seq_len = args.eval_seq_len + total_tokens = val_tokens.numel() - 1 + + if args.doc_isolated_eval: + # Build windows per document — context never crosses document boundaries. + # Each window's (ws, s) is in absolute token-stream coordinates. + docs = _find_docs(val_tokens) + all_windows: list[tuple[int, int]] = [] + for doc_start, doc_len in docs: + doc_pred_len = doc_len - 1 # number of prediction positions + doc_windows = _build_sliding_windows(doc_pred_len, seq_len, stride) + for ws, s in doc_windows: + all_windows.append((doc_start + ws, s)) + else: + all_windows = _build_sliding_windows(total_tokens, seq_len, stride) + total_windows = len(all_windows) + + # Distribute across ranks + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = all_windows[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_items = my_windows[bi:bi + batch_seqs] + bsz = len(batch_items) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + score_starts: list[int] = [] + + for i, (ws, s) in enumerate(batch_items): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + score_starts.append(s) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, (ws, s) in enumerate(batch_items): + wlen = wlens[i] + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Progress (rank 0 only) + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def _ttt_run_phase( + model: nn.Module, + ttt_params: list[torch.nn.Parameter], + optimizer: torch.optim.Optimizer, + val_tokens: torch.Tensor, + seq_len: int, + batch_seqs: int, + epochs: int, + max_steps: int, + device: torch.device, + rank: int, + world_size: int, + phase_name: str, + t0: float, + cosine: bool = False, + warmup_frac: float = 0.0, +) -> None: + """Run one TTT phase with DDP gradient sharding across GPUs. + Each GPU processes batch_seqs sequences, gradients are manually all_reduced. + Supports cosine lr decay and linear warmup.""" + distributed = world_size > 1 + n_tokens = val_tokens.numel() + total_seqs = (n_tokens - 1) // seq_len + my_start_seq = (total_seqs * rank) // world_size + my_end_seq = (total_seqs * (rank + 1)) // world_size + + # Store initial lr for cosine/warmup scheduling + if cosine or warmup_frac > 0: + for g in optimizer.param_groups: + g["initial_lr"] = g["lr"] + # Estimate actual steps per epoch for cosine schedule + steps_per_epoch = min((my_end_seq - my_start_seq) // max(batch_seqs, 1), max_steps) + total_steps = epochs * steps_per_epoch + global_step = 0 + + model.train() + for epoch in range(epochs): + epoch_loss = torch.zeros((), device=device, dtype=torch.float64) + epoch_tokens = torch.zeros((), device=device, dtype=torch.float64) + step_i = 0 + # Contiguous slicing over this GPU's shard (matches #398 pattern) + for batch_start in range(my_start_seq, my_end_seq, batch_seqs): + if step_i >= max_steps: + break + + # LR schedule: warmup then cosine decay + if (cosine or warmup_frac > 0) and total_steps > 0: + progress = global_step / total_steps + if warmup_frac > 0 and progress < warmup_frac: + mul = progress / warmup_frac + elif cosine: + cos_start = warmup_frac if warmup_frac > 0 else 0.0 + cos_progress = (progress - cos_start) / (1.0 - cos_start) if cos_start < 1.0 else 0.0 + mul = 0.5 * (1.0 + math.cos(math.pi * min(cos_progress, 1.0))) + else: + mul = 1.0 + for g in optimizer.param_groups: + g["lr"] = g["initial_lr"] * mul + + batch_end = min(batch_start + batch_seqs, my_end_seq) + raw_start = batch_start * seq_len + raw_end = batch_end * seq_len + 1 + if raw_end > n_tokens: + break + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + loss.backward() + + if distributed: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + epoch_loss += loss.detach().to(torch.float64) * x.numel() + epoch_tokens += x.numel() + step_i += 1 + global_step += 1 + + if distributed: + dist.all_reduce(epoch_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(epoch_tokens, op=dist.ReduceOp.SUM) + + if rank == 0: + avg_loss = epoch_loss.item() / max(epoch_tokens.item(), 1) + cur_lr = optimizer.param_groups[0]["lr"] + print(f"ttt_{phase_name} epoch:{epoch+1}/{epochs} loss:{avg_loss:.4f} " + f"lr:{cur_lr:.6f} steps:{step_i} time:{time.perf_counter()-t0:.1f}s", flush=True) + + +def ttt_adapt( + args: Hyperparameters, + model: nn.Module, + device: torch.device, + val_tokens: torch.Tensor, + rank: int = 0, + world_size: int = 1, +) -> None: + """Two-phase TTT with DDP gradient sharding. + + Phase 1 (norm-only): Fix quantization artifacts by adapting only LayerNorm + weights, scales, resid_mix, and q_gain (~22K params). Low risk, high epoch count. + + Phase 2 (selective blocks): Adapt last N blocks + norms + head to val distribution. + Higher risk, lower epoch count. + + Falls back to single-phase SGD if TTT_TWO_PHASE=0. + """ + t0 = time.perf_counter() + seq_len = args.train_seq_len + batch_seqs = args.ttt_batch_seqs + + if args.ttt_two_phase: + # ── Phase 1: Norm-only recalibration ──────────────────────────── + # Freeze everything except norms, scales, resid_mix, q_gain + norm_params = [] + for p in model.parameters(): + p.requires_grad_(False) + for name, p in model.named_parameters(): + if any(k in name for k in ("norm", "scale", "resid_mix", "q_gain", "skip_weight")): + p.requires_grad_(True) + norm_params.append(p) + n_norm = sum(p.numel() for p in norm_params) + if rank == 0: + print(f"ttt_phase1:start params:{n_norm} epochs:{args.ttt_p1_epochs} lr:{args.ttt_p1_lr}", flush=True) + + optimizer_p1 = torch.optim.Adam(norm_params, lr=args.ttt_p1_lr) + _ttt_run_phase( + model, norm_params, optimizer_p1, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_p1_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="phase1", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer_p1 + + # ── Phase 2: Selective block adaptation ───────────────────────── + # Unfreeze last N blocks + all norms + head + embeddings + for p in model.parameters(): + p.requires_grad_(False) + num_layers = len(list(model.blocks)) # type: ignore[attr-defined] + phase2_params = [] + for name, p in model.named_parameters(): + is_late_block = False + for i in range(max(0, num_layers - args.ttt_p2_unfreeze_blocks), num_layers): + if f"blocks.{i}." in name: + is_late_block = True + break + is_norm_or_scale = any(k in name for k in ("norm", "scale", "resid_mix", "q_gain", "skip_weight")) + is_head = "lm_head" in name or "tok_emb" in name + if is_late_block or is_norm_or_scale or is_head: + p.requires_grad_(True) + phase2_params.append(p) + n_p2 = sum(p.numel() for p in phase2_params) + if rank == 0: + print(f"ttt_phase2:start params:{n_p2} epochs:{args.ttt_p2_epochs} lr:{args.ttt_p2_lr}", flush=True) + + if args.ttt_optimizer == "adamw": + optimizer_p2 = torch.optim.AdamW(phase2_params, lr=args.ttt_p2_lr, weight_decay=0.0) + else: + optimizer_p2 = torch.optim.SGD(phase2_params, lr=args.ttt_p2_lr, momentum=args.ttt_momentum) + _ttt_run_phase( + model, phase2_params, optimizer_p2, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_p2_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="phase2", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer_p2 + else: + # ── Single-phase TTT with cosine + per-layer lr ─────────────── + frozen = set() + for i, block in enumerate(model.blocks): # type: ignore[attr-defined] + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + frozen.add(id(p)) + + if args.ttt_perlayer: + # Per-layer lr: higher for more quant-damaged MLP proj, lower for fc + proj_params = [p for n, p in model.named_parameters() + if "mlp.proj" in n and p.requires_grad and id(p) not in frozen] + fc_params = [p for n, p in model.named_parameters() + if "mlp.fc" in n and p.requires_grad and id(p) not in frozen] + other_params = [p for p in model.parameters() + if p.requires_grad and id(p) not in frozen + and id(p) not in {id(q) for q in proj_params + fc_params}] + param_groups = [ + {"params": proj_params, "lr": args.ttt_lr * 3.0}, + {"params": fc_params, "lr": args.ttt_lr * 0.5}, + {"params": other_params, "lr": args.ttt_lr}, + ] + param_groups = [g for g in param_groups if g["params"]] + ttt_params = proj_params + fc_params + other_params + else: + ttt_params = [p for p in model.parameters() if p.requires_grad and id(p) not in frozen] + param_groups = [{"params": ttt_params, "lr": args.ttt_lr}] + + if rank == 0: + n_ttt = sum(p.numel() for p in ttt_params) + print(f"ttt:start params:{n_ttt} epochs:{args.ttt_epochs} lr:{args.ttt_lr} " + f"freeze:{args.ttt_freeze_blocks} optimizer:{args.ttt_optimizer} " + f"cosine:{args.ttt_cosine} warmup:{args.ttt_warmup_frac} perlayer:{args.ttt_perlayer}", flush=True) + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(param_groups, momentum=args.ttt_momentum) + _ttt_run_phase( + model, ttt_params, optimizer, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="single", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer + + # Unfreeze all params for eval + for p in model.parameters(): + p.requires_grad_(True) + if rank == 0: + print(f"ttt:done elapsed={time.perf_counter()-t0:.1f}s", flush=True) + + +def reptile_ttt( + args: Hyperparameters, + model: nn.Module, + device: torch.device, + val_tokens: torch.Tensor, + rank: int = 0, +) -> None: + """Reptile meta-TTT: find weights that adapt fast to val distribution. + Runs after EMA/SWA, before standard TTT. Makes TTT ~10x more effective.""" + t0 = time.perf_counter() + seq_len = args.train_seq_len + n_tokens = val_tokens.numel() + + # Only adapt MLP params of last 1/4 blocks + num_blocks = len(model.blocks) + suffix_start = num_blocks - num_blocks // 4 + ttt_params = {} + for name, p in model.named_parameters(): + if any(f'blocks.{i}.' in name and '.mlp.' in name for i in range(suffix_start, num_blocks)): + ttt_params[name] = p + + base_state = {name: p.data.clone() for name, p in ttt_params.items()} + reptile_steps = 0 + + while (time.perf_counter() - t0) < args.reptile_budget_s: + # Save current params + saved = {name: p.data.clone() for name, p in ttt_params.items()} + + # Inner loop: N SGD steps on a random chunk + model.train() + start = random.randint(0, max(n_tokens - seq_len - 1, 0)) + chunk = val_tokens[start:start + seq_len + 1].to(device=device, dtype=torch.int64) + x = chunk[:-1].unsqueeze(0) + y = chunk[1:].unsqueeze(0) + + for inner_step in range(args.reptile_inner_steps): + model.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + loss.backward() + with torch.no_grad(): + for name, param in ttt_params.items(): + if param.grad is not None: + param.data -= args.reptile_inner_lr * param.grad + + # Outer loop: move base toward adapted params + with torch.no_grad(): + for name, param in ttt_params.items(): + base_state[name] += args.reptile_outer_lr * (param.data - base_state[name]) + param.data.copy_(base_state[name]) + + reptile_steps += 1 + + if rank == 0: + print(f"reptile_ttt:done steps:{reptile_steps} elapsed:{time.perf_counter()-t0:.1f}s", flush=True) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + # Set module-level flag so CausalSelfAttention.forward can use FA3. + global _use_fa3 + _use_fa3 = args.use_fa3 and _HAS_FA3 + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.eval_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_loops=args.num_loops, + lora_rank=args.lora_rank, + mlp_hidden=args.mlp_hidden, + smear_gate=args.smear_gate, + bigram_hash=args.bigram_hash, + bigram_hash_buckets=args.bigram_hash_buckets, + bigram_hash_dim=args.bigram_hash_dim, + ortho_init=args.ortho_init, + xsa_last_n=args.xsa_last_n, + ntk_base_seq_len=args.train_seq_len if args.eval_seq_len > args.train_seq_len else 0, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + unet_skips=args.unet_skips, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + if args._tier2: + log0(f"*** TIER2_MODE: proxy run max={args.max_wallclock_seconds:.0f}s iters={args.iterations} " + f"ema={args.ema_enabled} ttt={args.ttt_enabled} qat={args.qat} " + f"-- compare val_bpb@step2000 against baseline tier2 run ***") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"xsa_last_n:{args.xsa_last_n} active_layers:{xsa_layers}") + ntk_active = args.eval_seq_len > args.train_seq_len + log0(f"ntk_rope:{'enabled' if ntk_active else 'disabled'} train_seq_len:{args.train_seq_len} eval_seq_len:{args.eval_seq_len}") + head_dim = args.model_dim // args.num_heads + _rope_dims_active = args.rope_dims > 0 and args.rope_dims < head_dim + log0(f"partial_rope:{'enabled' if _rope_dims_active else 'disabled'} rope_dims:{args.rope_dims if _rope_dims_active else head_dim}/{head_dim} (ROPE_DIMS={args.rope_dims})") + log0(f"ln_scale:{'enabled' if args.ln_scale else 'disabled'} (scale RMSNorm output by 1/sqrt(layer_idx+1))") + log0(f"unet_skips:{'enabled' if args.unet_skips else 'disabled'} (U-Net skip connections, enc={base_model.num_encoder_layers} dec={base_model.num_decoder_layers})") + if args.ve_enabled: + log0(f"ve:enabled dim={args.ve_dim} layers={base_model.ve_layer_indices}") + else: + log0("ve:disabled") + + log0(f"smear_gate:{args.smear_gate} bigram_hash:{args.bigram_hash} swa:{args.swa} " + f"ortho_init:{args.ortho_init} late_k_fp16:{args.late_k_fp16} " + f"fa3:{_use_fa3}(available={_HAS_FA3}) muon_wd:{args.muon_wd} adam_wd:{args.adam_wd}") + + # FP16 tied embedding export: skip int8 quantization for tok_emb.weight at export time. + # Avoids compounding int8 errors through both input embedding and output projection. + if args.fp16_embed_export and args.tie_embeddings: + _FP16_EXPORT_NAMES.add("tok_emb.weight") + log0(f"fp16_embed_export:enabled (tok_emb.weight kept in fp16, ~{args.vocab_size * args.model_dim * 2 / 1024:.0f}KB)") + + # Late-K: keep K projections of last 2 layers in fp16 (not quantized). + # Saves per-query context accuracy where it matters most — near the output. + if args.late_k_fp16: + effective_depth = args.num_layers * args.num_loops + for layer_idx in range(effective_depth - 2, effective_depth): + block_idx = layer_idx % args.num_layers + key_name = f"blocks.{block_idx}.attn.c_k.weight" + _FP16_EXPORT_NAMES.add(key_name) + log0(f"late_k_fp16:enabled (last 2 effective layers' c_k.weight kept in fp16)") + + for module in base_model.modules(): + if isinstance(module, (CastedLinear, AttentionLoRA)): + module.float() + restore_low_dim_params_to_fp32(base_model) + log0(f"qat:{args.qat} (activates when lr_scale < 0.1; absmax int6 STE)") + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + # bigram_hash.proj is a dense 2D projection — Muon is appropriate + if base_model.bigram_hash is not None: + matrix_params.append(base_model.bigram_hash.proj.weight) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights is not None: + scalar_params.append(base_model.skip_weights) + # smear_gate.gate is now a single nn.Parameter(dim) — AdamW at scalar_lr + if base_model.smear_gate is not None: + scalar_params.append(base_model.smear_gate.gate) + # bigram_hash.scale is a learned scalar — AdamW at scalar_lr + if base_model.bigram_hash is not None: + scalar_params.append(base_model.bigram_hash.scale) + # VE: scales go to scalar, proj to matrix, embed to tok group + if base_model.ve_shared is not None: + scalar_params.append(base_model.ve_shared.scale) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + if base_model.ve_layer_scales is not None: + for vs in base_model.ve_layer_scales: + scalar_params.append(vs) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + # bigram_hash.embed is an embedding table — train with AdamW alongside tok_emb + embed_params = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.embed.weight) + if base_model.ve_shared is not None: + embed_params.append(base_model.ve_shared.embed.weight) + optimizer_tok = torch.optim.AdamW( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lora_adapters is not None: + lora_params = list(base_model.lora_adapters.parameters()) + optimizer_lora = torch.optim.Adam( + [{"params": lora_params, "lr": args.lora_lr, "base_lr": args.lora_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.append(optimizer_lora) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + n_lora = sum(p.numel() for p in base_model.lora_adapters.parameters()) if base_model.lora_adapters is not None else 0 + effective_depth = args.num_layers * args.num_loops + log0(f"model_params:{n_params} (unique_layers:{args.num_layers} loops:{args.num_loops} effective_depth:{effective_depth} lora_rank:{args.lora_rank} lora_params:{n_lora})") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + # Curriculum learning: sort shards by average document length (easy first). + curriculum_files: list[Path] | None = None + if args.curriculum: + train_file_list = [Path(p) for p in sorted(glob.glob(args.train_files))] + curriculum_files = sort_shards_by_doc_length(train_file_list) + log0(f"curriculum:enabled shards_sorted_by_doc_length ({len(curriculum_files)} shards)") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + sorted_files=curriculum_files) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + # NOTE: QAT graph priming removed — it caused torch.compile to use a slower + # compilation path for the non-QAT forward pass (step_avg jumped from 44ms to 58ms). + # The one-time recompile when QAT activates (~30-90s) is cheaper than the cumulative + # overhead of a slower non-QAT path across thousands of steps. + + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + sorted_files=curriculum_files) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + qat_active = False + + # EMA: exponential moving average — smoother than SWA, better quantization compression. + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + log0(f"ema:initialized decay={args.ema_decay}") + + # SWA: fallback if EMA disabled + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + + # EMA: update every step during training + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for k, v in base_model.state_dict().items(): + ema_state[k].mul_(d).add_(v.detach().float(), alpha=1.0 - d) + + # LR-scale-based QAT activation: activate when lr_scale < 0.1 (last ~10% of warmdown, + # ~300 steps). Zero overhead for 90%+ of training; absmax scale makes per-step cost minimal. + if args.qat and not qat_active: + if scale < 0.1: + qat_active = True + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module._qat = True + elapsed_s = training_time_ms / 1000.0 + log0(f"qat_activated step:{step}/{args.iterations} lr_scale:{scale:.4f} elapsed:{elapsed_s:.1f}s remaining:{args.max_wallclock_seconds - elapsed_s:.1f}s") + + # SWA: accumulate weight averages during warmdown for smoother quantization. + # Accumulate in float32 to avoid bf16 precision loss over thousands of additions. + # Sample every 200 steps for sufficient checkpoint diversity. + if args.swa and step >= int(args.iterations * args.swa_start_frac) and step % 200 == 0: + if swa_state is None: + swa_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa_started step:{step}") + else: + swa_count += 1 + for k, v in base_model.state_dict().items(): + swa_state[k] += v.detach().float() + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + # Apply EMA weights (preferred) or SWA fallback before serialization. + if ema_state is not None: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + base_model.load_state_dict( + {k: v.to(dtype=current_state[k].dtype) for k, v in ema_state.items()}, strict=True + ) + del ema_state + elif args.swa and swa_state is not None and swa_count > 1: + log0(f"swa_applied count:{swa_count}") + current_state = base_model.state_dict() + avg_state = {} + for k, v in swa_state.items(): + avg = v / swa_count + avg_state[k] = avg.to(dtype=current_state[k].dtype) + base_model.load_state_dict(avg_state, strict=True) + + # Reptile meta-TTT: makes subsequent TTT ~10x more effective by finding weights that adapt fast. + if args.reptile_enabled: + log0(f"reptile_ttt:start budget={args.reptile_budget_s:.0f}s inner_steps={args.reptile_inner_steps} inner_lr={args.reptile_inner_lr} outer_lr={args.reptile_outer_lr}") + reptile_ttt(args, base_model, device, val_tokens, rank=rank) + + # TTT: adapt to val distribution before eval + if args.ttt_enabled: + if args.ttt_two_phase: + log0(f"ttt:start two_phase p1_epochs={args.ttt_p1_epochs} p1_lr={args.ttt_p1_lr} " + f"p2_epochs={args.ttt_p2_epochs} p2_lr={args.ttt_p2_lr} p2_blocks={args.ttt_p2_unfreeze_blocks} " + f"batch_seqs={args.ttt_batch_seqs}") + else: + log0(f"ttt:start lr={args.ttt_lr} epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"batch_seqs={args.ttt_batch_seqs}") + ttt_adapt(args, base_model, device, val_tokens, rank=rank, world_size=world_size) + + # Magnitude pruning: zero the smallest weights for better zstd compression. + # Zeroed weights compress to nearly nothing. Applied after TTT, before serialization. + if args.prune_pct > 0.0: + pruned_count = 0 + total_count = 0 + with torch.no_grad(): + for name, p in base_model.named_parameters(): + if p.ndim >= 2 and p.numel() >= 65536: + threshold = torch.quantile(p.abs().float().flatten(), args.prune_pct / 100.0) + mask = p.abs() > threshold + pruned_count += (~mask).sum().item() + total_count += p.numel() + p.mul_(mask) + log0(f"prune:{args.prune_pct:.1f}% zeroed {pruned_count}/{total_count} weights ({100*pruned_count/max(total_count,1):.1f}%)") + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + global _gptq_lite + _gptq_lite = args.gptq_lite + log0(f"quantization: {args.quant_bits}-bit gptq_lite:{_gptq_lite}") + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict(), bits=args.quant_bits) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if args.use_zstd and _HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_method = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_method = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+{compress_method}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+{compress_method}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + # Decompress with the same method used for compression. + if args.use_zstd and _HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_raw_disk = dctx.decompress(quant_blob_disk) + else: + quant_raw_disk = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_raw_disk), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.eval_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs} doc_isolated:{args.doc_isolated_eval}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_seed1337.log b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_seed1337.log new file mode 100644 index 000000000..8179cb585 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_seed1337.log @@ -0,0 +1,2430 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn import flash_attn_func as _flash_attn_3_func + _HAS_FA3 = True # FA2 fallback — same input format, slightly slower than FA3 + except ImportError: + _HAS_FA3 = False +_use_fa3: bool = False # set at runtime after args are parsed + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) # eval window; NTK-RoPE scales if > train_seq_len + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # ── Tier-2 proxy mode ──────────────────────────────────────────────────── + # TIER2_MODE=1 overrides key settings for a fast 3-minute proxy run. + # Use this to validate architectural features before committing a full run. + # Schedule-dependent features (EMA, TTT, SWA) are disabled — they can only + # be evaluated meaningfully at full training duration. + # Compare val_bpb at step ~2000 against a baseline TIER2_MODE=1 run. + _tier2 = bool(int(os.environ.get("TIER2_MODE", "0"))) + if _tier2: + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 180.0)) + iterations = int(os.environ.get("ITERATIONS", 3000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 500)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) # faster final eval + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) # 0 = use mlp_mult * model_dim + fp16_embed_export = bool(int(os.environ.get("FP16_EMBED_EXPORT", "1"))) # keep tok_emb in fp16 at export + num_loops = int(os.environ.get("NUM_LOOPS", 1)) + lora_rank = int(os.environ.get("LORA_RANK", 0)) + qat = bool(int(os.environ.get("QAT", "1"))) + qat_min_seconds = float(os.environ.get("QAT_MIN_SECONDS", 120.0)) # guarantee QAT runs for at least this long + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + doc_isolated_eval = bool(int(os.environ.get("DOC_ISOLATED_EVAL", "1"))) # eval per-document, no cross-doc context + smear_gate = bool(int(os.environ.get("SMEAR_GATE", "1"))) # cheap bigram context at embedding layer + bigram_hash = bool(int(os.environ.get("BIGRAM_HASH", "1"))) # hash-based bigram embedding + bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 2048)) + bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128)) + swa = bool(int(os.environ.get("SWA", "0"))) # stochastic weight averaging (disabled: EMA preferred) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last N layers (0=disabled) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) # exponential moving average (replaces SWA) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) # EMA decay per step + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) # test-time training on val data + ttt_lr = float(os.environ.get("TTT_LR", 0.0005)) # AdamW lr (PR #442: AdamW beats SGD by 0.019 BPB) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 30)) # 30ep cosine beats 10ep flat by 16% + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) # only used if TTT_OPTIMIZER=sgd + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) # 0 = all blocks unfrozen (PR #398) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adamw") # "adamw" or "sgd" + ttt_max_steps = int(os.environ.get("TTT_MAX_STEPS", 300)) # cap steps per epoch (~10s per epoch) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 64)) # seqs per GPU per step (64*8=512 total) + ttt_cosine = bool(int(os.environ.get("TTT_COSINE", "1"))) # cosine lr decay during TTT (+16% over flat) + ttt_warmup_frac = float(os.environ.get("TTT_WARMUP_FRAC", 0.0)) # linear warmup fraction (0.1 = 10%) + ttt_perlayer = bool(int(os.environ.get("TTT_PERLAYER", "1"))) # per-layer lr (3x proj, 0.5x fc) — +23.5% combined with cosine + # Two-phase TTT (matches PR #415/#417 approach) + ttt_two_phase = bool(int(os.environ.get("TTT_TWO_PHASE", "0"))) # enable two-phase TTT + ttt_p1_epochs = int(os.environ.get("TTT_P1_EPOCHS", 50)) # phase 1: norm-only recalibration + ttt_p1_lr = float(os.environ.get("TTT_P1_LR", 0.01)) # phase 1: Adam lr + ttt_p2_epochs = int(os.environ.get("TTT_P2_EPOCHS", 10)) # phase 2: selective block adaptation + ttt_p2_lr = float(os.environ.get("TTT_P2_LR", 0.005)) # phase 2: SGD lr + ttt_p2_unfreeze_blocks = int(os.environ.get("TTT_P2_UNFREEZE_BLOCKS", 3)) # phase 2: unfreeze last N blocks + use_zstd = bool(int(os.environ.get("USE_ZSTD", "1"))) # use zstd instead of zlib for compression + curriculum = bool(int(os.environ.get("CURRICULUM", "0"))) # sort training shards by doc length (easy first) + quant_bits = int(os.environ.get("QUANT_BITS", 6)) # 8=int8, 6=int6 (int6 fits ~3x more params in 16MB) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 50000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + lora_lr = float(os.environ.get("LORA_LR", 0.01)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + ortho_init = bool(int(os.environ.get("ORTHO_INIT", "1"))) + late_k_fp16 = bool(int(os.environ.get("LATE_K_FP16", "1"))) + use_fa3 = bool(int(os.environ.get("USE_FA3", "1"))) + rope_dims = int(os.environ.get("ROPE_DIMS", 0)) # 0 = full RoPE; >0 = apply RoPE to only first N dims per head + ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) # scale block output by 1/sqrt(layer_idx+1) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "0"))) # U-Net style skip connections + prune_pct = float(os.environ.get("PRUNE_PCT", 0.0)) # magnitude pruning: zero smallest N% of weights before compression + gptq_lite = bool(int(os.environ.get("GPTQ_LITE", "1"))) # per-row optimal clip search at quantization (default ON, zero training cost) + reptile_enabled = bool(int(os.environ.get("REPTILE_TTT", "0"))) # Reptile meta-TTT before standard TTT + reptile_budget_s = float(os.environ.get("REPTILE_BUDGET_S", 60.0)) # seconds for Reptile meta-learning + reptile_inner_steps = int(os.environ.get("REPTILE_INNER_STEPS", 3)) # SGD steps per inner loop + reptile_inner_lr = float(os.environ.get("REPTILE_INNER_LR", 0.1)) # inner SGD learning rate + reptile_outer_lr = float(os.environ.get("REPTILE_OUTER_LR", 0.01)) # outer interpolation rate + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "0"))) # shared value embedding + ve_dim = int(os.environ.get("VE_DIM", 128)) # value embedding dimension + ve_layers = os.environ.get("VE_LAYERS", "9,10") # comma-separated layer indices + + # Disable schedule-dependent features in TIER2_MODE unless explicitly overridden + if _tier2: + qat = bool(int(os.environ.get("QAT", "0"))) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + swa = bool(int(os.environ.get("SWA", "0"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + # Cache flat buffer to avoid per-step allocation + if "_updates_flat" not in group: + group["_total_params"] = sum(int(p.numel()) for p in params) + group["_updates_flat"] = torch.zeros(group["_total_params"], device=params[0].device, dtype=torch.bfloat16) + total_params = group["_total_params"] + updates_flat = group["_updates_flat"] + updates_flat.zero_() + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + if wd > 0.0: + p.mul_(1.0 - wd * lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_2d_with_clip(t32: Tensor, clip_abs: Tensor, max_val: int) -> tuple[Tensor, Tensor, Tensor]: + """Quantize a 2D tensor with given per-row clip values. Returns (q, scale, reconstruction_error).""" + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8) + # Reconstruction error: MSE between original and dequantized + deq = q.float() * scale[:, None] + err = (t32 - deq).square().mean(dim=1) + return q, scale, err + + +# GPTQ-lite clip ratios: search these percentiles per weight matrix to find optimal clipping. +_GPTQ_CLIP_RATIOS = [0.9, 0.95, 0.99, 0.999, 0.99999] +_gptq_lite: bool = False # set at runtime from GPTQ_LITE env var + + +def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]: + max_val = 127 if bits == 8 else (2 ** (bits - 1)) - 1 # int6: 31, int8: 127 + t32 = t.float() + if t32.ndim == 2: + if _gptq_lite and t32.numel() > 0: + # GPTQ-lite: try multiple clip percentiles per row, pick lowest reconstruction error. + abs_vals = t32.abs() + best_q: Tensor | None = None + best_scale: Tensor | None = None + best_err: Tensor | None = None + for ratio in _GPTQ_CLIP_RATIOS: + clip_abs = torch.quantile(abs_vals, ratio, dim=1) + q, scale, err = _quantize_2d_with_clip(t32, clip_abs, max_val) + if best_err is None: + best_q, best_scale, best_err = q, scale, err + else: + # Per-row: keep whichever clip ratio gave lower error for each row + better = err < best_err + best_q[better] = q[better] + best_scale[better] = scale[better] + best_err[better] = err[better] + return best_q.contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Standard: single clip percentile + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(max_val) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous() + return q, scale + +# Names of tensors to keep in fp16 at export instead of quantizing to int8. +# Populated at runtime when fp16_embed_export=True. +_FP16_EXPORT_NAMES: set[str] = set() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor], bits: int = 8): + # Clean-script export format supporting int8 or int6: + # - per-row quantization for 2D float tensors (int8: [-127,127], int6: [-31,31]) + # - per-tensor quantization for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + # - fp16 passthrough for tensors in _FP16_EXPORT_NAMES (e.g. tied embeddings) + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # FP16 export bypass: keep specified tensors (e.g. tied embeddings) in fp16 + # instead of quantizing to int8. Avoids compounding int8 errors through both + # the input embedding and output projection paths. + if name in _FP16_EXPORT_NAMES: + kept = t.to(dtype=torch.float16).contiguous() + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t, bits=bits) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" list[Path]: + """Sort shards by average document length (shorter docs = easier = first). + + Reads only the first 100K tokens of each shard to estimate avg doc length. + Shards with shorter average documents contain simpler, more repetitive text + that helps the model learn basic patterns before encountering harder material. + """ + difficulties: list[tuple[float, Path]] = [] + for f in files: + header = np.fromfile(f, dtype=" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device, + sorted_files: list[Path] | None = None): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern, sorted_files=sorted_files) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +def fake_quantize_int8_per_row(w: Tensor) -> Tensor: + """Simulate per-row int6 quantization with straight-through estimator. + + Uses absmax per-row scale (no torch.quantile — O(n) instead of O(n log n)). + Matches the int6 export range [-31, 31]. + Backward: gradients pass through as if no quantization happened (STE). + """ + w32 = w.float() + scale = (w32.abs().amax(dim=1) / 31.0).clamp_min(1.0 / 31.0) + w_q = torch.clamp(torch.round(w32 / scale[:, None]), -31, 31) + w_deq = w_q * scale[:, None] + return w + (w_deq.to(w.dtype) - w).detach() + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + _qat: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self._qat and self.training: + w = fake_quantize_int8_per_row(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +class AttentionLoRA(nn.Module): + """Per-iteration LoRA adapters for attention Q, K, V, and output projections. + + Initialized so that the LoRA contribution is zero at the start of training + (B matrices are zeros). During training, the optimizer learns per-iteration + specialization while the base attention weights remain shared across loops. + """ + def __init__(self, dim: int, kv_dim: int, rank: int): + super().__init__() + self.q_A = nn.Parameter(torch.empty(dim, rank)) + self.q_B = nn.Parameter(torch.zeros(rank, dim)) + self.k_A = nn.Parameter(torch.empty(dim, rank)) + self.k_B = nn.Parameter(torch.zeros(rank, kv_dim)) + self.v_A = nn.Parameter(torch.empty(dim, rank)) + self.v_B = nn.Parameter(torch.zeros(rank, kv_dim)) + self.proj_A = nn.Parameter(torch.empty(dim, rank)) + self.proj_B = nn.Parameter(torch.zeros(rank, dim)) + self._init_lora() + + def _init_lora(self) -> None: + for name in ("q_A", "k_A", "v_A", "proj_A"): + nn.init.kaiming_uniform_(getattr(self, name), a=math.sqrt(5)) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + # ntk_base_seq_len: if > 0, apply NTK-aware RoPE scaling when seq_len > ntk_base_seq_len + # (lets a model trained at seq_len=1024 generalise to seq_len=2048 at eval with no quality loss) + def __init__(self, dim: int, base: float = 10000.0, ntk_base_seq_len: int = 0): + super().__init__() + self._dim = dim + self._base = base + self._ntk_base_seq_len = ntk_base_seq_len + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + if self._ntk_base_seq_len > 0 and seq_len > self._ntk_base_seq_len: + # NTK-aware scaling: extend context without fine-tuning + scale = seq_len / self._ntk_base_seq_len + ntk_base = self._base * (scale ** (self._dim / (self._dim - 2))) + inv_freq = 1.0 / (ntk_base ** (torch.arange(0, self._dim, 2, dtype=torch.float32, device=device) / self._dim)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + # rope_dims: if >0 apply RoPE only to first rope_dims dims of each head; rest are position-free + self.rope_dims = rope_dims if rope_dims > 0 else self.head_dim + self.rotary = Rotary(self.rope_dims, base=rope_base, ntk_base_seq_len=ntk_base_seq_len) + self.use_xsa = False # enabled on last N layers by GPT.__init__ + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Exclusive Self Attention: subtract self-value projection (arXiv:2603.09078). + GQA-aware reshape avoids repeat_interleave — zero extra allocation.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) # (B, T, Hkv, 1, D) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, lora: AttentionLoRA | None = None, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + if lora is not None: + # LoRA delta: (bsz, seqlen, dim) @ (dim, rank) @ (rank, out_dim) + # autocast handles fp32->bf16 cast of LoRA params automatically + q = q + (x @ lora.q_A) @ lora.q_B + k = k + (x @ lora.k_A) @ lora.k_B + v = v + (x @ lora.v_A) @ lora.v_B + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dims < self.head_dim: + # Partial RoPE: apply rotation to first rope_dims dims, leave remaining dims untouched + q = torch.cat([apply_rotary_emb(q[..., :self.rope_dims], cos, sin), q[..., self.rope_dims:]], dim=-1) + k = torch.cat([apply_rotary_emb(k[..., :self.rope_dims], cos, sin), k[..., self.rope_dims:]], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if _HAS_FA3 and _use_fa3: + # FA3 expects [bsz, seqlen, heads, head_dim] + q_fa = q.transpose(1, 2).contiguous() + k_fa = k.transpose(1, 2).contiguous() + v_fa = v.transpose(1, 2).contiguous() + y = _flash_attn_3_func(q_fa, k_fa, v_fa, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v_fa) + y = y.reshape(bsz, seqlen, dim) + else: + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + y = self._xsa_efficient(y.transpose(1, 2), v.transpose(1, 2)) + else: + y = y.transpose(1, 2) + y = y.contiguous().reshape(bsz, seqlen, dim) + out = self.proj(y) + if lora is not None: + out = out + (y @ lora.proj_A) @ lora.proj_B + return out + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, ntk_base_seq_len=ntk_base_seq_len, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, lora: AttentionLoRA | None = None, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x) * s, lora=lora, v_embed=v_embed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + + +# ── Shared Value Embedding (VE) ────────────────────────────────────────────── +class ValueEmbedding(nn.Module): + """Learned embedding added to attention values in selected layers. + One shared table across layers, with per-layer learned scales.""" + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, kv_dim, bias=False) if ve_dim != kv_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +# ── SmearGate: cheap bigram context at embedding layer ────────────────────── +class SmearGate(nn.Module): + """Blends each token's embedding with the previous token's via per-channel sigmoid gates. + 512 independent channel gates (vs scalar gate) give the model richer bigram context.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + # x: (bsz, seq_len, dim) + prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] # (1, 1, dim) + return torch.lerp(x, prev, g) + + +# ── BigramHash: hash-based bigram embedding ───────────────────────────────── +class BigramHashEmbedding(nn.Module): + """Maps consecutive token pairs to embeddings via a hash table. + Injects explicit bigram statistics into the residual stream.""" + def __init__(self, num_buckets: int, hash_dim: int, model_dim: int): + super().__init__() + self.num_buckets = num_buckets + self.embed = nn.Embedding(num_buckets, hash_dim) + self.proj = nn.Linear(hash_dim, model_dim, bias=False) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + nn.init.normal_(self.embed.weight, std=0.01) + nn.init.zeros_(self.proj.weight) + + def _hash_pair(self, prev_ids: Tensor, cur_ids: Tensor) -> Tensor: + # XOR hash with large primes for better distribution + return (torch.bitwise_xor(36313 * prev_ids.long(), 27191 * cur_ids.long()) % max(self.num_buckets - 1, 1)).to(prev_ids.device) + + def forward(self, input_ids: Tensor) -> Tensor: + # input_ids: (bsz, seq_len) + prev_ids = torch.cat([torch.zeros_like(input_ids[:, :1]), input_ids[:, :-1]], dim=1) + bucket_ids = self._hash_pair(prev_ids, input_ids) + return self.scale.to(dtype=input_ids.dtype if input_ids.is_floating_point() else torch.bfloat16) * self.proj(self.embed(bucket_ids)) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_loops: int = 1, + lora_rank: int = 0, + mlp_hidden: int = 0, + smear_gate: bool = False, + bigram_hash: bool = False, + bigram_hash_buckets: int = 4096, + bigram_hash_dim: int = 128, + ortho_init: bool = True, + xsa_last_n: int = 0, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + unet_skips: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.ortho_init = ortho_init + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_unique_layers = num_layers + self.num_loops = num_loops + effective_depth = num_layers * num_loops + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.smear_gate = SmearGate(model_dim) if smear_gate else None + self.bigram_hash = BigramHashEmbedding(bigram_hash_buckets, bigram_hash_dim, model_dim) if bigram_hash else None + # Shared Value Embedding: one table, added to V in selected layers + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = num_kv_heads * (model_dim // num_heads) + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) if ve_enabled else None + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) if ve_enabled else None + self.unet_skips = unet_skips + self.num_encoder_layers = effective_depth // 2 + self.num_decoder_layers = effective_depth - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) if unet_skips else None + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + mlp_hidden=mlp_hidden, + ntk_base_seq_len=ntk_base_seq_len, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + ) + for i in range(num_layers) + ] + ) + # Per-(loop, block) LoRA adapters for attention projections. + # Only created when num_loops > 1 and lora_rank > 0. + kv_dim = num_kv_heads * (model_dim // num_heads) + if lora_rank > 0 and num_loops > 1: + self.lora_adapters = nn.ModuleList( + [ + nn.ModuleList( + [AttentionLoRA(model_dim, kv_dim, lora_rank) for _ in range(num_layers)] + ) + for _ in range(num_loops) + ] + ) + else: + self.lora_adapters = None + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif self.ortho_init and isinstance(module, (nn.Linear, CastedLinear)) \ + and not getattr(module, "_zero_init", False) \ + and module.weight.ndim >= 2: + nn.init.orthogonal_(module.weight, gain=1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_gate is not None: + x = self.smear_gate(x) + x0 = x + num_enc = self.num_encoder_layers + num_dec = self.num_decoder_layers + + # Compute shared VE once (cached across layers) + ve_base = self.ve_shared(input_ids) if self.ve_shared is not None else None + + def _get_ve(layer_idx: int) -> Tensor | None: + if ve_base is None or layer_idx not in self.ve_layer_indices: + return None + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + if self.unet_skips: + skips: list[Tensor] = [] + for i in range(num_enc): + x = self.blocks[i](x, x0, v_embed=_get_ve(i)) + skips.append(x) + for i in range(num_dec): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[num_enc + i](x, x0, v_embed=_get_ve(num_enc + i)) + else: + eff_idx = 0 + for loop_idx in range(self.num_loops): + for block_idx in range(self.num_unique_layers): + lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None + x = self.blocks[block_idx](x, x0, lora=lora, v_embed=_get_ve(eff_idx)) + eff_idx += 1 + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_gate is not None: + x = self.smear_gate(x) + x0 = x + num_enc = self.num_encoder_layers + num_dec = self.num_decoder_layers + + if self.unet_skips: + skips: list[Tensor] = [] + for i in range(num_enc): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(num_dec): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[num_enc + i](x, x0) + else: + eff_idx = 0 + for loop_idx in range(self.num_loops): + for block_idx in range(self.num_unique_layers): + lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None + x = self.blocks[block_idx](x, x0, lora=lora) + eff_idx += 1 + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +BOS_ID = 1 # SentencePiece BOS token ID + +def _find_docs(all_tokens: Tensor) -> list[tuple[int, int]]: + """Return (start, length) for each document, identified by BOS boundaries. + + Each document starts at a BOS token and extends to just before the next BOS. + The last document extends to the end of the token stream. + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].tolist() + if not bos_positions: + return [(0, all_tokens.numel())] + docs = [] + for i in range(len(bos_positions)): + start = bos_positions[i] + end = bos_positions[i + 1] if i + 1 < len(bos_positions) else all_tokens.numel() + if end - start >= 2: # need at least 2 tokens for (x, y) pair + docs.append((start, end - start)) + return docs + + +def _build_sliding_windows( + total_tokens: int, seq_len: int, stride: int +) -> list[tuple[int, int]]: + """Return (window_start, score_start) pairs covering every token exactly once. + + Every token in [0, total_tokens) is scored by exactly one window. + Full windows score their last `stride` positions (first window scores all seq_len). + One tail-aligned window covers any tokens beyond the last full window's end. + """ + if total_tokens <= 0: + return [] + if total_tokens <= seq_len: + return [(0, 0)] + + windows: list[tuple[int, int]] = [] + last_full_end = 0 + ws = 0 + while ws + seq_len <= total_tokens: + s = 0 if ws == 0 else seq_len - stride + windows.append((ws, s)) + last_full_end = ws + seq_len + ws += stride + + # One tail window ending exactly at total_tokens covers any remainder. + if last_full_end < total_tokens: + tail_ws = total_tokens - seq_len + tail_s = last_full_end - tail_ws # skip already-scored prefix + windows.append((tail_ws, tail_s)) + + return windows + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + + Windows of eval_seq_len advance by `stride`. Only newly covered tokens + contribute to the score (first window scores all seq_len; non-first full + windows score the last `stride` tokens; one tail window covers any remainder). + Every validation token is counted exactly once. + """ + seq_len = args.eval_seq_len + total_tokens = val_tokens.numel() - 1 + + if args.doc_isolated_eval: + # Build windows per document — context never crosses document boundaries. + # Each window's (ws, s) is in absolute token-stream coordinates. + docs = _find_docs(val_tokens) + all_windows: list[tuple[int, int]] = [] + for doc_start, doc_len in docs: + doc_pred_len = doc_len - 1 # number of prediction positions + doc_windows = _build_sliding_windows(doc_pred_len, seq_len, stride) + for ws, s in doc_windows: + all_windows.append((doc_start + ws, s)) + else: + all_windows = _build_sliding_windows(total_tokens, seq_len, stride) + total_windows = len(all_windows) + + # Distribute across ranks + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = all_windows[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_items = my_windows[bi:bi + batch_seqs] + bsz = len(batch_items) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + score_starts: list[int] = [] + + for i, (ws, s) in enumerate(batch_items): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + score_starts.append(s) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, (ws, s) in enumerate(batch_items): + wlen = wlens[i] + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Progress (rank 0 only) + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def _ttt_run_phase( + model: nn.Module, + ttt_params: list[torch.nn.Parameter], + optimizer: torch.optim.Optimizer, + val_tokens: torch.Tensor, + seq_len: int, + batch_seqs: int, + epochs: int, + max_steps: int, + device: torch.device, + rank: int, + world_size: int, + phase_name: str, + t0: float, + cosine: bool = False, + warmup_frac: float = 0.0, +) -> None: + """Run one TTT phase with DDP gradient sharding across GPUs. + Each GPU processes batch_seqs sequences, gradients are manually all_reduced. + Supports cosine lr decay and linear warmup.""" + distributed = world_size > 1 + n_tokens = val_tokens.numel() + total_seqs = (n_tokens - 1) // seq_len + my_start_seq = (total_seqs * rank) // world_size + my_end_seq = (total_seqs * (rank + 1)) // world_size + + # Store initial lr for cosine/warmup scheduling + if cosine or warmup_frac > 0: + for g in optimizer.param_groups: + g["initial_lr"] = g["lr"] + # Estimate actual steps per epoch for cosine schedule + steps_per_epoch = min((my_end_seq - my_start_seq) // max(batch_seqs, 1), max_steps) + total_steps = epochs * steps_per_epoch + global_step = 0 + + model.train() + for epoch in range(epochs): + epoch_loss = torch.zeros((), device=device, dtype=torch.float64) + epoch_tokens = torch.zeros((), device=device, dtype=torch.float64) + step_i = 0 + # Contiguous slicing over this GPU's shard (matches #398 pattern) + for batch_start in range(my_start_seq, my_end_seq, batch_seqs): + if step_i >= max_steps: + break + + # LR schedule: warmup then cosine decay + if (cosine or warmup_frac > 0) and total_steps > 0: + progress = global_step / total_steps + if warmup_frac > 0 and progress < warmup_frac: + mul = progress / warmup_frac + elif cosine: + cos_start = warmup_frac if warmup_frac > 0 else 0.0 + cos_progress = (progress - cos_start) / (1.0 - cos_start) if cos_start < 1.0 else 0.0 + mul = 0.5 * (1.0 + math.cos(math.pi * min(cos_progress, 1.0))) + else: + mul = 1.0 + for g in optimizer.param_groups: + g["lr"] = g["initial_lr"] * mul + + batch_end = min(batch_start + batch_seqs, my_end_seq) + raw_start = batch_start * seq_len + raw_end = batch_end * seq_len + 1 + if raw_end > n_tokens: + break + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + loss.backward() + + if distributed: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + epoch_loss += loss.detach().to(torch.float64) * x.numel() + epoch_tokens += x.numel() + step_i += 1 + global_step += 1 + + if distributed: + dist.all_reduce(epoch_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(epoch_tokens, op=dist.ReduceOp.SUM) + + if rank == 0: + avg_loss = epoch_loss.item() / max(epoch_tokens.item(), 1) + cur_lr = optimizer.param_groups[0]["lr"] + print(f"ttt_{phase_name} epoch:{epoch+1}/{epochs} loss:{avg_loss:.4f} " + f"lr:{cur_lr:.6f} steps:{step_i} time:{time.perf_counter()-t0:.1f}s", flush=True) + + +def ttt_adapt( + args: Hyperparameters, + model: nn.Module, + device: torch.device, + val_tokens: torch.Tensor, + rank: int = 0, + world_size: int = 1, +) -> None: + """Two-phase TTT with DDP gradient sharding. + + Phase 1 (norm-only): Fix quantization artifacts by adapting only LayerNorm + weights, scales, resid_mix, and q_gain (~22K params). Low risk, high epoch count. + + Phase 2 (selective blocks): Adapt last N blocks + norms + head to val distribution. + Higher risk, lower epoch count. + + Falls back to single-phase SGD if TTT_TWO_PHASE=0. + """ + t0 = time.perf_counter() + seq_len = args.train_seq_len + batch_seqs = args.ttt_batch_seqs + + if args.ttt_two_phase: + # ── Phase 1: Norm-only recalibration ──────────────────────────── + # Freeze everything except norms, scales, resid_mix, q_gain + norm_params = [] + for p in model.parameters(): + p.requires_grad_(False) + for name, p in model.named_parameters(): + if any(k in name for k in ("norm", "scale", "resid_mix", "q_gain", "skip_weight")): + p.requires_grad_(True) + norm_params.append(p) + n_norm = sum(p.numel() for p in norm_params) + if rank == 0: + print(f"ttt_phase1:start params:{n_norm} epochs:{args.ttt_p1_epochs} lr:{args.ttt_p1_lr}", flush=True) + + optimizer_p1 = torch.optim.Adam(norm_params, lr=args.ttt_p1_lr) + _ttt_run_phase( + model, norm_params, optimizer_p1, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_p1_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="phase1", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer_p1 + + # ── Phase 2: Selective block adaptation ───────────────────────── + # Unfreeze last N blocks + all norms + head + embeddings + for p in model.parameters(): + p.requires_grad_(False) + num_layers = len(list(model.blocks)) # type: ignore[attr-defined] + phase2_params = [] + for name, p in model.named_parameters(): + is_late_block = False + for i in range(max(0, num_layers - args.ttt_p2_unfreeze_blocks), num_layers): + if f"blocks.{i}." in name: + is_late_block = True + break + is_norm_or_scale = any(k in name for k in ("norm", "scale", "resid_mix", "q_gain", "skip_weight")) + is_head = "lm_head" in name or "tok_emb" in name + if is_late_block or is_norm_or_scale or is_head: + p.requires_grad_(True) + phase2_params.append(p) + n_p2 = sum(p.numel() for p in phase2_params) + if rank == 0: + print(f"ttt_phase2:start params:{n_p2} epochs:{args.ttt_p2_epochs} lr:{args.ttt_p2_lr}", flush=True) + + if args.ttt_optimizer == "adamw": + optimizer_p2 = torch.optim.AdamW(phase2_params, lr=args.ttt_p2_lr, weight_decay=0.0) + else: + optimizer_p2 = torch.optim.SGD(phase2_params, lr=args.ttt_p2_lr, momentum=args.ttt_momentum) + _ttt_run_phase( + model, phase2_params, optimizer_p2, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_p2_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="phase2", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer_p2 + else: + # ── Single-phase TTT with cosine + per-layer lr ─────────────── + frozen = set() + for i, block in enumerate(model.blocks): # type: ignore[attr-defined] + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + frozen.add(id(p)) + + if args.ttt_perlayer: + # Per-layer lr: higher for more quant-damaged MLP proj, lower for fc + proj_params = [p for n, p in model.named_parameters() + if "mlp.proj" in n and p.requires_grad and id(p) not in frozen] + fc_params = [p for n, p in model.named_parameters() + if "mlp.fc" in n and p.requires_grad and id(p) not in frozen] + other_params = [p for p in model.parameters() + if p.requires_grad and id(p) not in frozen + and id(p) not in {id(q) for q in proj_params + fc_params}] + param_groups = [ + {"params": proj_params, "lr": args.ttt_lr * 3.0}, + {"params": fc_params, "lr": args.ttt_lr * 0.5}, + {"params": other_params, "lr": args.ttt_lr}, + ] + param_groups = [g for g in param_groups if g["params"]] + ttt_params = proj_params + fc_params + other_params + else: + ttt_params = [p for p in model.parameters() if p.requires_grad and id(p) not in frozen] + param_groups = [{"params": ttt_params, "lr": args.ttt_lr}] + + if rank == 0: + n_ttt = sum(p.numel() for p in ttt_params) + print(f"ttt:start params:{n_ttt} epochs:{args.ttt_epochs} lr:{args.ttt_lr} " + f"freeze:{args.ttt_freeze_blocks} optimizer:{args.ttt_optimizer} " + f"cosine:{args.ttt_cosine} warmup:{args.ttt_warmup_frac} perlayer:{args.ttt_perlayer}", flush=True) + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(param_groups, momentum=args.ttt_momentum) + _ttt_run_phase( + model, ttt_params, optimizer, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="single", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer + + # Unfreeze all params for eval + for p in model.parameters(): + p.requires_grad_(True) + if rank == 0: + print(f"ttt:done elapsed={time.perf_counter()-t0:.1f}s", flush=True) + + +def reptile_ttt( + args: Hyperparameters, + model: nn.Module, + device: torch.device, + val_tokens: torch.Tensor, + rank: int = 0, +) -> None: + """Reptile meta-TTT: find weights that adapt fast to val distribution. + Runs after EMA/SWA, before standard TTT. Makes TTT ~10x more effective.""" + t0 = time.perf_counter() + seq_len = args.train_seq_len + n_tokens = val_tokens.numel() + + # Only adapt MLP params of last 1/4 blocks + num_blocks = len(model.blocks) + suffix_start = num_blocks - num_blocks // 4 + ttt_params = {} + for name, p in model.named_parameters(): + if any(f'blocks.{i}.' in name and '.mlp.' in name for i in range(suffix_start, num_blocks)): + ttt_params[name] = p + + base_state = {name: p.data.clone() for name, p in ttt_params.items()} + reptile_steps = 0 + + while (time.perf_counter() - t0) < args.reptile_budget_s: + # Save current params + saved = {name: p.data.clone() for name, p in ttt_params.items()} + + # Inner loop: N SGD steps on a random chunk + model.train() + start = random.randint(0, max(n_tokens - seq_len - 1, 0)) + chunk = val_tokens[start:start + seq_len + 1].to(device=device, dtype=torch.int64) + x = chunk[:-1].unsqueeze(0) + y = chunk[1:].unsqueeze(0) + + for inner_step in range(args.reptile_inner_steps): + model.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + loss.backward() + with torch.no_grad(): + for name, param in ttt_params.items(): + if param.grad is not None: + param.data -= args.reptile_inner_lr * param.grad + + # Outer loop: move base toward adapted params + with torch.no_grad(): + for name, param in ttt_params.items(): + base_state[name] += args.reptile_outer_lr * (param.data - base_state[name]) + param.data.copy_(base_state[name]) + + reptile_steps += 1 + + if rank == 0: + print(f"reptile_ttt:done steps:{reptile_steps} elapsed:{time.perf_counter()-t0:.1f}s", flush=True) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + # Set module-level flag so CausalSelfAttention.forward can use FA3. + global _use_fa3 + _use_fa3 = args.use_fa3 and _HAS_FA3 + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.eval_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_loops=args.num_loops, + lora_rank=args.lora_rank, + mlp_hidden=args.mlp_hidden, + smear_gate=args.smear_gate, + bigram_hash=args.bigram_hash, + bigram_hash_buckets=args.bigram_hash_buckets, + bigram_hash_dim=args.bigram_hash_dim, + ortho_init=args.ortho_init, + xsa_last_n=args.xsa_last_n, + ntk_base_seq_len=args.train_seq_len if args.eval_seq_len > args.train_seq_len else 0, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + unet_skips=args.unet_skips, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + if args._tier2: + log0(f"*** TIER2_MODE: proxy run max={args.max_wallclock_seconds:.0f}s iters={args.iterations} " + f"ema={args.ema_enabled} ttt={args.ttt_enabled} qat={args.qat} " + f"-- compare val_bpb@step2000 against baseline tier2 run ***") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"xsa_last_n:{args.xsa_last_n} active_layers:{xsa_layers}") + ntk_active = args.eval_seq_len > args.train_seq_len + log0(f"ntk_rope:{'enabled' if ntk_active else 'disabled'} train_seq_len:{args.train_seq_len} eval_seq_len:{args.eval_seq_len}") + head_dim = args.model_dim // args.num_heads + _rope_dims_active = args.rope_dims > 0 and args.rope_dims < head_dim + log0(f"partial_rope:{'enabled' if _rope_dims_active else 'disabled'} rope_dims:{args.rope_dims if _rope_dims_active else head_dim}/{head_dim} (ROPE_DIMS={args.rope_dims})") + log0(f"ln_scale:{'enabled' if args.ln_scale else 'disabled'} (scale RMSNorm output by 1/sqrt(layer_idx+1))") + log0(f"unet_skips:{'enabled' if args.unet_skips else 'disabled'} (U-Net skip connections, enc={base_model.num_encoder_layers} dec={base_model.num_decoder_layers})") + if args.ve_enabled: + log0(f"ve:enabled dim={args.ve_dim} layers={base_model.ve_layer_indices}") + else: + log0("ve:disabled") + + log0(f"smear_gate:{args.smear_gate} bigram_hash:{args.bigram_hash} swa:{args.swa} " + f"ortho_init:{args.ortho_init} late_k_fp16:{args.late_k_fp16} " + f"fa3:{_use_fa3}(available={_HAS_FA3}) muon_wd:{args.muon_wd} adam_wd:{args.adam_wd}") + + # FP16 tied embedding export: skip int8 quantization for tok_emb.weight at export time. + # Avoids compounding int8 errors through both input embedding and output projection. + if args.fp16_embed_export and args.tie_embeddings: + _FP16_EXPORT_NAMES.add("tok_emb.weight") + log0(f"fp16_embed_export:enabled (tok_emb.weight kept in fp16, ~{args.vocab_size * args.model_dim * 2 / 1024:.0f}KB)") + + # Late-K: keep K projections of last 2 layers in fp16 (not quantized). + # Saves per-query context accuracy where it matters most — near the output. + if args.late_k_fp16: + effective_depth = args.num_layers * args.num_loops + for layer_idx in range(effective_depth - 2, effective_depth): + block_idx = layer_idx % args.num_layers + key_name = f"blocks.{block_idx}.attn.c_k.weight" + _FP16_EXPORT_NAMES.add(key_name) + log0(f"late_k_fp16:enabled (last 2 effective layers' c_k.weight kept in fp16)") + + for module in base_model.modules(): + if isinstance(module, (CastedLinear, AttentionLoRA)): + module.float() + restore_low_dim_params_to_fp32(base_model) + log0(f"qat:{args.qat} (activates when lr_scale < 0.1; absmax int6 STE)") + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + # bigram_hash.proj is a dense 2D projection — Muon is appropriate + if base_model.bigram_hash is not None: + matrix_params.append(base_model.bigram_hash.proj.weight) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights is not None: + scalar_params.append(base_model.skip_weights) + # smear_gate.gate is now a single nn.Parameter(dim) — AdamW at scalar_lr + if base_model.smear_gate is not None: + scalar_params.append(base_model.smear_gate.gate) + # bigram_hash.scale is a learned scalar — AdamW at scalar_lr + if base_model.bigram_hash is not None: + scalar_params.append(base_model.bigram_hash.scale) + # VE: scales go to scalar, proj to matrix, embed to tok group + if base_model.ve_shared is not None: + scalar_params.append(base_model.ve_shared.scale) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + if base_model.ve_layer_scales is not None: + for vs in base_model.ve_layer_scales: + scalar_params.append(vs) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + # bigram_hash.embed is an embedding table — train with AdamW alongside tok_emb + embed_params = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.embed.weight) + if base_model.ve_shared is not None: + embed_params.append(base_model.ve_shared.embed.weight) + optimizer_tok = torch.optim.AdamW( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lora_adapters is not None: + lora_params = list(base_model.lora_adapters.parameters()) + optimizer_lora = torch.optim.Adam( + [{"params": lora_params, "lr": args.lora_lr, "base_lr": args.lora_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.append(optimizer_lora) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + n_lora = sum(p.numel() for p in base_model.lora_adapters.parameters()) if base_model.lora_adapters is not None else 0 + effective_depth = args.num_layers * args.num_loops + log0(f"model_params:{n_params} (unique_layers:{args.num_layers} loops:{args.num_loops} effective_depth:{effective_depth} lora_rank:{args.lora_rank} lora_params:{n_lora})") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + # Curriculum learning: sort shards by average document length (easy first). + curriculum_files: list[Path] | None = None + if args.curriculum: + train_file_list = [Path(p) for p in sorted(glob.glob(args.train_files))] + curriculum_files = sort_shards_by_doc_length(train_file_list) + log0(f"curriculum:enabled shards_sorted_by_doc_length ({len(curriculum_files)} shards)") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + sorted_files=curriculum_files) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + # NOTE: QAT graph priming removed — it caused torch.compile to use a slower + # compilation path for the non-QAT forward pass (step_avg jumped from 44ms to 58ms). + # The one-time recompile when QAT activates (~30-90s) is cheaper than the cumulative + # overhead of a slower non-QAT path across thousands of steps. + + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + sorted_files=curriculum_files) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + qat_active = False + + # EMA: exponential moving average — smoother than SWA, better quantization compression. + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + log0(f"ema:initialized decay={args.ema_decay}") + + # SWA: fallback if EMA disabled + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + + # EMA: update every step during training + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for k, v in base_model.state_dict().items(): + ema_state[k].mul_(d).add_(v.detach().float(), alpha=1.0 - d) + + # LR-scale-based QAT activation: activate when lr_scale < 0.1 (last ~10% of warmdown, + # ~300 steps). Zero overhead for 90%+ of training; absmax scale makes per-step cost minimal. + if args.qat and not qat_active: + if scale < 0.1: + qat_active = True + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module._qat = True + elapsed_s = training_time_ms / 1000.0 + log0(f"qat_activated step:{step}/{args.iterations} lr_scale:{scale:.4f} elapsed:{elapsed_s:.1f}s remaining:{args.max_wallclock_seconds - elapsed_s:.1f}s") + + # SWA: accumulate weight averages during warmdown for smoother quantization. + # Accumulate in float32 to avoid bf16 precision loss over thousands of additions. + # Sample every 200 steps for sufficient checkpoint diversity. + if args.swa and step >= int(args.iterations * args.swa_start_frac) and step % 200 == 0: + if swa_state is None: + swa_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa_started step:{step}") + else: + swa_count += 1 + for k, v in base_model.state_dict().items(): + swa_state[k] += v.detach().float() + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + # Apply EMA weights (preferred) or SWA fallback before serialization. + if ema_state is not None: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + base_model.load_state_dict( + {k: v.to(dtype=current_state[k].dtype) for k, v in ema_state.items()}, strict=True + ) + del ema_state + elif args.swa and swa_state is not None and swa_count > 1: + log0(f"swa_applied count:{swa_count}") + current_state = base_model.state_dict() + avg_state = {} + for k, v in swa_state.items(): + avg = v / swa_count + avg_state[k] = avg.to(dtype=current_state[k].dtype) + base_model.load_state_dict(avg_state, strict=True) + + # Reptile meta-TTT: makes subsequent TTT ~10x more effective by finding weights that adapt fast. + if args.reptile_enabled: + log0(f"reptile_ttt:start budget={args.reptile_budget_s:.0f}s inner_steps={args.reptile_inner_steps} inner_lr={args.reptile_inner_lr} outer_lr={args.reptile_outer_lr}") + reptile_ttt(args, base_model, device, val_tokens, rank=rank) + + # TTT: adapt to val distribution before eval + if args.ttt_enabled: + if args.ttt_two_phase: + log0(f"ttt:start two_phase p1_epochs={args.ttt_p1_epochs} p1_lr={args.ttt_p1_lr} " + f"p2_epochs={args.ttt_p2_epochs} p2_lr={args.ttt_p2_lr} p2_blocks={args.ttt_p2_unfreeze_blocks} " + f"batch_seqs={args.ttt_batch_seqs}") + else: + log0(f"ttt:start lr={args.ttt_lr} epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"batch_seqs={args.ttt_batch_seqs}") + ttt_adapt(args, base_model, device, val_tokens, rank=rank, world_size=world_size) + + # Magnitude pruning: zero the smallest weights for better zstd compression. + # Zeroed weights compress to nearly nothing. Applied after TTT, before serialization. + if args.prune_pct > 0.0: + pruned_count = 0 + total_count = 0 + with torch.no_grad(): + for name, p in base_model.named_parameters(): + if p.ndim >= 2 and p.numel() >= 65536: + threshold = torch.quantile(p.abs().float().flatten(), args.prune_pct / 100.0) + mask = p.abs() > threshold + pruned_count += (~mask).sum().item() + total_count += p.numel() + p.mul_(mask) + log0(f"prune:{args.prune_pct:.1f}% zeroed {pruned_count}/{total_count} weights ({100*pruned_count/max(total_count,1):.1f}%)") + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + global _gptq_lite + _gptq_lite = args.gptq_lite + log0(f"quantization: {args.quant_bits}-bit gptq_lite:{_gptq_lite}") + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict(), bits=args.quant_bits) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if args.use_zstd and _HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_method = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_method = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+{compress_method}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+{compress_method}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + # Decompress with the same method used for compression. + if args.use_zstd and _HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_raw_disk = dctx.decompress(quant_blob_disk) + else: + quant_raw_disk = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_raw_disk), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.eval_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs} doc_isolated:{args.doc_isolated_eval}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Sun Mar 22 22:50:00 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 23C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 23C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 22C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 23C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 23C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 21C P0 109W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 21C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 3799 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 3800 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 3801 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 3802 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 3803 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 3804 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 3805 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 3806 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +xsa_last_n:0 active_layers:[] +ntk_rope:disabled train_seq_len:2048 eval_seq_len:2048 +partial_rope:enabled rope_dims:16/64 (ROPE_DIMS=16) +ln_scale:enabled (scale RMSNorm output by 1/sqrt(layer_idx+1)) +unet_skips:enabled (U-Net skip connections, enc=5 dec=6) +ve:disabled +smear_gate:True bigram_hash:True swa:False ortho_init:True late_k_fp16:False fa3:True(available=True) muon_wd:0.04 adam_wd:0.04 +qat:False (activates when lr_scale < 0.1; absmax int6 STE) +model_params:26829913 (unique_layers:11 loops:1 effective_depth:11 lora_rank:0 lora_params:0) +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +ema:initialized decay=0.997 +step:0/20000 val_loss:6.9285 val_bpb:4.1034 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9289 train_time:138ms step_avg:137.67ms +step:2/20000 train_loss:8.5619 train_time:209ms step_avg:104.50ms +step:3/20000 train_loss:7.7343 train_time:303ms step_avg:100.90ms +step:4/20000 train_loss:7.2715 train_time:396ms step_avg:99.05ms +step:5/20000 train_loss:6.8782 train_time:473ms step_avg:94.65ms +step:6/20000 train_loss:7.7641 train_time:550ms step_avg:91.67ms +step:7/20000 train_loss:6.7733 train_time:626ms step_avg:89.44ms +step:8/20000 train_loss:6.6122 train_time:704ms step_avg:88.05ms +step:9/20000 train_loss:6.3570 train_time:796ms step_avg:88.45ms +step:10/20000 train_loss:6.1513 train_time:873ms step_avg:87.26ms +step:200/20000 train_loss:2.8050 train_time:16315ms step_avg:81.58ms +step:400/20000 train_loss:2.2925 train_time:32849ms step_avg:82.12ms +step:600/20000 train_loss:2.4861 train_time:49684ms step_avg:82.81ms +step:800/20000 train_loss:2.2366 train_time:66449ms step_avg:83.06ms +step:1000/20000 train_loss:2.3285 train_time:82903ms step_avg:82.90ms +step:1000/20000 val_loss:2.2851 val_bpb:1.3533 train_time:82936ms step_avg:82.94ms +step:1200/20000 train_loss:2.3561 train_time:99797ms step_avg:83.16ms +step:1400/20000 train_loss:2.3815 train_time:116560ms step_avg:83.26ms +step:1600/20000 train_loss:2.0522 train_time:133196ms step_avg:83.25ms +step:1800/20000 train_loss:2.1601 train_time:150277ms step_avg:83.49ms +step:2000/20000 train_loss:2.1978 train_time:167136ms step_avg:83.57ms +step:2000/20000 val_loss:2.1802 val_bpb:1.2912 train_time:167151ms step_avg:83.58ms +step:2200/20000 train_loss:2.0209 train_time:184129ms step_avg:83.70ms +step:2400/20000 train_loss:2.1381 train_time:201031ms step_avg:83.76ms +step:2600/20000 train_loss:2.3681 train_time:218483ms step_avg:84.03ms +step:2800/20000 train_loss:2.1756 train_time:235286ms step_avg:84.03ms +step:3000/20000 train_loss:2.1665 train_time:251640ms step_avg:83.88ms +step:3000/20000 val_loss:2.1330 val_bpb:1.2633 train_time:251658ms step_avg:83.89ms +step:3200/20000 train_loss:2.1288 train_time:267912ms step_avg:83.72ms +step:3400/20000 train_loss:2.1066 train_time:284687ms step_avg:83.73ms +step:3600/20000 train_loss:2.0447 train_time:301087ms step_avg:83.64ms +step:3800/20000 train_loss:2.1516 train_time:317366ms step_avg:83.52ms +step:4000/20000 train_loss:2.1214 train_time:333489ms step_avg:83.37ms +step:4000/20000 val_loss:2.1126 val_bpb:1.2512 train_time:333514ms step_avg:83.38ms +step:4200/20000 train_loss:2.1213 train_time:353194ms step_avg:84.09ms +step:4400/20000 train_loss:2.0476 train_time:370516ms step_avg:84.21ms +step:4600/20000 train_loss:1.9050 train_time:387834ms step_avg:84.31ms +step:4800/20000 train_loss:2.1856 train_time:403996ms step_avg:84.17ms +step:5000/20000 train_loss:1.9324 train_time:421226ms step_avg:84.25ms +step:5000/20000 val_loss:2.0719 val_bpb:1.2271 train_time:421241ms step_avg:84.25ms +step:5200/20000 train_loss:2.0882 train_time:437349ms step_avg:84.11ms +step:5400/20000 train_loss:2.0943 train_time:454459ms step_avg:84.16ms +step:5600/20000 train_loss:2.0790 train_time:470648ms step_avg:84.04ms +step:5800/20000 train_loss:2.0243 train_time:487385ms step_avg:84.03ms +step:6000/20000 train_loss:2.0991 train_time:503972ms step_avg:84.00ms +step:6000/20000 val_loss:2.0237 val_bpb:1.1986 train_time:503988ms step_avg:84.00ms +step:6200/20000 train_loss:1.9592 train_time:521687ms step_avg:84.14ms +step:6400/20000 train_loss:2.0271 train_time:538498ms step_avg:84.14ms +step:6600/20000 train_loss:1.9730 train_time:555509ms step_avg:84.17ms +step:6800/20000 train_loss:2.0246 train_time:571939ms step_avg:84.11ms +step:7000/20000 train_loss:2.0558 train_time:588810ms step_avg:84.12ms +step:7000/20000 val_loss:1.9585 val_bpb:1.1599 train_time:588825ms step_avg:84.12ms +step:7101/20000 val_loss:1.9547 val_bpb:1.1577 train_time:599890ms step_avg:84.48ms +stopping_early: wallclock_cap train_time:599890ms step:7101/20000 +peak memory allocated: 13405 MiB reserved: 13690 MiB +ema:applying EMA weights +ttt:start lr=0.0005 epochs=30 freeze_blocks=0 batch_seqs=64 +Serialized model: 105658303 bytes +Code size: 104414 bytes +Total submission size: 105762717 bytes +quantization: 6-bit gptq_lite:True +Serialized model int8+zstd-22: 15258143 bytes (payload:27056482 raw_torch:27113039 payload_ratio:3.90x) +Total submission size int8+zstd-22: 15362557 bytes +final_eval_mode:sliding_window stride:64 batch_seqs:32 doc_isolated:False +final_int8_zlib_roundtrip val_loss:1.8504 val_bpb:1.0959 eval_time:186475ms +final_int8_zlib_roundtrip_exact val_loss:1.85037787 val_bpb:1.09589800 diff --git a/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_seed42.log b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_seed42.log new file mode 100644 index 000000000..e221016b0 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_seed42.log @@ -0,0 +1,2427 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn import flash_attn_func as _flash_attn_3_func + _HAS_FA3 = True # FA2 fallback — same input format, slightly slower than FA3 + except ImportError: + _HAS_FA3 = False +_use_fa3: bool = False # set at runtime after args are parsed + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) # eval window; NTK-RoPE scales if > train_seq_len + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # ── Tier-2 proxy mode ──────────────────────────────────────────────────── + # TIER2_MODE=1 overrides key settings for a fast 3-minute proxy run. + # Use this to validate architectural features before committing a full run. + # Schedule-dependent features (EMA, TTT, SWA) are disabled — they can only + # be evaluated meaningfully at full training duration. + # Compare val_bpb at step ~2000 against a baseline TIER2_MODE=1 run. + _tier2 = bool(int(os.environ.get("TIER2_MODE", "0"))) + if _tier2: + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 180.0)) + iterations = int(os.environ.get("ITERATIONS", 3000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 500)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) # faster final eval + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) # 0 = use mlp_mult * model_dim + fp16_embed_export = bool(int(os.environ.get("FP16_EMBED_EXPORT", "1"))) # keep tok_emb in fp16 at export + num_loops = int(os.environ.get("NUM_LOOPS", 1)) + lora_rank = int(os.environ.get("LORA_RANK", 0)) + qat = bool(int(os.environ.get("QAT", "1"))) + qat_min_seconds = float(os.environ.get("QAT_MIN_SECONDS", 120.0)) # guarantee QAT runs for at least this long + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + doc_isolated_eval = bool(int(os.environ.get("DOC_ISOLATED_EVAL", "1"))) # eval per-document, no cross-doc context + smear_gate = bool(int(os.environ.get("SMEAR_GATE", "1"))) # cheap bigram context at embedding layer + bigram_hash = bool(int(os.environ.get("BIGRAM_HASH", "1"))) # hash-based bigram embedding + bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 2048)) + bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128)) + swa = bool(int(os.environ.get("SWA", "0"))) # stochastic weight averaging (disabled: EMA preferred) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last N layers (0=disabled) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) # exponential moving average (replaces SWA) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) # EMA decay per step + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) # test-time training on val data + ttt_lr = float(os.environ.get("TTT_LR", 0.0005)) # AdamW lr (PR #442: AdamW beats SGD by 0.019 BPB) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 30)) # 30ep cosine beats 10ep flat by 16% + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) # only used if TTT_OPTIMIZER=sgd + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) # 0 = all blocks unfrozen (PR #398) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adamw") # "adamw" or "sgd" + ttt_max_steps = int(os.environ.get("TTT_MAX_STEPS", 300)) # cap steps per epoch (~10s per epoch) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 64)) # seqs per GPU per step (64*8=512 total) + ttt_cosine = bool(int(os.environ.get("TTT_COSINE", "1"))) # cosine lr decay during TTT (+16% over flat) + ttt_warmup_frac = float(os.environ.get("TTT_WARMUP_FRAC", 0.0)) # linear warmup fraction (0.1 = 10%) + ttt_perlayer = bool(int(os.environ.get("TTT_PERLAYER", "1"))) # per-layer lr (3x proj, 0.5x fc) — +23.5% combined with cosine + # Two-phase TTT (matches PR #415/#417 approach) + ttt_two_phase = bool(int(os.environ.get("TTT_TWO_PHASE", "0"))) # enable two-phase TTT + ttt_p1_epochs = int(os.environ.get("TTT_P1_EPOCHS", 50)) # phase 1: norm-only recalibration + ttt_p1_lr = float(os.environ.get("TTT_P1_LR", 0.01)) # phase 1: Adam lr + ttt_p2_epochs = int(os.environ.get("TTT_P2_EPOCHS", 10)) # phase 2: selective block adaptation + ttt_p2_lr = float(os.environ.get("TTT_P2_LR", 0.005)) # phase 2: SGD lr + ttt_p2_unfreeze_blocks = int(os.environ.get("TTT_P2_UNFREEZE_BLOCKS", 3)) # phase 2: unfreeze last N blocks + use_zstd = bool(int(os.environ.get("USE_ZSTD", "1"))) # use zstd instead of zlib for compression + curriculum = bool(int(os.environ.get("CURRICULUM", "0"))) # sort training shards by doc length (easy first) + quant_bits = int(os.environ.get("QUANT_BITS", 6)) # 8=int8, 6=int6 (int6 fits ~3x more params in 16MB) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 50000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + lora_lr = float(os.environ.get("LORA_LR", 0.01)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + ortho_init = bool(int(os.environ.get("ORTHO_INIT", "1"))) + late_k_fp16 = bool(int(os.environ.get("LATE_K_FP16", "1"))) + use_fa3 = bool(int(os.environ.get("USE_FA3", "1"))) + rope_dims = int(os.environ.get("ROPE_DIMS", 0)) # 0 = full RoPE; >0 = apply RoPE to only first N dims per head + ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) # scale block output by 1/sqrt(layer_idx+1) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "0"))) # U-Net style skip connections + prune_pct = float(os.environ.get("PRUNE_PCT", 0.0)) # magnitude pruning: zero smallest N% of weights before compression + gptq_lite = bool(int(os.environ.get("GPTQ_LITE", "1"))) # per-row optimal clip search at quantization (default ON, zero training cost) + reptile_enabled = bool(int(os.environ.get("REPTILE_TTT", "0"))) # Reptile meta-TTT before standard TTT + reptile_budget_s = float(os.environ.get("REPTILE_BUDGET_S", 60.0)) # seconds for Reptile meta-learning + reptile_inner_steps = int(os.environ.get("REPTILE_INNER_STEPS", 3)) # SGD steps per inner loop + reptile_inner_lr = float(os.environ.get("REPTILE_INNER_LR", 0.1)) # inner SGD learning rate + reptile_outer_lr = float(os.environ.get("REPTILE_OUTER_LR", 0.01)) # outer interpolation rate + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "0"))) # shared value embedding + ve_dim = int(os.environ.get("VE_DIM", 128)) # value embedding dimension + ve_layers = os.environ.get("VE_LAYERS", "9,10") # comma-separated layer indices + + # Disable schedule-dependent features in TIER2_MODE unless explicitly overridden + if _tier2: + qat = bool(int(os.environ.get("QAT", "0"))) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + swa = bool(int(os.environ.get("SWA", "0"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + # Cache flat buffer to avoid per-step allocation + if "_updates_flat" not in group: + group["_total_params"] = sum(int(p.numel()) for p in params) + group["_updates_flat"] = torch.zeros(group["_total_params"], device=params[0].device, dtype=torch.bfloat16) + total_params = group["_total_params"] + updates_flat = group["_updates_flat"] + updates_flat.zero_() + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + if wd > 0.0: + p.mul_(1.0 - wd * lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_2d_with_clip(t32: Tensor, clip_abs: Tensor, max_val: int) -> tuple[Tensor, Tensor, Tensor]: + """Quantize a 2D tensor with given per-row clip values. Returns (q, scale, reconstruction_error).""" + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8) + # Reconstruction error: MSE between original and dequantized + deq = q.float() * scale[:, None] + err = (t32 - deq).square().mean(dim=1) + return q, scale, err + + +# GPTQ-lite clip ratios: search these percentiles per weight matrix to find optimal clipping. +_GPTQ_CLIP_RATIOS = [0.9, 0.95, 0.99, 0.999, 0.99999] +_gptq_lite: bool = False # set at runtime from GPTQ_LITE env var + + +def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]: + max_val = 127 if bits == 8 else (2 ** (bits - 1)) - 1 # int6: 31, int8: 127 + t32 = t.float() + if t32.ndim == 2: + if _gptq_lite and t32.numel() > 0: + # GPTQ-lite: try multiple clip percentiles per row, pick lowest reconstruction error. + abs_vals = t32.abs() + best_q: Tensor | None = None + best_scale: Tensor | None = None + best_err: Tensor | None = None + for ratio in _GPTQ_CLIP_RATIOS: + clip_abs = torch.quantile(abs_vals, ratio, dim=1) + q, scale, err = _quantize_2d_with_clip(t32, clip_abs, max_val) + if best_err is None: + best_q, best_scale, best_err = q, scale, err + else: + # Per-row: keep whichever clip ratio gave lower error for each row + better = err < best_err + best_q[better] = q[better] + best_scale[better] = scale[better] + best_err[better] = err[better] + return best_q.contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Standard: single clip percentile + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(max_val) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous() + return q, scale + +# Names of tensors to keep in fp16 at export instead of quantizing to int8. +# Populated at runtime when fp16_embed_export=True. +_FP16_EXPORT_NAMES: set[str] = set() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor], bits: int = 8): + # Clean-script export format supporting int8 or int6: + # - per-row quantization for 2D float tensors (int8: [-127,127], int6: [-31,31]) + # - per-tensor quantization for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + # - fp16 passthrough for tensors in _FP16_EXPORT_NAMES (e.g. tied embeddings) + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # FP16 export bypass: keep specified tensors (e.g. tied embeddings) in fp16 + # instead of quantizing to int8. Avoids compounding int8 errors through both + # the input embedding and output projection paths. + if name in _FP16_EXPORT_NAMES: + kept = t.to(dtype=torch.float16).contiguous() + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t, bits=bits) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" list[Path]: + """Sort shards by average document length (shorter docs = easier = first). + + Reads only the first 100K tokens of each shard to estimate avg doc length. + Shards with shorter average documents contain simpler, more repetitive text + that helps the model learn basic patterns before encountering harder material. + """ + difficulties: list[tuple[float, Path]] = [] + for f in files: + header = np.fromfile(f, dtype=" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device, + sorted_files: list[Path] | None = None): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern, sorted_files=sorted_files) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +def fake_quantize_int8_per_row(w: Tensor) -> Tensor: + """Simulate per-row int6 quantization with straight-through estimator. + + Uses absmax per-row scale (no torch.quantile — O(n) instead of O(n log n)). + Matches the int6 export range [-31, 31]. + Backward: gradients pass through as if no quantization happened (STE). + """ + w32 = w.float() + scale = (w32.abs().amax(dim=1) / 31.0).clamp_min(1.0 / 31.0) + w_q = torch.clamp(torch.round(w32 / scale[:, None]), -31, 31) + w_deq = w_q * scale[:, None] + return w + (w_deq.to(w.dtype) - w).detach() + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + _qat: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self._qat and self.training: + w = fake_quantize_int8_per_row(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +class AttentionLoRA(nn.Module): + """Per-iteration LoRA adapters for attention Q, K, V, and output projections. + + Initialized so that the LoRA contribution is zero at the start of training + (B matrices are zeros). During training, the optimizer learns per-iteration + specialization while the base attention weights remain shared across loops. + """ + def __init__(self, dim: int, kv_dim: int, rank: int): + super().__init__() + self.q_A = nn.Parameter(torch.empty(dim, rank)) + self.q_B = nn.Parameter(torch.zeros(rank, dim)) + self.k_A = nn.Parameter(torch.empty(dim, rank)) + self.k_B = nn.Parameter(torch.zeros(rank, kv_dim)) + self.v_A = nn.Parameter(torch.empty(dim, rank)) + self.v_B = nn.Parameter(torch.zeros(rank, kv_dim)) + self.proj_A = nn.Parameter(torch.empty(dim, rank)) + self.proj_B = nn.Parameter(torch.zeros(rank, dim)) + self._init_lora() + + def _init_lora(self) -> None: + for name in ("q_A", "k_A", "v_A", "proj_A"): + nn.init.kaiming_uniform_(getattr(self, name), a=math.sqrt(5)) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + # ntk_base_seq_len: if > 0, apply NTK-aware RoPE scaling when seq_len > ntk_base_seq_len + # (lets a model trained at seq_len=1024 generalise to seq_len=2048 at eval with no quality loss) + def __init__(self, dim: int, base: float = 10000.0, ntk_base_seq_len: int = 0): + super().__init__() + self._dim = dim + self._base = base + self._ntk_base_seq_len = ntk_base_seq_len + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + if self._ntk_base_seq_len > 0 and seq_len > self._ntk_base_seq_len: + # NTK-aware scaling: extend context without fine-tuning + scale = seq_len / self._ntk_base_seq_len + ntk_base = self._base * (scale ** (self._dim / (self._dim - 2))) + inv_freq = 1.0 / (ntk_base ** (torch.arange(0, self._dim, 2, dtype=torch.float32, device=device) / self._dim)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + # rope_dims: if >0 apply RoPE only to first rope_dims dims of each head; rest are position-free + self.rope_dims = rope_dims if rope_dims > 0 else self.head_dim + self.rotary = Rotary(self.rope_dims, base=rope_base, ntk_base_seq_len=ntk_base_seq_len) + self.use_xsa = False # enabled on last N layers by GPT.__init__ + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Exclusive Self Attention: subtract self-value projection (arXiv:2603.09078). + GQA-aware reshape avoids repeat_interleave — zero extra allocation.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) # (B, T, Hkv, 1, D) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, lora: AttentionLoRA | None = None, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + if lora is not None: + # LoRA delta: (bsz, seqlen, dim) @ (dim, rank) @ (rank, out_dim) + # autocast handles fp32->bf16 cast of LoRA params automatically + q = q + (x @ lora.q_A) @ lora.q_B + k = k + (x @ lora.k_A) @ lora.k_B + v = v + (x @ lora.v_A) @ lora.v_B + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dims < self.head_dim: + # Partial RoPE: apply rotation to first rope_dims dims, leave remaining dims untouched + q = torch.cat([apply_rotary_emb(q[..., :self.rope_dims], cos, sin), q[..., self.rope_dims:]], dim=-1) + k = torch.cat([apply_rotary_emb(k[..., :self.rope_dims], cos, sin), k[..., self.rope_dims:]], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if _HAS_FA3 and _use_fa3: + # FA3 expects [bsz, seqlen, heads, head_dim] + q_fa = q.transpose(1, 2).contiguous() + k_fa = k.transpose(1, 2).contiguous() + v_fa = v.transpose(1, 2).contiguous() + y = _flash_attn_3_func(q_fa, k_fa, v_fa, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v_fa) + y = y.reshape(bsz, seqlen, dim) + else: + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + y = self._xsa_efficient(y.transpose(1, 2), v.transpose(1, 2)) + else: + y = y.transpose(1, 2) + y = y.contiguous().reshape(bsz, seqlen, dim) + out = self.proj(y) + if lora is not None: + out = out + (y @ lora.proj_A) @ lora.proj_B + return out + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, ntk_base_seq_len=ntk_base_seq_len, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, lora: AttentionLoRA | None = None, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x) * s, lora=lora, v_embed=v_embed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + + +# ── Shared Value Embedding (VE) ────────────────────────────────────────────── +class ValueEmbedding(nn.Module): + """Learned embedding added to attention values in selected layers. + One shared table across layers, with per-layer learned scales.""" + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, kv_dim, bias=False) if ve_dim != kv_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +# ── SmearGate: cheap bigram context at embedding layer ────────────────────── +class SmearGate(nn.Module): + """Blends each token's embedding with the previous token's via per-channel sigmoid gates. + 512 independent channel gates (vs scalar gate) give the model richer bigram context.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + # x: (bsz, seq_len, dim) + prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] # (1, 1, dim) + return torch.lerp(x, prev, g) + + +# ── BigramHash: hash-based bigram embedding ───────────────────────────────── +class BigramHashEmbedding(nn.Module): + """Maps consecutive token pairs to embeddings via a hash table. + Injects explicit bigram statistics into the residual stream.""" + def __init__(self, num_buckets: int, hash_dim: int, model_dim: int): + super().__init__() + self.num_buckets = num_buckets + self.embed = nn.Embedding(num_buckets, hash_dim) + self.proj = nn.Linear(hash_dim, model_dim, bias=False) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + nn.init.normal_(self.embed.weight, std=0.01) + nn.init.zeros_(self.proj.weight) + + def _hash_pair(self, prev_ids: Tensor, cur_ids: Tensor) -> Tensor: + # XOR hash with large primes for better distribution + return (torch.bitwise_xor(36313 * prev_ids.long(), 27191 * cur_ids.long()) % max(self.num_buckets - 1, 1)).to(prev_ids.device) + + def forward(self, input_ids: Tensor) -> Tensor: + # input_ids: (bsz, seq_len) + prev_ids = torch.cat([torch.zeros_like(input_ids[:, :1]), input_ids[:, :-1]], dim=1) + bucket_ids = self._hash_pair(prev_ids, input_ids) + return self.scale.to(dtype=input_ids.dtype if input_ids.is_floating_point() else torch.bfloat16) * self.proj(self.embed(bucket_ids)) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_loops: int = 1, + lora_rank: int = 0, + mlp_hidden: int = 0, + smear_gate: bool = False, + bigram_hash: bool = False, + bigram_hash_buckets: int = 4096, + bigram_hash_dim: int = 128, + ortho_init: bool = True, + xsa_last_n: int = 0, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + unet_skips: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.ortho_init = ortho_init + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_unique_layers = num_layers + self.num_loops = num_loops + effective_depth = num_layers * num_loops + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.smear_gate = SmearGate(model_dim) if smear_gate else None + self.bigram_hash = BigramHashEmbedding(bigram_hash_buckets, bigram_hash_dim, model_dim) if bigram_hash else None + # Shared Value Embedding: one table, added to V in selected layers + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = num_kv_heads * (model_dim // num_heads) + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) if ve_enabled else None + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) if ve_enabled else None + self.unet_skips = unet_skips + self.num_encoder_layers = effective_depth // 2 + self.num_decoder_layers = effective_depth - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) if unet_skips else None + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + mlp_hidden=mlp_hidden, + ntk_base_seq_len=ntk_base_seq_len, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + ) + for i in range(num_layers) + ] + ) + # Per-(loop, block) LoRA adapters for attention projections. + # Only created when num_loops > 1 and lora_rank > 0. + kv_dim = num_kv_heads * (model_dim // num_heads) + if lora_rank > 0 and num_loops > 1: + self.lora_adapters = nn.ModuleList( + [ + nn.ModuleList( + [AttentionLoRA(model_dim, kv_dim, lora_rank) for _ in range(num_layers)] + ) + for _ in range(num_loops) + ] + ) + else: + self.lora_adapters = None + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif self.ortho_init and isinstance(module, (nn.Linear, CastedLinear)) \ + and not getattr(module, "_zero_init", False) \ + and module.weight.ndim >= 2: + nn.init.orthogonal_(module.weight, gain=1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_gate is not None: + x = self.smear_gate(x) + x0 = x + num_enc = self.num_encoder_layers + num_dec = self.num_decoder_layers + + # Compute shared VE once (cached across layers) + ve_base = self.ve_shared(input_ids) if self.ve_shared is not None else None + + def _get_ve(layer_idx: int) -> Tensor | None: + if ve_base is None or layer_idx not in self.ve_layer_indices: + return None + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + if self.unet_skips: + skips: list[Tensor] = [] + for i in range(num_enc): + x = self.blocks[i](x, x0, v_embed=_get_ve(i)) + skips.append(x) + for i in range(num_dec): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[num_enc + i](x, x0, v_embed=_get_ve(num_enc + i)) + else: + eff_idx = 0 + for loop_idx in range(self.num_loops): + for block_idx in range(self.num_unique_layers): + lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None + x = self.blocks[block_idx](x, x0, lora=lora, v_embed=_get_ve(eff_idx)) + eff_idx += 1 + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_gate is not None: + x = self.smear_gate(x) + x0 = x + num_enc = self.num_encoder_layers + num_dec = self.num_decoder_layers + + if self.unet_skips: + skips: list[Tensor] = [] + for i in range(num_enc): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(num_dec): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[num_enc + i](x, x0) + else: + eff_idx = 0 + for loop_idx in range(self.num_loops): + for block_idx in range(self.num_unique_layers): + lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None + x = self.blocks[block_idx](x, x0, lora=lora) + eff_idx += 1 + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +BOS_ID = 1 # SentencePiece BOS token ID + +def _find_docs(all_tokens: Tensor) -> list[tuple[int, int]]: + """Return (start, length) for each document, identified by BOS boundaries. + + Each document starts at a BOS token and extends to just before the next BOS. + The last document extends to the end of the token stream. + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].tolist() + if not bos_positions: + return [(0, all_tokens.numel())] + docs = [] + for i in range(len(bos_positions)): + start = bos_positions[i] + end = bos_positions[i + 1] if i + 1 < len(bos_positions) else all_tokens.numel() + if end - start >= 2: # need at least 2 tokens for (x, y) pair + docs.append((start, end - start)) + return docs + + +def _build_sliding_windows( + total_tokens: int, seq_len: int, stride: int +) -> list[tuple[int, int]]: + """Return (window_start, score_start) pairs covering every token exactly once. + + Every token in [0, total_tokens) is scored by exactly one window. + Full windows score their last `stride` positions (first window scores all seq_len). + One tail-aligned window covers any tokens beyond the last full window's end. + """ + if total_tokens <= 0: + return [] + if total_tokens <= seq_len: + return [(0, 0)] + + windows: list[tuple[int, int]] = [] + last_full_end = 0 + ws = 0 + while ws + seq_len <= total_tokens: + s = 0 if ws == 0 else seq_len - stride + windows.append((ws, s)) + last_full_end = ws + seq_len + ws += stride + + # One tail window ending exactly at total_tokens covers any remainder. + if last_full_end < total_tokens: + tail_ws = total_tokens - seq_len + tail_s = last_full_end - tail_ws # skip already-scored prefix + windows.append((tail_ws, tail_s)) + + return windows + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + + Windows of eval_seq_len advance by `stride`. Only newly covered tokens + contribute to the score (first window scores all seq_len; non-first full + windows score the last `stride` tokens; one tail window covers any remainder). + Every validation token is counted exactly once. + """ + seq_len = args.eval_seq_len + total_tokens = val_tokens.numel() - 1 + + if args.doc_isolated_eval: + # Build windows per document — context never crosses document boundaries. + # Each window's (ws, s) is in absolute token-stream coordinates. + docs = _find_docs(val_tokens) + all_windows: list[tuple[int, int]] = [] + for doc_start, doc_len in docs: + doc_pred_len = doc_len - 1 # number of prediction positions + doc_windows = _build_sliding_windows(doc_pred_len, seq_len, stride) + for ws, s in doc_windows: + all_windows.append((doc_start + ws, s)) + else: + all_windows = _build_sliding_windows(total_tokens, seq_len, stride) + total_windows = len(all_windows) + + # Distribute across ranks + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = all_windows[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_items = my_windows[bi:bi + batch_seqs] + bsz = len(batch_items) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + score_starts: list[int] = [] + + for i, (ws, s) in enumerate(batch_items): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + score_starts.append(s) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, (ws, s) in enumerate(batch_items): + wlen = wlens[i] + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Progress (rank 0 only) + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def _ttt_run_phase( + model: nn.Module, + ttt_params: list[torch.nn.Parameter], + optimizer: torch.optim.Optimizer, + val_tokens: torch.Tensor, + seq_len: int, + batch_seqs: int, + epochs: int, + max_steps: int, + device: torch.device, + rank: int, + world_size: int, + phase_name: str, + t0: float, + cosine: bool = False, + warmup_frac: float = 0.0, +) -> None: + """Run one TTT phase with DDP gradient sharding across GPUs. + Each GPU processes batch_seqs sequences, gradients are manually all_reduced. + Supports cosine lr decay and linear warmup.""" + distributed = world_size > 1 + n_tokens = val_tokens.numel() + total_seqs = (n_tokens - 1) // seq_len + my_start_seq = (total_seqs * rank) // world_size + my_end_seq = (total_seqs * (rank + 1)) // world_size + + # Store initial lr for cosine/warmup scheduling + if cosine or warmup_frac > 0: + for g in optimizer.param_groups: + g["initial_lr"] = g["lr"] + # Estimate actual steps per epoch for cosine schedule + steps_per_epoch = min((my_end_seq - my_start_seq) // max(batch_seqs, 1), max_steps) + total_steps = epochs * steps_per_epoch + global_step = 0 + + model.train() + for epoch in range(epochs): + epoch_loss = torch.zeros((), device=device, dtype=torch.float64) + epoch_tokens = torch.zeros((), device=device, dtype=torch.float64) + step_i = 0 + # Contiguous slicing over this GPU's shard (matches #398 pattern) + for batch_start in range(my_start_seq, my_end_seq, batch_seqs): + if step_i >= max_steps: + break + + # LR schedule: warmup then cosine decay + if (cosine or warmup_frac > 0) and total_steps > 0: + progress = global_step / total_steps + if warmup_frac > 0 and progress < warmup_frac: + mul = progress / warmup_frac + elif cosine: + cos_start = warmup_frac if warmup_frac > 0 else 0.0 + cos_progress = (progress - cos_start) / (1.0 - cos_start) if cos_start < 1.0 else 0.0 + mul = 0.5 * (1.0 + math.cos(math.pi * min(cos_progress, 1.0))) + else: + mul = 1.0 + for g in optimizer.param_groups: + g["lr"] = g["initial_lr"] * mul + + batch_end = min(batch_start + batch_seqs, my_end_seq) + raw_start = batch_start * seq_len + raw_end = batch_end * seq_len + 1 + if raw_end > n_tokens: + break + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + loss.backward() + + if distributed: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + epoch_loss += loss.detach().to(torch.float64) * x.numel() + epoch_tokens += x.numel() + step_i += 1 + global_step += 1 + + if distributed: + dist.all_reduce(epoch_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(epoch_tokens, op=dist.ReduceOp.SUM) + + if rank == 0: + avg_loss = epoch_loss.item() / max(epoch_tokens.item(), 1) + cur_lr = optimizer.param_groups[0]["lr"] + print(f"ttt_{phase_name} epoch:{epoch+1}/{epochs} loss:{avg_loss:.4f} " + f"lr:{cur_lr:.6f} steps:{step_i} time:{time.perf_counter()-t0:.1f}s", flush=True) + + +def ttt_adapt( + args: Hyperparameters, + model: nn.Module, + device: torch.device, + val_tokens: torch.Tensor, + rank: int = 0, + world_size: int = 1, +) -> None: + """Two-phase TTT with DDP gradient sharding. + + Phase 1 (norm-only): Fix quantization artifacts by adapting only LayerNorm + weights, scales, resid_mix, and q_gain (~22K params). Low risk, high epoch count. + + Phase 2 (selective blocks): Adapt last N blocks + norms + head to val distribution. + Higher risk, lower epoch count. + + Falls back to single-phase SGD if TTT_TWO_PHASE=0. + """ + t0 = time.perf_counter() + seq_len = args.train_seq_len + batch_seqs = args.ttt_batch_seqs + + if args.ttt_two_phase: + # ── Phase 1: Norm-only recalibration ──────────────────────────── + # Freeze everything except norms, scales, resid_mix, q_gain + norm_params = [] + for p in model.parameters(): + p.requires_grad_(False) + for name, p in model.named_parameters(): + if any(k in name for k in ("norm", "scale", "resid_mix", "q_gain", "skip_weight")): + p.requires_grad_(True) + norm_params.append(p) + n_norm = sum(p.numel() for p in norm_params) + if rank == 0: + print(f"ttt_phase1:start params:{n_norm} epochs:{args.ttt_p1_epochs} lr:{args.ttt_p1_lr}", flush=True) + + optimizer_p1 = torch.optim.Adam(norm_params, lr=args.ttt_p1_lr) + _ttt_run_phase( + model, norm_params, optimizer_p1, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_p1_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="phase1", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer_p1 + + # ── Phase 2: Selective block adaptation ───────────────────────── + # Unfreeze last N blocks + all norms + head + embeddings + for p in model.parameters(): + p.requires_grad_(False) + num_layers = len(list(model.blocks)) # type: ignore[attr-defined] + phase2_params = [] + for name, p in model.named_parameters(): + is_late_block = False + for i in range(max(0, num_layers - args.ttt_p2_unfreeze_blocks), num_layers): + if f"blocks.{i}." in name: + is_late_block = True + break + is_norm_or_scale = any(k in name for k in ("norm", "scale", "resid_mix", "q_gain", "skip_weight")) + is_head = "lm_head" in name or "tok_emb" in name + if is_late_block or is_norm_or_scale or is_head: + p.requires_grad_(True) + phase2_params.append(p) + n_p2 = sum(p.numel() for p in phase2_params) + if rank == 0: + print(f"ttt_phase2:start params:{n_p2} epochs:{args.ttt_p2_epochs} lr:{args.ttt_p2_lr}", flush=True) + + if args.ttt_optimizer == "adamw": + optimizer_p2 = torch.optim.AdamW(phase2_params, lr=args.ttt_p2_lr, weight_decay=0.0) + else: + optimizer_p2 = torch.optim.SGD(phase2_params, lr=args.ttt_p2_lr, momentum=args.ttt_momentum) + _ttt_run_phase( + model, phase2_params, optimizer_p2, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_p2_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="phase2", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer_p2 + else: + # ── Single-phase TTT with cosine + per-layer lr ─────────────── + frozen = set() + for i, block in enumerate(model.blocks): # type: ignore[attr-defined] + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + frozen.add(id(p)) + + if args.ttt_perlayer: + # Per-layer lr: higher for more quant-damaged MLP proj, lower for fc + proj_params = [p for n, p in model.named_parameters() + if "mlp.proj" in n and p.requires_grad and id(p) not in frozen] + fc_params = [p for n, p in model.named_parameters() + if "mlp.fc" in n and p.requires_grad and id(p) not in frozen] + other_params = [p for p in model.parameters() + if p.requires_grad and id(p) not in frozen + and id(p) not in {id(q) for q in proj_params + fc_params}] + param_groups = [ + {"params": proj_params, "lr": args.ttt_lr * 3.0}, + {"params": fc_params, "lr": args.ttt_lr * 0.5}, + {"params": other_params, "lr": args.ttt_lr}, + ] + param_groups = [g for g in param_groups if g["params"]] + ttt_params = proj_params + fc_params + other_params + else: + ttt_params = [p for p in model.parameters() if p.requires_grad and id(p) not in frozen] + param_groups = [{"params": ttt_params, "lr": args.ttt_lr}] + + if rank == 0: + n_ttt = sum(p.numel() for p in ttt_params) + print(f"ttt:start params:{n_ttt} epochs:{args.ttt_epochs} lr:{args.ttt_lr} " + f"freeze:{args.ttt_freeze_blocks} optimizer:{args.ttt_optimizer} " + f"cosine:{args.ttt_cosine} warmup:{args.ttt_warmup_frac} perlayer:{args.ttt_perlayer}", flush=True) + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(param_groups, momentum=args.ttt_momentum) + _ttt_run_phase( + model, ttt_params, optimizer, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="single", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer + + # Unfreeze all params for eval + for p in model.parameters(): + p.requires_grad_(True) + if rank == 0: + print(f"ttt:done elapsed={time.perf_counter()-t0:.1f}s", flush=True) + + +def reptile_ttt( + args: Hyperparameters, + model: nn.Module, + device: torch.device, + val_tokens: torch.Tensor, + rank: int = 0, +) -> None: + """Reptile meta-TTT: find weights that adapt fast to val distribution. + Runs after EMA/SWA, before standard TTT. Makes TTT ~10x more effective.""" + t0 = time.perf_counter() + seq_len = args.train_seq_len + n_tokens = val_tokens.numel() + + # Only adapt MLP params of last 1/4 blocks + num_blocks = len(model.blocks) + suffix_start = num_blocks - num_blocks // 4 + ttt_params = {} + for name, p in model.named_parameters(): + if any(f'blocks.{i}.' in name and '.mlp.' in name for i in range(suffix_start, num_blocks)): + ttt_params[name] = p + + base_state = {name: p.data.clone() for name, p in ttt_params.items()} + reptile_steps = 0 + + while (time.perf_counter() - t0) < args.reptile_budget_s: + # Save current params + saved = {name: p.data.clone() for name, p in ttt_params.items()} + + # Inner loop: N SGD steps on a random chunk + model.train() + start = random.randint(0, max(n_tokens - seq_len - 1, 0)) + chunk = val_tokens[start:start + seq_len + 1].to(device=device, dtype=torch.int64) + x = chunk[:-1].unsqueeze(0) + y = chunk[1:].unsqueeze(0) + + for inner_step in range(args.reptile_inner_steps): + model.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + loss.backward() + with torch.no_grad(): + for name, param in ttt_params.items(): + if param.grad is not None: + param.data -= args.reptile_inner_lr * param.grad + + # Outer loop: move base toward adapted params + with torch.no_grad(): + for name, param in ttt_params.items(): + base_state[name] += args.reptile_outer_lr * (param.data - base_state[name]) + param.data.copy_(base_state[name]) + + reptile_steps += 1 + + if rank == 0: + print(f"reptile_ttt:done steps:{reptile_steps} elapsed:{time.perf_counter()-t0:.1f}s", flush=True) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + # Set module-level flag so CausalSelfAttention.forward can use FA3. + global _use_fa3 + _use_fa3 = args.use_fa3 and _HAS_FA3 + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.eval_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_loops=args.num_loops, + lora_rank=args.lora_rank, + mlp_hidden=args.mlp_hidden, + smear_gate=args.smear_gate, + bigram_hash=args.bigram_hash, + bigram_hash_buckets=args.bigram_hash_buckets, + bigram_hash_dim=args.bigram_hash_dim, + ortho_init=args.ortho_init, + xsa_last_n=args.xsa_last_n, + ntk_base_seq_len=args.train_seq_len if args.eval_seq_len > args.train_seq_len else 0, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + unet_skips=args.unet_skips, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + if args._tier2: + log0(f"*** TIER2_MODE: proxy run max={args.max_wallclock_seconds:.0f}s iters={args.iterations} " + f"ema={args.ema_enabled} ttt={args.ttt_enabled} qat={args.qat} " + f"-- compare val_bpb@step2000 against baseline tier2 run ***") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"xsa_last_n:{args.xsa_last_n} active_layers:{xsa_layers}") + ntk_active = args.eval_seq_len > args.train_seq_len + log0(f"ntk_rope:{'enabled' if ntk_active else 'disabled'} train_seq_len:{args.train_seq_len} eval_seq_len:{args.eval_seq_len}") + head_dim = args.model_dim // args.num_heads + _rope_dims_active = args.rope_dims > 0 and args.rope_dims < head_dim + log0(f"partial_rope:{'enabled' if _rope_dims_active else 'disabled'} rope_dims:{args.rope_dims if _rope_dims_active else head_dim}/{head_dim} (ROPE_DIMS={args.rope_dims})") + log0(f"ln_scale:{'enabled' if args.ln_scale else 'disabled'} (scale RMSNorm output by 1/sqrt(layer_idx+1))") + log0(f"unet_skips:{'enabled' if args.unet_skips else 'disabled'} (U-Net skip connections, enc={base_model.num_encoder_layers} dec={base_model.num_decoder_layers})") + if args.ve_enabled: + log0(f"ve:enabled dim={args.ve_dim} layers={base_model.ve_layer_indices}") + else: + log0("ve:disabled") + + log0(f"smear_gate:{args.smear_gate} bigram_hash:{args.bigram_hash} swa:{args.swa} " + f"ortho_init:{args.ortho_init} late_k_fp16:{args.late_k_fp16} " + f"fa3:{_use_fa3}(available={_HAS_FA3}) muon_wd:{args.muon_wd} adam_wd:{args.adam_wd}") + + # FP16 tied embedding export: skip int8 quantization for tok_emb.weight at export time. + # Avoids compounding int8 errors through both input embedding and output projection. + if args.fp16_embed_export and args.tie_embeddings: + _FP16_EXPORT_NAMES.add("tok_emb.weight") + log0(f"fp16_embed_export:enabled (tok_emb.weight kept in fp16, ~{args.vocab_size * args.model_dim * 2 / 1024:.0f}KB)") + + # Late-K: keep K projections of last 2 layers in fp16 (not quantized). + # Saves per-query context accuracy where it matters most — near the output. + if args.late_k_fp16: + effective_depth = args.num_layers * args.num_loops + for layer_idx in range(effective_depth - 2, effective_depth): + block_idx = layer_idx % args.num_layers + key_name = f"blocks.{block_idx}.attn.c_k.weight" + _FP16_EXPORT_NAMES.add(key_name) + log0(f"late_k_fp16:enabled (last 2 effective layers' c_k.weight kept in fp16)") + + for module in base_model.modules(): + if isinstance(module, (CastedLinear, AttentionLoRA)): + module.float() + restore_low_dim_params_to_fp32(base_model) + log0(f"qat:{args.qat} (activates when lr_scale < 0.1; absmax int6 STE)") + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + # bigram_hash.proj is a dense 2D projection — Muon is appropriate + if base_model.bigram_hash is not None: + matrix_params.append(base_model.bigram_hash.proj.weight) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights is not None: + scalar_params.append(base_model.skip_weights) + # smear_gate.gate is now a single nn.Parameter(dim) — AdamW at scalar_lr + if base_model.smear_gate is not None: + scalar_params.append(base_model.smear_gate.gate) + # bigram_hash.scale is a learned scalar — AdamW at scalar_lr + if base_model.bigram_hash is not None: + scalar_params.append(base_model.bigram_hash.scale) + # VE: scales go to scalar, proj to matrix, embed to tok group + if base_model.ve_shared is not None: + scalar_params.append(base_model.ve_shared.scale) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + if base_model.ve_layer_scales is not None: + for vs in base_model.ve_layer_scales: + scalar_params.append(vs) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + # bigram_hash.embed is an embedding table — train with AdamW alongside tok_emb + embed_params = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.embed.weight) + if base_model.ve_shared is not None: + embed_params.append(base_model.ve_shared.embed.weight) + optimizer_tok = torch.optim.AdamW( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lora_adapters is not None: + lora_params = list(base_model.lora_adapters.parameters()) + optimizer_lora = torch.optim.Adam( + [{"params": lora_params, "lr": args.lora_lr, "base_lr": args.lora_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.append(optimizer_lora) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + n_lora = sum(p.numel() for p in base_model.lora_adapters.parameters()) if base_model.lora_adapters is not None else 0 + effective_depth = args.num_layers * args.num_loops + log0(f"model_params:{n_params} (unique_layers:{args.num_layers} loops:{args.num_loops} effective_depth:{effective_depth} lora_rank:{args.lora_rank} lora_params:{n_lora})") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + # Curriculum learning: sort shards by average document length (easy first). + curriculum_files: list[Path] | None = None + if args.curriculum: + train_file_list = [Path(p) for p in sorted(glob.glob(args.train_files))] + curriculum_files = sort_shards_by_doc_length(train_file_list) + log0(f"curriculum:enabled shards_sorted_by_doc_length ({len(curriculum_files)} shards)") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + sorted_files=curriculum_files) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + # NOTE: QAT graph priming removed — it caused torch.compile to use a slower + # compilation path for the non-QAT forward pass (step_avg jumped from 44ms to 58ms). + # The one-time recompile when QAT activates (~30-90s) is cheaper than the cumulative + # overhead of a slower non-QAT path across thousands of steps. + + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + sorted_files=curriculum_files) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + qat_active = False + + # EMA: exponential moving average — smoother than SWA, better quantization compression. + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + log0(f"ema:initialized decay={args.ema_decay}") + + # SWA: fallback if EMA disabled + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + + # EMA: update every step during training + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for k, v in base_model.state_dict().items(): + ema_state[k].mul_(d).add_(v.detach().float(), alpha=1.0 - d) + + # LR-scale-based QAT activation: activate when lr_scale < 0.1 (last ~10% of warmdown, + # ~300 steps). Zero overhead for 90%+ of training; absmax scale makes per-step cost minimal. + if args.qat and not qat_active: + if scale < 0.1: + qat_active = True + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module._qat = True + elapsed_s = training_time_ms / 1000.0 + log0(f"qat_activated step:{step}/{args.iterations} lr_scale:{scale:.4f} elapsed:{elapsed_s:.1f}s remaining:{args.max_wallclock_seconds - elapsed_s:.1f}s") + + # SWA: accumulate weight averages during warmdown for smoother quantization. + # Accumulate in float32 to avoid bf16 precision loss over thousands of additions. + # Sample every 200 steps for sufficient checkpoint diversity. + if args.swa and step >= int(args.iterations * args.swa_start_frac) and step % 200 == 0: + if swa_state is None: + swa_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa_started step:{step}") + else: + swa_count += 1 + for k, v in base_model.state_dict().items(): + swa_state[k] += v.detach().float() + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + # Apply EMA weights (preferred) or SWA fallback before serialization. + if ema_state is not None: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + base_model.load_state_dict( + {k: v.to(dtype=current_state[k].dtype) for k, v in ema_state.items()}, strict=True + ) + del ema_state + elif args.swa and swa_state is not None and swa_count > 1: + log0(f"swa_applied count:{swa_count}") + current_state = base_model.state_dict() + avg_state = {} + for k, v in swa_state.items(): + avg = v / swa_count + avg_state[k] = avg.to(dtype=current_state[k].dtype) + base_model.load_state_dict(avg_state, strict=True) + + # Reptile meta-TTT: makes subsequent TTT ~10x more effective by finding weights that adapt fast. + if args.reptile_enabled: + log0(f"reptile_ttt:start budget={args.reptile_budget_s:.0f}s inner_steps={args.reptile_inner_steps} inner_lr={args.reptile_inner_lr} outer_lr={args.reptile_outer_lr}") + reptile_ttt(args, base_model, device, val_tokens, rank=rank) + + # TTT: adapt to val distribution before eval + if args.ttt_enabled: + if args.ttt_two_phase: + log0(f"ttt:start two_phase p1_epochs={args.ttt_p1_epochs} p1_lr={args.ttt_p1_lr} " + f"p2_epochs={args.ttt_p2_epochs} p2_lr={args.ttt_p2_lr} p2_blocks={args.ttt_p2_unfreeze_blocks} " + f"batch_seqs={args.ttt_batch_seqs}") + else: + log0(f"ttt:start lr={args.ttt_lr} epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"batch_seqs={args.ttt_batch_seqs}") + ttt_adapt(args, base_model, device, val_tokens, rank=rank, world_size=world_size) + + # Magnitude pruning: zero the smallest weights for better zstd compression. + # Zeroed weights compress to nearly nothing. Applied after TTT, before serialization. + if args.prune_pct > 0.0: + pruned_count = 0 + total_count = 0 + with torch.no_grad(): + for name, p in base_model.named_parameters(): + if p.ndim >= 2 and p.numel() >= 65536: + threshold = torch.quantile(p.abs().float().flatten(), args.prune_pct / 100.0) + mask = p.abs() > threshold + pruned_count += (~mask).sum().item() + total_count += p.numel() + p.mul_(mask) + log0(f"prune:{args.prune_pct:.1f}% zeroed {pruned_count}/{total_count} weights ({100*pruned_count/max(total_count,1):.1f}%)") + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + global _gptq_lite + _gptq_lite = args.gptq_lite + log0(f"quantization: {args.quant_bits}-bit gptq_lite:{_gptq_lite}") + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict(), bits=args.quant_bits) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if args.use_zstd and _HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_method = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_method = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+{compress_method}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+{compress_method}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + # Decompress with the same method used for compression. + if args.use_zstd and _HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_raw_disk = dctx.decompress(quant_blob_disk) + else: + quant_raw_disk = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_raw_disk), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.eval_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs} doc_isolated:{args.doc_isolated_eval}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Sun Mar 22 23:17:21 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 24C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 24C P0 112W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 26C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 23C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 23C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 25C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 22C P0 108W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 22C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 46695 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 46696 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 46697 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 46698 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 46699 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 46700 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 46701 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 46702 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +xsa_last_n:0 active_layers:[] +ntk_rope:disabled train_seq_len:2048 eval_seq_len:2048 +partial_rope:enabled rope_dims:16/64 (ROPE_DIMS=16) +ln_scale:enabled (scale RMSNorm output by 1/sqrt(layer_idx+1)) +unet_skips:enabled (U-Net skip connections, enc=5 dec=6) +ve:disabled +smear_gate:True bigram_hash:True swa:False ortho_init:True late_k_fp16:False fa3:True(available=True) muon_wd:0.04 adam_wd:0.04 +qat:False (activates when lr_scale < 0.1; absmax int6 STE) +model_params:26829913 (unique_layers:11 loops:1 effective_depth:11 lora_rank:0 lora_params:0) +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +ema:initialized decay=0.997 +step:0/20000 val_loss:6.9320 val_bpb:4.1055 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9334 train_time:115ms step_avg:114.68ms +step:2/20000 train_loss:8.7739 train_time:202ms step_avg:101.23ms +step:3/20000 train_loss:7.8777 train_time:303ms step_avg:101.04ms +step:4/20000 train_loss:7.3190 train_time:378ms step_avg:94.53ms +step:5/20000 train_loss:6.9660 train_time:453ms step_avg:90.68ms +step:6/20000 train_loss:7.7223 train_time:529ms step_avg:88.25ms +step:7/20000 train_loss:6.7481 train_time:606ms step_avg:86.58ms +step:8/20000 train_loss:6.6493 train_time:700ms step_avg:87.44ms +step:9/20000 train_loss:6.4476 train_time:775ms step_avg:86.16ms +step:10/20000 train_loss:6.1911 train_time:852ms step_avg:85.16ms +step:200/20000 train_loss:2.8193 train_time:16817ms step_avg:84.08ms +step:400/20000 train_loss:2.2926 train_time:33784ms step_avg:84.46ms +step:600/20000 train_loss:2.4888 train_time:50822ms step_avg:84.70ms +step:800/20000 train_loss:2.2321 train_time:67843ms step_avg:84.80ms +step:1000/20000 train_loss:2.3249 train_time:84772ms step_avg:84.77ms +step:1000/20000 val_loss:2.2804 val_bpb:1.3506 train_time:84787ms step_avg:84.79ms +step:1200/20000 train_loss:2.3468 train_time:102195ms step_avg:85.16ms +step:1400/20000 train_loss:2.3824 train_time:119182ms step_avg:85.13ms +step:1600/20000 train_loss:2.0522 train_time:136100ms step_avg:85.06ms +step:1800/20000 train_loss:2.1528 train_time:153770ms step_avg:85.43ms +step:2000/20000 train_loss:2.1926 train_time:171094ms step_avg:85.55ms +step:2000/20000 val_loss:2.1756 val_bpb:1.2885 train_time:171109ms step_avg:85.55ms +step:2200/20000 train_loss:2.0115 train_time:188283ms step_avg:85.58ms +step:2400/20000 train_loss:2.1350 train_time:205403ms step_avg:85.58ms +step:2600/20000 train_loss:2.3684 train_time:223316ms step_avg:85.89ms +step:2800/20000 train_loss:2.1761 train_time:242615ms step_avg:86.65ms +step:3000/20000 train_loss:2.1601 train_time:260572ms step_avg:86.86ms +step:3000/20000 val_loss:2.1304 val_bpb:1.2617 train_time:260588ms step_avg:86.86ms +step:3200/20000 train_loss:2.1246 train_time:280254ms step_avg:87.58ms +step:3400/20000 train_loss:2.0994 train_time:299008ms step_avg:87.94ms +step:3600/20000 train_loss:2.0371 train_time:316259ms step_avg:87.85ms +step:3800/20000 train_loss:2.1477 train_time:333634ms step_avg:87.80ms +step:4000/20000 train_loss:2.1157 train_time:350325ms step_avg:87.58ms +step:4000/20000 val_loss:2.1068 val_bpb:1.2477 train_time:350341ms step_avg:87.59ms +step:4200/20000 train_loss:2.1034 train_time:374563ms step_avg:89.18ms +step:4400/20000 train_loss:2.0322 train_time:392829ms step_avg:89.28ms +step:4600/20000 train_loss:1.8875 train_time:409977ms step_avg:89.13ms +step:4800/20000 train_loss:2.1690 train_time:426555ms step_avg:88.87ms +step:5000/20000 train_loss:1.9182 train_time:444189ms step_avg:88.84ms +step:5000/20000 val_loss:2.0565 val_bpb:1.2180 train_time:444203ms step_avg:88.84ms +step:5200/20000 train_loss:2.0714 train_time:461720ms step_avg:88.79ms +step:5400/20000 train_loss:2.0776 train_time:482351ms step_avg:89.32ms +step:5600/20000 train_loss:2.0573 train_time:500127ms step_avg:89.31ms +step:5800/20000 train_loss:2.0011 train_time:519711ms step_avg:89.61ms +step:6000/20000 train_loss:2.0728 train_time:537932ms step_avg:89.66ms +step:6000/20000 val_loss:1.9982 val_bpb:1.1834 train_time:537947ms step_avg:89.66ms +step:6200/20000 train_loss:1.9339 train_time:555647ms step_avg:89.62ms +step:6400/20000 train_loss:2.0048 train_time:572457ms step_avg:89.45ms +step:6600/20000 train_loss:1.9458 train_time:589921ms step_avg:89.38ms +step:6700/20000 val_loss:1.9566 val_bpb:1.1588 train_time:599899ms step_avg:89.54ms +stopping_early: wallclock_cap train_time:599899ms step:6700/20000 +peak memory allocated: 13401 MiB reserved: 13550 MiB +ema:applying EMA weights +ttt:start lr=0.0005 epochs=30 freeze_blocks=0 batch_seqs=64 +Serialized model: 105658303 bytes +Code size: 104414 bytes +Total submission size: 105762717 bytes +quantization: 6-bit gptq_lite:True +Serialized model int8+zstd-22: 15362215 bytes (payload:27056482 raw_torch:27113039 payload_ratio:3.90x) +Total submission size int8+zstd-22: 15466629 bytes +final_eval_mode:sliding_window stride:64 batch_seqs:32 doc_isolated:False +final_int8_zlib_roundtrip val_loss:1.8524 val_bpb:1.0971 eval_time:186610ms +final_int8_zlib_roundtrip_exact val_loss:1.85244699 val_bpb:1.09712345 diff --git a/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_seed7.log b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_seed7.log new file mode 100644 index 000000000..3b7c53791 --- /dev/null +++ b/records/track_10min_16mb/2026-03-22_CosineTTT_PerLayer/train_seed7.log @@ -0,0 +1,2428 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +try: + import zstandard as zstd + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + from flash_attn_interface import flash_attn_func as _flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn import flash_attn_func as _flash_attn_3_func + _HAS_FA3 = True # FA2 fallback — same input format, slightly slower than FA3 + except ImportError: + _HAS_FA3 = False +_use_fa3: bool = False # set at runtime after args are parsed + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) # eval window; NTK-RoPE scales if > train_seq_len + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # ── Tier-2 proxy mode ──────────────────────────────────────────────────── + # TIER2_MODE=1 overrides key settings for a fast 3-minute proxy run. + # Use this to validate architectural features before committing a full run. + # Schedule-dependent features (EMA, TTT, SWA) are disabled — they can only + # be evaluated meaningfully at full training duration. + # Compare val_bpb at step ~2000 against a baseline TIER2_MODE=1 run. + _tier2 = bool(int(os.environ.get("TIER2_MODE", "0"))) + if _tier2: + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 180.0)) + iterations = int(os.environ.get("ITERATIONS", 3000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 500)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) # faster final eval + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) # 0 = use mlp_mult * model_dim + fp16_embed_export = bool(int(os.environ.get("FP16_EMBED_EXPORT", "1"))) # keep tok_emb in fp16 at export + num_loops = int(os.environ.get("NUM_LOOPS", 1)) + lora_rank = int(os.environ.get("LORA_RANK", 0)) + qat = bool(int(os.environ.get("QAT", "1"))) + qat_min_seconds = float(os.environ.get("QAT_MIN_SECONDS", 120.0)) # guarantee QAT runs for at least this long + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + doc_isolated_eval = bool(int(os.environ.get("DOC_ISOLATED_EVAL", "1"))) # eval per-document, no cross-doc context + smear_gate = bool(int(os.environ.get("SMEAR_GATE", "1"))) # cheap bigram context at embedding layer + bigram_hash = bool(int(os.environ.get("BIGRAM_HASH", "1"))) # hash-based bigram embedding + bigram_hash_buckets = int(os.environ.get("BIGRAM_HASH_BUCKETS", 2048)) + bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128)) + swa = bool(int(os.environ.get("SWA", "0"))) # stochastic weight averaging (disabled: EMA preferred) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last N layers (0=disabled) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) # exponential moving average (replaces SWA) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) # EMA decay per step + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) # test-time training on val data + ttt_lr = float(os.environ.get("TTT_LR", 0.0005)) # AdamW lr (PR #442: AdamW beats SGD by 0.019 BPB) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 30)) # 30ep cosine beats 10ep flat by 16% + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) # only used if TTT_OPTIMIZER=sgd + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) # 0 = all blocks unfrozen (PR #398) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adamw") # "adamw" or "sgd" + ttt_max_steps = int(os.environ.get("TTT_MAX_STEPS", 300)) # cap steps per epoch (~10s per epoch) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 64)) # seqs per GPU per step (64*8=512 total) + ttt_cosine = bool(int(os.environ.get("TTT_COSINE", "1"))) # cosine lr decay during TTT (+16% over flat) + ttt_warmup_frac = float(os.environ.get("TTT_WARMUP_FRAC", 0.0)) # linear warmup fraction (0.1 = 10%) + ttt_perlayer = bool(int(os.environ.get("TTT_PERLAYER", "1"))) # per-layer lr (3x proj, 0.5x fc) — +23.5% combined with cosine + # Two-phase TTT (matches PR #415/#417 approach) + ttt_two_phase = bool(int(os.environ.get("TTT_TWO_PHASE", "0"))) # enable two-phase TTT + ttt_p1_epochs = int(os.environ.get("TTT_P1_EPOCHS", 50)) # phase 1: norm-only recalibration + ttt_p1_lr = float(os.environ.get("TTT_P1_LR", 0.01)) # phase 1: Adam lr + ttt_p2_epochs = int(os.environ.get("TTT_P2_EPOCHS", 10)) # phase 2: selective block adaptation + ttt_p2_lr = float(os.environ.get("TTT_P2_LR", 0.005)) # phase 2: SGD lr + ttt_p2_unfreeze_blocks = int(os.environ.get("TTT_P2_UNFREEZE_BLOCKS", 3)) # phase 2: unfreeze last N blocks + use_zstd = bool(int(os.environ.get("USE_ZSTD", "1"))) # use zstd instead of zlib for compression + curriculum = bool(int(os.environ.get("CURRICULUM", "0"))) # sort training shards by doc length (easy first) + quant_bits = int(os.environ.get("QUANT_BITS", 6)) # 8=int8, 6=int6 (int6 fits ~3x more params in 16MB) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 50000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + lora_lr = float(os.environ.get("LORA_LR", 0.01)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + ortho_init = bool(int(os.environ.get("ORTHO_INIT", "1"))) + late_k_fp16 = bool(int(os.environ.get("LATE_K_FP16", "1"))) + use_fa3 = bool(int(os.environ.get("USE_FA3", "1"))) + rope_dims = int(os.environ.get("ROPE_DIMS", 0)) # 0 = full RoPE; >0 = apply RoPE to only first N dims per head + ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) # scale block output by 1/sqrt(layer_idx+1) + unet_skips = bool(int(os.environ.get("UNET_SKIPS", "0"))) # U-Net style skip connections + prune_pct = float(os.environ.get("PRUNE_PCT", 0.0)) # magnitude pruning: zero smallest N% of weights before compression + gptq_lite = bool(int(os.environ.get("GPTQ_LITE", "1"))) # per-row optimal clip search at quantization (default ON, zero training cost) + reptile_enabled = bool(int(os.environ.get("REPTILE_TTT", "0"))) # Reptile meta-TTT before standard TTT + reptile_budget_s = float(os.environ.get("REPTILE_BUDGET_S", 60.0)) # seconds for Reptile meta-learning + reptile_inner_steps = int(os.environ.get("REPTILE_INNER_STEPS", 3)) # SGD steps per inner loop + reptile_inner_lr = float(os.environ.get("REPTILE_INNER_LR", 0.1)) # inner SGD learning rate + reptile_outer_lr = float(os.environ.get("REPTILE_OUTER_LR", 0.01)) # outer interpolation rate + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "0"))) # shared value embedding + ve_dim = int(os.environ.get("VE_DIM", 128)) # value embedding dimension + ve_layers = os.environ.get("VE_LAYERS", "9,10") # comma-separated layer indices + + # Disable schedule-dependent features in TIER2_MODE unless explicitly overridden + if _tier2: + qat = bool(int(os.environ.get("QAT", "0"))) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + swa = bool(int(os.environ.get("SWA", "0"))) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + # Cache flat buffer to avoid per-step allocation + if "_updates_flat" not in group: + group["_total_params"] = sum(int(p.numel()) for p in params) + group["_updates_flat"] = torch.zeros(group["_total_params"], device=params[0].device, dtype=torch.bfloat16) + total_params = group["_total_params"] + updates_flat = group["_updates_flat"] + updates_flat.zero_() + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + if wd > 0.0: + p.mul_(1.0 - wd * lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_2d_with_clip(t32: Tensor, clip_abs: Tensor, max_val: int) -> tuple[Tensor, Tensor, Tensor]: + """Quantize a 2D tensor with given per-row clip values. Returns (q, scale, reconstruction_error).""" + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8) + # Reconstruction error: MSE between original and dequantized + deq = q.float() * scale[:, None] + err = (t32 - deq).square().mean(dim=1) + return q, scale, err + + +# GPTQ-lite clip ratios: search these percentiles per weight matrix to find optimal clipping. +_GPTQ_CLIP_RATIOS = [0.9, 0.95, 0.99, 0.999, 0.99999] +_gptq_lite: bool = False # set at runtime from GPTQ_LITE env var + + +def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]: + max_val = 127 if bits == 8 else (2 ** (bits - 1)) - 1 # int6: 31, int8: 127 + t32 = t.float() + if t32.ndim == 2: + if _gptq_lite and t32.numel() > 0: + # GPTQ-lite: try multiple clip percentiles per row, pick lowest reconstruction error. + abs_vals = t32.abs() + best_q: Tensor | None = None + best_scale: Tensor | None = None + best_err: Tensor | None = None + for ratio in _GPTQ_CLIP_RATIOS: + clip_abs = torch.quantile(abs_vals, ratio, dim=1) + q, scale, err = _quantize_2d_with_clip(t32, clip_abs, max_val) + if best_err is None: + best_q, best_scale, best_err = q, scale, err + else: + # Per-row: keep whichever clip ratio gave lower error for each row + better = err < best_err + best_q[better] = q[better] + best_scale[better] = scale[better] + best_err[better] = err[better] + return best_q.contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Standard: single clip percentile + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(max_val) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous() + return q, scale + +# Names of tensors to keep in fp16 at export instead of quantizing to int8. +# Populated at runtime when fp16_embed_export=True. +_FP16_EXPORT_NAMES: set[str] = set() + +def quantize_state_dict_int8(state_dict: dict[str, Tensor], bits: int = 8): + # Clean-script export format supporting int8 or int6: + # - per-row quantization for 2D float tensors (int8: [-127,127], int6: [-31,31]) + # - per-tensor quantization for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + # - fp16 passthrough for tensors in _FP16_EXPORT_NAMES (e.g. tied embeddings) + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # FP16 export bypass: keep specified tensors (e.g. tied embeddings) in fp16 + # instead of quantizing to int8. Avoids compounding int8 errors through both + # the input embedding and output projection paths. + if name in _FP16_EXPORT_NAMES: + kept = t.to(dtype=torch.float16).contiguous() + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t, bits=bits) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" list[Path]: + """Sort shards by average document length (shorter docs = easier = first). + + Reads only the first 100K tokens of each shard to estimate avg doc length. + Shards with shorter average documents contain simpler, more repetitive text + that helps the model learn basic patterns before encountering harder material. + """ + difficulties: list[tuple[float, Path]] = [] + for f in files: + header = np.fromfile(f, dtype=" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device, + sorted_files: list[Path] | None = None): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern, sorted_files=sorted_files) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +def fake_quantize_int8_per_row(w: Tensor) -> Tensor: + """Simulate per-row int6 quantization with straight-through estimator. + + Uses absmax per-row scale (no torch.quantile — O(n) instead of O(n log n)). + Matches the int6 export range [-31, 31]. + Backward: gradients pass through as if no quantization happened (STE). + """ + w32 = w.float() + scale = (w32.abs().amax(dim=1) / 31.0).clamp_min(1.0 / 31.0) + w_q = torch.clamp(torch.round(w32 / scale[:, None]), -31, 31) + w_deq = w_q * scale[:, None] + return w + (w_deq.to(w.dtype) - w).detach() + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + _qat: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self._qat and self.training: + w = fake_quantize_int8_per_row(w) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +class AttentionLoRA(nn.Module): + """Per-iteration LoRA adapters for attention Q, K, V, and output projections. + + Initialized so that the LoRA contribution is zero at the start of training + (B matrices are zeros). During training, the optimizer learns per-iteration + specialization while the base attention weights remain shared across loops. + """ + def __init__(self, dim: int, kv_dim: int, rank: int): + super().__init__() + self.q_A = nn.Parameter(torch.empty(dim, rank)) + self.q_B = nn.Parameter(torch.zeros(rank, dim)) + self.k_A = nn.Parameter(torch.empty(dim, rank)) + self.k_B = nn.Parameter(torch.zeros(rank, kv_dim)) + self.v_A = nn.Parameter(torch.empty(dim, rank)) + self.v_B = nn.Parameter(torch.zeros(rank, kv_dim)) + self.proj_A = nn.Parameter(torch.empty(dim, rank)) + self.proj_B = nn.Parameter(torch.zeros(rank, dim)) + self._init_lora() + + def _init_lora(self) -> None: + for name in ("q_A", "k_A", "v_A", "proj_A"): + nn.init.kaiming_uniform_(getattr(self, name), a=math.sqrt(5)) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + # ntk_base_seq_len: if > 0, apply NTK-aware RoPE scaling when seq_len > ntk_base_seq_len + # (lets a model trained at seq_len=1024 generalise to seq_len=2048 at eval with no quality loss) + def __init__(self, dim: int, base: float = 10000.0, ntk_base_seq_len: int = 0): + super().__init__() + self._dim = dim + self._base = base + self._ntk_base_seq_len = ntk_base_seq_len + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + if self._ntk_base_seq_len > 0 and seq_len > self._ntk_base_seq_len: + # NTK-aware scaling: extend context without fine-tuning + scale = seq_len / self._ntk_base_seq_len + ntk_base = self._base * (scale ** (self._dim / (self._dim - 2))) + inv_freq = 1.0 / (ntk_base ** (torch.arange(0, self._dim, 2, dtype=torch.float32, device=device) / self._dim)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + # rope_dims: if >0 apply RoPE only to first rope_dims dims of each head; rest are position-free + self.rope_dims = rope_dims if rope_dims > 0 else self.head_dim + self.rotary = Rotary(self.rope_dims, base=rope_base, ntk_base_seq_len=ntk_base_seq_len) + self.use_xsa = False # enabled on last N layers by GPT.__init__ + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Exclusive Self Attention: subtract self-value projection (arXiv:2603.09078). + GQA-aware reshape avoids repeat_interleave — zero extra allocation.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) # (B, T, Hkv, 1, D) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, lora: AttentionLoRA | None = None, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + k = self.c_k(x) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + if lora is not None: + # LoRA delta: (bsz, seqlen, dim) @ (dim, rank) @ (rank, out_dim) + # autocast handles fp32->bf16 cast of LoRA params automatically + q = q + (x @ lora.q_A) @ lora.q_B + k = k + (x @ lora.k_A) @ lora.k_B + v = v + (x @ lora.v_A) @ lora.v_B + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dims < self.head_dim: + # Partial RoPE: apply rotation to first rope_dims dims, leave remaining dims untouched + q = torch.cat([apply_rotary_emb(q[..., :self.rope_dims], cos, sin), q[..., self.rope_dims:]], dim=-1) + k = torch.cat([apply_rotary_emb(k[..., :self.rope_dims], cos, sin), k[..., self.rope_dims:]], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if _HAS_FA3 and _use_fa3: + # FA3 expects [bsz, seqlen, heads, head_dim] + q_fa = q.transpose(1, 2).contiguous() + k_fa = k.transpose(1, 2).contiguous() + v_fa = v.transpose(1, 2).contiguous() + y = _flash_attn_3_func(q_fa, k_fa, v_fa, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v_fa) + y = y.reshape(bsz, seqlen, dim) + else: + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + if self.use_xsa: + y = self._xsa_efficient(y.transpose(1, 2), v.transpose(1, 2)) + else: + y = y.transpose(1, 2) + y = y.contiguous().reshape(bsz, seqlen, dim) + out = self.proj(y) + if lora is not None: + out = out + (y @ lora.proj_A) @ lora.proj_B + return out + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, ntk_base_seq_len=ntk_base_seq_len, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, lora: AttentionLoRA | None = None, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x) * s, lora=lora, v_embed=v_embed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + + +# ── Shared Value Embedding (VE) ────────────────────────────────────────────── +class ValueEmbedding(nn.Module): + """Learned embedding added to attention values in selected layers. + One shared table across layers, with per-layer learned scales.""" + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, kv_dim, bias=False) if ve_dim != kv_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +# ── SmearGate: cheap bigram context at embedding layer ────────────────────── +class SmearGate(nn.Module): + """Blends each token's embedding with the previous token's via per-channel sigmoid gates. + 512 independent channel gates (vs scalar gate) give the model richer bigram context.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + # x: (bsz, seq_len, dim) + prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] # (1, 1, dim) + return torch.lerp(x, prev, g) + + +# ── BigramHash: hash-based bigram embedding ───────────────────────────────── +class BigramHashEmbedding(nn.Module): + """Maps consecutive token pairs to embeddings via a hash table. + Injects explicit bigram statistics into the residual stream.""" + def __init__(self, num_buckets: int, hash_dim: int, model_dim: int): + super().__init__() + self.num_buckets = num_buckets + self.embed = nn.Embedding(num_buckets, hash_dim) + self.proj = nn.Linear(hash_dim, model_dim, bias=False) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + nn.init.normal_(self.embed.weight, std=0.01) + nn.init.zeros_(self.proj.weight) + + def _hash_pair(self, prev_ids: Tensor, cur_ids: Tensor) -> Tensor: + # XOR hash with large primes for better distribution + return (torch.bitwise_xor(36313 * prev_ids.long(), 27191 * cur_ids.long()) % max(self.num_buckets - 1, 1)).to(prev_ids.device) + + def forward(self, input_ids: Tensor) -> Tensor: + # input_ids: (bsz, seq_len) + prev_ids = torch.cat([torch.zeros_like(input_ids[:, :1]), input_ids[:, :-1]], dim=1) + bucket_ids = self._hash_pair(prev_ids, input_ids) + return self.scale.to(dtype=input_ids.dtype if input_ids.is_floating_point() else torch.bfloat16) * self.proj(self.embed(bucket_ids)) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_loops: int = 1, + lora_rank: int = 0, + mlp_hidden: int = 0, + smear_gate: bool = False, + bigram_hash: bool = False, + bigram_hash_buckets: int = 4096, + bigram_hash_dim: int = 128, + ortho_init: bool = True, + xsa_last_n: int = 0, + ntk_base_seq_len: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + unet_skips: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.ortho_init = ortho_init + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_unique_layers = num_layers + self.num_loops = num_loops + effective_depth = num_layers * num_loops + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.smear_gate = SmearGate(model_dim) if smear_gate else None + self.bigram_hash = BigramHashEmbedding(bigram_hash_buckets, bigram_hash_dim, model_dim) if bigram_hash else None + # Shared Value Embedding: one table, added to V in selected layers + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = num_kv_heads * (model_dim // num_heads) + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) if ve_enabled else None + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) if ve_enabled else None + self.unet_skips = unet_skips + self.num_encoder_layers = effective_depth // 2 + self.num_decoder_layers = effective_depth - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) if unet_skips else None + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + mlp_hidden=mlp_hidden, + ntk_base_seq_len=ntk_base_seq_len, + rope_dims=rope_dims, + layer_idx=i, + ln_scale=ln_scale, + ) + for i in range(num_layers) + ] + ) + # Per-(loop, block) LoRA adapters for attention projections. + # Only created when num_loops > 1 and lora_rank > 0. + kv_dim = num_kv_heads * (model_dim // num_heads) + if lora_rank > 0 and num_loops > 1: + self.lora_adapters = nn.ModuleList( + [ + nn.ModuleList( + [AttentionLoRA(model_dim, kv_dim, lora_rank) for _ in range(num_layers)] + ) + for _ in range(num_loops) + ] + ) + else: + self.lora_adapters = None + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif self.ortho_init and isinstance(module, (nn.Linear, CastedLinear)) \ + and not getattr(module, "_zero_init", False) \ + and module.weight.ndim >= 2: + nn.init.orthogonal_(module.weight, gain=1.0) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_gate is not None: + x = self.smear_gate(x) + x0 = x + num_enc = self.num_encoder_layers + num_dec = self.num_decoder_layers + + # Compute shared VE once (cached across layers) + ve_base = self.ve_shared(input_ids) if self.ve_shared is not None else None + + def _get_ve(layer_idx: int) -> Tensor | None: + if ve_base is None or layer_idx not in self.ve_layer_indices: + return None + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + if self.unet_skips: + skips: list[Tensor] = [] + for i in range(num_enc): + x = self.blocks[i](x, x0, v_embed=_get_ve(i)) + skips.append(x) + for i in range(num_dec): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[num_enc + i](x, x0, v_embed=_get_ve(num_enc + i)) + else: + eff_idx = 0 + for loop_idx in range(self.num_loops): + for block_idx in range(self.num_unique_layers): + lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None + x = self.blocks[block_idx](x, x0, lora=lora, v_embed=_get_ve(eff_idx)) + eff_idx += 1 + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_gate is not None: + x = self.smear_gate(x) + x0 = x + num_enc = self.num_encoder_layers + num_dec = self.num_decoder_layers + + if self.unet_skips: + skips: list[Tensor] = [] + for i in range(num_enc): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(num_dec): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[num_enc + i](x, x0) + else: + eff_idx = 0 + for loop_idx in range(self.num_loops): + for block_idx in range(self.num_unique_layers): + lora = self.lora_adapters[loop_idx][block_idx] if self.lora_adapters is not None else None + x = self.blocks[block_idx](x, x0, lora=lora) + eff_idx += 1 + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +BOS_ID = 1 # SentencePiece BOS token ID + +def _find_docs(all_tokens: Tensor) -> list[tuple[int, int]]: + """Return (start, length) for each document, identified by BOS boundaries. + + Each document starts at a BOS token and extends to just before the next BOS. + The last document extends to the end of the token stream. + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].tolist() + if not bos_positions: + return [(0, all_tokens.numel())] + docs = [] + for i in range(len(bos_positions)): + start = bos_positions[i] + end = bos_positions[i + 1] if i + 1 < len(bos_positions) else all_tokens.numel() + if end - start >= 2: # need at least 2 tokens for (x, y) pair + docs.append((start, end - start)) + return docs + + +def _build_sliding_windows( + total_tokens: int, seq_len: int, stride: int +) -> list[tuple[int, int]]: + """Return (window_start, score_start) pairs covering every token exactly once. + + Every token in [0, total_tokens) is scored by exactly one window. + Full windows score their last `stride` positions (first window scores all seq_len). + One tail-aligned window covers any tokens beyond the last full window's end. + """ + if total_tokens <= 0: + return [] + if total_tokens <= seq_len: + return [(0, 0)] + + windows: list[tuple[int, int]] = [] + last_full_end = 0 + ws = 0 + while ws + seq_len <= total_tokens: + s = 0 if ws == 0 else seq_len - stride + windows.append((ws, s)) + last_full_end = ws + seq_len + ws += stride + + # One tail window ending exactly at total_tokens covers any remainder. + if last_full_end < total_tokens: + tail_ws = total_tokens - seq_len + tail_s = last_full_end - tail_ws # skip already-scored prefix + windows.append((tail_ws, tail_s)) + + return windows + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + + Windows of eval_seq_len advance by `stride`. Only newly covered tokens + contribute to the score (first window scores all seq_len; non-first full + windows score the last `stride` tokens; one tail window covers any remainder). + Every validation token is counted exactly once. + """ + seq_len = args.eval_seq_len + total_tokens = val_tokens.numel() - 1 + + if args.doc_isolated_eval: + # Build windows per document — context never crosses document boundaries. + # Each window's (ws, s) is in absolute token-stream coordinates. + docs = _find_docs(val_tokens) + all_windows: list[tuple[int, int]] = [] + for doc_start, doc_len in docs: + doc_pred_len = doc_len - 1 # number of prediction positions + doc_windows = _build_sliding_windows(doc_pred_len, seq_len, stride) + for ws, s in doc_windows: + all_windows.append((doc_start + ws, s)) + else: + all_windows = _build_sliding_windows(total_tokens, seq_len, stride) + total_windows = len(all_windows) + + # Distribute across ranks + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = all_windows[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_items = my_windows[bi:bi + batch_seqs] + bsz = len(batch_items) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + score_starts: list[int] = [] + + for i, (ws, s) in enumerate(batch_items): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + score_starts.append(s) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, (ws, s) in enumerate(batch_items): + wlen = wlens[i] + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # Progress (rank 0 only) + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# TEST-TIME TRAINING (TTT) +# ----------------------------- + +def _ttt_run_phase( + model: nn.Module, + ttt_params: list[torch.nn.Parameter], + optimizer: torch.optim.Optimizer, + val_tokens: torch.Tensor, + seq_len: int, + batch_seqs: int, + epochs: int, + max_steps: int, + device: torch.device, + rank: int, + world_size: int, + phase_name: str, + t0: float, + cosine: bool = False, + warmup_frac: float = 0.0, +) -> None: + """Run one TTT phase with DDP gradient sharding across GPUs. + Each GPU processes batch_seqs sequences, gradients are manually all_reduced. + Supports cosine lr decay and linear warmup.""" + distributed = world_size > 1 + n_tokens = val_tokens.numel() + total_seqs = (n_tokens - 1) // seq_len + my_start_seq = (total_seqs * rank) // world_size + my_end_seq = (total_seqs * (rank + 1)) // world_size + + # Store initial lr for cosine/warmup scheduling + if cosine or warmup_frac > 0: + for g in optimizer.param_groups: + g["initial_lr"] = g["lr"] + # Estimate actual steps per epoch for cosine schedule + steps_per_epoch = min((my_end_seq - my_start_seq) // max(batch_seqs, 1), max_steps) + total_steps = epochs * steps_per_epoch + global_step = 0 + + model.train() + for epoch in range(epochs): + epoch_loss = torch.zeros((), device=device, dtype=torch.float64) + epoch_tokens = torch.zeros((), device=device, dtype=torch.float64) + step_i = 0 + # Contiguous slicing over this GPU's shard (matches #398 pattern) + for batch_start in range(my_start_seq, my_end_seq, batch_seqs): + if step_i >= max_steps: + break + + # LR schedule: warmup then cosine decay + if (cosine or warmup_frac > 0) and total_steps > 0: + progress = global_step / total_steps + if warmup_frac > 0 and progress < warmup_frac: + mul = progress / warmup_frac + elif cosine: + cos_start = warmup_frac if warmup_frac > 0 else 0.0 + cos_progress = (progress - cos_start) / (1.0 - cos_start) if cos_start < 1.0 else 0.0 + mul = 0.5 * (1.0 + math.cos(math.pi * min(cos_progress, 1.0))) + else: + mul = 1.0 + for g in optimizer.param_groups: + g["lr"] = g["initial_lr"] * mul + + batch_end = min(batch_start + batch_seqs, my_end_seq) + raw_start = batch_start * seq_len + raw_end = batch_end * seq_len + 1 + if raw_end > n_tokens: + break + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + loss.backward() + + if distributed: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + epoch_loss += loss.detach().to(torch.float64) * x.numel() + epoch_tokens += x.numel() + step_i += 1 + global_step += 1 + + if distributed: + dist.all_reduce(epoch_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(epoch_tokens, op=dist.ReduceOp.SUM) + + if rank == 0: + avg_loss = epoch_loss.item() / max(epoch_tokens.item(), 1) + cur_lr = optimizer.param_groups[0]["lr"] + print(f"ttt_{phase_name} epoch:{epoch+1}/{epochs} loss:{avg_loss:.4f} " + f"lr:{cur_lr:.6f} steps:{step_i} time:{time.perf_counter()-t0:.1f}s", flush=True) + + +def ttt_adapt( + args: Hyperparameters, + model: nn.Module, + device: torch.device, + val_tokens: torch.Tensor, + rank: int = 0, + world_size: int = 1, +) -> None: + """Two-phase TTT with DDP gradient sharding. + + Phase 1 (norm-only): Fix quantization artifacts by adapting only LayerNorm + weights, scales, resid_mix, and q_gain (~22K params). Low risk, high epoch count. + + Phase 2 (selective blocks): Adapt last N blocks + norms + head to val distribution. + Higher risk, lower epoch count. + + Falls back to single-phase SGD if TTT_TWO_PHASE=0. + """ + t0 = time.perf_counter() + seq_len = args.train_seq_len + batch_seqs = args.ttt_batch_seqs + + if args.ttt_two_phase: + # ── Phase 1: Norm-only recalibration ──────────────────────────── + # Freeze everything except norms, scales, resid_mix, q_gain + norm_params = [] + for p in model.parameters(): + p.requires_grad_(False) + for name, p in model.named_parameters(): + if any(k in name for k in ("norm", "scale", "resid_mix", "q_gain", "skip_weight")): + p.requires_grad_(True) + norm_params.append(p) + n_norm = sum(p.numel() for p in norm_params) + if rank == 0: + print(f"ttt_phase1:start params:{n_norm} epochs:{args.ttt_p1_epochs} lr:{args.ttt_p1_lr}", flush=True) + + optimizer_p1 = torch.optim.Adam(norm_params, lr=args.ttt_p1_lr) + _ttt_run_phase( + model, norm_params, optimizer_p1, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_p1_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="phase1", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer_p1 + + # ── Phase 2: Selective block adaptation ───────────────────────── + # Unfreeze last N blocks + all norms + head + embeddings + for p in model.parameters(): + p.requires_grad_(False) + num_layers = len(list(model.blocks)) # type: ignore[attr-defined] + phase2_params = [] + for name, p in model.named_parameters(): + is_late_block = False + for i in range(max(0, num_layers - args.ttt_p2_unfreeze_blocks), num_layers): + if f"blocks.{i}." in name: + is_late_block = True + break + is_norm_or_scale = any(k in name for k in ("norm", "scale", "resid_mix", "q_gain", "skip_weight")) + is_head = "lm_head" in name or "tok_emb" in name + if is_late_block or is_norm_or_scale or is_head: + p.requires_grad_(True) + phase2_params.append(p) + n_p2 = sum(p.numel() for p in phase2_params) + if rank == 0: + print(f"ttt_phase2:start params:{n_p2} epochs:{args.ttt_p2_epochs} lr:{args.ttt_p2_lr}", flush=True) + + if args.ttt_optimizer == "adamw": + optimizer_p2 = torch.optim.AdamW(phase2_params, lr=args.ttt_p2_lr, weight_decay=0.0) + else: + optimizer_p2 = torch.optim.SGD(phase2_params, lr=args.ttt_p2_lr, momentum=args.ttt_momentum) + _ttt_run_phase( + model, phase2_params, optimizer_p2, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_p2_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="phase2", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer_p2 + else: + # ── Single-phase TTT with cosine + per-layer lr ─────────────── + frozen = set() + for i, block in enumerate(model.blocks): # type: ignore[attr-defined] + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + frozen.add(id(p)) + + if args.ttt_perlayer: + # Per-layer lr: higher for more quant-damaged MLP proj, lower for fc + proj_params = [p for n, p in model.named_parameters() + if "mlp.proj" in n and p.requires_grad and id(p) not in frozen] + fc_params = [p for n, p in model.named_parameters() + if "mlp.fc" in n and p.requires_grad and id(p) not in frozen] + other_params = [p for p in model.parameters() + if p.requires_grad and id(p) not in frozen + and id(p) not in {id(q) for q in proj_params + fc_params}] + param_groups = [ + {"params": proj_params, "lr": args.ttt_lr * 3.0}, + {"params": fc_params, "lr": args.ttt_lr * 0.5}, + {"params": other_params, "lr": args.ttt_lr}, + ] + param_groups = [g for g in param_groups if g["params"]] + ttt_params = proj_params + fc_params + other_params + else: + ttt_params = [p for p in model.parameters() if p.requires_grad and id(p) not in frozen] + param_groups = [{"params": ttt_params, "lr": args.ttt_lr}] + + if rank == 0: + n_ttt = sum(p.numel() for p in ttt_params) + print(f"ttt:start params:{n_ttt} epochs:{args.ttt_epochs} lr:{args.ttt_lr} " + f"freeze:{args.ttt_freeze_blocks} optimizer:{args.ttt_optimizer} " + f"cosine:{args.ttt_cosine} warmup:{args.ttt_warmup_frac} perlayer:{args.ttt_perlayer}", flush=True) + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0) + else: + optimizer = torch.optim.SGD(param_groups, momentum=args.ttt_momentum) + _ttt_run_phase( + model, ttt_params, optimizer, val_tokens, seq_len, batch_seqs, + epochs=args.ttt_epochs, max_steps=args.ttt_max_steps, + device=device, rank=rank, world_size=world_size, + phase_name="single", t0=t0, + cosine=args.ttt_cosine, warmup_frac=args.ttt_warmup_frac, + ) + del optimizer + + # Unfreeze all params for eval + for p in model.parameters(): + p.requires_grad_(True) + if rank == 0: + print(f"ttt:done elapsed={time.perf_counter()-t0:.1f}s", flush=True) + + +def reptile_ttt( + args: Hyperparameters, + model: nn.Module, + device: torch.device, + val_tokens: torch.Tensor, + rank: int = 0, +) -> None: + """Reptile meta-TTT: find weights that adapt fast to val distribution. + Runs after EMA/SWA, before standard TTT. Makes TTT ~10x more effective.""" + t0 = time.perf_counter() + seq_len = args.train_seq_len + n_tokens = val_tokens.numel() + + # Only adapt MLP params of last 1/4 blocks + num_blocks = len(model.blocks) + suffix_start = num_blocks - num_blocks // 4 + ttt_params = {} + for name, p in model.named_parameters(): + if any(f'blocks.{i}.' in name and '.mlp.' in name for i in range(suffix_start, num_blocks)): + ttt_params[name] = p + + base_state = {name: p.data.clone() for name, p in ttt_params.items()} + reptile_steps = 0 + + while (time.perf_counter() - t0) < args.reptile_budget_s: + # Save current params + saved = {name: p.data.clone() for name, p in ttt_params.items()} + + # Inner loop: N SGD steps on a random chunk + model.train() + start = random.randint(0, max(n_tokens - seq_len - 1, 0)) + chunk = val_tokens[start:start + seq_len + 1].to(device=device, dtype=torch.int64) + x = chunk[:-1].unsqueeze(0) + y = chunk[1:].unsqueeze(0) + + for inner_step in range(args.reptile_inner_steps): + model.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + loss.backward() + with torch.no_grad(): + for name, param in ttt_params.items(): + if param.grad is not None: + param.data -= args.reptile_inner_lr * param.grad + + # Outer loop: move base toward adapted params + with torch.no_grad(): + for name, param in ttt_params.items(): + base_state[name] += args.reptile_outer_lr * (param.data - base_state[name]) + param.data.copy_(base_state[name]) + + reptile_steps += 1 + + if rank == 0: + print(f"reptile_ttt:done steps:{reptile_steps} elapsed:{time.perf_counter()-t0:.1f}s", flush=True) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + # Set module-level flag so CausalSelfAttention.forward can use FA3. + global _use_fa3 + _use_fa3 = args.use_fa3 and _HAS_FA3 + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.eval_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_loops=args.num_loops, + lora_rank=args.lora_rank, + mlp_hidden=args.mlp_hidden, + smear_gate=args.smear_gate, + bigram_hash=args.bigram_hash, + bigram_hash_buckets=args.bigram_hash_buckets, + bigram_hash_dim=args.bigram_hash_dim, + ortho_init=args.ortho_init, + xsa_last_n=args.xsa_last_n, + ntk_base_seq_len=args.train_seq_len if args.eval_seq_len > args.train_seq_len else 0, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + unet_skips=args.unet_skips, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + if args._tier2: + log0(f"*** TIER2_MODE: proxy run max={args.max_wallclock_seconds:.0f}s iters={args.iterations} " + f"ema={args.ema_enabled} ttt={args.ttt_enabled} qat={args.qat} " + f"-- compare val_bpb@step2000 against baseline tier2 run ***") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"xsa_last_n:{args.xsa_last_n} active_layers:{xsa_layers}") + ntk_active = args.eval_seq_len > args.train_seq_len + log0(f"ntk_rope:{'enabled' if ntk_active else 'disabled'} train_seq_len:{args.train_seq_len} eval_seq_len:{args.eval_seq_len}") + head_dim = args.model_dim // args.num_heads + _rope_dims_active = args.rope_dims > 0 and args.rope_dims < head_dim + log0(f"partial_rope:{'enabled' if _rope_dims_active else 'disabled'} rope_dims:{args.rope_dims if _rope_dims_active else head_dim}/{head_dim} (ROPE_DIMS={args.rope_dims})") + log0(f"ln_scale:{'enabled' if args.ln_scale else 'disabled'} (scale RMSNorm output by 1/sqrt(layer_idx+1))") + log0(f"unet_skips:{'enabled' if args.unet_skips else 'disabled'} (U-Net skip connections, enc={base_model.num_encoder_layers} dec={base_model.num_decoder_layers})") + if args.ve_enabled: + log0(f"ve:enabled dim={args.ve_dim} layers={base_model.ve_layer_indices}") + else: + log0("ve:disabled") + + log0(f"smear_gate:{args.smear_gate} bigram_hash:{args.bigram_hash} swa:{args.swa} " + f"ortho_init:{args.ortho_init} late_k_fp16:{args.late_k_fp16} " + f"fa3:{_use_fa3}(available={_HAS_FA3}) muon_wd:{args.muon_wd} adam_wd:{args.adam_wd}") + + # FP16 tied embedding export: skip int8 quantization for tok_emb.weight at export time. + # Avoids compounding int8 errors through both input embedding and output projection. + if args.fp16_embed_export and args.tie_embeddings: + _FP16_EXPORT_NAMES.add("tok_emb.weight") + log0(f"fp16_embed_export:enabled (tok_emb.weight kept in fp16, ~{args.vocab_size * args.model_dim * 2 / 1024:.0f}KB)") + + # Late-K: keep K projections of last 2 layers in fp16 (not quantized). + # Saves per-query context accuracy where it matters most — near the output. + if args.late_k_fp16: + effective_depth = args.num_layers * args.num_loops + for layer_idx in range(effective_depth - 2, effective_depth): + block_idx = layer_idx % args.num_layers + key_name = f"blocks.{block_idx}.attn.c_k.weight" + _FP16_EXPORT_NAMES.add(key_name) + log0(f"late_k_fp16:enabled (last 2 effective layers' c_k.weight kept in fp16)") + + for module in base_model.modules(): + if isinstance(module, (CastedLinear, AttentionLoRA)): + module.float() + restore_low_dim_params_to_fp32(base_model) + log0(f"qat:{args.qat} (activates when lr_scale < 0.1; absmax int6 STE)") + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + # bigram_hash.proj is a dense 2D projection — Muon is appropriate + if base_model.bigram_hash is not None: + matrix_params.append(base_model.bigram_hash.proj.weight) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights is not None: + scalar_params.append(base_model.skip_weights) + # smear_gate.gate is now a single nn.Parameter(dim) — AdamW at scalar_lr + if base_model.smear_gate is not None: + scalar_params.append(base_model.smear_gate.gate) + # bigram_hash.scale is a learned scalar — AdamW at scalar_lr + if base_model.bigram_hash is not None: + scalar_params.append(base_model.bigram_hash.scale) + # VE: scales go to scalar, proj to matrix, embed to tok group + if base_model.ve_shared is not None: + scalar_params.append(base_model.ve_shared.scale) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + if base_model.ve_layer_scales is not None: + for vs in base_model.ve_layer_scales: + scalar_params.append(vs) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + # bigram_hash.embed is an embedding table — train with AdamW alongside tok_emb + embed_params = [base_model.tok_emb.weight] + if base_model.bigram_hash is not None: + embed_params.append(base_model.bigram_hash.embed.weight) + if base_model.ve_shared is not None: + embed_params.append(base_model.ve_shared.embed.weight) + optimizer_tok = torch.optim.AdamW( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lora_adapters is not None: + lora_params = list(base_model.lora_adapters.parameters()) + optimizer_lora = torch.optim.Adam( + [{"params": lora_params, "lr": args.lora_lr, "base_lr": args.lora_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.append(optimizer_lora) + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + n_lora = sum(p.numel() for p in base_model.lora_adapters.parameters()) if base_model.lora_adapters is not None else 0 + effective_depth = args.num_layers * args.num_loops + log0(f"model_params:{n_params} (unique_layers:{args.num_layers} loops:{args.num_loops} effective_depth:{effective_depth} lora_rank:{args.lora_rank} lora_params:{n_lora})") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + # Curriculum learning: sort shards by average document length (easy first). + curriculum_files: list[Path] | None = None + if args.curriculum: + train_file_list = [Path(p) for p in sorted(glob.glob(args.train_files))] + curriculum_files = sort_shards_by_doc_length(train_file_list) + log0(f"curriculum:enabled shards_sorted_by_doc_length ({len(curriculum_files)} shards)") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + sorted_files=curriculum_files) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + # NOTE: QAT graph priming removed — it caused torch.compile to use a slower + # compilation path for the non-QAT forward pass (step_avg jumped from 44ms to 58ms). + # The one-time recompile when QAT activates (~30-90s) is cheaper than the cumulative + # overhead of a slower non-QAT path across thousands of steps. + + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device, + sorted_files=curriculum_files) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + qat_active = False + + # EMA: exponential moving average — smoother than SWA, better quantization compression. + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + log0(f"ema:initialized decay={args.ema_decay}") + + # SWA: fallback if EMA disabled + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + + # EMA: update every step during training + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for k, v in base_model.state_dict().items(): + ema_state[k].mul_(d).add_(v.detach().float(), alpha=1.0 - d) + + # LR-scale-based QAT activation: activate when lr_scale < 0.1 (last ~10% of warmdown, + # ~300 steps). Zero overhead for 90%+ of training; absmax scale makes per-step cost minimal. + if args.qat and not qat_active: + if scale < 0.1: + qat_active = True + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module._qat = True + elapsed_s = training_time_ms / 1000.0 + log0(f"qat_activated step:{step}/{args.iterations} lr_scale:{scale:.4f} elapsed:{elapsed_s:.1f}s remaining:{args.max_wallclock_seconds - elapsed_s:.1f}s") + + # SWA: accumulate weight averages during warmdown for smoother quantization. + # Accumulate in float32 to avoid bf16 precision loss over thousands of additions. + # Sample every 200 steps for sufficient checkpoint diversity. + if args.swa and step >= int(args.iterations * args.swa_start_frac) and step % 200 == 0: + if swa_state is None: + swa_state = {k: v.detach().float().clone() for k, v in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa_started step:{step}") + else: + swa_count += 1 + for k, v in base_model.state_dict().items(): + swa_state[k] += v.detach().float() + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + # Apply EMA weights (preferred) or SWA fallback before serialization. + if ema_state is not None: + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + base_model.load_state_dict( + {k: v.to(dtype=current_state[k].dtype) for k, v in ema_state.items()}, strict=True + ) + del ema_state + elif args.swa and swa_state is not None and swa_count > 1: + log0(f"swa_applied count:{swa_count}") + current_state = base_model.state_dict() + avg_state = {} + for k, v in swa_state.items(): + avg = v / swa_count + avg_state[k] = avg.to(dtype=current_state[k].dtype) + base_model.load_state_dict(avg_state, strict=True) + + # Reptile meta-TTT: makes subsequent TTT ~10x more effective by finding weights that adapt fast. + if args.reptile_enabled: + log0(f"reptile_ttt:start budget={args.reptile_budget_s:.0f}s inner_steps={args.reptile_inner_steps} inner_lr={args.reptile_inner_lr} outer_lr={args.reptile_outer_lr}") + reptile_ttt(args, base_model, device, val_tokens, rank=rank) + + # TTT: adapt to val distribution before eval + if args.ttt_enabled: + if args.ttt_two_phase: + log0(f"ttt:start two_phase p1_epochs={args.ttt_p1_epochs} p1_lr={args.ttt_p1_lr} " + f"p2_epochs={args.ttt_p2_epochs} p2_lr={args.ttt_p2_lr} p2_blocks={args.ttt_p2_unfreeze_blocks} " + f"batch_seqs={args.ttt_batch_seqs}") + else: + log0(f"ttt:start lr={args.ttt_lr} epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks} " + f"batch_seqs={args.ttt_batch_seqs}") + ttt_adapt(args, base_model, device, val_tokens, rank=rank, world_size=world_size) + + # Magnitude pruning: zero the smallest weights for better zstd compression. + # Zeroed weights compress to nearly nothing. Applied after TTT, before serialization. + if args.prune_pct > 0.0: + pruned_count = 0 + total_count = 0 + with torch.no_grad(): + for name, p in base_model.named_parameters(): + if p.ndim >= 2 and p.numel() >= 65536: + threshold = torch.quantile(p.abs().float().flatten(), args.prune_pct / 100.0) + mask = p.abs() > threshold + pruned_count += (~mask).sum().item() + total_count += p.numel() + p.mul_(mask) + log0(f"prune:{args.prune_pct:.1f}% zeroed {pruned_count}/{total_count} weights ({100*pruned_count/max(total_count,1):.1f}%)") + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + global _gptq_lite + _gptq_lite = args.gptq_lite + log0(f"quantization: {args.quant_bits}-bit gptq_lite:{_gptq_lite}") + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict(), bits=args.quant_bits) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if args.use_zstd and _HAS_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_method = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_method = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+{compress_method}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+{compress_method}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + # Decompress with the same method used for compression. + if args.use_zstd and _HAS_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_raw_disk = dctx.decompress(quant_blob_disk) + else: + quant_raw_disk = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_raw_disk), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.eval_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs} doc_isolated:{args.doc_isolated_eval}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + stride=args.eval_stride, + batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Sun Mar 22 23:42:21 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.211.01 Driver Version: 570.211.01 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 24C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:2A:00.0 Off | 0 | +| N/A 25C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:3A:00.0 Off | 0 | +| N/A 27C P0 115W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 23C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:84:00.0 Off | 0 | +| N/A 24C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:8B:00.0 Off | 0 | +| N/A 25C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:91:00.0 Off | 0 | +| N/A 23C P0 108W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:E4:00.0 Off | 0 | +| N/A 22C P0 110W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 53848 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 53849 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 53850 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 53851 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 53852 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 53853 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 53854 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 53855 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +xsa_last_n:0 active_layers:[] +ntk_rope:disabled train_seq_len:2048 eval_seq_len:2048 +partial_rope:enabled rope_dims:16/64 (ROPE_DIMS=16) +ln_scale:enabled (scale RMSNorm output by 1/sqrt(layer_idx+1)) +unet_skips:enabled (U-Net skip connections, enc=5 dec=6) +ve:disabled +smear_gate:True bigram_hash:True swa:False ortho_init:True late_k_fp16:False fa3:True(available=True) muon_wd:0.04 adam_wd:0.04 +qat:False (activates when lr_scale < 0.1; absmax int6 STE) +model_params:26829913 (unique_layers:11 loops:1 effective_depth:11 lora_rank:0 lora_params:0) +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:7 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +ema:initialized decay=0.997 +step:0/20000 val_loss:6.9316 val_bpb:4.1053 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9332 train_time:111ms step_avg:111.20ms +step:2/20000 train_loss:8.8137 train_time:206ms step_avg:102.89ms +step:3/20000 train_loss:7.9149 train_time:299ms step_avg:99.69ms +step:4/20000 train_loss:7.2103 train_time:390ms step_avg:97.61ms +step:5/20000 train_loss:6.8448 train_time:466ms step_avg:93.27ms +step:6/20000 train_loss:7.6745 train_time:546ms step_avg:91.04ms +step:7/20000 train_loss:6.6129 train_time:622ms step_avg:88.89ms +step:8/20000 train_loss:6.5272 train_time:704ms step_avg:87.95ms +step:9/20000 train_loss:6.3834 train_time:796ms step_avg:88.46ms +step:10/20000 train_loss:6.1265 train_time:888ms step_avg:88.77ms +step:200/20000 train_loss:2.8059 train_time:16971ms step_avg:84.86ms +step:400/20000 train_loss:2.2829 train_time:34025ms step_avg:85.06ms +step:600/20000 train_loss:2.4805 train_time:51074ms step_avg:85.12ms +step:800/20000 train_loss:2.2281 train_time:68477ms step_avg:85.60ms +step:1000/20000 train_loss:2.3261 train_time:85321ms step_avg:85.32ms +step:1000/20000 val_loss:2.2778 val_bpb:1.3491 train_time:85336ms step_avg:85.34ms +step:1200/20000 train_loss:2.3523 train_time:102800ms step_avg:85.67ms +step:1400/20000 train_loss:2.3749 train_time:119777ms step_avg:85.55ms +step:1600/20000 train_loss:2.0517 train_time:137141ms step_avg:85.71ms +step:1800/20000 train_loss:2.1466 train_time:154555ms step_avg:85.86ms +step:2000/20000 train_loss:2.1912 train_time:171745ms step_avg:85.87ms +step:2000/20000 val_loss:2.1744 val_bpb:1.2878 train_time:171760ms step_avg:85.88ms +step:2200/20000 train_loss:2.0102 train_time:189290ms step_avg:86.04ms +step:2400/20000 train_loss:2.1357 train_time:206378ms step_avg:85.99ms +step:2600/20000 train_loss:2.3751 train_time:224461ms step_avg:86.33ms +step:2800/20000 train_loss:2.1744 train_time:241405ms step_avg:86.22ms +step:3000/20000 train_loss:2.1648 train_time:257975ms step_avg:85.99ms +step:3000/20000 val_loss:2.1303 val_bpb:1.2617 train_time:257991ms step_avg:86.00ms +step:3200/20000 train_loss:2.1305 train_time:274715ms step_avg:85.85ms +step:3400/20000 train_loss:2.1023 train_time:291341ms step_avg:85.69ms +step:3600/20000 train_loss:2.0442 train_time:308244ms step_avg:85.62ms +step:3800/20000 train_loss:2.1444 train_time:325058ms step_avg:85.54ms +step:4000/20000 train_loss:2.1215 train_time:341744ms step_avg:85.44ms +step:4000/20000 val_loss:2.1114 val_bpb:1.2505 train_time:341760ms step_avg:85.44ms +step:4200/20000 train_loss:2.1101 train_time:361932ms step_avg:86.17ms +step:4400/20000 train_loss:2.0416 train_time:379301ms step_avg:86.20ms +step:4600/20000 train_loss:1.8983 train_time:396324ms step_avg:86.16ms +step:4800/20000 train_loss:2.1773 train_time:412710ms step_avg:85.98ms +step:5000/20000 train_loss:1.9271 train_time:430625ms step_avg:86.12ms +step:5000/20000 val_loss:2.0657 val_bpb:1.2234 train_time:430652ms step_avg:86.13ms +step:5200/20000 train_loss:2.0829 train_time:447201ms step_avg:86.00ms +step:5400/20000 train_loss:2.0880 train_time:464862ms step_avg:86.09ms +step:5600/20000 train_loss:2.0704 train_time:481485ms step_avg:85.98ms +step:5800/20000 train_loss:2.0222 train_time:498342ms step_avg:85.92ms +step:6000/20000 train_loss:2.0924 train_time:515560ms step_avg:85.93ms +step:6000/20000 val_loss:2.0154 val_bpb:1.1937 train_time:515579ms step_avg:85.93ms +step:6200/20000 train_loss:1.9535 train_time:533199ms step_avg:86.00ms +step:6400/20000 train_loss:2.0203 train_time:550125ms step_avg:85.96ms +step:6600/20000 train_loss:1.9604 train_time:567066ms step_avg:85.92ms +step:6800/20000 train_loss:2.0152 train_time:584022ms step_avg:85.89ms +step:6987/20000 val_loss:1.9552 val_bpb:1.1580 train_time:599941ms step_avg:85.87ms +stopping_early: wallclock_cap train_time:599941ms step:6987/20000 +peak memory allocated: 13401 MiB reserved: 13552 MiB +ema:applying EMA weights +ttt:start lr=0.0005 epochs=30 freeze_blocks=0 batch_seqs=64 +Serialized model: 105658303 bytes +Code size: 104414 bytes +Total submission size: 105762717 bytes +quantization: 6-bit gptq_lite:True +Serialized model int8+zstd-22: 15730216 bytes (payload:27056482 raw_torch:27113039 payload_ratio:3.90x) +Total submission size int8+zstd-22: 15834630 bytes +final_eval_mode:sliding_window stride:64 batch_seqs:32 doc_isolated:False +final_int8_zlib_roundtrip val_loss:1.8537 val_bpb:1.0979 eval_time:186906ms +final_int8_zlib_roundtrip_exact val_loss:1.85369369 val_bpb:1.09786182 diff --git a/records/track_10min_16mb/FINDINGS.md b/records/track_10min_16mb/FINDINGS.md new file mode 100644 index 000000000..dceaaf31c --- /dev/null +++ b/records/track_10min_16mb/FINDINGS.md @@ -0,0 +1,310 @@ +# Parameter Golf — Experimental Findings +**Updated: 2026-03-21** + +Everything we've tested, what we learned, and what remains untested. + +--- + +## Tested by us — with real data + +### 1. Int8 QAT (Quantization-Aware Training) +- **PR:** #145 (submitted) +- **Hypothesis:** Training through the int8 quantization grid via STE reduces the post-export quality gap. +- **Implementation:** `fake_quantize_int8_per_row` using `torch.quantile(w.abs(), INT8_CLIP_Q, dim=1)` matching export pipeline exactly. Activated at 30% of training. +- **Result:** NEGATIVE. `torch.quantile` adds ~20% per-step overhead (64ms → 77ms), costing ~2,000 training steps. Post-quant val_bpb: 1.2052 vs control's 1.1929. The lost training tokens hurt more than the quant gap recovery. +- **Lesson:** Int8 QAT with exact percentile matching is too expensive under a 10-minute wallclock cap. QAT only pays off with int6 (larger gap) or a faster approximate quantile (amax). +- **Hardware note:** Confirmed on both slow (65ms/step) and fast (44ms/step) RunPod H100 nodes. +- **Update:** See Finding #2b — Late QAT is also dead code under torch.compile, making the entire QAT approach moot regardless of overhead. + +### 2. QAT Graph Priming +- **Hypothesis:** Pre-compiling both QAT and non-QAT torch.compile graphs during warmup avoids a 30-90s mid-training recompile. +- **Result:** NEGATIVE. Graph priming caused torch.compile to use a slower compilation path for the NON-QAT forward pass. step_avg was 65ms from step 1 (vs 44ms baseline), even before QAT activated. +- **Lesson:** Don't pre-prime conditional code paths under `torch.compile(dynamic=False, fullgraph=True)`. Accept the one-time recompile cost instead. + +### 2b. Late QAT is dead code under torch.compile +- **Discovery:** `torch.compile(dynamic=False, fullgraph=True)` constant-folds `CastedLinear._qat` at first trace (when `_qat=False`). The STE branch in `forward()` is dead-code-eliminated. Setting `_qat=True` later triggers a recompile but does not restore the STE path. +- **Effect:** Late QAT activation does nothing except cause a recompile spike (~100 lost steps). No quantization-aware gradients are produced. +- **Evidence:** Explains why (a) QAT activation always caused a step_avg spike, (b) QAT never visibly improved roundtrip scores, (c) PR #332 found "Late QAT is counterproductive at 12L." +- **Credit:** @152334H identified the constant-folding mechanism; documented in PR #315. +- **Fix:** Set `QAT=0` to avoid the recompile cost entirely. A proper fix would require passing the QAT flag as a function argument rather than a class attribute, but at this point the simplest path is to disable it. + +### 3. Sliding Window Evaluation (stride=64) +- **Result:** POSITIVE (reproduced). val_bpb 1.1929 vs baseline 1.2244. Improvement: -0.032 BPB. +- **Note:** This is a reproduction of the SlidingWindowEval entry (#50), not our original work. Confirmed on fast hardware (44ms/step, 13,651 steps). + +### 4. Doc-Isolated Sliding Window Evaluation +- **Hypothesis:** Evaluating per-document (no cross-doc context bleed) improves BPB, based on LoRA TTT ablation showing +0.011 BPB. +- **Result:** INCONCLUSIVE. Tested only on slow hardware (70ms/step) stacked with other changes. val_bpb 1.1952 (10L) and 1.2045 (9L). Cannot isolate doc-isolation's effect from other variables. +- **Observation:** Produced only 37,402 windows vs 121,134 for flat-stream eval. Faster eval time (43s vs 73s). +- **Needs:** Clean A/B test on fast hardware — flat-stream vs doc-isolated, same model. + +### 5. FP16 Tied Embedding Export +- **Result:** POSITIVE (reproduced from WarmdownQuantization entry). Avoids int8 compounding through input+output paths. Costs ~512KB artifact space, offset by MLP_HIDDEN=992. +- **Note:** 10L + FP16 embed exceeds 16MB cap. Only works with 9L + MLP_HIDDEN=992 or with int6 quantization. + +### 6. Aggressive Warmdown (WD=20000) +- **Result:** Used in kitchen sink runs but never isolated. WarmdownQuantization entry reports quant gap drops from 0.014 to 0.005 BPB. +- **Needs:** Clean A/B test — WD=1200 vs WD=20000, same model, same eval. + +### 7. 10 Layers (vs 9) +- **Result:** 10L model trains but artifact exceeds 16MB at int8 (17.6MB). Needs int6 or weight sharing to fit. +- **Observation:** 10L at 76ms/step (10L+seq2048) vs 44ms (9L+seq1024) — the extra layer + longer sequence is significantly slower per step. + +### 8. SWA (Stochastic Weight Averaging) +- **Hypothesis:** Weight averaging during warmdown finds flatter minima that quantize better. +- **Result (v1):** SEVERE BUG. Accumulated in bf16 for 3,596 steps → precision overflow → val_bpb 2.62. +- **Bug fix:** Accumulate in float32, sample every 50 steps, cast back to model dtype. +- **Result (v2, from ablation table):** +0.0004 BPB vs control — effectively no effect at WD=1200. +- **Lesson:** SWA gains reported by other entries likely require very long warmdown (WD=20000+) to accumulate enough diverse snapshots. Superseded by EMA which is smoother and always-on. + +### 9. Seq2048 Training +- **Result:** Used in kitchen sink but never isolated. Slower per step (76ms vs 44ms for seq1024) but each step processes 2x longer sequences. +- **Needs:** Clean A/B test — seq1024 vs seq2048, same total training time. + +### 10. Tuned LRs (MATRIX_LR=0.06, etc.) +- **Result:** Used in kitchen sink but never isolated. From WarmdownQuantization entry, optimized for WD=20000. +- **Note:** Different optimal LRs for different warmdown schedules. 10L_MixedPrecision entry uses 0.02 (lower), WarmdownQuantization uses 0.06 (higher). + +### 11. Curriculum Learning (sort shards by document length) +- **Hypothesis:** Feeding easier data (shorter documents) first accelerates early convergence, producing a better model in the same wallclock budget. +- **Implementation:** Sort 80 training shards by average document length (estimated from first 100K tokens per shard via BOS token counting). Shorter docs first. +- **Result:** MARGINAL NEGATIVE. val_bpb 1.1942 vs control 1.1929 (+0.0013 BPB worse). +- **Observation:** Early training loss WAS lower with curriculum (step 1000: 2.194 vs 2.370), confirming easier data helps early convergence. But final model was marginally worse — suggesting the model overfitted to easy patterns rather than learning generalizable features. +- **Lesson:** Simple shard-order curriculum doesn't help at this scale. More sophisticated curriculum (within-batch difficulty mixing, anti-curriculum) might work but adds complexity. + +### 12. Int6 + 3x MLP (our implementation) +- **Result:** STRONG POSITIVE. val_bpb **1.1708** vs control 1.1929 (-0.0221 BPP). +- **Implementation:** `QUANT_BITS=6` in quantize_float_tensor (max_val=31 instead of 127), zstd-22 compression, MLP_HIDDEN=1536, FP16_EMBED_EXPORT=1. +- **Details:** 21.8M params, 15.2MB artifact (824KB headroom), 48ms/step, 12,507 steps. +- **Lesson:** Int6 + wider MLP is the single biggest lever. The freed artifact budget enables 27% more parameters in 4.5% less space. + +### 13. Depth Recurrence + Huginn Eval-Time Scaling (Depth Sharing) +- **Hypothesis:** 3 shared blocks × 3 loops = 9 effective layers with 1/3 unique params. At eval, increase to 6 loops for free depth. +- **Result v1 (with U-Net skips):** NOT VIABLE. Pre-quant val_bpb 1.2934, post-quant 6-loop eval: val_bpb **4.34** (near-random noise). +- **Result v2 (flat loops, no skips):** WORSE. val_bpb **5.58** — flat loops amplify errors rather than refine. +- **Root cause:** Blocks learn a position-specific function for their depth in the 3-loop stack, not a general iterative refinement operator. Extra loops compound distribution mismatch, not refinement. +- **Scale argument:** Huginn validated at 3.5B params (~100M+ per unique layer). Our unique layers are ~2.5M params — insufficient capacity to simultaneously be a good LM and a general refiner. +- **Artifact:** 5.6MB — proves 10MB+ headroom exists with weight sharing + int6. +- **Lesson:** Do not retry Huginn-style eval scaling at this scale without fundamentally different approach. The technique does not transfer below ~100M params per unique layer. + +### 14. Multi-Token Prediction +- **Hypothesis:** Predicting t+2 alongside t+1 gives richer gradients per step. +- **Implementation:** Auxiliary Linear(dim, dim) head predicts t+2, weighted 0.5× in loss. Excluded from artifact. +- **Result:** MARGINAL NEGATIVE. val_bpb 1.1947 vs control 1.1929 (+0.0018). 3% overhead (45.3ms vs 43.9ms), 410 fewer steps. +- **Lesson:** Auxiliary multi-token prediction doesn't improve the primary task at this scale. The gradient signal from predicting t+2 doesn't transfer to t+1 quality. + +### 15. Int5 Quantization (QUANT_BITS=5, MLP_HIDDEN=1920) +- **Hypothesis:** 5-bit quantization frees even more artifact budget than int6, enabling MLP_HIDDEN=1920. +- **Result:** NOT VIABLE. Pre-quant val_bpb 1.1885 (better than int6!), post-quant val_bpb **1.5458**. Quantization gap: +0.357 BPB — 15× larger than int6's +0.024. +- **Root cause:** Int5 has only 31 representable levels (vs int6's 63). Per-row quantization error at 31 levels is large enough to destroy language modeling entirely. +- **Note:** Pre-quant result was actually better (more params fit in budget), but quantization erases all gains. +- **Lesson:** Int5 is not viable with post-training quantization. Would require int5-aware QAT (simulating 5-bit during training), which is not worth building given int8 QAT was already a negative finding. + +### 16. Optimizer Coverage Bug (SmearGate + BigramHash frozen) +- **Discovery:** SmearGate and BigramHashEmbedding were not in any optimizer parameter group in all runs prior to 2026-03-21. +- **Effect:** Both modules trained frozen from initialization: + - `SmearGate.gate` (initialized to zeros) → sigmoid(0) = 0.5 fixed for all 512 channels + - `BigramHash.proj.weight` (initialized to zeros) → permanent zero projection, hash embeddings contributed nothing +- **Implication:** Every result using `SMEAR_GATE=1 BIGRAM_HASH=1` (the default since their introduction) had these features silently disabled. The ablation result "SmearGate + BigramHash hurt with int6 (+0.003 BPB)" from the Int6_3xMLP README was measured with BOTH FEATURES BROKEN. +- **Fix:** Added bigram_hash.proj.weight to Muon matrix_params, smear_gate.gate to AdamW scalar_params, bigram_hash.embed.weight to tok_emb AdamW group. +- **Same bug independently found by other participants at roughly the same time.** +- **Lesson:** Always verify that every nn.Module submodule appears in at least one optimizer parameter group. A module that initializes to zero or near-zero and is never updated will appear to "work" (no crash) while contributing nothing. + +### 17. 11 Layers — step count trap +- **Hypothesis:** 11 layers (vs 9) gives more model capacity; int6 provides budget headroom. +- **Result:** NEGATIVE. val_bpb 1.1907 vs 9L's 1.1708. Regression of +0.020 BPB. +- **Root cause:** 11 layers runs at ~83ms/step vs 9L's 48ms/step. In 600s: ~7,200 steps vs ~12,500 steps. The ~40% step count loss outweighs the depth gain. +- **Additional factor:** Hyperparameters were suboptimal (muon_momentum=0.95, rope_base=10000, grad_clip disabled) — these were fixed in the XSA+EMA+TTT run. +- **Lesson:** Adding layers only helps if step time increase is less than capacity gain. At this scale, each extra layer costs ~7ms/step. Going from 9→11 layers costs ~35ms/step total, too much for 600s budget. Would need to pair depth increase with NTK-RoPE seq_len=1024 to recover steps. + +### 18. Flash Attention 2 (FA2) +- **Result:** POSITIVE for step time at long sequences. No measurable val_bpb difference vs SDPA. +- **Details:** FA2 saves ~5-8ms/step at seq_len=2048 vs `F.scaled_dot_product_attention`. At seq_len=1024 the benefit is smaller. +- **Install note (RunPod):** Requires `pip install flash-attn --no-cache-dir --no-build-isolation`. FA3 (`flash_attn_interface`) is NOT available — cross-device link error on RunPod filesystem with torch2.9+cu128. +- **Lesson:** Worth including for the step time savings. No quality effect — it's a mathematically equivalent kernel. + +### 19. XSA — Exclusive Self Attention (arXiv:2603.09078) +- **Result:** POSITIVE (part of combined 1.1401 run, not isolated). Zero parameters, minimal compute overhead with GQA-aware implementation. +- **Mechanism:** Subtract self-value projection from attention output: `y = y - (y·v̂)v̂`. Forces each head to draw from other tokens rather than looping back to its own value. +- **Applied to last 4 layers only** — early layers benefit from self-loops for feature extraction. + +### 20. EMA (Exponential Moving Average, decay=0.997) +- **Result:** POSITIVE (part of combined 1.1401 run, not isolated). Replaces SWA for smoother weight averaging. +- **Mechanism:** Shadow copy of weights updated every step: `ema = 0.997*ema + 0.003*weights`. Loaded before quantization, replacing last-step weights. +- **vs SWA:** EMA is always-on, smoother, no snapshot schedule needed. Float32 accumulation avoids the bf16 precision bug that caused SWA v1 to fail. + +### 21. TTT — Test-Time Training +- **Result:** POSITIVE (part of combined 1.1401 run, not isolated). +- **Mechanism:** 3-epoch SGD (lr=0.002, momentum=0.9) on validation data after EMA applied, before final eval. First 2 blocks frozen for stability. +- **Runtime:** ~40–60s. + +### 21b. TTT Optimizer and Schedule Comparison (34 configs, 4 rounds) +- **Finding:** AdamW with cosine lr decay and per-layer lr groups improved TTT effectiveness by 23.5% over flat-lr AdamW in our testing across 34 configurations. +- **Round 1 (8 configs):** AdamW at flat lr=5e-4 confirmed as best uniform optimizer (consistent with PR #442). SGD is 24-41% less effective regardless of lr or momentum. Adam ≈ AdamW (weight decay irrelevant at wd=0; adaptive lr is the mechanism). Freezing early blocks reduces improvement by 3.6%. +- **Round 2 (8 configs):** Lower lr with more epochs (2e-4, 20ep) gives 5.5% more improvement than higher lr fewer epochs (5e-4, 10ep). Cosine decay helps at higher starting lr (+4.1% at 1e-3 start). Per-layer lr shows positive signal (+1.5%): 3× for MLP output projections, 0.5× for input projections, matching their 3.4× quantization damage ratio. +- **Round 3 (10 configs):** Cosine decay at 30 epochs gives +16% over 10-epoch flat. Warmup (10%) + cosine at higher lr gives +15.1% in 2/3 the time. Per-layer + cosine combined: +11.6%. +- **Round 4 (8 configs):** Per-layer + cosine + 30 epochs: +23.5% combined (more than the sum of individual gains). Focal loss (γ=1,2,3) did not improve over cross-entropy — hard tokens appear to be unpredictable rather than undertrained. KL divergence from pre-quant teacher did not improve over cross-entropy — the signal is too weak relative to CE on validation tokens. +- **Validated on 8×H100:** 30 epochs at ~15.5s/epoch with DDP gradient sharding. val_bpb 1.0959 (seed=1337). +- **Mechanism:** Cosine schedule applies full lr early to address large-scale quantization damage, then progressively reduces to refine without overshooting. Per-layer lr allocates more adaptation to parameters with higher quantization damage. + +### 22. NTK-aware RoPE (train_seq_len=1024, eval_seq_len=2048) +- **Result:** INCONCLUSIVE. Used in the 1.1401 run but superseded by training at seq_len=2048 directly (no NTK scaling needed). The NTK mechanism works correctly but training at native seq_len=2048 is simpler and avoids extrapolation risk. +- **Mechanism:** Scale RoPE base at eval: `ntk_base = rope_base * (eval_seq_len / train_seq_len) ** (head_dim / (head_dim - 2))`. +- **Note:** Top leaderboard entries train at seq_len=2048 with rope_base=10000 rather than using NTK scaling. + +### 23. Magnitude Pruning + zstd Interaction +- **Hypothesis:** Zeroing the smallest N% of weights before zstd compression improves compression ratio. +- **Result:** NON-MONOTONIC. 1% pruning: neutral. 3% pruning: artifact 728KB **larger**. 5% pruning: neutral. Tested on real checkpoint with re-export (no retraining). +- **Root cause:** zstd-22 compression interacts non-trivially with the distribution of zero values. At 3%, the pattern of zeroed weights creates byte sequences that zstd handles less efficiently than the original near-zero values. At 5%, enough contiguous zeros appear for run-length encoding to dominate. +- **Lesson:** Always measure full compressed artifact size, not just payload size. Pruning percentage should be validated empirically for each model, not assumed monotonic. + +### 24. Learned Codebook Quantization (K-means, K=256) +- **Hypothesis:** K-means codebook places 256 levels at actual weight cluster centers rather than a uniform grid. Should achieve lower reconstruction error at same storage cost. +- **Result:** MIXED. Reconstruction MSE 87% lower than int6 uniform. But compressed artifact 25% **larger** (585KB vs 466KB per MLP tensor under zstd-22). +- **Root cause:** Codebook indices (uint8, 256 possible values) have higher byte entropy than int6 values (int8 clamped to [-31,31], 63 effective values). zstd-22 compresses the low-entropy int6 stream more efficiently. +- **Lesson:** Reconstruction accuracy and compressed artifact size measure different things. The downstream codec must be considered jointly with the quantization scheme. + +### 25. Embedding Low-Rank Structure (SVD Analysis) +- **Hypothesis:** If the embedding weight matrix is low-rank, a compact generator (codes + projection) could replace full storage. +- **Result:** NOT LOW-RANK. Rank-64 explains 41.9% of variance. Rank-256 explains 83.3%. With 1024 BPE tokens × 512 dimensions, each token embedding is distinct. +- **Lesson:** Low-rank factorization of the tied embedding is not viable at vocab_size=1024. The vocabulary is small enough that each token carries unique information. + +### 26. Symmetry-Transport Compression (Procrustes Alignment) +- **Hypothesis:** Weight matrices across layers may be rotational variants of each other. Store one prototype + per-layer rotation matrix instead of N independent matrices. +- **Analysis:** Procrustes alignment shows 91-93% MSE reduction for MLP output projections across all 11 layer pairs. Cross-seed analysis confirms this is an architecture-level property (90% reduction between same layers of different seeds). +- **Result:** NOT VIABLE. Full rotation matrices (512×1536 fp16 = 1.5MB each, 10 rotations = 15MB) exceed the artifact budget. Low-rank approximation of the rotation delta (rank-128) captures only 16.6% of variance. +- **Lesson:** Reconstruction MSE reduction does not guarantee compressed size reduction. The rotation matrices are dense and full-rank — they do not compress efficiently under zstd-22. + +### 27. Progressive Layer Dropping +- **Hypothesis:** Randomly skipping layers during training (with probability proportional to depth) forces per-layer independence, acting as regularization. +- **Result:** NOT EFFECTIVE in our testing. Combined with head dropout, caused 0.06 BPB regression at step 1000 (val_bpb 1.4067 vs baseline 1.3394). The DDP implementation also introduced higher-order ops incompatible with torch.compile + DDPOptimizer. +- **Note:** Layer dropping and head dropout were tested together; the individual contribution of each was not isolated. + +### 28. Focal Loss for TTT +- **Hypothesis:** Down-weighting easy tokens during TTT (focal loss, γ=1,2,3) focuses gradient on hard tokens where adaptation is most needed. +- **Result:** NOT EFFECTIVE. γ=1 was neutral (-0.4%), γ=2 and γ=3 were worse (-1.5%, -3.1%) compared to standard cross-entropy. Consistent with the observation that hard tokens are unpredictable rather than undertrained. + +### 29. KL Divergence TTT from Pre-Quant Model +- **Hypothesis:** Using KL divergence from the pre-quantization model as the TTT loss preserves the model's learned distribution better than cross-entropy on raw tokens. +- **Result:** NOT EFFECTIVE. Pure KL barely moved the model (delta -0.066 vs -1.717 for CE). Even a 50/50 blend of KL and CE was 28% less effective than CE alone. The pre-quant and post-quant models are similar enough that the KL signal is weak relative to cross-entropy on the validation data. + +--- + +## Tested by others — from leaderboard and PRs + +### Int6 Quantization + zstd (community) +- **Source:** PRs #114, #128, #162, #164, #173, #179, #180 +- **Effect:** Frees ~3MB of artifact budget. Enables 3x MLP expansion (hidden=1536) which alone gives ~0.02 BPB. +- **Status:** The dominant meta. Every top-5 pending PR uses int6. +- **Our status:** ✅ Implemented — PR #212 (val_bpb=1.1708). + +### SmearGate (per-channel) +- **Source:** modded-nanogpt community, PRs #162, #164 +- **Effect:** Learned sigmoid gate blending each token with previous token's embedding. Original: scalar gate. Our version: 512 per-channel gates (more expressive). +- **Our status:** ✅ Implemented and optimizer-fixed in 11L_XSA_EMA_TTT. Prior runs had this frozen (see Finding #16). + +### BigramHash +- **Source:** modded-nanogpt community, PRs #162, #164 +- **Effect:** Hash-based bigram embedding table (2048 buckets → 128d → 512d projection). ~524K params. +- **Our status:** ✅ Implemented and optimizer-fixed in 11L_XSA_EMA_TTT. Prior runs had proj zeroed out (see Finding #16). +- **Note:** Buckets=2048 (our default); some entries use 4096. May be worth trying. + +### NorMuon Optimizer +- **Source:** arxiv 2510.05491, PR #173 +- **Effect:** Muon variant with per-neuron normalization. Drop-in replacement. "Modest but repeatable" gains. +- **Our status:** Not implemented. Low priority given other higher-impact changes in flight. + +### OrthoInit +- **Source:** modded-nanogpt community +- **Effect:** Orthogonal init on all large matrices. Faster early convergence, better conditioning. +- **Our status:** ✅ Implemented (`ORTHO_INIT=1`, default on). + +### Depth Recurrence / Layer Sharing +- **Source:** PR #167, Subformer (arxiv 2101.00234), Huginn (arxiv 2502.05171) +- **Effect:** 3 shared blocks → 9 effective layers = 1/3 params. Frees ~10MB artifact budget. +- **Our status:** ❌ Tested — did not produce usable results (see Finding #13). Do not retry without fundamentally different approach. + +### LoRA Test-Time Training (TTT) +- **Source:** PR #77 (@samacquaviva) +- **Effect:** Per-document LoRA adaptation during eval. +0.003 BPB from TTT itself, +0.011 from doc isolation, +0.034 from striding. +- **Our status:** ✅ Implemented as full-weight SGD TTT (not LoRA). Part of combined 1.1401 run. + +### Paid Prefix +- **Source:** PR #168 +- **Effect:** Store 12.9M val tokens verbatim in artifact. val_bpb 1.0238. Rules exploit. +- **Our status:** Not pursuing. Non-standard — will not be merged. + +### Int5 Quantization (community) +- **Source:** PR #180 +- **Effect:** 5-bit quantization enabling larger models. +- **Our status:** ❌ Tested — not viable (see Finding #15). Post-quant gap 15× worse than int6. Not viable. + +### Hyperparameter tuning (community findings) +- **Source:** Top leaderboard entries +- **Key values:** muon_momentum=0.99, warmup 0.92→0.99 over 1500 steps, rope_base=10000, grad_clip=0.3, tied_embed_lr=0.035 +- **Our status:** ✅ All applied as defaults in 11L_XSA_EMA_TTT. + +### Reptile Meta-Learning for TTT (community finding) +- **Source:** Non-record research submission in the competition queue +- **Finding 1 (POSITIVE):** Reptile meta-learning improves SmearGate-enabled TTT by **0.011 BPB** — 10× better than naive TTT (+0.001). Standard TTT barely adapts because SmearGate already captures local bigram context; Reptile's outer-loop objective forces the model to learn representations that are fast to adapt. +- **Finding 2 (NEGATIVE):** Error-guided TTT — concentrating adaptation steps on the highest-loss tokens — does not improve val_loss. Hard tokens are **genuinely unpredictable**, not undertrained. Per-token loss analysis: hardest 2.7% of tokens account for ~15% of total loss, and those tokens resist TTT regardless of method. +- **Finding 3:** 13 layers outperforms 10 layers on 8×H100 (val_bpb 1.1884 vs 1.2090) despite 23% fewer training steps. Depth gain outweighs step-count loss at 13L — this is the crossover point we missed when going 9→11L. +- **Our status:** Not implementing Reptile. Plain SGD TTT is already implemented; if it underperforms, Reptile is the next option. + +### SWA Checkpoint Count Matters (community finding) +- **Source:** Non-record quantization findings in the competition queue +- **Finding:** With 84 checkpoint average (SWA every step in warmdown, WD=20000), int6+zstd roundtrip BPB is **lower** than pre-quant BPB (1.5164 vs 1.5536, gap = -0.037). SWA with enough checkpoints eliminates quantization-sensitive weight outliers entirely — quantization actually *improves* the score. +- **Why our SWA showed no effect:** We tested SWA at WD=1200 producing only a handful of checkpoints. The smoothing is insufficient to remove outliers. The effect requires ~50+ checkpoint average across a long warmdown. +- **Implication:** SWA is not a weak technique — it was undertested. At WD=20000 with frequent sampling it may outperform EMA. Not worth switching now, but explains the discrepancy between our result (+0.0004) and entries that report SWA gains. +- **Our status:** Superseded by EMA for now. If EMA underperforms, revisit SWA with WD=20000. + +--- + +## Not yet tested by us + +### BitNet b1.58 (Ternary Weights) +- **Source:** arxiv 2402.17764, 2407.09527 +- **Potential:** 1.58 bits/param → 60M+ params in 16MB. 3x more params than int8. +- **Risk:** Needs ~2x params to match FP16 quality. Training stability at 20M scale unproven. +- **Competition status:** PR #126 attempted but didn't converge (1.7510 BPB). PR #139 got 1.2029 with 65M params — works but undertrained. + +### Learned Compression Codebooks +- **Potential:** Train a small codebook that compresses better than int6+zstd. +- **Status:** Nobody has tried this. + +### BigramHash 4096 buckets (vs our 2048) +- **Potential:** More hash buckets = less collision = cleaner bigram signal. ~1MB extra params. +- **Status:** Low-risk tweak, untested. + +### TTT with more epochs (5 vs 3) +- **Potential:** More adaptation to val distribution. Diminishing returns beyond ~5 epochs. +- **Status:** Easy knob to turn if 3-epoch result is competitive. + +### Disable QAT during training +- **Potential:** QAT adds ~20% step overhead (Finding #1). Disabling it frees ~1,400 extra steps in 600s. +- **Risk:** Larger quantization gap at export. May or may not be worth it depending on int6 gap magnitude. +- **Status:** Worth testing as a quick ablation once base result is known. + +--- + +## Key meta-lessons + +1. **Hardware variance matters more than most techniques.** 44ms vs 70ms/step is a 60% difference — that's 13,600 vs 8,500 steps, worth ~0.015 BPB. + +2. **The competition rewards composition, not innovation.** The top entries stack 5-8 known techniques. Clean ablations are rare and valued. + +3. **Int6 + 3x MLP is the dominant meta.** Everything else is marginal by comparison. The ~0.02 BPB from wider MLP is larger than most other individual techniques. + +4. **torch.compile is fragile.** Conditional code paths, graph priming, and mutable module attributes all cause subtle performance regressions. + +5. **bf16 accumulation is dangerous.** Any running sum over thousands of steps must use float32. + +6. **The 16MB artifact cap is the binding constraint.** Every architectural decision is downstream of "how do I fit more effective parameters in 16MB?" + +7. **Always verify optimizer parameter coverage.** Every nn.Module must appear in at least one optimizer group. Modules initializing to zero and never updating produce no error and no training signal. + +8. **Adding layers has a step-count cost.** Each extra layer adds ~7ms/step. Check that capacity gain exceeds the step loss for a 600s budget before going deeper. + +9. **SWA only works with many checkpoints over a long warmdown.** A handful of checkpoints at WD=1200 shows no effect. 84 checkpoints at WD=20000 reverses the quantization gap entirely. The technique was undertested, not ineffective. + +11. **Hard tokens resist TTT regardless of method.** The hardest 2.7% of tokens (by loss) account for ~15% of total loss and are genuinely unpredictable — not a training artifact. Don't design TTT strategies targeting these tokens.