diff --git a/.gitignore b/.gitignore index 3423c416a..eb85368f2 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,6 @@ data/manifest.json data/docs_selected.jsonl .mypy_cache/ .venv -logs/ \ No newline at end of file +logs/autoresearch-results.tsv +verify.sh +logs/ diff --git a/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/README.md b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/README.md new file mode 100644 index 000000000..9180cca44 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/README.md @@ -0,0 +1,96 @@ +This record combines optimizer tuning, training at longer sequence length, and sliding window evaluation to improve on the naive baseline without changing the model architecture. + +## Key Changes from Baseline + +### Training Improvements +- **Sequence length 2048** (baseline: 1024): Longer context during training improves the model's ability to use positional information. Steps are ~18% slower but quality gain is worth it. +- **Warmdown 10000** (baseline: 1200): Much longer learning rate decay schedule. With the wallclock-based warmdown, this means the LR decays throughout most of training, producing a smoother convergence. +- **Muon backend steps 10** (baseline: 5): More Newton-Schulz iterations in the Muon optimizer produce better gradient orthogonalization. +- **Gradient clipping norm=1.0** (baseline: disabled): Stabilizes training, especially important with the longer warmdown. +- **Adam beta2=0.99** (baseline: 0.95): Smoother second moment estimate for embedding and scalar parameters. +- **Scalar LR=0.02** (baseline: 0.04): Lower learning rate for scale/gate parameters (attn_scale, mlp_scale, resid_mix, skip_weights) improves stability. + +### Evaluation Improvement +- **Sliding window eval (stride=64)**: Instead of chopping the validation set into non-overlapping 2048-token chunks (where the first token has zero context), we use overlapping windows advancing by 64 tokens. Only the last 64 tokens of each window are scored, giving every token 1984+ tokens of context. The first window scores all tokens. This is a pure eval improvement — the model weights are identical. + +### What Didn't Work (Tried and Reverted) +- SwiGLU MLP: Better per-param quality but the 3-matrix design uses more params per layer, blowing the 16MB budget at convergence. +- FP16 embedding passthrough: Reduces quantization error from ~0.007 to ~0.0003 BPB, but adds ~500KB to the artifact, pushing over 16MB. +- More layers (10-12): Better BPB but always exceeded the 16MB artifact limit at full convergence. The int8+zlib compression ratio is ~0.93 bytes/param at 8xH100 convergence. +- Higher/lower learning rates for matrix_lr, tied_embed_lr: The defaults (0.04, 0.05) are well-tuned. +- Depth recurrence, lower RoPE base, different KV head counts: All worse. + +## Configuration + +Same architecture as baseline: +- Layout: `VOCAB_SIZE=1024 NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=2` +- Tied output/input embeddings: `TIE_EMBEDDINGS=1` +- ReLU^2 MLP (unchanged) + +Modified hyperparameters: +- `TRAIN_SEQ_LEN=2048` (was 1024) +- `WARMDOWN_ITERS=10000` (was 1200) +- `MUON_BACKEND_STEPS=10` (was 5) +- `GRAD_CLIP_NORM=1.0` (was 0.0) +- `BETA2=0.99` (was 0.95) +- `SCALAR_LR=0.02` (was 0.04) +- `EVAL_STRIDE=64` (sliding window evaluation) + +## Command + +```bash +RUN_ID=submission_seed1337 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +MAX_WALLCLOCK_SECONDS=600 \ +TRAIN_LOG_EVERY=200 \ +VAL_LOSS_EVERY=2000 \ +EVAL_BATCH_SEQS=1024 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Key Metrics (from `train.log`) + +- Timed training stopped at `11520/20000` steps due to the wallclock cap. +- Pre-quant eval at stop: `val_loss:2.0313`, `val_bpb:1.2031` +- Post-quant sliding window eval: `val_loss:2.0032`, `val_bpb:1.1864` +- Exact printed metric: `final_int8_zlib_roundtrip_exact val_bpb:1.18641686` +- Train time: `600019ms` (`step_avg:52.08ms`) +- Peak memory: `10121 MiB allocated`, `10440 MiB reserved` +- Eval time: `132519ms` (sliding window, stride=64, batch_seqs=1024) +- Serialized model int8+zlib: `15808653 bytes` +- Code size: `52684 bytes` +- Total submission size int8+zlib: `15861337 bytes` + +## Training Volume + +- Global batch: `524288` tokens/step +- Total train tokens seen: `6,044,098,560` + +## Reproducibility (3 seeds) + +| Seed | Steps | val_loss | val_bpb | Artifact | +|------|-------|----------|---------|----------| +| 1337 | 11,520 | 2.00321 | 1.18642 | 15,861,337 | +| 1338 | 11,520 | 2.00428 | 1.18705 | 15,859,751 | +| 1339 | 11,523 | 2.00667 | 1.18847 | 15,867,480 | + +- Sample mean val_loss: `2.00472` +- Sample std: `0.00177` +- Current SOTA val_loss: `2.01348` +- Required improvement: `0.005 nats` +- Actual improvement: `0.00876 nats` +- One-sided t-test: `t=8.57`, `df=2`, `p < 0.01` + +## Methodology + +Changes were discovered through 46 iterations of automated experimentation (autoresearch) on a proxy test setup (RTX 3090, 2000 steps), then validated on 4xH100 and finally 8xH100. The proxy correctly identified directional improvements but could not predict exact artifact sizes at full convergence, leading to several over-budget configurations being tested on H100. + +## Included Files + +- `train_gpt.py` (code snapshot used for the run) +- `train.log` (canonical run, SEED=1337) +- `train_seed1338.log` (reproducibility run, SEED=1338) +- `train_seed1339.log` (reproducibility run, SEED=1339) +- `submission.json` (leaderboard metadata) diff --git a/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/submission.json b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/submission.json new file mode 100644 index 000000000..5c49995c7 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/submission.json @@ -0,0 +1,17 @@ +{ + "author": "RAC", + "github_id": "andreanjos", + "name": "Optimizer Tuning + Sliding Window Eval", + "blurb": "Baseline 9x512 SP-1024 architecture with optimizer improvements (warmdown=10000, muon_backend_steps=10, grad_clip=1.0, beta2=0.99, scalar_lr=0.02) and seq2048 training. Sliding window evaluation at stride=64 scores every token with near-maximum context. Post-quant int8+zlib roundtrip under the 16,000,000-byte cap.", + "date": "2026-03-21T06:00:00Z", + "val_loss": 2.00320987, + "val_bpb": 1.18641686, + "pre_quant_val_loss": 2.0313, + "pre_quant_val_bpb": 1.2031, + "step_stop": 11520, + "wallclock_seconds": 600.019, + "eval_time_seconds": 132.519, + "bytes_total": 15861337, + "bytes_model_int8_zlib": 15808653, + "bytes_code": 52684 +} diff --git a/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train.log b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train.log new file mode 100644 index 000000000..f7861c483 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train.log @@ -0,0 +1,114 @@ +logs/8xh100_9layer_nofp16embed.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.02 +train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9370 train_time:105ms step_avg:105.45ms +step:2/20000 train_loss:17.5737 train_time:147ms step_avg:73.41ms +step:3/20000 train_loss:13.1619 train_time:198ms step_avg:66.05ms +step:4/20000 train_loss:8.2470 train_time:250ms step_avg:62.49ms +step:5/20000 train_loss:6.3365 train_time:302ms step_avg:60.33ms +step:6/20000 train_loss:7.2407 train_time:353ms step_avg:58.90ms +step:7/20000 train_loss:6.2363 train_time:405ms step_avg:57.88ms +step:8/20000 train_loss:6.0346 train_time:457ms step_avg:57.11ms +step:9/20000 train_loss:5.8972 train_time:509ms step_avg:56.51ms +step:10/20000 train_loss:5.7512 train_time:560ms step_avg:56.03ms +step:200/20000 train_loss:2.7615 train_time:10441ms step_avg:52.20ms +step:400/20000 train_loss:2.2909 train_time:20830ms step_avg:52.08ms +step:600/20000 train_loss:2.4978 train_time:31224ms step_avg:52.04ms +step:800/20000 train_loss:2.2474 train_time:41627ms step_avg:52.03ms +step:1000/20000 train_loss:2.3363 train_time:52034ms step_avg:52.03ms +step:1200/20000 train_loss:2.3576 train_time:62456ms step_avg:52.05ms +step:1400/20000 train_loss:2.3836 train_time:72861ms step_avg:52.04ms +step:1600/20000 train_loss:2.0454 train_time:83272ms step_avg:52.04ms +step:1800/20000 train_loss:2.1603 train_time:93672ms step_avg:52.04ms +step:2000/20000 train_loss:2.2083 train_time:104074ms step_avg:52.04ms +step:2000/20000 val_loss:2.1909 val_bpb:1.2975 train_time:104086ms step_avg:52.04ms +step:2200/20000 train_loss:2.0275 train_time:114480ms step_avg:52.04ms +step:2400/20000 train_loss:2.1568 train_time:124883ms step_avg:52.03ms +step:2600/20000 train_loss:2.3789 train_time:135288ms step_avg:52.03ms +step:2800/20000 train_loss:2.1904 train_time:145692ms step_avg:52.03ms +step:3000/20000 train_loss:2.1813 train_time:156097ms step_avg:52.03ms +step:3200/20000 train_loss:2.1458 train_time:166500ms step_avg:52.03ms +step:3400/20000 train_loss:2.1104 train_time:176909ms step_avg:52.03ms +step:3600/20000 train_loss:2.0582 train_time:187316ms step_avg:52.03ms +step:3800/20000 train_loss:2.1655 train_time:197726ms step_avg:52.03ms +step:4000/20000 train_loss:2.1274 train_time:208168ms step_avg:52.04ms +step:4000/20000 val_loss:2.1197 val_bpb:1.2554 train_time:208179ms step_avg:52.04ms +step:4200/20000 train_loss:2.1172 train_time:218645ms step_avg:52.06ms +step:4400/20000 train_loss:2.0575 train_time:229055ms step_avg:52.06ms +step:4600/20000 train_loss:1.9276 train_time:239470ms step_avg:52.06ms +step:4800/20000 train_loss:2.2088 train_time:249884ms step_avg:52.06ms +step:5000/20000 train_loss:1.9610 train_time:260389ms step_avg:52.08ms +step:5200/20000 train_loss:2.1223 train_time:270797ms step_avg:52.08ms +step:5400/20000 train_loss:2.1388 train_time:281210ms step_avg:52.08ms +step:5600/20000 train_loss:2.1251 train_time:291619ms step_avg:52.07ms +step:5800/20000 train_loss:2.0806 train_time:302028ms step_avg:52.07ms +step:6000/20000 train_loss:2.1595 train_time:312442ms step_avg:52.07ms +step:6000/20000 val_loss:2.0863 val_bpb:1.2356 train_time:312453ms step_avg:52.08ms +step:6200/20000 train_loss:2.0288 train_time:322855ms step_avg:52.07ms +step:6400/20000 train_loss:2.1062 train_time:333264ms step_avg:52.07ms +step:6600/20000 train_loss:2.0640 train_time:343676ms step_avg:52.07ms +step:6800/20000 train_loss:2.1290 train_time:354088ms step_avg:52.07ms +step:7000/20000 train_loss:2.1750 train_time:364508ms step_avg:52.07ms +step:7200/20000 train_loss:2.1447 train_time:374922ms step_avg:52.07ms +step:7400/20000 train_loss:2.0649 train_time:385336ms step_avg:52.07ms +step:7600/20000 train_loss:1.9417 train_time:395751ms step_avg:52.07ms +step:7800/20000 train_loss:2.0889 train_time:406164ms step_avg:52.07ms +step:8000/20000 train_loss:2.0593 train_time:416580ms step_avg:52.07ms +step:8000/20000 val_loss:2.0610 val_bpb:1.2206 train_time:416591ms step_avg:52.07ms +step:8200/20000 train_loss:2.1323 train_time:426997ms step_avg:52.07ms +step:8400/20000 train_loss:2.0714 train_time:437478ms step_avg:52.08ms +step:8600/20000 train_loss:2.0887 train_time:447892ms step_avg:52.08ms +step:8800/20000 train_loss:2.0444 train_time:458310ms step_avg:52.08ms +step:9000/20000 train_loss:1.9627 train_time:468721ms step_avg:52.08ms +step:9200/20000 train_loss:2.0257 train_time:479142ms step_avg:52.08ms +step:9400/20000 train_loss:2.0612 train_time:489555ms step_avg:52.08ms +step:9600/20000 train_loss:2.0844 train_time:499973ms step_avg:52.08ms +step:9800/20000 train_loss:1.9934 train_time:510389ms step_avg:52.08ms +step:10000/20000 train_loss:2.0501 train_time:520802ms step_avg:52.08ms +step:10000/20000 val_loss:2.0421 val_bpb:1.2094 train_time:520813ms step_avg:52.08ms +step:10200/20000 train_loss:2.0035 train_time:531220ms step_avg:52.08ms +step:10400/20000 train_loss:2.0217 train_time:541641ms step_avg:52.08ms +step:10600/20000 train_loss:1.9142 train_time:552057ms step_avg:52.08ms +step:10800/20000 train_loss:2.1162 train_time:562468ms step_avg:52.08ms +step:11000/20000 train_loss:2.0469 train_time:572886ms step_avg:52.08ms +step:11200/20000 train_loss:2.0079 train_time:583305ms step_avg:52.08ms +step:11400/20000 train_loss:1.9918 train_time:593729ms step_avg:52.08ms +step:11520/20000 val_loss:2.0313 val_bpb:1.2031 train_time:600019ms step_avg:52.08ms +stopping_early: wallclock_cap train_time:600019ms step:11520/20000 +peak memory allocated: 10121 MiB reserved: 10440 MiB +Serialized model: 67224983 bytes +Code size: 52684 bytes +Total submission size: 67277667 bytes +Serialized model int8+zlib: 15808653 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) +Total submission size int8+zlib: 15861337 bytes +final_int8_zlib_roundtrip val_loss:2.0032 val_bpb:1.1864 eval_time:132519ms +final_int8_zlib_roundtrip_exact val_loss:2.00320987 val_bpb:1.18641686 diff --git a/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train_gpt.py b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train_gpt.py new file mode 100644 index 000000000..7e8153a82 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train_gpt.py @@ -0,0 +1,1243 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 10000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 10)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.99)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 1.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + + Windows of train_seq_len advance by eval_stride. Only the last stride tokens + per window contribute to the score (first window scores all). Every token + in the validation set is scored exactly once. + """ + seq_len = args.train_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + # Build window starts; skip any too short to score a full stride + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride] + total_windows = len(window_starts) + + # Distribute across ranks + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Access the base model (unwrap DDP/compile if needed) + base = model + while hasattr(base, "module"): + base = base.module + while hasattr(base, "_orig_mod"): + base = base._orig_mod + + base.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + # First window scores all tokens; others score last stride + s = 0 if ws == 0 else wlen - stride + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = loss_sum / token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + # The tied embedding is kept in fp16 — it serves as both input and output head + # and is extremely sensitive to quantization error. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + q_val_loss, q_val_bpb = eval_val_sliding( + args, model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + else: + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train_seed1338.log b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train_seed1338.log new file mode 100644 index 000000000..75e83a76d --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train_seed1338.log @@ -0,0 +1,114 @@ +logs/seed1338.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.02 +train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1338 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9373 val_bpb:4.1086 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9363 train_time:99ms step_avg:98.94ms +step:2/20000 train_loss:17.5902 train_time:145ms step_avg:72.48ms +step:3/20000 train_loss:13.0747 train_time:197ms step_avg:65.69ms +step:4/20000 train_loss:8.2157 train_time:249ms step_avg:62.24ms +step:5/20000 train_loss:6.3594 train_time:301ms step_avg:60.21ms +step:6/20000 train_loss:7.2604 train_time:353ms step_avg:58.86ms +step:7/20000 train_loss:6.2714 train_time:405ms step_avg:57.88ms +step:8/20000 train_loss:6.0885 train_time:457ms step_avg:57.12ms +step:9/20000 train_loss:5.9214 train_time:509ms step_avg:56.55ms +step:10/20000 train_loss:5.7799 train_time:561ms step_avg:56.09ms +step:200/20000 train_loss:2.7706 train_time:10485ms step_avg:52.43ms +step:400/20000 train_loss:2.3013 train_time:20902ms step_avg:52.25ms +step:600/20000 train_loss:2.4981 train_time:31315ms step_avg:52.19ms +step:800/20000 train_loss:2.2484 train_time:41729ms step_avg:52.16ms +step:1000/20000 train_loss:2.3402 train_time:52143ms step_avg:52.14ms +step:1200/20000 train_loss:2.3547 train_time:62557ms step_avg:52.13ms +step:1400/20000 train_loss:2.3859 train_time:72969ms step_avg:52.12ms +step:1600/20000 train_loss:2.0480 train_time:83383ms step_avg:52.11ms +step:1800/20000 train_loss:2.1621 train_time:93829ms step_avg:52.13ms +step:2000/20000 train_loss:2.2106 train_time:104242ms step_avg:52.12ms +step:2000/20000 val_loss:2.1919 val_bpb:1.2982 train_time:104253ms step_avg:52.13ms +step:2200/20000 train_loss:2.0281 train_time:114650ms step_avg:52.11ms +step:2400/20000 train_loss:2.1593 train_time:125061ms step_avg:52.11ms +step:2600/20000 train_loss:2.3770 train_time:135474ms step_avg:52.11ms +step:2800/20000 train_loss:2.1903 train_time:145888ms step_avg:52.10ms +step:3000/20000 train_loss:2.1827 train_time:156299ms step_avg:52.10ms +step:3200/20000 train_loss:2.1438 train_time:166715ms step_avg:52.10ms +step:3400/20000 train_loss:2.1122 train_time:177129ms step_avg:52.10ms +step:3600/20000 train_loss:2.0579 train_time:187539ms step_avg:52.09ms +step:3800/20000 train_loss:2.1652 train_time:197946ms step_avg:52.09ms +step:4000/20000 train_loss:2.1270 train_time:208356ms step_avg:52.09ms +step:4000/20000 val_loss:2.1201 val_bpb:1.2556 train_time:208368ms step_avg:52.09ms +step:4200/20000 train_loss:2.1172 train_time:218832ms step_avg:52.10ms +step:4400/20000 train_loss:2.0586 train_time:229244ms step_avg:52.10ms +step:4600/20000 train_loss:1.9268 train_time:239660ms step_avg:52.10ms +step:4800/20000 train_loss:2.2117 train_time:250067ms step_avg:52.10ms +step:5000/20000 train_loss:1.9657 train_time:260479ms step_avg:52.10ms +step:5200/20000 train_loss:2.1212 train_time:270894ms step_avg:52.09ms +step:5400/20000 train_loss:2.1398 train_time:281310ms step_avg:52.09ms +step:5600/20000 train_loss:2.1262 train_time:291726ms step_avg:52.09ms +step:5800/20000 train_loss:2.0827 train_time:302140ms step_avg:52.09ms +step:6000/20000 train_loss:2.1639 train_time:312553ms step_avg:52.09ms +step:6000/20000 val_loss:2.0871 val_bpb:1.2361 train_time:312565ms step_avg:52.09ms +step:6200/20000 train_loss:2.0341 train_time:322968ms step_avg:52.09ms +step:6400/20000 train_loss:2.1101 train_time:333383ms step_avg:52.09ms +step:6600/20000 train_loss:2.0661 train_time:343786ms step_avg:52.09ms +step:6800/20000 train_loss:2.1296 train_time:354198ms step_avg:52.09ms +step:7000/20000 train_loss:2.1749 train_time:364612ms step_avg:52.09ms +step:7200/20000 train_loss:2.1403 train_time:375026ms step_avg:52.09ms +step:7400/20000 train_loss:2.0682 train_time:385435ms step_avg:52.09ms +step:7600/20000 train_loss:1.9445 train_time:395845ms step_avg:52.08ms +step:7800/20000 train_loss:2.0919 train_time:406258ms step_avg:52.08ms +step:8000/20000 train_loss:2.0608 train_time:416677ms step_avg:52.08ms +step:8000/20000 val_loss:2.0613 val_bpb:1.2208 train_time:416689ms step_avg:52.09ms +step:8200/20000 train_loss:2.1357 train_time:427089ms step_avg:52.08ms +step:8400/20000 train_loss:2.0750 train_time:437564ms step_avg:52.09ms +step:8600/20000 train_loss:2.0839 train_time:447977ms step_avg:52.09ms +step:8800/20000 train_loss:2.0465 train_time:458393ms step_avg:52.09ms +step:9000/20000 train_loss:1.9620 train_time:468802ms step_avg:52.09ms +step:9200/20000 train_loss:2.0247 train_time:479207ms step_avg:52.09ms +step:9400/20000 train_loss:2.0696 train_time:489620ms step_avg:52.09ms +step:9600/20000 train_loss:2.0885 train_time:500031ms step_avg:52.09ms +step:9800/20000 train_loss:1.9928 train_time:510441ms step_avg:52.09ms +step:10000/20000 train_loss:2.0533 train_time:520851ms step_avg:52.09ms +step:10000/20000 val_loss:2.0425 val_bpb:1.2097 train_time:520862ms step_avg:52.09ms +step:10200/20000 train_loss:2.0099 train_time:531263ms step_avg:52.08ms +step:10400/20000 train_loss:2.0272 train_time:541674ms step_avg:52.08ms +step:10600/20000 train_loss:1.9144 train_time:552092ms step_avg:52.08ms +step:10800/20000 train_loss:2.1234 train_time:562504ms step_avg:52.08ms +step:11000/20000 train_loss:2.0445 train_time:572918ms step_avg:52.08ms +step:11200/20000 train_loss:2.0099 train_time:583333ms step_avg:52.08ms +step:11400/20000 train_loss:1.9932 train_time:593741ms step_avg:52.08ms +step:11520/20000 val_loss:2.0319 val_bpb:1.2034 train_time:600031ms step_avg:52.09ms +stopping_early: wallclock_cap train_time:600031ms step:11520/20000 +peak memory allocated: 10121 MiB reserved: 10296 MiB +Serialized model: 67224983 bytes +Code size: 52684 bytes +Total submission size: 67277667 bytes +Serialized model int8+zlib: 15807067 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) +Total submission size int8+zlib: 15859751 bytes +final_int8_zlib_roundtrip val_loss:2.0043 val_bpb:1.1871 eval_time:133507ms +final_int8_zlib_roundtrip_exact val_loss:2.00428472 val_bpb:1.18705345 diff --git a/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train_seed1339.log b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train_seed1339.log new file mode 100644 index 000000000..b9fe69102 --- /dev/null +++ b/records/track_10min_16mb/2026-03-21_OptimizerTuning_SlideEval64/train_seed1339.log @@ -0,0 +1,114 @@ +logs/seed1339.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17059912 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.02 +train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1339 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9372 val_bpb:4.1086 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9375 train_time:104ms step_avg:104.45ms +step:2/20000 train_loss:17.6419 train_time:148ms step_avg:73.79ms +step:3/20000 train_loss:13.2331 train_time:200ms step_avg:66.51ms +step:4/20000 train_loss:8.2973 train_time:251ms step_avg:62.84ms +step:5/20000 train_loss:6.3609 train_time:303ms step_avg:60.67ms +step:6/20000 train_loss:7.2163 train_time:355ms step_avg:59.22ms +step:7/20000 train_loss:6.2440 train_time:407ms step_avg:58.16ms +step:8/20000 train_loss:6.0475 train_time:459ms step_avg:57.38ms +step:9/20000 train_loss:5.9193 train_time:511ms step_avg:56.78ms +step:10/20000 train_loss:5.7879 train_time:563ms step_avg:56.30ms +step:200/20000 train_loss:2.7731 train_time:10492ms step_avg:52.46ms +step:400/20000 train_loss:2.2850 train_time:20923ms step_avg:52.31ms +step:600/20000 train_loss:2.4986 train_time:31339ms step_avg:52.23ms +step:800/20000 train_loss:2.2434 train_time:41757ms step_avg:52.20ms +step:1000/20000 train_loss:2.3354 train_time:52164ms step_avg:52.16ms +step:1200/20000 train_loss:2.3538 train_time:62572ms step_avg:52.14ms +step:1400/20000 train_loss:2.3855 train_time:73001ms step_avg:52.14ms +step:1600/20000 train_loss:2.0454 train_time:83440ms step_avg:52.15ms +step:1800/20000 train_loss:2.1597 train_time:93866ms step_avg:52.15ms +step:2000/20000 train_loss:2.2104 train_time:104272ms step_avg:52.14ms +step:2000/20000 val_loss:2.1921 val_bpb:1.2983 train_time:104283ms step_avg:52.14ms +step:2200/20000 train_loss:2.0295 train_time:114681ms step_avg:52.13ms +step:2400/20000 train_loss:2.1611 train_time:125089ms step_avg:52.12ms +step:2600/20000 train_loss:2.3811 train_time:135497ms step_avg:52.11ms +step:2800/20000 train_loss:2.1909 train_time:145913ms step_avg:52.11ms +step:3000/20000 train_loss:2.1838 train_time:156320ms step_avg:52.11ms +step:3200/20000 train_loss:2.1451 train_time:166733ms step_avg:52.10ms +step:3400/20000 train_loss:2.1102 train_time:177143ms step_avg:52.10ms +step:3600/20000 train_loss:2.0616 train_time:187557ms step_avg:52.10ms +step:3800/20000 train_loss:2.1677 train_time:197966ms step_avg:52.10ms +step:4000/20000 train_loss:2.1282 train_time:208372ms step_avg:52.09ms +step:4000/20000 val_loss:2.1214 val_bpb:1.2564 train_time:208384ms step_avg:52.10ms +step:4200/20000 train_loss:2.1215 train_time:218847ms step_avg:52.11ms +step:4400/20000 train_loss:2.0589 train_time:229255ms step_avg:52.10ms +step:4600/20000 train_loss:1.9281 train_time:239672ms step_avg:52.10ms +step:4800/20000 train_loss:2.2115 train_time:250082ms step_avg:52.10ms +step:5000/20000 train_loss:1.9644 train_time:260494ms step_avg:52.10ms +step:5200/20000 train_loss:2.1211 train_time:270905ms step_avg:52.10ms +step:5400/20000 train_loss:2.1376 train_time:281312ms step_avg:52.09ms +step:5600/20000 train_loss:2.1270 train_time:291722ms step_avg:52.09ms +step:5800/20000 train_loss:2.0878 train_time:302128ms step_avg:52.09ms +step:6000/20000 train_loss:2.1638 train_time:312530ms step_avg:52.09ms +step:6000/20000 val_loss:2.0888 val_bpb:1.2371 train_time:312542ms step_avg:52.09ms +step:6200/20000 train_loss:2.0331 train_time:322936ms step_avg:52.09ms +step:6400/20000 train_loss:2.1113 train_time:333341ms step_avg:52.08ms +step:6600/20000 train_loss:2.0664 train_time:343743ms step_avg:52.08ms +step:6800/20000 train_loss:2.1335 train_time:354152ms step_avg:52.08ms +step:7000/20000 train_loss:2.1718 train_time:364561ms step_avg:52.08ms +step:7200/20000 train_loss:2.1408 train_time:374963ms step_avg:52.08ms +step:7400/20000 train_loss:2.0687 train_time:385363ms step_avg:52.08ms +step:7600/20000 train_loss:1.9505 train_time:395775ms step_avg:52.08ms +step:7800/20000 train_loss:2.0882 train_time:406179ms step_avg:52.07ms +step:8000/20000 train_loss:2.0613 train_time:416586ms step_avg:52.07ms +step:8000/20000 val_loss:2.0634 val_bpb:1.2221 train_time:416597ms step_avg:52.07ms +step:8200/20000 train_loss:2.1331 train_time:426994ms step_avg:52.07ms +step:8400/20000 train_loss:2.0777 train_time:437466ms step_avg:52.08ms +step:8600/20000 train_loss:2.0911 train_time:447872ms step_avg:52.08ms +step:8800/20000 train_loss:2.0460 train_time:458280ms step_avg:52.08ms +step:9000/20000 train_loss:1.9634 train_time:468692ms step_avg:52.08ms +step:9200/20000 train_loss:2.0240 train_time:479096ms step_avg:52.08ms +step:9400/20000 train_loss:2.0652 train_time:489502ms step_avg:52.07ms +step:9600/20000 train_loss:2.0887 train_time:499905ms step_avg:52.07ms +step:9800/20000 train_loss:1.9940 train_time:510308ms step_avg:52.07ms +step:10000/20000 train_loss:2.0535 train_time:520714ms step_avg:52.07ms +step:10000/20000 val_loss:2.0447 val_bpb:1.2110 train_time:520725ms step_avg:52.07ms +step:10200/20000 train_loss:2.0034 train_time:531119ms step_avg:52.07ms +step:10400/20000 train_loss:2.0233 train_time:541528ms step_avg:52.07ms +step:10600/20000 train_loss:1.9153 train_time:551940ms step_avg:52.07ms +step:10800/20000 train_loss:2.1185 train_time:562349ms step_avg:52.07ms +step:11000/20000 train_loss:2.0443 train_time:572760ms step_avg:52.07ms +step:11200/20000 train_loss:2.0084 train_time:583168ms step_avg:52.07ms +step:11400/20000 train_loss:1.9939 train_time:593580ms step_avg:52.07ms +step:11523/20000 val_loss:2.0339 val_bpb:1.2046 train_time:600017ms step_avg:52.07ms +stopping_early: wallclock_cap train_time:600017ms step:11523/20000 +peak memory allocated: 10122 MiB reserved: 10440 MiB +Serialized model: 67224983 bytes +Code size: 52684 bytes +Total submission size: 67277667 bytes +Serialized model int8+zlib: 15814796 bytes (payload:17178912 raw_torch:17224025 payload_ratio:3.91x) +Total submission size int8+zlib: 15867480 bytes +final_int8_zlib_roundtrip val_loss:2.0067 val_bpb:1.1885 eval_time:133326ms +final_int8_zlib_roundtrip_exact val_loss:2.00666857 val_bpb:1.18846530 diff --git a/train_gpt.py b/train_gpt.py index 0deb0565f..7e8153a82 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -52,10 +52,10 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 10000)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) - train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) @@ -69,6 +69,8 @@ class Hyperparameters: tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) @@ -76,15 +78,15 @@ class Hyperparameters: tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) - muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 10)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) beta1 = float(os.environ.get("BETA1", 0.9)) - beta2 = float(os.environ.get("BETA2", 0.95)) + beta2 = float(os.environ.get("BETA2", 0.99)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 1.0)) # ----------------------------- # MUON OPTIMIZER @@ -277,6 +279,102 @@ def eval_val( model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +def eval_val_sliding( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context. + + Windows of train_seq_len advance by eval_stride. Only the last stride tokens + per window contribute to the score (first window scores all). Every token + in the validation set is scored exactly once. + """ + seq_len = args.train_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + # Build window starts; skip any too short to score a full stride + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride] + total_windows = len(window_starts) + + # Distribute across ranks + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Access the base model (unwrap DDP/compile if needed) + base = model + while hasattr(base, "module"): + base = base.module + while hasattr(base, "_orig_mod"): + base = base._orig_mod + + base.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + # First window scores all tokens; others score last stride + s = 0 if ws == 0 else wlen - stride + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = loss_sum / token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + # ----------------------------- # POST-TRAINING QUANTIZATION # ----------------------------- @@ -370,6 +468,8 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): # Small float tensors are cheap enough to keep directly. We still downcast # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + # The tied embedding is kept in fp16 — it serves as both input and output head + # and is extremely sensitive to quantization error. if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: kept = keep_float_tensor(name, t, passthrough_orig_dtypes) passthrough[name] = kept @@ -723,6 +823,25 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + # ----------------------------- # TRAINING @@ -1099,18 +1218,16 @@ def lr_mul(step: int, elapsed_ms: float) -> float: base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) torch.cuda.synchronize() t_qeval = time.perf_counter() - q_val_loss, q_val_bpb = eval_val( - args, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + q_val_loss, q_val_bpb = eval_val_sliding( + args, model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + else: + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) torch.cuda.synchronize() log0( f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} "