diff --git a/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/README.md b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/README.md new file mode 100644 index 000000000..a42773d0a --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/README.md @@ -0,0 +1,83 @@ +# DominationV2 + BOS-Reset Bigram Cache + TTT + +**val_bpb: 1.1382** (3-seed mean, std 0.0010) | **~15.5 MB** | 8xH100 SXM + +## Results (8xH100 80GB SXM, PyTorch 2.9.1+cu128) + +| Seed | step_avg | steps | val_bpb | Artifact | +|------|----------|-------|---------|----------| +| 1337 | 69.7ms | 8,611 | **1.1371** | 15,504,722 | +| 42 | 69.8ms | 8,605 | **1.1385** | 15,579,418 | +| 2025 | 69.7ms | 8,621 | **1.1389** | 15,505,762 | +| **Mean** | **69.7ms** | **8,612** | **1.1382** | | + +### Timing Budget + +| Phase | Time | +|-------|------| +| Training (8,611 steps @ 69.7ms) | 600s | +| TTT (3 epochs) | ~10s | +| Sliding window + cache eval | ~223s | +| **Total eval** | **~233s** | + +## BOS-Reset Bigram Cache + +An eval-time bigram cache applied during sliding window evaluation, after quantization roundtrip and TTT. + +For each scored token, the cache tracks bigram counts from already-scored tokens within the current document and blends with model probabilities: + +``` +p_final = (1 - alpha_eff) * p_model + alpha_eff * p_cache + +p_cache = count(prev, target) / count(prev) +alpha_eff = 0.20 * count / (count + 8) scales with observed data +alpha_eff *= (entropy / max_entropy) higher when model is uncertain +``` + +Cache resets at every BOS token (document boundary). Updated only after each token is scored (score-first, same ordering as TTT in PR #549). + +## Architecture + +DominationV2 stack: + +| Component | Setting | +|-----------|---------| +| Layers | 11 (512d, 8H, 4KV) | +| MLP | 3x relu² | +| U-Net | 5 encoder + 6 decoder with skip connections | +| XSA | Last 4 layers | +| SmearGate | Per-dimension blend with previous token | +| BigramHash | 2048 buckets, dim=128 | +| OrthoInit | Orthogonal init with depth scaling | +| EMA | Decay=0.997 | +| Quantization | Mixed int6/int8 + zstd-22 | +| TTT | 3 epochs, lr=1e-4 | + +### Cache Settings + +| Parameter | Value | +|-----------|-------| +| CACHE_ALPHA | 0.20 | +| CACHE_TAU | 8.0 | +| CACHE_ENTROPY_POWER | 1.0 | +| Eval stride | 64 | + +## Run Command + +```bash +python3 data/cached_challenge_fineweb.py --variant sp1024 +pip install zstandard + +cd records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT + +DATA_PATH=../../data/datasets/fineweb10B_sp1024 \ +TOKENIZER_PATH=../../data/tokenizers/fineweb_1024_bpe.model \ +SEED=1337 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Credits + +- DominationV2 base: built on upstream PR #64 and PR #198 +- Bigram cache: inspired by classical cache language models (Grave et al., 2016) +- TTT: adapted from PR #461 diff --git a/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/submission.json b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/submission.json new file mode 100644 index 000000000..df22c5d19 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/submission.json @@ -0,0 +1,17 @@ +{ + "author": "Shouryamaan Jain", + "github_id": "shouryamaanjain", + "name": "DominationV2 + BOS-Reset Bigram Cache + TTT", + "blurb": "DominationV2 (11L, 3x relu², XSA-4, EMA, SmearGate, BigramHash, OrthoInit, mixed int6/int8 + zstd-22) with eval-time BOS-reset bigram cache after TTT. Cache builds document-local bigram counts from already-scored tokens, blended with model probabilities gated by entropy. 3-seed mean: 1.1382 (std 0.0010).", + "date": "2026-03-27", + "val_loss": 1.91991417, + "val_bpb": 1.13708132, + "bytes_total": 15504722, + "seeds": { + "1337": {"val_loss": 1.91991417, "val_bpb": 1.13708132, "bytes": 15504722}, + "42": {"val_loss": 1.92231620, "val_bpb": 1.13850394, "bytes": 15579418}, + "2025": {"val_loss": 1.92302940, "val_bpb": 1.13892633, "bytes": 15505762} + }, + "mean_bpb": 1.1382, + "std_bpb": 0.0010 +} diff --git a/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_gpt.py b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_gpt.py new file mode 100644 index 000000000..da71e1a05 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_gpt.py @@ -0,0 +1,985 @@ +""" +Domination V2: Full frontier technique stack for parameter golf. + +Adds XSA (Exclusive Self Attention, arXiv:2603.09078) on last N layers, +EMA weight averaging (replacing SWA), and TTT (test-time training) at eval. +Built on V1: 11L MLP-3x, per-dim SmearGate, BigramHash, OrthoInit, Muon WD. +""" + +from __future__ import annotations + +import bisect +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +try: + import zstandard as zstd + HAS_ZSTD = True +except ImportError: + HAS_ZSTD = False + +BOS_ID = 1 + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_lr = float(os.environ.get("TTT_LR", 1e-4)) + cache_enabled = bool(int(os.environ.get("CACHE_ENABLED", "1"))) + cache_alpha = float(os.environ.get("CACHE_ALPHA", 0.20)) + cache_tau = float(os.environ.get("CACHE_TAU", 8.0)) + cache_entropy_power = float(os.environ.get("CACHE_ENTROPY_POWER", 1.0)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + +# ----------------------------- +# MUON OPTIMIZER WITH WEIGHT DECAY +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay)) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts(sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): base_bytes_np[token_id] = 1; continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): has_leading_space_np[token_id] = True; piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return (torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: raise ValueError(f"Validation split is too short for seq_len={seq_len}") + return tokens[: usable + 1] + +def find_docs(all_tokens: Tensor) -> list[tuple[int, int]]: + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].tolist() + docs = [] + for i, start in enumerate(bos_positions): + end = bos_positions[i + 1] + 1 if i + 1 < len(bos_positions) else int(all_tokens.numel()) + if end - start >= 2: docs.append((int(start), int(end - start))) + return docs + +def shard_docs_by_tokens(docs: list[tuple[int, int]], rank: int, world_size: int) -> list[tuple[int, int]]: + if rank < 0 or rank >= world_size: + raise ValueError(f"rank {rank} out of range for world_size={world_size}") + if not docs: + return [] + prefix = [0] + for _, doc_len in docs: + prefix.append(prefix[-1] + max(doc_len - 1, 0)) + total_pred_tokens = prefix[-1] + bounds = [0] + prev = 0 + for shard_idx in range(1, world_size): + target = (total_pred_tokens * shard_idx) / world_size + lo = bisect.bisect_left(prefix, target, lo=prev, hi=len(prefix)) + candidates = [idx for idx in (lo - 1, lo) if prev <= idx <= len(docs)] + if not candidates: + cut = prev + else: + cut = min(candidates, key=lambda idx: (abs(prefix[idx] - target), -idx)) + bounds.append(cut) + prev = cut + bounds.append(len(docs)) + return docs[bounds[rank]:bounds[rank + 1]] + +def iter_flat_score_windows(total_tokens: int, seq_len: int, stride: int): + if total_tokens <= 0: + return + first_end = min(seq_len, total_tokens) + yield 0, first_end, 0, first_end + scored_end = first_end + while scored_end < total_tokens: + next_end = min(scored_end + stride, total_tokens) + ws = max(0, next_end - seq_len) + yield ws, next_end - ws, scored_end, next_end + scored_end = next_end + +def eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut): + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError(f"VAL_BATCH_SIZE too small for world={world_size} accum={grad_accum_steps} seq={args.train_seq_len}") + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + vls = torch.zeros((), device=device, dtype=torch.float64) + vtc = torch.zeros((), device=device, dtype=torch.float64) + vbc = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for bs in range(seq_start, seq_end, local_batch_seqs): + be = min(bs + local_batch_seqs, seq_end) + rs, re = bs * args.train_seq_len, be * args.train_seq_len + 1 + local = val_tokens[rs:re].to(device=device, dtype=torch.int64, non_blocking=True) + x, y = local[:-1].reshape(-1, args.train_seq_len), local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + bl = model(x, y).detach() + bc = float(y.numel()) + vls += bl.to(torch.float64) * bc; vtc += bc + pi, ti = x.reshape(-1), y.reshape(-1) + tb = base_bytes_lut[ti].to(dtype=torch.int16) + tb += (has_leading_space_lut[ti] & ~is_boundary_token_lut[pi]).to(dtype=torch.int16) + vbc += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(vls, op=dist.ReduceOp.SUM); dist.all_reduce(vtc, op=dist.ReduceOp.SUM); dist.all_reduce(vbc, op=dist.ReduceOp.SUM) + vl = vls / vtc; bpt = vl.item() / math.log(2.0); tpb = vtc.item() / vbc.item() + model.train(); return float(vl.item()), float(bpt * tpb) + +def eval_val_sliding(args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride, batch_seqs=32): + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + windows = list(iter_flat_score_windows(total_tokens, seq_len, stride)) + total_windows = len(windows) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + ls = torch.zeros((), device=device, dtype=torch.float64) + tc = torch.zeros((), device=device, dtype=torch.float64) + bc = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + bw = my_windows[bi:bi + batch_seqs]; bsz = len(bw) + xb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + yb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + score_starts = [] + for i, (ws, wl, score_start, _) in enumerate(bw): + wlens.append(wl); score_starts.append(score_start) + ch = val_tokens[ws:ws + wl + 1].to(dtype=torch.int64, device=device) + xb[i, :wl] = ch[:-1]; yb[i, :wl] = ch[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(xb) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), yb.reshape(-1), reduction="none").reshape(bsz, seq_len) + for i, (ws, _, _, _) in enumerate(bw): + wl = wlens[i]; s = score_starts[i] - ws + sn = nll[i, s:wl].to(torch.float64); ls += sn.sum(); tc += float(wl - s) + tgt, prev = yb[i, s:wl], xb[i, s:wl] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + bc += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)); pct = done / len(my_windows) * 100 + rb = 0.0 + if tc.item() > 0: rl = (ls / tc).item(); rb = rl / math.log(2.0) * (tc.item() / bc.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={rb:.6f}", flush=True) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(ls, op=dist.ReduceOp.SUM); dist.all_reduce(tc, op=dist.ReduceOp.SUM); dist.all_reduce(bc, op=dist.ReduceOp.SUM) + vl = (ls / tc).item(); bpt = vl / math.log(2.0); tpb = tc.item() / bc.item() + base_model.train(); return vl, bpt * tpb + +def eval_val_sliding_cache(args, base_model, rank, world_size, device, val_tokens, docs, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride, batch_seqs=32): + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + local_docs = shard_docs_by_tokens(docs, rank, world_size) + loss_sum = 0.0; token_count = 0.0; byte_count = 0.0 + if not local_docs: + ls = torch.tensor(loss_sum, device=device, dtype=torch.float64) + tc = torch.tensor(token_count, device=device, dtype=torch.float64) + bc = torch.tensor(byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(ls, op=dist.ReduceOp.SUM); dist.all_reduce(tc, op=dist.ReduceOp.SUM); dist.all_reduce(bc, op=dist.ReduceOp.SUM) + vl = (ls / tc).item(); bpt = vl / math.log(2.0); tpb = tc.item() / bc.item() + base_model.train(); return vl, bpt * tpb + + local_start = local_docs[0][0] + local_end = local_docs[-1][0] + local_docs[-1][1] - 1 + local_windows = [] + for ws, wl, score_start, score_end in iter_flat_score_windows(total_tokens, seq_len, stride): + if score_end <= local_start: continue + if score_start >= local_end: break + ss = max(score_start, local_start); se = min(score_end, local_end) + if ss < se: local_windows.append((ws, wl, ss, se)) + + cache_active = args.cache_enabled + base_model.eval() + base_bytes_np = base_bytes_lut.cpu().numpy().astype(np.int64) + has_space_np = has_leading_space_lut.cpu().numpy().astype(np.bool_) + is_boundary_np = is_boundary_token_lut.cpu().numpy().astype(np.bool_) + max_entropy = math.log(base_bytes_np.shape[0]) + max_count = np.iinfo(np.uint16).max + row_totals = np.zeros(base_bytes_np.shape[0], dtype=np.int32) + counts = np.zeros((base_bytes_np.shape[0], base_bytes_np.shape[0]), dtype=np.uint16) + touched_pairs: list[tuple[int, int]] = [] + touched_rows: list[int] = [] + doc_idx = 0 + doc_end = local_docs[doc_idx][0] + local_docs[doc_idx][1] - 1 + + def reset_doc_cache(): + for p, y in touched_pairs: counts[p, y] = 0 + for p in touched_rows: row_totals[p] = 0 + touched_pairs.clear(); touched_rows.clear() + + with torch.inference_mode(): + for bi in range(0, len(local_windows), batch_seqs): + batch = local_windows[bi:bi + batch_seqs]; bsz = len(batch) + xb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + yb = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + for i, (ws, wl, _, _) in enumerate(batch): + ch = val_tokens[ws:ws + wl + 1].to(dtype=torch.int64, device=device) + xb[i, :wl] = ch[:-1]; yb[i, :wl] = ch[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(xb) + + if cache_active: + log_probs = F.log_softmax(logits.float(), dim=-1) + entropy = -(log_probs.exp() * log_probs).sum(dim=-1) + else: + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), yb.reshape(-1), reduction="none").reshape(bsz, seq_len) + + for i, (ws, _, score_start, score_end) in enumerate(batch): + lo = score_start - ws; hi = score_end - ws + if cache_active: + tgt = yb[i, lo:hi] + prev_np = xb[i, lo:hi].detach().cpu().numpy().astype(np.uint16, copy=True) + tgt_np = tgt.detach().cpu().numpy().astype(np.uint16, copy=True) + lp_np = log_probs[i, lo:hi].gather(1, tgt[:, None]).squeeze(1).detach().cpu().numpy().astype(np.float32, copy=True) + ent_np = entropy[i, lo:hi].detach().cpu().numpy().astype(np.float32, copy=True) + base_prob_np = np.exp(lp_np.astype(np.float64, copy=False)) + pos = score_start; off = 0 + while pos < score_end: + seg_end = min(score_end, doc_end) + take = seg_end - pos + ps = prev_np[off:off + take] + ys = tgt_np[off:off + take] + probs = base_prob_np[off:off + take] + ents = ent_np[off:off + take] + for j in range(take): + p = int(ps[j]); y = int(ys[j]) + total = int(row_totals[p]); seen = int(counts[p, y]); cache_p = (seen / total) if total > 0 else 0.0 + eff_alpha = args.cache_alpha * (total / (total + args.cache_tau)) if total > 0 else 0.0 + eff_alpha *= max(float(ents[j]) / max_entropy, 0.0) ** args.cache_entropy_power + mix_p = (1.0 - eff_alpha) * float(probs[j]) + eff_alpha * cache_p + loss_sum += -math.log(max(mix_p, 1e-30)) + if total == 0: touched_rows.append(p) + if seen == 0: touched_pairs.append((p, y)) + counts[p, y] = min(seen + 1, max_count) + row_totals[p] = total + 1 + tb = base_bytes_np[ys] + tb = tb + np.logical_and(has_space_np[ys], np.logical_not(is_boundary_np[ps])).astype(np.int64) + byte_count += float(tb.sum()); token_count += int(take) + pos = seg_end; off += take + if pos >= doc_end: + reset_doc_cache(); doc_idx += 1 + if doc_idx < len(local_docs): + doc_end = local_docs[doc_idx][0] + local_docs[doc_idx][1] - 1 + else: + sn = nll[i, lo:hi]; loss_sum += float(sn.sum().item()); token_count += float(hi - lo) + prev, tgt = xb[i, lo:hi], yb[i, lo:hi] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += float(tb.sum().item()) + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(local_windows)); pct = done / len(local_windows) * 100 + rb = (loss_sum / math.log(2.0)) / byte_count if byte_count > 0 else 0.0 + print(f" cache_eval [{pct:5.1f}%] {done}/{len(local_windows)} windows running_bpb={rb:.6f}", flush=True) + + ls = torch.tensor(loss_sum, device=device, dtype=torch.float64) + tc = torch.tensor(token_count, device=device, dtype=torch.float64) + bc = torch.tensor(byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(ls, op=dist.ReduceOp.SUM); dist.all_reduce(tc, op=dist.ReduceOp.SUM); dist.all_reduce(bc, op=dist.ReduceOp.SUM) + vl = (ls / tc).item(); bpt = vl / math.log(2.0); tpb = tc.item() / bc.item() + base_model.train(); return vl, bpt * tpb + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get("CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale").split(",") if p) + +INT6_QUANT_RANGE = 31 +MIXED_QUANT_INT6_CATS = frozenset( + c.strip() for c in os.environ.get("MIXED_QUANT_INT6_CATS", "mlp,attn").split(",") if c.strip() +) +STE_QAT_ENABLED = bool(int(os.environ.get("STE_QAT_ENABLED", "0"))) +STE_QAT_RANGE = int(os.environ.get("STE_QAT_RANGE", INT6_QUANT_RANGE)) + +def _classify_param(name): + if "tok_emb" in name or "lm_head" in name: return "embed" + if ".mlp." in name: return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): return "attn" + return "other" + +def _get_ste_range_for_param(name): + return STE_QAT_RANGE + +def quantize_int6_per_row(t): + t32 = t.float() + if t32.ndim == 2: + rm = t32.abs().amax(dim=1) + sc = (rm / 31.0).clamp_min(1.0 / 31.0).to(torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()[:, None]), -32, 31).to(torch.int8) + return q, sc + am = t32.abs().max().item() + sc = torch.tensor(am / 31.0 if am > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()), -32, 31).to(torch.int8) + return q, sc + +def quantize_int8_per_row(t): + t32 = t.float() + if t32.ndim == 2: + rm = t32.abs().amax(dim=1) + sc = (rm / 127.0).clamp_min(1e-8).to(torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()[:, None]), -127, 127).to(torch.int8) + return q, sc + am = t32.abs().max().item() + sc = torch.tensor(am / 127.0 if am > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / sc.float()), -127, 127).to(torch.int8) + return q, sc + +def mixed_quantize(state_dict, int6_cats): + result, meta = {}, {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_int8_per_row(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed(result, meta, template_sd): + out = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: continue + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig.dtype in (torch.float32, torch.bfloat16): + t = t.to(orig.dtype) + out[name] = t; continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig.dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig.dtype) + return out + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file): + hb = 256 * np.dtype(" 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._advance_file(); continue + k = min(remaining, avail); chunks.append(self.tokens[self.pos:self.pos + k]); self.pos += k; remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern, rank, world_size, device): + self.rank, self.world_size, self.device, self.stream = rank, world_size, device, TokenStream(pattern) + def next_batch(self, global_tokens, seq_len, grad_accum_steps): + lt = global_tokens // (self.world_size * grad_accum_steps); prs = lt + 1 + chunk = self.stream.take(prs * self.world_size); s = self.rank * prs + local = chunk[s:s + prs].to(dtype=torch.int64) + x, y = local[:-1].reshape(-1, seq_len), local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps=None): super().__init__(); self.eps = eps + def forward(self, x): return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def forward(self, x): + w = self.weight.to(x.dtype) + if self.training and STE_QAT_ENABLED and w.ndim == 2: + with torch.no_grad(): + w32 = w.float(); rm = w32.abs().amax(dim=1).clamp_min(1e-8) + sc = rm / 31.0; wc = torch.clamp(w32, -rm[:, None], rm[:, None]) + wq = (torch.round(wc / sc[:, None]) * sc[:, None]).to(x.dtype) + w = w + (wq - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + +def restore_low_dim_params_to_fp32(module): + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim, base=10000.0): + super().__init__() + self.register_buffer("inv_freq", 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)), persistent=False) + self._seq_len_cached = 0; self._cos_cached = None; self._sin_cached = None + def forward(self, seq_len, device, dtype): + if self._cos_cached is None or self._seq_len_cached != seq_len or self._cos_cached.device != device: + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :]; self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x, cos, sin): + h = x.size(-1) // 2; x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads; self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False); self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False); self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + self.use_xsa = False + + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads)) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + v_bthd = v.transpose(1, 2).contiguous() + y = self._xsa_efficient(y, v_bthd) + return self.proj(y.reshape(bsz, seqlen, dim)) + +class MLP(nn.Module): + def __init__(self, dim, mlp_mult): + super().__init__(); hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False); self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x): return self.proj(torch.relu(self.fc(x)).square()) + +class SmearGate(nn.Module): + """Per-dimension SmearGate (from PR #194): each dim has its own blend ratio.""" + def __init__(self, dim): + super().__init__(); self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x): + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size, bigram_dim, model_dim): + super().__init__(); self.bvs = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim); nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def forward(self, token_ids): + t = token_ids.to(torch.int32); mod = self.bvs - 1 + out = torch.empty_like(t); out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + h = self.embed(out.long()) + if self.proj is not None: h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class Block(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init): + super().__init__() + self.attn_norm, self.mlp_norm = RMSNorm(), RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + def forward(self, x, x0): + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * self.attn(self.attn_norm(x)) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class GPT(nn.Module): + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, mlp_mult, + tie_embeddings, tied_embed_init_std, logit_softcap, rope_base, qk_gain_init, + bigram_vocab_size=0, bigram_dim=128, xsa_last_n=0): + super().__init__() + self.tie_embeddings, self.tied_embed_init_std, self.logit_softcap = tie_embeddings, tied_embed_init_std, logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) for _ in range(num_layers)]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: self.lm_head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self): + if self.tie_embeddings: nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + nl = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): module.weight.mul_(1.0 / math.sqrt(2 * nl)) + + def _run_blocks(self, x, x0): + skips = [] + for i in range(self.num_encoder_layers): x = self.blocks[i](x, x0); skips.append(x) + for i in range(self.num_decoder_layers): + if skips: x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + return x + + def _embed(self, input_ids): + x = self.tok_emb(input_ids) + if self.bigram is not None: x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + return self.smear(x) + + def _logits(self, x): + if self.tie_embeddings: lp = F.linear(x, self.tok_emb.weight) + else: lp = self.lm_head(x) + return self.logit_softcap * torch.tanh(lp / self.logit_softcap) + + def forward(self, input_ids, target_ids): + x0 = self._embed(input_ids) + x = self.final_norm(self._run_blocks(x0, x0)).reshape(-1, x0.size(-1)) + return F.cross_entropy(self._logits(x).float(), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids): + x0 = self._embed(input_ids) + return self._logits(self.final_norm(self._run_blocks(x0, x0))) + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main(): + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8"); args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank, world_size = int(os.environ.get("RANK", "0")), int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + grad_accum_steps = 8 // world_size; grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank); torch.cuda.set_device(device) + if distributed: dist.init_process_group(backend="nccl", device_id=device); dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True; torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False); enable_flash_sdp(True); enable_mem_efficient_sdp(False); enable_math_sdp(False) + logfile = None + if master_process: os.makedirs("logs", exist_ok=True); logfile = f"logs/{args.run_id}.txt"; print(logfile) + def log0(msg, console=True): + if not master_process: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + log0(code, console=False); log0("=" * 100, console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + random.seed(args.seed); np.random.seed(args.seed); torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:tokens:{val_tokens.numel() - 1}") + + base_model = GPT(vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [p for n, p in block_named_params if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = [p for n, p in block_named_params if p.ndim < 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + if base_model.skip_weights.numel() > 0: scalar_params.append(base_model.skip_weights) + for p in base_model.smear.parameters(): scalar_params.append(p) + if base_model.bigram is not None: + for n, p in base_model.bigram.named_parameters(): + if p.ndim == 2 and p.shape[0] >= 64 and p.shape[1] >= 64: matrix_params.append(p) + else: scalar_params.append(p) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.AdamW([{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_wd) + for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW([{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True) + optimizers = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizers.insert(1, torch.optim.AdamW([{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=args.weight_decay, fused=True)) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} swa:{args.swa_enabled} compression:{'zstd-22' if HAS_ZSTD else 'zlib-9'}") + log0(f"bigram_vocab:{args.bigram_vocab_size} bigram_dim:{args.bigram_dim} grad_clip:{args.grad_clip_norm} muon_wd:{args.muon_wd}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} warmdown:{args.warmdown_iters} seed:{args.seed}") + log0(f"ste_qat:{STE_QAT_ENABLED} ste_range:{STE_QAT_RANGE} int6_cats:{MIXED_QUANT_INT6_CATS}") + log0(f"xsa_last_n:{args.xsa_last_n} ema:{args.ema_enabled} ema_decay:{args.ema_decay} ttt:{args.ttt_enabled} ttt_epochs:{args.ttt_epochs}") + log0(f"cache_enabled:{args.cache_enabled} cache_alpha:{args.cache_alpha} cache_tau:{args.cache_tau} cache_entropy_power:{args.cache_entropy_power}") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step, elapsed_ms): + if args.warmdown_iters <= 0: return 1.0 + if max_wallclock_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1); wms = args.warmdown_iters * step_ms + rms = max(max_wallclock_ms - elapsed_ms, 0.0) + return rms / max(wms, 1e-9) if rms <= wms else 1.0 + + if args.warmup_steps > 0: + init_sd = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers]; model.train() + for ws in range(args.warmup_steps): + zero_grad_all() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = ms == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): wl = model(x, y) + (wl * grad_scale).backward() + for o in optimizers: o.step() + zero_grad_all() + if args.warmup_steps <= 20 or (ws + 1) % 10 == 0: log0(f"warmup_step:{ws + 1}/{args.warmup_steps}") + base_model.load_state_dict(init_sd, strict=True) + for o, s in zip(optimizers, init_opts, strict=True): o.load_state_dict(s) + zero_grad_all() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + training_time_ms = 0.0; stop_after_step = None; swa_state = None; swa_count = 0 + ema_state = None + if args.ema_enabled: + ema_state = {n: t.detach().float().clone() for n, t in base_model.state_dict().items()} + torch.cuda.synchronize(); t0 = time.perf_counter(); step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + torch.cuda.synchronize(); training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0); scale = lr_mul(step, elapsed_ms) + if args.swa_enabled and not args.ema_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + current = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} + if swa_state is None: + swa_state = current; swa_count = 1 + else: + inv = 1.0 / (swa_count + 1); keep = 1.0 - inv + for k, t in current.items(): + if torch.is_floating_point(swa_state[k]): swa_state[k].mul_(keep).add_(t, alpha=inv) + else: swa_state[k] = t + swa_count += 1 + zero_grad_all(); train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = ms == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): loss = model(x, y) + train_loss += loss.detach(); (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + for group in optimizer_muon.param_groups: group["momentum"] = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for opt in optimizers: + for group in opt.param_groups: group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: opt.step() + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for n, t in base_model.state_dict().items(): + ema_state[n].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + zero_grad_all(); step += 1 + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_ms:.0f}ms step_avg:{approx_ms / step:.2f}ms") + reached_cap = max_wallclock_ms is not None and approx_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + rct = torch.tensor(int(reached_cap), device=device); dist.all_reduce(rct, op=dist.ReduceOp.MAX); reached_cap = bool(rct.item()) + if stop_after_step is None and reached_cap: stop_after_step = step + + log0(f"peak memory: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + if ema_state is not None: + log0("ema: applying EMA weights") + avg_sd = {n: t.to(dtype=base_model.state_dict()[n].dtype) for n, t in ema_state.items()} + base_model.load_state_dict(avg_sd, strict=True); del ema_state, avg_sd + elif swa_state is not None: + log0(f"swa: averaging {swa_count} checkpoints") + base_model.load_state_dict(swa_state, strict=True); del swa_state + + export_sd = base_model.state_dict() + if master_process: + torch.save(export_sd, "final_model.pt") + log0(f"Serialized model: {os.path.getsize('final_model.pt')} bytes Code: {len(code.encode('utf-8'))} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize(sd_cpu, MIXED_QUANT_INT6_CATS) + qbuf = io.BytesIO(); torch.save({"w": quant_result, "m": quant_meta}, qbuf); qraw = qbuf.getvalue() + if HAS_ZSTD: qblob = zstd.ZstdCompressor(level=22).compress(qraw); cl = "zstd-22" + else: qblob = zlib.compress(qraw, level=9); cl = "zlib-9" + if master_process: + with open("final_model.int8.ptz", "wb") as f: f.write(qblob) + qfb = len(qblob); cb = len(code.encode("utf-8")) + log0(f"final_int8_zlib_roundtrip compressed_model_bytes:{qfb} code_bytes:{cb} total_artifact_bytes:{qfb + cb}") + log0(f"Serialized {cl}: {qfb} bytes Total: {qfb + cb} bytes") + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: qbd = f.read() + rd = zstd.ZstdDecompressor().decompress(qbd) if HAS_ZSTD else zlib.decompress(qbd) + qs = torch.load(io.BytesIO(rd), map_location="cpu") + base_model.load_state_dict(dequantize_mixed(qs["w"], qs["m"], sd_cpu), strict=True) + + if args.ttt_enabled and args.ttt_epochs > 0: + log0(f"ttt: starting {args.ttt_epochs} epochs of test-time training (lr={args.ttt_lr})") + torch.cuda.synchronize(); ttt_start = time.perf_counter() + base_model.train() + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + ttt_opt = torch.optim.SGD(ttt_params, lr=args.ttt_lr) + ttt_seq_len = args.train_seq_len + total_val = val_tokens.numel() - 1 + total_seqs = total_val // ttt_seq_len + my_start = (total_seqs * rank) // world_size + my_end = (total_seqs * (rank + 1)) // world_size + for ttt_ep in range(args.ttt_epochs): + ttt_loss_sum = 0.0; ttt_count = 0 + for si in range(my_start, my_end, 4): + se = min(si + 4, my_end); bsz = se - si + rs = si * ttt_seq_len; re = se * ttt_seq_len + 1 + local = val_tokens[rs:re].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(bsz, ttt_seq_len) + y = local[1:].reshape(bsz, ttt_seq_len) + ttt_opt.zero_grad() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + ttt_opt.step() + ttt_loss_sum += loss.item() * bsz; ttt_count += bsz + if rank == 0: + log0(f" ttt_epoch:{ttt_ep+1}/{args.ttt_epochs} loss:{ttt_loss_sum/max(ttt_count,1):.4f}") + torch.cuda.synchronize() + log0(f"ttt: done in {time.perf_counter() - ttt_start:.1f}s") + + torch.cuda.synchronize(); tqe = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + if args.cache_enabled: + docs = find_docs(val_tokens) + log0(f"final_eval_mode:sliding_window_cache stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs} docs:{len(docs)}") + qvl, qvb = eval_val_sliding_cache(args, base_model, rank, world_size, device, val_tokens, docs, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, stride=args.eval_stride, batch_seqs=args.eval_batch_seqs) + else: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + qvl, qvb = eval_val_sliding(args, base_model, rank, world_size, device, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, stride=args.eval_stride, batch_seqs=args.eval_batch_seqs) + else: + qvl, qvb = eval_val(args, model, rank, world_size, device, grad_accum_steps, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"final_roundtrip val_loss:{qvl:.4f} val_bpb:{qvb:.4f} eval_time:{1000.0 * (time.perf_counter() - tqe):.0f}ms") + log0(f"final_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvb:.8f}") + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_seed1337.log b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_seed1337.log new file mode 100644 index 000000000..0c3da1e9e --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_seed1337.log @@ -0,0 +1,238 @@ +W0327 11:48:04.875000 61577 torch/distributed/run.py:803] +W0327 11:48:04.875000 61577 torch/distributed/run.py:803] ***************************************** +W0327 11:48:04.875000 61577 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0327 11:48:04.875000 61577 torch/distributed/run.py:803] ***************************************** +logs/domv2_cache_seed1337.txt +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:tokens:62021632 +model_params:26829913 swa:False compression:zstd-22 +bigram_vocab:2048 bigram_dim:128 grad_clip:0.3 muon_wd:0.04 +train_batch_tokens:524288 train_seq_len:2048 warmdown:3000 seed:1337 +ste_qat:False ste_range:31 int6_cats:frozenset({'attn', 'mlp'}) +xsa_last_n:4 ema:True ema_decay:0.997 ttt:True ttt_epochs:3 +cache_enabled:True cache_alpha:0.2 cache_tau:8.0 cache_entropy_power:1.0 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9303 val_bpb:4.1045 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9315 train_time:128ms step_avg:128.17ms +step:2/20000 train_loss:8.6437 train_time:195ms step_avg:97.32ms +step:3/20000 train_loss:7.9089 train_time:263ms step_avg:87.77ms +step:4/20000 train_loss:7.3376 train_time:332ms step_avg:83.05ms +step:5/20000 train_loss:6.9265 train_time:401ms step_avg:80.24ms +step:6/20000 train_loss:7.8223 train_time:470ms step_avg:78.36ms +step:7/20000 train_loss:6.7293 train_time:539ms step_avg:77.00ms +step:8/20000 train_loss:6.6041 train_time:608ms step_avg:75.98ms +step:9/20000 train_loss:6.3944 train_time:677ms step_avg:75.18ms +step:10/20000 train_loss:6.1955 train_time:745ms step_avg:74.54ms +step:100/20000 train_loss:3.2807 train_time:6982ms step_avg:69.82ms +step:200/20000 train_loss:2.7487 train_time:13982ms step_avg:69.91ms +step:300/20000 train_loss:2.3597 train_time:20918ms step_avg:69.73ms +step:400/20000 train_loss:2.2403 train_time:27926ms step_avg:69.81ms +step:500/20000 train_loss:2.3874 train_time:34885ms step_avg:69.77ms +step:500/20000 val_loss:2.3863 val_bpb:1.4133 train_time:34890ms step_avg:69.78ms +step:600/20000 train_loss:2.4504 train_time:41919ms step_avg:69.87ms +step:700/20000 train_loss:2.3476 train_time:48892ms step_avg:69.85ms +step:800/20000 train_loss:2.2090 train_time:55932ms step_avg:69.92ms +step:900/20000 train_loss:2.2599 train_time:62920ms step_avg:69.91ms +step:1000/20000 train_loss:2.3160 train_time:69971ms step_avg:69.97ms +step:1000/20000 val_loss:2.2601 val_bpb:1.3385 train_time:69977ms step_avg:69.98ms +step:1100/20000 train_loss:2.1826 train_time:76957ms step_avg:69.96ms +step:1200/20000 train_loss:2.3348 train_time:83989ms step_avg:69.99ms +step:1300/20000 train_loss:2.3141 train_time:90971ms step_avg:69.98ms +step:1400/20000 train_loss:2.3712 train_time:98002ms step_avg:70.00ms +step:1500/20000 train_loss:2.1763 train_time:104969ms step_avg:69.98ms +step:1500/20000 val_loss:2.2199 val_bpb:1.3148 train_time:104974ms step_avg:69.98ms +step:1600/20000 train_loss:2.0403 train_time:111984ms step_avg:69.99ms +step:1700/20000 train_loss:2.1093 train_time:118936ms step_avg:69.96ms +step:1800/20000 train_loss:2.1454 train_time:125943ms step_avg:69.97ms +step:1900/20000 train_loss:2.1336 train_time:132898ms step_avg:69.95ms +step:2000/20000 train_loss:2.1843 train_time:139908ms step_avg:69.95ms +step:2000/20000 val_loss:2.1688 val_bpb:1.2845 train_time:139912ms step_avg:69.96ms +step:2100/20000 train_loss:2.2052 train_time:146915ms step_avg:69.96ms +step:2200/20000 train_loss:2.0032 train_time:153858ms step_avg:69.94ms +step:2300/20000 train_loss:2.3074 train_time:160859ms step_avg:69.94ms +step:2400/20000 train_loss:2.1279 train_time:167804ms step_avg:69.92ms +step:2500/20000 train_loss:2.0646 train_time:174798ms step_avg:69.92ms +step:2500/20000 val_loss:2.1409 val_bpb:1.2680 train_time:174803ms step_avg:69.92ms +step:2600/20000 train_loss:2.3640 train_time:181733ms step_avg:69.90ms +step:2700/20000 train_loss:2.0865 train_time:188726ms step_avg:69.90ms +step:2800/20000 train_loss:2.1716 train_time:195653ms step_avg:69.88ms +step:2900/20000 train_loss:2.1187 train_time:202642ms step_avg:69.88ms +step:3000/20000 train_loss:2.1611 train_time:209576ms step_avg:69.86ms +step:3000/20000 val_loss:2.1268 val_bpb:1.2596 train_time:209582ms step_avg:69.86ms +step:3100/20000 train_loss:2.1346 train_time:216574ms step_avg:69.86ms +step:3200/20000 train_loss:2.1279 train_time:223498ms step_avg:69.84ms +step:3300/20000 train_loss:2.1765 train_time:230486ms step_avg:69.84ms +step:3400/20000 train_loss:2.0995 train_time:237412ms step_avg:69.83ms +step:3500/20000 train_loss:2.1940 train_time:244408ms step_avg:69.83ms +step:3500/20000 val_loss:2.1184 val_bpb:1.2546 train_time:244413ms step_avg:69.83ms +step:3600/20000 train_loss:2.0437 train_time:251331ms step_avg:69.81ms +step:3700/20000 train_loss:2.0750 train_time:258319ms step_avg:69.82ms +step:3800/20000 train_loss:2.1454 train_time:265243ms step_avg:69.80ms +step:3900/20000 train_loss:1.9305 train_time:272239ms step_avg:69.80ms +step:4000/20000 train_loss:2.1165 train_time:279164ms step_avg:69.79ms +step:4000/20000 val_loss:2.1099 val_bpb:1.2496 train_time:279169ms step_avg:69.79ms +step:4100/20000 train_loss:2.1329 train_time:286145ms step_avg:69.79ms +step:4200/20000 train_loss:2.1127 train_time:293126ms step_avg:69.79ms +step:4300/20000 train_loss:1.9561 train_time:300048ms step_avg:69.78ms +step:4400/20000 train_loss:2.0532 train_time:307024ms step_avg:69.78ms +step:4500/20000 train_loss:2.2005 train_time:313945ms step_avg:69.77ms +step:4500/20000 val_loss:2.1042 val_bpb:1.2462 train_time:313949ms step_avg:69.77ms +step:4600/20000 train_loss:1.9146 train_time:320931ms step_avg:69.77ms +step:4700/20000 train_loss:2.2180 train_time:327852ms step_avg:69.76ms +step:4800/20000 train_loss:2.2052 train_time:334837ms step_avg:69.76ms +step:4900/20000 train_loss:2.1131 train_time:341759ms step_avg:69.75ms +step:5000/20000 train_loss:1.9651 train_time:348744ms step_avg:69.75ms +step:5000/20000 val_loss:2.0988 val_bpb:1.2430 train_time:348748ms step_avg:69.75ms +step:5100/20000 train_loss:1.9738 train_time:355661ms step_avg:69.74ms +step:5200/20000 train_loss:2.1269 train_time:362651ms step_avg:69.74ms +step:5300/20000 train_loss:2.1538 train_time:369571ms step_avg:69.73ms +step:5400/20000 train_loss:2.1349 train_time:376563ms step_avg:69.73ms +step:5500/20000 train_loss:2.0968 train_time:383478ms step_avg:69.72ms +step:5500/20000 val_loss:2.0985 val_bpb:1.2428 train_time:383484ms step_avg:69.72ms +step:5600/20000 train_loss:2.1303 train_time:390455ms step_avg:69.72ms +step:5700/20000 train_loss:2.1201 train_time:397383ms step_avg:69.72ms +step:5800/20000 train_loss:2.0813 train_time:404359ms step_avg:69.72ms +step:5900/20000 train_loss:2.0338 train_time:411289ms step_avg:69.71ms +step:6000/20000 train_loss:2.1607 train_time:418272ms step_avg:69.71ms +step:6000/20000 val_loss:2.0832 val_bpb:1.2338 train_time:418276ms step_avg:69.71ms +step:6100/20000 train_loss:2.0559 train_time:425188ms step_avg:69.70ms +step:6200/20000 train_loss:2.0215 train_time:432180ms step_avg:69.71ms +step:6300/20000 train_loss:1.9619 train_time:439164ms step_avg:69.71ms +step:6400/20000 train_loss:2.0982 train_time:446093ms step_avg:69.70ms +step:6500/20000 train_loss:2.0016 train_time:453076ms step_avg:69.70ms +step:6500/20000 val_loss:2.0612 val_bpb:1.2208 train_time:453081ms step_avg:69.70ms +step:6600/20000 train_loss:2.0407 train_time:459993ms step_avg:69.70ms +step:6700/20000 train_loss:2.0758 train_time:466971ms step_avg:69.70ms +step:6800/20000 train_loss:2.0988 train_time:473898ms step_avg:69.69ms +step:6900/20000 train_loss:2.0083 train_time:480886ms step_avg:69.69ms +step:7000/20000 train_loss:2.1383 train_time:487802ms step_avg:69.69ms +step:7000/20000 val_loss:2.0368 val_bpb:1.2063 train_time:487807ms step_avg:69.69ms +step:7100/20000 train_loss:1.9693 train_time:494781ms step_avg:69.69ms +step:7200/20000 train_loss:2.0954 train_time:501701ms step_avg:69.68ms +step:7300/20000 train_loss:1.9896 train_time:508689ms step_avg:69.68ms +step:7400/20000 train_loss:2.0143 train_time:515620ms step_avg:69.68ms +step:7500/20000 train_loss:2.0002 train_time:522598ms step_avg:69.68ms +step:7500/20000 val_loss:2.0099 val_bpb:1.1904 train_time:522603ms step_avg:69.68ms +step:7600/20000 train_loss:1.8822 train_time:529518ms step_avg:69.67ms +step:7700/20000 train_loss:1.9543 train_time:536507ms step_avg:69.68ms +step:7800/20000 train_loss:2.0183 train_time:543439ms step_avg:69.67ms +step:7900/20000 train_loss:1.9950 train_time:550422ms step_avg:69.67ms +step:8000/20000 train_loss:1.9773 train_time:557345ms step_avg:69.67ms +step:8000/20000 val_loss:1.9792 val_bpb:1.1722 train_time:557350ms step_avg:69.67ms +step:8100/20000 train_loss:2.0069 train_time:564433ms step_avg:69.68ms +step:8200/20000 train_loss:2.0409 train_time:571355ms step_avg:69.68ms +step:8300/20000 train_loss:1.9618 train_time:578336ms step_avg:69.68ms +step:8400/20000 train_loss:1.9728 train_time:585328ms step_avg:69.68ms +step:8500/20000 train_loss:1.9490 train_time:592266ms step_avg:69.68ms +step:8500/20000 val_loss:1.9477 val_bpb:1.1535 train_time:592271ms step_avg:69.68ms +step:8600/20000 train_loss:1.9772 train_time:599256ms step_avg:69.68ms +step:8611/20000 val_loss:1.9440 val_bpb:1.1513 train_time:600019ms step_avg:69.68ms +stopping_early: wallclock_cap train_time:600019ms step:8611/20000 +peak memory: 14070 MiB +ema: applying EMA weights +Serialized model: 105789375 bytes Code: 55299 bytes +final_int8_zlib_roundtrip compressed_model_bytes:15449423 code_bytes:55299 total_artifact_bytes:15504722 +Serialized zstd-22: 15449423 bytes Total: 15504722 bytes +ttt: starting 3 epochs of test-time training (lr=0.0001) + ttt_epoch:1/3 loss:1.9384 + ttt_epoch:2/3 loss:1.9383 + ttt_epoch:3/3 loss:1.9383 +ttt: done in 115.0s +final_eval_mode:sliding_window_cache stride:64 batch_seqs:32 docs:50000 + cache_eval [ 0.0%] 32/121106 windows running_bpb=1.200132 + cache_eval [ 1.3%] 1632/121106 windows running_bpb=1.128135 + cache_eval [ 2.7%] 3232/121106 windows running_bpb=1.131753 + cache_eval [ 4.0%] 4832/121106 windows running_bpb=1.125248 + cache_eval [ 5.3%] 6432/121106 windows running_bpb=1.136743 + cache_eval [ 6.6%] 8032/121106 windows running_bpb=1.138170 + cache_eval [ 8.0%] 9632/121106 windows running_bpb=1.139946 + cache_eval [ 9.3%] 11232/121106 windows running_bpb=1.135473 + cache_eval [ 10.6%] 12832/121106 windows running_bpb=1.132988 + cache_eval [ 11.9%] 14432/121106 windows running_bpb=1.134791 + cache_eval [ 13.2%] 16032/121106 windows running_bpb=1.143362 + cache_eval [ 14.6%] 17632/121106 windows running_bpb=1.141847 + cache_eval [ 15.9%] 19232/121106 windows running_bpb=1.143136 + cache_eval [ 17.2%] 20832/121106 windows running_bpb=1.141409 + cache_eval [ 18.5%] 22432/121106 windows running_bpb=1.139894 + cache_eval [ 19.8%] 24032/121106 windows running_bpb=1.140201 + cache_eval [ 21.2%] 25632/121106 windows running_bpb=1.141730 + cache_eval [ 22.5%] 27232/121106 windows running_bpb=1.142189 + cache_eval [ 23.8%] 28832/121106 windows running_bpb=1.148183 + cache_eval [ 25.1%] 30432/121106 windows running_bpb=1.145760 + cache_eval [ 26.4%] 32032/121106 windows running_bpb=1.146732 + cache_eval [ 27.8%] 33632/121106 windows running_bpb=1.145414 + cache_eval [ 29.1%] 35232/121106 windows running_bpb=1.144865 + cache_eval [ 30.4%] 36832/121106 windows running_bpb=1.144488 + cache_eval [ 31.7%] 38432/121106 windows running_bpb=1.145160 + cache_eval [ 33.1%] 40032/121106 windows running_bpb=1.142655 + cache_eval [ 34.4%] 41632/121106 windows running_bpb=1.141567 + cache_eval [ 35.7%] 43232/121106 windows running_bpb=1.141951 + cache_eval [ 37.0%] 44832/121106 windows running_bpb=1.140691 + cache_eval [ 38.3%] 46432/121106 windows running_bpb=1.140579 + cache_eval [ 39.7%] 48032/121106 windows running_bpb=1.139845 + cache_eval [ 41.0%] 49632/121106 windows running_bpb=1.141030 + cache_eval [ 42.3%] 51232/121106 windows running_bpb=1.142129 + cache_eval [ 43.6%] 52832/121106 windows running_bpb=1.142639 + cache_eval [ 44.9%] 54432/121106 windows running_bpb=1.142166 + cache_eval [ 46.3%] 56032/121106 windows running_bpb=1.142539 + cache_eval [ 47.6%] 57632/121106 windows running_bpb=1.141640 + cache_eval [ 48.9%] 59232/121106 windows running_bpb=1.137732 + cache_eval [ 50.2%] 60832/121106 windows running_bpb=1.137876 + cache_eval [ 51.6%] 62432/121106 windows running_bpb=1.138834 + cache_eval [ 52.9%] 64032/121106 windows running_bpb=1.138990 + cache_eval [ 54.2%] 65632/121106 windows running_bpb=1.138892 + cache_eval [ 55.5%] 67232/121106 windows running_bpb=1.137706 + cache_eval [ 56.8%] 68832/121106 windows running_bpb=1.137473 + cache_eval [ 58.2%] 70432/121106 windows running_bpb=1.136778 + cache_eval [ 59.5%] 72032/121106 windows running_bpb=1.136860 + cache_eval [ 60.8%] 73632/121106 windows running_bpb=1.136786 + cache_eval [ 62.1%] 75232/121106 windows running_bpb=1.136981 + cache_eval [ 63.4%] 76832/121106 windows running_bpb=1.136759 + cache_eval [ 64.8%] 78432/121106 windows running_bpb=1.137413 + cache_eval [ 66.1%] 80032/121106 windows running_bpb=1.137712 + cache_eval [ 67.4%] 81632/121106 windows running_bpb=1.137421 + cache_eval [ 68.7%] 83232/121106 windows running_bpb=1.138431 + cache_eval [ 70.0%] 84832/121106 windows running_bpb=1.140362 + cache_eval [ 71.4%] 86432/121106 windows running_bpb=1.139637 + cache_eval [ 72.7%] 88032/121106 windows running_bpb=1.140307 + cache_eval [ 74.0%] 89632/121106 windows running_bpb=1.140658 + cache_eval [ 75.3%] 91232/121106 windows running_bpb=1.140666 + cache_eval [ 76.7%] 92832/121106 windows running_bpb=1.140263 + cache_eval [ 78.0%] 94432/121106 windows running_bpb=1.140510 + cache_eval [ 79.3%] 96032/121106 windows running_bpb=1.139904 + cache_eval [ 80.6%] 97632/121106 windows running_bpb=1.142744 + cache_eval [ 81.9%] 99232/121106 windows running_bpb=1.142768 + cache_eval [ 83.3%] 100832/121106 windows running_bpb=1.142777 + cache_eval [ 84.6%] 102432/121106 windows running_bpb=1.142421 + cache_eval [ 85.9%] 104032/121106 windows running_bpb=1.141930 + cache_eval [ 87.2%] 105632/121106 windows running_bpb=1.141202 + cache_eval [ 88.5%] 107232/121106 windows running_bpb=1.141189 + cache_eval [ 89.9%] 108832/121106 windows running_bpb=1.141821 + cache_eval [ 91.2%] 110432/121106 windows running_bpb=1.141865 + cache_eval [ 92.5%] 112032/121106 windows running_bpb=1.141870 + cache_eval [ 93.8%] 113632/121106 windows running_bpb=1.142325 + cache_eval [ 95.1%] 115232/121106 windows running_bpb=1.142088 + cache_eval [ 96.5%] 116832/121106 windows running_bpb=1.141720 + cache_eval [ 97.8%] 118432/121106 windows running_bpb=1.142036 + cache_eval [ 99.1%] 120032/121106 windows running_bpb=1.142149 +final_roundtrip val_loss:1.9199 val_bpb:1.1371 eval_time:222952ms +final_roundtrip_exact val_loss:1.91991417 val_bpb:1.13708132 diff --git a/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_seed2025.log b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_seed2025.log new file mode 100644 index 000000000..439ebf733 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_seed2025.log @@ -0,0 +1,238 @@ +W0327 12:46:56.558000 100729 torch/distributed/run.py:803] +W0327 12:46:56.558000 100729 torch/distributed/run.py:803] ***************************************** +W0327 12:46:56.558000 100729 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0327 12:46:56.558000 100729 torch/distributed/run.py:803] ***************************************** +logs/domv2_cache_seed2025.txt +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:tokens:62021632 +model_params:26829913 swa:False compression:zstd-22 +bigram_vocab:2048 bigram_dim:128 grad_clip:0.3 muon_wd:0.04 +train_batch_tokens:524288 train_seq_len:2048 warmdown:3000 seed:2025 +ste_qat:False ste_range:31 int6_cats:frozenset({'attn', 'mlp'}) +xsa_last_n:4 ema:True ema_decay:0.997 ttt:True ttt_epochs:3 +cache_enabled:True cache_alpha:0.2 cache_tau:8.0 cache_entropy_power:1.0 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9277 val_bpb:4.1030 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9275 train_time:124ms step_avg:124.07ms +step:2/20000 train_loss:8.5401 train_time:189ms step_avg:94.51ms +step:3/20000 train_loss:7.7941 train_time:258ms step_avg:85.96ms +step:4/20000 train_loss:7.2821 train_time:327ms step_avg:81.83ms +step:5/20000 train_loss:6.8628 train_time:396ms step_avg:79.20ms +step:6/20000 train_loss:7.7249 train_time:465ms step_avg:77.45ms +step:7/20000 train_loss:6.7533 train_time:533ms step_avg:76.20ms +step:8/20000 train_loss:6.6126 train_time:602ms step_avg:75.25ms +step:9/20000 train_loss:6.4335 train_time:671ms step_avg:74.51ms +step:10/20000 train_loss:6.1966 train_time:739ms step_avg:73.93ms +step:100/20000 train_loss:3.2883 train_time:6961ms step_avg:69.61ms +step:200/20000 train_loss:2.7451 train_time:13946ms step_avg:69.73ms +step:300/20000 train_loss:2.3738 train_time:20881ms step_avg:69.60ms +step:400/20000 train_loss:2.2459 train_time:27882ms step_avg:69.70ms +step:500/20000 train_loss:2.3930 train_time:34827ms step_avg:69.65ms +step:500/20000 val_loss:2.3952 val_bpb:1.4186 train_time:34831ms step_avg:69.66ms +step:600/20000 train_loss:2.4555 train_time:41838ms step_avg:69.73ms +step:700/20000 train_loss:2.3517 train_time:48798ms step_avg:69.71ms +step:800/20000 train_loss:2.2098 train_time:55821ms step_avg:69.78ms +step:900/20000 train_loss:2.2570 train_time:62795ms step_avg:69.77ms +step:1000/20000 train_loss:2.3149 train_time:69826ms step_avg:69.83ms +step:1000/20000 val_loss:2.2656 val_bpb:1.3418 train_time:69831ms step_avg:69.83ms +step:1100/20000 train_loss:2.1855 train_time:76798ms step_avg:69.82ms +step:1200/20000 train_loss:2.3350 train_time:83826ms step_avg:69.85ms +step:1300/20000 train_loss:2.3135 train_time:90786ms step_avg:69.84ms +step:1400/20000 train_loss:2.3723 train_time:97803ms step_avg:69.86ms +step:1500/20000 train_loss:2.1758 train_time:104765ms step_avg:69.84ms +step:1500/20000 val_loss:2.2216 val_bpb:1.3157 train_time:104771ms step_avg:69.85ms +step:1600/20000 train_loss:2.0461 train_time:111788ms step_avg:69.87ms +step:1700/20000 train_loss:2.1150 train_time:118741ms step_avg:69.85ms +step:1800/20000 train_loss:2.1446 train_time:125748ms step_avg:69.86ms +step:1900/20000 train_loss:2.1321 train_time:132777ms step_avg:69.88ms +step:2000/20000 train_loss:2.1842 train_time:139783ms step_avg:69.89ms +step:2000/20000 val_loss:2.1692 val_bpb:1.2847 train_time:139788ms step_avg:69.89ms +step:2100/20000 train_loss:2.2051 train_time:146787ms step_avg:69.90ms +step:2200/20000 train_loss:2.0067 train_time:153722ms step_avg:69.87ms +step:2300/20000 train_loss:2.3036 train_time:160719ms step_avg:69.88ms +step:2400/20000 train_loss:2.1309 train_time:167653ms step_avg:69.86ms +step:2500/20000 train_loss:2.0619 train_time:174661ms step_avg:69.86ms +step:2500/20000 val_loss:2.1415 val_bpb:1.2683 train_time:174666ms step_avg:69.87ms +step:2600/20000 train_loss:2.3659 train_time:181586ms step_avg:69.84ms +step:2700/20000 train_loss:2.0898 train_time:188571ms step_avg:69.84ms +step:2800/20000 train_loss:2.1759 train_time:195501ms step_avg:69.82ms +step:2900/20000 train_loss:2.1235 train_time:202486ms step_avg:69.82ms +step:3000/20000 train_loss:2.1651 train_time:209422ms step_avg:69.81ms +step:3000/20000 val_loss:2.1282 val_bpb:1.2604 train_time:209426ms step_avg:69.81ms +step:3100/20000 train_loss:2.1399 train_time:216403ms step_avg:69.81ms +step:3200/20000 train_loss:2.1256 train_time:223321ms step_avg:69.79ms +step:3300/20000 train_loss:2.1788 train_time:230301ms step_avg:69.79ms +step:3400/20000 train_loss:2.1022 train_time:237215ms step_avg:69.77ms +step:3500/20000 train_loss:2.1929 train_time:244197ms step_avg:69.77ms +step:3500/20000 val_loss:2.1201 val_bpb:1.2556 train_time:244202ms step_avg:69.77ms +step:3600/20000 train_loss:2.0407 train_time:251116ms step_avg:69.75ms +step:3700/20000 train_loss:2.0792 train_time:258090ms step_avg:69.75ms +step:3800/20000 train_loss:2.1496 train_time:265003ms step_avg:69.74ms +step:3900/20000 train_loss:1.9308 train_time:271986ms step_avg:69.74ms +step:4000/20000 train_loss:2.1223 train_time:278902ms step_avg:69.73ms +step:4000/20000 val_loss:2.1119 val_bpb:1.2508 train_time:278907ms step_avg:69.73ms +step:4100/20000 train_loss:2.1384 train_time:285877ms step_avg:69.73ms +step:4200/20000 train_loss:2.1164 train_time:292857ms step_avg:69.73ms +step:4300/20000 train_loss:1.9602 train_time:299770ms step_avg:69.71ms +step:4400/20000 train_loss:2.0572 train_time:306746ms step_avg:69.71ms +step:4500/20000 train_loss:2.2054 train_time:313660ms step_avg:69.70ms +step:4500/20000 val_loss:2.1067 val_bpb:1.2477 train_time:313665ms step_avg:69.70ms +step:4600/20000 train_loss:1.9148 train_time:320640ms step_avg:69.70ms +step:4700/20000 train_loss:2.2181 train_time:327560ms step_avg:69.69ms +step:4800/20000 train_loss:2.2066 train_time:334533ms step_avg:69.69ms +step:4900/20000 train_loss:2.1138 train_time:341448ms step_avg:69.68ms +step:5000/20000 train_loss:1.9636 train_time:348423ms step_avg:69.68ms +step:5000/20000 val_loss:2.1013 val_bpb:1.2445 train_time:348427ms step_avg:69.69ms +step:5100/20000 train_loss:1.9763 train_time:355345ms step_avg:69.68ms +step:5200/20000 train_loss:2.1252 train_time:362321ms step_avg:69.68ms +step:5300/20000 train_loss:2.1580 train_time:369233ms step_avg:69.67ms +step:5400/20000 train_loss:2.1434 train_time:376212ms step_avg:69.67ms +step:5500/20000 train_loss:2.0974 train_time:383124ms step_avg:69.66ms +step:5500/20000 val_loss:2.0999 val_bpb:1.2436 train_time:383128ms step_avg:69.66ms +step:5600/20000 train_loss:2.1380 train_time:390098ms step_avg:69.66ms +step:5700/20000 train_loss:2.1239 train_time:397018ms step_avg:69.65ms +step:5800/20000 train_loss:2.0867 train_time:403994ms step_avg:69.65ms +step:5900/20000 train_loss:2.0385 train_time:410909ms step_avg:69.65ms +step:6000/20000 train_loss:2.1580 train_time:417884ms step_avg:69.65ms +step:6000/20000 val_loss:2.0859 val_bpb:1.2354 train_time:417888ms step_avg:69.65ms +step:6100/20000 train_loss:2.0588 train_time:424800ms step_avg:69.64ms +step:6200/20000 train_loss:2.0257 train_time:431782ms step_avg:69.64ms +step:6300/20000 train_loss:1.9636 train_time:438754ms step_avg:69.64ms +step:6400/20000 train_loss:2.0944 train_time:445661ms step_avg:69.63ms +step:6500/20000 train_loss:2.0108 train_time:452650ms step_avg:69.64ms +step:6500/20000 val_loss:2.0640 val_bpb:1.2224 train_time:452655ms step_avg:69.64ms +step:6600/20000 train_loss:2.0481 train_time:459565ms step_avg:69.63ms +step:6700/20000 train_loss:2.0820 train_time:466549ms step_avg:69.63ms +step:6800/20000 train_loss:2.1004 train_time:473464ms step_avg:69.63ms +step:6900/20000 train_loss:2.0111 train_time:480444ms step_avg:69.63ms +step:7000/20000 train_loss:2.1371 train_time:487358ms step_avg:69.62ms +step:7000/20000 val_loss:2.0395 val_bpb:1.2079 train_time:487364ms step_avg:69.62ms +step:7100/20000 train_loss:1.9692 train_time:494334ms step_avg:69.62ms +step:7200/20000 train_loss:2.1018 train_time:501246ms step_avg:69.62ms +step:7300/20000 train_loss:1.9896 train_time:508228ms step_avg:69.62ms +step:7400/20000 train_loss:2.0154 train_time:515144ms step_avg:69.61ms +step:7500/20000 train_loss:2.0061 train_time:522119ms step_avg:69.62ms +step:7500/20000 val_loss:2.0130 val_bpb:1.1922 train_time:522124ms step_avg:69.62ms +step:7600/20000 train_loss:1.8836 train_time:529032ms step_avg:69.61ms +step:7700/20000 train_loss:1.9588 train_time:536017ms step_avg:69.61ms +step:7800/20000 train_loss:2.0204 train_time:542937ms step_avg:69.61ms +step:7900/20000 train_loss:1.9967 train_time:549912ms step_avg:69.61ms +step:8000/20000 train_loss:1.9800 train_time:556833ms step_avg:69.60ms +step:8000/20000 val_loss:1.9824 val_bpb:1.1741 train_time:556837ms step_avg:69.60ms +step:8100/20000 train_loss:2.0078 train_time:563815ms step_avg:69.61ms +step:8200/20000 train_loss:2.0440 train_time:570730ms step_avg:69.60ms +step:8300/20000 train_loss:1.9585 train_time:577713ms step_avg:69.60ms +step:8400/20000 train_loss:1.9727 train_time:584696ms step_avg:69.61ms +step:8500/20000 train_loss:1.9535 train_time:591618ms step_avg:69.60ms +step:8500/20000 val_loss:1.9508 val_bpb:1.1554 train_time:591623ms step_avg:69.60ms +step:8600/20000 train_loss:1.9812 train_time:598599ms step_avg:69.60ms +step:8621/20000 val_loss:1.9467 val_bpb:1.1530 train_time:600055ms step_avg:69.60ms +stopping_early: wallclock_cap train_time:600055ms step:8621/20000 +peak memory: 14069 MiB +ema: applying EMA weights +Serialized model: 105789375 bytes Code: 55299 bytes +final_int8_zlib_roundtrip compressed_model_bytes:15450463 code_bytes:55299 total_artifact_bytes:15505762 +Serialized zstd-22: 15450463 bytes Total: 15505762 bytes +ttt: starting 3 epochs of test-time training (lr=0.0001) + ttt_epoch:1/3 loss:1.9414 + ttt_epoch:2/3 loss:1.9414 + ttt_epoch:3/3 loss:1.9413 +ttt: done in 118.5s +final_eval_mode:sliding_window_cache stride:64 batch_seqs:32 docs:50000 + cache_eval [ 0.0%] 32/121106 windows running_bpb=1.205968 + cache_eval [ 1.3%] 1632/121106 windows running_bpb=1.130748 + cache_eval [ 2.7%] 3232/121106 windows running_bpb=1.133911 + cache_eval [ 4.0%] 4832/121106 windows running_bpb=1.127734 + cache_eval [ 5.3%] 6432/121106 windows running_bpb=1.138936 + cache_eval [ 6.6%] 8032/121106 windows running_bpb=1.140306 + cache_eval [ 8.0%] 9632/121106 windows running_bpb=1.142094 + cache_eval [ 9.3%] 11232/121106 windows running_bpb=1.137667 + cache_eval [ 10.6%] 12832/121106 windows running_bpb=1.135047 + cache_eval [ 11.9%] 14432/121106 windows running_bpb=1.136844 + cache_eval [ 13.2%] 16032/121106 windows running_bpb=1.145309 + cache_eval [ 14.6%] 17632/121106 windows running_bpb=1.143876 + cache_eval [ 15.9%] 19232/121106 windows running_bpb=1.145193 + cache_eval [ 17.2%] 20832/121106 windows running_bpb=1.143445 + cache_eval [ 18.5%] 22432/121106 windows running_bpb=1.142004 + cache_eval [ 19.8%] 24032/121106 windows running_bpb=1.142330 + cache_eval [ 21.2%] 25632/121106 windows running_bpb=1.143715 + cache_eval [ 22.5%] 27232/121106 windows running_bpb=1.144164 + cache_eval [ 23.8%] 28832/121106 windows running_bpb=1.150123 + cache_eval [ 25.1%] 30432/121106 windows running_bpb=1.147637 + cache_eval [ 26.4%] 32032/121106 windows running_bpb=1.148574 + cache_eval [ 27.8%] 33632/121106 windows running_bpb=1.147185 + cache_eval [ 29.1%] 35232/121106 windows running_bpb=1.146599 + cache_eval [ 30.4%] 36832/121106 windows running_bpb=1.146223 + cache_eval [ 31.7%] 38432/121106 windows running_bpb=1.146896 + cache_eval [ 33.1%] 40032/121106 windows running_bpb=1.144406 + cache_eval [ 34.4%] 41632/121106 windows running_bpb=1.143350 + cache_eval [ 35.7%] 43232/121106 windows running_bpb=1.143715 + cache_eval [ 37.0%] 44832/121106 windows running_bpb=1.142444 + cache_eval [ 38.3%] 46432/121106 windows running_bpb=1.142401 + cache_eval [ 39.7%] 48032/121106 windows running_bpb=1.141617 + cache_eval [ 41.0%] 49632/121106 windows running_bpb=1.142826 + cache_eval [ 42.3%] 51232/121106 windows running_bpb=1.143886 + cache_eval [ 43.6%] 52832/121106 windows running_bpb=1.144447 + cache_eval [ 44.9%] 54432/121106 windows running_bpb=1.144011 + cache_eval [ 46.3%] 56032/121106 windows running_bpb=1.144430 + cache_eval [ 47.6%] 57632/121106 windows running_bpb=1.143505 + cache_eval [ 48.9%] 59232/121106 windows running_bpb=1.139606 + cache_eval [ 50.2%] 60832/121106 windows running_bpb=1.139748 + cache_eval [ 51.6%] 62432/121106 windows running_bpb=1.140688 + cache_eval [ 52.9%] 64032/121106 windows running_bpb=1.140847 + cache_eval [ 54.2%] 65632/121106 windows running_bpb=1.140738 + cache_eval [ 55.5%] 67232/121106 windows running_bpb=1.139558 + cache_eval [ 56.8%] 68832/121106 windows running_bpb=1.139337 + cache_eval [ 58.2%] 70432/121106 windows running_bpb=1.138679 + cache_eval [ 59.5%] 72032/121106 windows running_bpb=1.138838 + cache_eval [ 60.8%] 73632/121106 windows running_bpb=1.138780 + cache_eval [ 62.1%] 75232/121106 windows running_bpb=1.138948 + cache_eval [ 63.4%] 76832/121106 windows running_bpb=1.138726 + cache_eval [ 64.8%] 78432/121106 windows running_bpb=1.139326 + cache_eval [ 66.1%] 80032/121106 windows running_bpb=1.139616 + cache_eval [ 67.4%] 81632/121106 windows running_bpb=1.139306 + cache_eval [ 68.7%] 83232/121106 windows running_bpb=1.140344 + cache_eval [ 70.0%] 84832/121106 windows running_bpb=1.142276 + cache_eval [ 71.4%] 86432/121106 windows running_bpb=1.141561 + cache_eval [ 72.7%] 88032/121106 windows running_bpb=1.142234 + cache_eval [ 74.0%] 89632/121106 windows running_bpb=1.142599 + cache_eval [ 75.3%] 91232/121106 windows running_bpb=1.142596 + cache_eval [ 76.7%] 92832/121106 windows running_bpb=1.142200 + cache_eval [ 78.0%] 94432/121106 windows running_bpb=1.142448 + cache_eval [ 79.3%] 96032/121106 windows running_bpb=1.141854 + cache_eval [ 80.6%] 97632/121106 windows running_bpb=1.144712 + cache_eval [ 81.9%] 99232/121106 windows running_bpb=1.144708 + cache_eval [ 83.3%] 100832/121106 windows running_bpb=1.144742 + cache_eval [ 84.6%] 102432/121106 windows running_bpb=1.144390 + cache_eval [ 85.9%] 104032/121106 windows running_bpb=1.143899 + cache_eval [ 87.2%] 105632/121106 windows running_bpb=1.143180 + cache_eval [ 88.5%] 107232/121106 windows running_bpb=1.143176 + cache_eval [ 89.9%] 108832/121106 windows running_bpb=1.143798 + cache_eval [ 91.2%] 110432/121106 windows running_bpb=1.143838 + cache_eval [ 92.5%] 112032/121106 windows running_bpb=1.143843 + cache_eval [ 93.8%] 113632/121106 windows running_bpb=1.144276 + cache_eval [ 95.1%] 115232/121106 windows running_bpb=1.144065 + cache_eval [ 96.5%] 116832/121106 windows running_bpb=1.143722 + cache_eval [ 97.8%] 118432/121106 windows running_bpb=1.144020 + cache_eval [ 99.1%] 120032/121106 windows running_bpb=1.144131 +final_roundtrip val_loss:1.9230 val_bpb:1.1389 eval_time:218156ms +final_roundtrip_exact val_loss:1.92302940 val_bpb:1.13892633 diff --git a/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_seed42.log b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_seed42.log new file mode 100644 index 000000000..060639346 --- /dev/null +++ b/records/track_10min_16mb/2026-03-27_DominationV2_BigramCache_TTT/train_seed42.log @@ -0,0 +1,238 @@ +W0327 12:21:11.273000 99757 torch/distributed/run.py:803] +W0327 12:21:11.273000 99757 torch/distributed/run.py:803] ***************************************** +W0327 12:21:11.273000 99757 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0327 12:21:11.273000 99757 torch/distributed/run.py:803] ***************************************** +logs/domv2_cache_seed42.txt +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:tokens:62021632 +model_params:26829913 swa:False compression:zstd-22 +bigram_vocab:2048 bigram_dim:128 grad_clip:0.3 muon_wd:0.04 +train_batch_tokens:524288 train_seq_len:2048 warmdown:3000 seed:42 +ste_qat:False ste_range:31 int6_cats:frozenset({'attn', 'mlp'}) +xsa_last_n:4 ema:True ema_decay:0.997 ttt:True ttt_epochs:3 +cache_enabled:True cache_alpha:0.2 cache_tau:8.0 cache_entropy_power:1.0 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9307 val_bpb:4.1047 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9307 train_time:126ms step_avg:125.72ms +step:2/20000 train_loss:8.6915 train_time:190ms step_avg:95.17ms +step:3/20000 train_loss:7.8965 train_time:259ms step_avg:86.36ms +step:4/20000 train_loss:7.2439 train_time:328ms step_avg:81.96ms +step:5/20000 train_loss:6.8270 train_time:397ms step_avg:79.31ms +step:6/20000 train_loss:7.7771 train_time:465ms step_avg:77.53ms +step:7/20000 train_loss:6.7029 train_time:534ms step_avg:76.25ms +step:8/20000 train_loss:6.6068 train_time:602ms step_avg:75.30ms +step:9/20000 train_loss:6.4193 train_time:671ms step_avg:74.57ms +step:10/20000 train_loss:6.1913 train_time:740ms step_avg:74.01ms +step:100/20000 train_loss:3.2843 train_time:6975ms step_avg:69.75ms +step:200/20000 train_loss:2.7476 train_time:13973ms step_avg:69.87ms +step:300/20000 train_loss:2.3722 train_time:20909ms step_avg:69.70ms +step:400/20000 train_loss:2.2493 train_time:27911ms step_avg:69.78ms +step:500/20000 train_loss:2.3924 train_time:34858ms step_avg:69.72ms +step:500/20000 val_loss:2.3869 val_bpb:1.4136 train_time:34862ms step_avg:69.72ms +step:600/20000 train_loss:2.4527 train_time:41875ms step_avg:69.79ms +step:700/20000 train_loss:2.3473 train_time:48849ms step_avg:69.78ms +step:800/20000 train_loss:2.2073 train_time:55890ms step_avg:69.86ms +step:900/20000 train_loss:2.2624 train_time:62871ms step_avg:69.86ms +step:1000/20000 train_loss:2.3122 train_time:69919ms step_avg:69.92ms +step:1000/20000 val_loss:2.2634 val_bpb:1.3405 train_time:69923ms step_avg:69.92ms +step:1100/20000 train_loss:2.1868 train_time:76910ms step_avg:69.92ms +step:1200/20000 train_loss:2.3350 train_time:83955ms step_avg:69.96ms +step:1300/20000 train_loss:2.3151 train_time:90954ms step_avg:69.96ms +step:1400/20000 train_loss:2.3734 train_time:97987ms step_avg:69.99ms +step:1500/20000 train_loss:2.1814 train_time:104960ms step_avg:69.97ms +step:1500/20000 val_loss:2.2217 val_bpb:1.3158 train_time:104965ms step_avg:69.98ms +step:1600/20000 train_loss:2.0455 train_time:111994ms step_avg:70.00ms +step:1700/20000 train_loss:2.1141 train_time:118977ms step_avg:69.99ms +step:1800/20000 train_loss:2.1447 train_time:125996ms step_avg:70.00ms +step:1900/20000 train_loss:2.1353 train_time:132969ms step_avg:69.98ms +step:2000/20000 train_loss:2.1861 train_time:139991ms step_avg:70.00ms +step:2000/20000 val_loss:2.1697 val_bpb:1.2850 train_time:139996ms step_avg:70.00ms +step:2100/20000 train_loss:2.2077 train_time:147013ms step_avg:70.01ms +step:2200/20000 train_loss:2.0064 train_time:153970ms step_avg:69.99ms +step:2300/20000 train_loss:2.3035 train_time:160984ms step_avg:69.99ms +step:2400/20000 train_loss:2.1315 train_time:167943ms step_avg:69.98ms +step:2500/20000 train_loss:2.0674 train_time:174957ms step_avg:69.98ms +step:2500/20000 val_loss:2.1422 val_bpb:1.2687 train_time:174961ms step_avg:69.98ms +step:2600/20000 train_loss:2.3621 train_time:181924ms step_avg:69.97ms +step:2700/20000 train_loss:2.0891 train_time:188934ms step_avg:69.98ms +step:2800/20000 train_loss:2.1726 train_time:195881ms step_avg:69.96ms +step:2900/20000 train_loss:2.1192 train_time:202881ms step_avg:69.96ms +step:3000/20000 train_loss:2.1615 train_time:209824ms step_avg:69.94ms +step:3000/20000 val_loss:2.1282 val_bpb:1.2604 train_time:209829ms step_avg:69.94ms +step:3100/20000 train_loss:2.1407 train_time:216829ms step_avg:69.94ms +step:3200/20000 train_loss:2.1315 train_time:223771ms step_avg:69.93ms +step:3300/20000 train_loss:2.1792 train_time:230765ms step_avg:69.93ms +step:3400/20000 train_loss:2.1065 train_time:237707ms step_avg:69.91ms +step:3500/20000 train_loss:2.1938 train_time:244699ms step_avg:69.91ms +step:3500/20000 val_loss:2.1210 val_bpb:1.2562 train_time:244702ms step_avg:69.91ms +step:3600/20000 train_loss:2.0435 train_time:251633ms step_avg:69.90ms +step:3700/20000 train_loss:2.0760 train_time:258626ms step_avg:69.90ms +step:3800/20000 train_loss:2.1514 train_time:265559ms step_avg:69.88ms +step:3900/20000 train_loss:1.9336 train_time:272547ms step_avg:69.88ms +step:4000/20000 train_loss:2.1178 train_time:279476ms step_avg:69.87ms +step:4000/20000 val_loss:2.1106 val_bpb:1.2500 train_time:279481ms step_avg:69.87ms +step:4100/20000 train_loss:2.1353 train_time:286468ms step_avg:69.87ms +step:4200/20000 train_loss:2.1135 train_time:293459ms step_avg:69.87ms +step:4300/20000 train_loss:1.9542 train_time:300398ms step_avg:69.86ms +step:4400/20000 train_loss:2.0547 train_time:307385ms step_avg:69.86ms +step:4500/20000 train_loss:2.2044 train_time:314322ms step_avg:69.85ms +step:4500/20000 val_loss:2.1075 val_bpb:1.2482 train_time:314327ms step_avg:69.85ms +step:4600/20000 train_loss:1.9174 train_time:321307ms step_avg:69.85ms +step:4700/20000 train_loss:2.2206 train_time:328234ms step_avg:69.84ms +step:4800/20000 train_loss:2.2065 train_time:335222ms step_avg:69.84ms +step:4900/20000 train_loss:2.1158 train_time:342146ms step_avg:69.83ms +step:5000/20000 train_loss:1.9690 train_time:349130ms step_avg:69.83ms +step:5000/20000 val_loss:2.1014 val_bpb:1.2446 train_time:349135ms step_avg:69.83ms +step:5100/20000 train_loss:1.9775 train_time:356055ms step_avg:69.81ms +step:5200/20000 train_loss:2.1256 train_time:363038ms step_avg:69.81ms +step:5300/20000 train_loss:2.1586 train_time:369967ms step_avg:69.81ms +step:5400/20000 train_loss:2.1436 train_time:376953ms step_avg:69.81ms +step:5500/20000 train_loss:2.0943 train_time:383881ms step_avg:69.80ms +step:5500/20000 val_loss:2.1008 val_bpb:1.2442 train_time:383886ms step_avg:69.80ms +step:5600/20000 train_loss:2.1372 train_time:390858ms step_avg:69.80ms +step:5700/20000 train_loss:2.1220 train_time:397789ms step_avg:69.79ms +step:5800/20000 train_loss:2.0821 train_time:404769ms step_avg:69.79ms +step:5900/20000 train_loss:2.0377 train_time:411693ms step_avg:69.78ms +step:6000/20000 train_loss:2.1624 train_time:418677ms step_avg:69.78ms +step:6000/20000 val_loss:2.0866 val_bpb:1.2358 train_time:418682ms step_avg:69.78ms +step:6100/20000 train_loss:2.0556 train_time:425603ms step_avg:69.77ms +step:6200/20000 train_loss:2.0224 train_time:432585ms step_avg:69.77ms +step:6300/20000 train_loss:1.9659 train_time:439571ms step_avg:69.77ms +step:6400/20000 train_loss:2.0977 train_time:446494ms step_avg:69.76ms +step:6500/20000 train_loss:2.0067 train_time:453515ms step_avg:69.77ms +step:6500/20000 val_loss:2.0629 val_bpb:1.2218 train_time:453520ms step_avg:69.77ms +step:6600/20000 train_loss:2.0437 train_time:460444ms step_avg:69.76ms +step:6700/20000 train_loss:2.0778 train_time:467431ms step_avg:69.77ms +step:6800/20000 train_loss:2.0975 train_time:474358ms step_avg:69.76ms +step:6900/20000 train_loss:2.0120 train_time:481349ms step_avg:69.76ms +step:7000/20000 train_loss:2.1340 train_time:488273ms step_avg:69.75ms +step:7000/20000 val_loss:2.0388 val_bpb:1.2075 train_time:488277ms step_avg:69.75ms +step:7100/20000 train_loss:1.9695 train_time:495256ms step_avg:69.75ms +step:7200/20000 train_loss:2.1017 train_time:502186ms step_avg:69.75ms +step:7300/20000 train_loss:1.9890 train_time:509173ms step_avg:69.75ms +step:7400/20000 train_loss:2.0143 train_time:516103ms step_avg:69.74ms +step:7500/20000 train_loss:2.0008 train_time:523089ms step_avg:69.75ms +step:7500/20000 val_loss:2.0117 val_bpb:1.1915 train_time:523094ms step_avg:69.75ms +step:7600/20000 train_loss:1.8796 train_time:530022ms step_avg:69.74ms +step:7700/20000 train_loss:1.9559 train_time:537010ms step_avg:69.74ms +step:7800/20000 train_loss:2.0185 train_time:543939ms step_avg:69.74ms +step:7900/20000 train_loss:1.9954 train_time:550925ms step_avg:69.74ms +step:8000/20000 train_loss:1.9790 train_time:557849ms step_avg:69.73ms +step:8000/20000 val_loss:1.9808 val_bpb:1.1731 train_time:557853ms step_avg:69.73ms +step:8100/20000 train_loss:2.0079 train_time:564845ms step_avg:69.73ms +step:8200/20000 train_loss:2.0415 train_time:571778ms step_avg:69.73ms +step:8300/20000 train_loss:1.9585 train_time:578772ms step_avg:69.73ms +step:8400/20000 train_loss:1.9716 train_time:585772ms step_avg:69.73ms +step:8500/20000 train_loss:1.9465 train_time:592706ms step_avg:69.73ms +step:8500/20000 val_loss:1.9492 val_bpb:1.1544 train_time:592711ms step_avg:69.73ms +step:8600/20000 train_loss:1.9772 train_time:599696ms step_avg:69.73ms +step:8605/20000 val_loss:1.9459 val_bpb:1.1524 train_time:600045ms step_avg:69.73ms +stopping_early: wallclock_cap train_time:600045ms step:8605/20000 +peak memory: 14069 MiB +ema: applying EMA weights +Serialized model: 105789375 bytes Code: 55299 bytes +final_int8_zlib_roundtrip compressed_model_bytes:15524119 code_bytes:55299 total_artifact_bytes:15579418 +Serialized zstd-22: 15524119 bytes Total: 15579418 bytes +ttt: starting 3 epochs of test-time training (lr=0.0001) + ttt_epoch:1/3 loss:1.9412 + ttt_epoch:2/3 loss:1.9412 + ttt_epoch:3/3 loss:1.9411 +ttt: done in 112.2s +final_eval_mode:sliding_window_cache stride:64 batch_seqs:32 docs:50000 + cache_eval [ 0.0%] 32/121106 windows running_bpb=1.205432 + cache_eval [ 1.3%] 1632/121106 windows running_bpb=1.130149 + cache_eval [ 2.7%] 3232/121106 windows running_bpb=1.133107 + cache_eval [ 4.0%] 4832/121106 windows running_bpb=1.126767 + cache_eval [ 5.3%] 6432/121106 windows running_bpb=1.138463 + cache_eval [ 6.6%] 8032/121106 windows running_bpb=1.140096 + cache_eval [ 8.0%] 9632/121106 windows running_bpb=1.141881 + cache_eval [ 9.3%] 11232/121106 windows running_bpb=1.137747 + cache_eval [ 10.6%] 12832/121106 windows running_bpb=1.135296 + cache_eval [ 11.9%] 14432/121106 windows running_bpb=1.137162 + cache_eval [ 13.2%] 16032/121106 windows running_bpb=1.145607 + cache_eval [ 14.6%] 17632/121106 windows running_bpb=1.144110 + cache_eval [ 15.9%] 19232/121106 windows running_bpb=1.145228 + cache_eval [ 17.2%] 20832/121106 windows running_bpb=1.143506 + cache_eval [ 18.5%] 22432/121106 windows running_bpb=1.142122 + cache_eval [ 19.8%] 24032/121106 windows running_bpb=1.142332 + cache_eval [ 21.2%] 25632/121106 windows running_bpb=1.143682 + cache_eval [ 22.5%] 27232/121106 windows running_bpb=1.144200 + cache_eval [ 23.8%] 28832/121106 windows running_bpb=1.150219 + cache_eval [ 25.1%] 30432/121106 windows running_bpb=1.147678 + cache_eval [ 26.4%] 32032/121106 windows running_bpb=1.148642 + cache_eval [ 27.8%] 33632/121106 windows running_bpb=1.147317 + cache_eval [ 29.1%] 35232/121106 windows running_bpb=1.146657 + cache_eval [ 30.4%] 36832/121106 windows running_bpb=1.146221 + cache_eval [ 31.7%] 38432/121106 windows running_bpb=1.146848 + cache_eval [ 33.1%] 40032/121106 windows running_bpb=1.144372 + cache_eval [ 34.4%] 41632/121106 windows running_bpb=1.143269 + cache_eval [ 35.7%] 43232/121106 windows running_bpb=1.143604 + cache_eval [ 37.0%] 44832/121106 windows running_bpb=1.142279 + cache_eval [ 38.3%] 46432/121106 windows running_bpb=1.142217 + cache_eval [ 39.7%] 48032/121106 windows running_bpb=1.141502 + cache_eval [ 41.0%] 49632/121106 windows running_bpb=1.142742 + cache_eval [ 42.3%] 51232/121106 windows running_bpb=1.143821 + cache_eval [ 43.6%] 52832/121106 windows running_bpb=1.144354 + cache_eval [ 44.9%] 54432/121106 windows running_bpb=1.143875 + cache_eval [ 46.3%] 56032/121106 windows running_bpb=1.144298 + cache_eval [ 47.6%] 57632/121106 windows running_bpb=1.143404 + cache_eval [ 48.9%] 59232/121106 windows running_bpb=1.139477 + cache_eval [ 50.2%] 60832/121106 windows running_bpb=1.139575 + cache_eval [ 51.6%] 62432/121106 windows running_bpb=1.140468 + cache_eval [ 52.9%] 64032/121106 windows running_bpb=1.140648 + cache_eval [ 54.2%] 65632/121106 windows running_bpb=1.140529 + cache_eval [ 55.5%] 67232/121106 windows running_bpb=1.139347 + cache_eval [ 56.8%] 68832/121106 windows running_bpb=1.139129 + cache_eval [ 58.2%] 70432/121106 windows running_bpb=1.138437 + cache_eval [ 59.5%] 72032/121106 windows running_bpb=1.138572 + cache_eval [ 60.8%] 73632/121106 windows running_bpb=1.138551 + cache_eval [ 62.1%] 75232/121106 windows running_bpb=1.138722 + cache_eval [ 63.4%] 76832/121106 windows running_bpb=1.138477 + cache_eval [ 64.8%] 78432/121106 windows running_bpb=1.139092 + cache_eval [ 66.1%] 80032/121106 windows running_bpb=1.139366 + cache_eval [ 67.4%] 81632/121106 windows running_bpb=1.139086 + cache_eval [ 68.7%] 83232/121106 windows running_bpb=1.140113 + cache_eval [ 70.0%] 84832/121106 windows running_bpb=1.142051 + cache_eval [ 71.4%] 86432/121106 windows running_bpb=1.141351 + cache_eval [ 72.7%] 88032/121106 windows running_bpb=1.142039 + cache_eval [ 74.0%] 89632/121106 windows running_bpb=1.142409 + cache_eval [ 75.3%] 91232/121106 windows running_bpb=1.142402 + cache_eval [ 76.7%] 92832/121106 windows running_bpb=1.141974 + cache_eval [ 78.0%] 94432/121106 windows running_bpb=1.142212 + cache_eval [ 79.3%] 96032/121106 windows running_bpb=1.141602 + cache_eval [ 80.6%] 97632/121106 windows running_bpb=1.144434 + cache_eval [ 81.9%] 99232/121106 windows running_bpb=1.144489 + cache_eval [ 83.3%] 100832/121106 windows running_bpb=1.144520 + cache_eval [ 84.6%] 102432/121106 windows running_bpb=1.144170 + cache_eval [ 85.9%] 104032/121106 windows running_bpb=1.143686 + cache_eval [ 87.2%] 105632/121106 windows running_bpb=1.142948 + cache_eval [ 88.5%] 107232/121106 windows running_bpb=1.142941 + cache_eval [ 89.9%] 108832/121106 windows running_bpb=1.143579 + cache_eval [ 91.2%] 110432/121106 windows running_bpb=1.143620 + cache_eval [ 92.5%] 112032/121106 windows running_bpb=1.143623 + cache_eval [ 93.8%] 113632/121106 windows running_bpb=1.144071 + cache_eval [ 95.1%] 115232/121106 windows running_bpb=1.143844 + cache_eval [ 96.5%] 116832/121106 windows running_bpb=1.143463 + cache_eval [ 97.8%] 118432/121106 windows running_bpb=1.143759 + cache_eval [ 99.1%] 120032/121106 windows running_bpb=1.143859 +final_roundtrip val_loss:1.9223 val_bpb:1.1385 eval_time:224846ms +final_roundtrip_exact val_loss:1.92231620 val_bpb:1.13850394