diff --git a/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/README.md b/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/README.md new file mode 100644 index 0000000000..a0d6b891cb --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/README.md @@ -0,0 +1,228 @@ +# JEPA v3 — Fixing Representation Collapse via Span Masking + +**Track:** Non-record (unlimited compute) +**Architecture:** 11L 512-dim U-Net Transformer, mlp_mult=3, GQA 8q/4kv +**val_bpb:** 1.2321 pre-quant — builds on JEPA v2 ([2026-04-01](../2026-04-01_JEPA_v2_MultiStep_Int6_BigramHash_EMA)) + +--- + +## The Problem v2 Left Unsolved + +JEPA v2 correctly diagnosed three implementation bugs (EMA momentum too high, single-step prediction redundant with CE, gradient accumulation batch mismatch) and fixed all three. After fixing them, the JEPA loss still stabilized at 0.002 — the same collapse as v1. + +v2 hypothesized that the remaining issue was structural: same-sequence next-k prediction may be too easy for a causal LM, since the context encoder already has access to all the information needed to predict nearby positions. v2 concluded: + +> *The fix requires masking (I-JEPA style) or cross-sequence targets — both are architectural changes that add complexity without a clear BPB path.* + +v3 implements that fix. + +--- + +## The Fix: Span-Masked JEPA + +The core change is to align with I-JEPA's training procedure: + +``` +Context tokens ──► Context Encoder ──► context embeddings + │ + Predictor + │ + ▼ +Target tokens ──► Target Encoder ──► target embeddings (ground truth) +``` + +**What changes:** The context encoder no longer sees the full sequence. Target span positions are replaced with a learned mask embedding (`jepa_mask_emb`) before the encoder. The prediction task is now genuinely hard — the context encoder cannot reconstruct the target token from its own input, and must rely on surrounding context. + +**What stays the same:** Architecture (11L U-Net, same hyperparameters), CE loss path (full unmasked forward), EMA target encoder (now sees full unmasked sequence), int6+LZMA compression, BigramHash, LeakyReLU(0.5)², artifact EMA. + +--- + +## Implementation + +### 1. Span Sampling + +Each training step, `sample_block_spans()` samples `JEPA_NUM_SPANS=4` non-overlapping contiguous spans from the 1024-token sequence. Span lengths are drawn from a Geometric distribution with mean `JEPA_SPAN_LEN_MEAN=16`, clamped to `[JEPA_SPAN_LEN_MIN=4, seq_len/(2*num_spans)]`. + +Defaults: 4 spans × ~16 tokens = ~6% of sequence masked per step. + +```python +def sample_block_spans(seq_len, num_spans, span_len_mean, span_len_min=4, device=None): + # Geometric(p=1/span_len_mean) lengths, greedy overlap resolution + # Returns LongTensor [num_spans, 2] of (start, end_exclusive) pairs +``` + +The Geometric distribution produces heavy-tailed lengths: most spans are short, occasionally longer — varying prediction difficulty within a step. + +### 2. Mask Embedding + +```python +self.jepa_mask_emb = nn.Parameter(torch.zeros(model_dim)) +``` + +A single learned 512-dim vector shared across all masked positions. Zero-init: starts neutral and is trained by the JEPA loss gradient to encode "this position is unknown, predict it from context." Conceptually equivalent to BERT's `[MASK]` token but in continuous embedding space. + +**Bigram leak fix:** `BigramHashEmbedding(prev_tok, masked_tok)` would reveal the masked token's identity via the Cantor hash `h(a,b) = (a+b)(a+b+1)/2 + b`. The bigram contribution is explicitly zeroed at masked positions before applying `jepa_mask_emb`. + +```python +if jepa_mask is not None: + bigram = bigram.masked_fill(jepa_mask.unsqueeze(-1), 0.0) # prevent hash leak +x = torch.where(jepa_mask.unsqueeze(-1), self.jepa_mask_emb.to(x.dtype), x) +``` + +### 3. Two-Pass Forward Per Step + +```python +# Pass 1: CE loss — full unmasked sequence (unchanged) +ce_loss = model(x, y) + +# Pass 2: JEPA — masked context encode + target encode +spans = sample_block_spans(T, num_spans, span_len_mean, span_len_min) +jepa_mask = torch.zeros((B, T), dtype=torch.bool, device=x.device) +for s, e in spans.tolist(): + jepa_mask[:, s:e] = True + +with torch.no_grad(): + z_target = target_encoder.encode(x) # full, unmasked + +z_context = base_model.encode(x, jepa_mask=jepa_mask) # masked input +z_pred = base_model.jepa_predictor(z_context) + +# Loss only at masked positions +z_p = z_pred[jepa_mask] # [N_masked, D] +z_t = z_target[jepa_mask] # [N_masked, D] + +mse_loss = F.mse_loss(z_p.float(), z_t.float()) +var_loss = vicreg_var_loss(z_p.float(), gamma=1.0, eps=1e-4) # predictor only +cov_loss = vicreg_cov_loss(z_p.float()) # predictor only +jepa_loss = mse_loss + 0.15 * var_loss + 0.02 * cov_loss + +loss = ce_loss + jepa_lambda * jepa_loss +``` + +The target encoder (EMA copy) sees the full unmasked sequence — its representations are not corrupted by masking. The CE path also remains fully unmasked. Only the context encoder sees the masked input. + +### 4. VICReg Anti-Collapse Regularization + +Span masking prevents collapse structurally, but VICReg terms provide an explicit signal to maintain a spread, decorrelated embedding space: + +```python +def vicreg_var_loss(z, gamma, eps): + """Hinge: penalize per-feature std < gamma across the batch of masked tokens.""" + z_c = z - z.mean(dim=0) + std = (z_c.pow(2).sum(0) / (n - 1) + eps).sqrt() + return (gamma - std).clamp(min=0).mean() + +def vicreg_cov_loss(z): + """Off-diagonal covariance penalty: decorrelate feature dimensions.""" + cov = (z - z.mean(0)).T @ (z - z.mean(0)) / (n - 1) + off = cov.pow(2); off.fill_diagonal_(0) + return off.sum() / d +``` + +Both terms are applied only to the **predictor** side (`z_pred[jepa_mask]`) where gradients flow. The target side is monitored as a diagnostic but receives no gradient — it is updated only via EMA. + +This follows V-JEPA practice: variance and covariance regularization on the online (predictor) representations ensures the model cannot collapse all masked positions to a single point or to a low-rank subspace. + +### 5. Optimizer Bug Fix (v2 regression) + +In v2's `train_gpt.py`, the optimizer setup iterates only `base_model.blocks.named_parameters()` — `jepa_predictor` and `jepa_mask_emb` are outside `blocks` and appear in none of the three optimizer groups (Muon, scalar Adam, tok Adam). This is verifiable in the v2 commit (`b4a428b`). The predictor was frozen at zero-init throughout training — JEPA gradients were computed but never applied to it. + +Fixed by explicitly appending predictor and `jepa_mask_emb` to the parameter lists: + +```python +scalar_params.append(base_model.jepa_mask_emb) +for name, p in base_model.jepa_predictor.named_parameters(): + if p.ndim == 2: + matrix_params.append(p) # fc.weight, proj.weight → Muon + else: + scalar_params.append(p) # RMSNorm → Adam +``` + +--- + +## Architecture Summary + +``` +11L U-Net Transformer (5 encoder + 6 decoder, skip connections) + dim=512, 8 attention heads, 4 KV heads (GQA) + mlp_mult=3, LeakyReLU(0.5)^2 + RoPE, RMSNorm, logit softcap=30 + +Embedding: + tok_emb(t) + BigramHashEmb(t-1, t) [zeroed at masked pos] → RMSNorm → transformer + +JEPA (auxiliary, span-masked): + context_encoder(x, mask=jepa_mask) → z_context → JEPAPredictor → z_pred + EMA target_encoder(x, mask=None) → z_target + Loss: MSE(z_pred[mask], z_target[mask]) + + 0.15 × VICReg_var(z_pred[mask]) ← anti-collapse variance hinge + + 0.02 × VICReg_cov(z_pred[mask]) ← decorrelation penalty + Spans: Geometric(mean=16), num_spans=4, ~6% of sequence per step + +Serialization (unchanged from v2): + artifact_ema (Polyak avg, decay=0.9999) + → int6 quantization (range [-31,31]) + → LZMA compression (preset=9) +``` + +--- + +## Results + +| Metric | Value | +|--------|-------| +| val_bpb (pre-quant) | **1.2321** | +| Architecture | 11L 512-dim U-Net | +| JEPA spans | 4 × Geometric(mean=16) | +| Mask ratio | ~6% per step | +| jepa_lambda | 0.12 | +| EMA momentum | 0.9 → 0.999 | +| VICReg var weight | 0.15 | +| VICReg cov weight | 0.02 | + +### Comparison to v2 + +| Submission | val_bpb | JEPA approach | Collapse? | +|---|---|---|---| +| v2 bigram (no JEPA) | 1.4617 | — | — | +| v2 full (next-k JEPA) | 1.6047 | Unmasked, offset [1,2,4,8] | Yes (loss→0.002) | +| **v3 (this)** | **1.2321** | **Span-masked, I-JEPA style** | **No** | + +v3 is **0.2326 BPB better than v2 with JEPA**, and **0.2296 BPB better than v2 without JEPA**. Span masking produces genuine gradient signal from step 1. + +--- + +## Key Env Vars + +```bash +JEPA_NUM_SPANS=4 # number of target spans per sequence +JEPA_SPAN_LEN_MEAN=16 # geometric mean span length (tokens) +JEPA_SPAN_LEN_MIN=4 # minimum span length +JEPA_LAMBDA=0.12 # JEPA loss weight +JEPA_EMA_MOMENTUM=0.9 # starting EMA momentum (rises to 0.999) +JEPA_PRED_DIM=256 # predictor hidden dim +JEPA_VAR_WEIGHT=0.15 # VICReg variance term weight +JEPA_COV_WEIGHT=0.02 # VICReg covariance term weight +JEPA_VAR_GAMMA=1.0 # target std for variance hinge +BIGRAM_VOCAB_SIZE=2048 # 0 = disable bigram embedding +ARTIFACT_EMA_DECAY=0.9999 +QUANT_MAX=31 # int6 +``` + +--- + +## What This Submission Is (and Isn't) + +This is a **research non-record submission**. The goal is to demonstrate that properly-structured JEPA — specifically, span masking that prevents the context encoder from seeing target tokens — produces the gradient signal that v1 and v2 failed to generate. + +The path from here to record territory requires combining span-masked JEPA with the compression and quantization techniques from the current SOTA (GPTQ, TTT, XSA). This submission establishes that the JEPA auxiliary objective itself is no longer the bottleneck. + +--- + +## References + +- JEPA original: LeCun (2022), "A Path Towards Autonomous Machine Intelligence" +- I-JEPA: Assran et al. (2023), CVPR — span masking for vision +- BYOL: Grill et al. (2020), NeurIPS — EMA target encoder design +- JEPA v2 (this repo): [2026-04-01](../2026-04-01_JEPA_v2_MultiStep_Int6_BigramHash_EMA) — multi-step prediction + optimizer fixes +- Parameter Golf SOTA: abaybektursun PR #1019 — 1.1147 BPB diff --git a/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/run_train_jepa.sh b/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/run_train_jepa.sh new file mode 100644 index 0000000000..9188df8cf6 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/run_train_jepa.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +# Small script that activates venv, trains jepa, and closes the remote instance. +source .venv/bin/activate +python train_jepa.py +shutdown -h now \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/submission.json b/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/submission.json new file mode 100644 index 0000000000..38def2c656 --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/submission.json @@ -0,0 +1,28 @@ +{ + "name": "aiejvn", + "github_id": "139830381", + "val_bpb": 1.2321, + "val_bpb_note": "Pre-quantization. Span-masked JEPA v3, ~20hr on AWS A10G.", + "track": "non_record_unlimited_compute", + "description": "JEPA v3: fixes representation collapse via I-JEPA-style span masking + VICReg (V-JEPA) loss terms. Context encoder sees target spans replaced with learned mask embedding; target encoder sees full unmasked sequence. VICReg variance hinge and covariance penalty applied to predictor representations prevent dimensional collapse.", + "hardware": "AWS A10G", + "date": "2026-04-12", + "status": "research_non_record", + "base_architecture": "11L 512-dim U-Net Transformer, mlp_mult=3, GQA 8q/4kv", + "techniques": { + "jepa_style": "span_masked_i_jepa", + "jepa_num_spans": 4, + "jepa_span_len_mean": 16, + "jepa_lambda": 0.12, + "jepa_ema_momentum": "0.9→0.999", + "jepa_pred_dim": 256, + "vicreg_var_weight": 0.15, + "vicreg_cov_weight": 0.02, + "vicreg_var_gamma": 1.0, + "bigram_vocab_size": 2048, + "artifact_ema_decay": 0.9999, + "quantization": "int6", + "compression": "lzma", + "activation": "leaky_relu_0.5_squared" + } +} diff --git a/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/train_jepa.py b/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/train_jepa.py new file mode 100644 index 0000000000..cc3958273a --- /dev/null +++ b/records/track_non_record_16mb/2026-04-12_JEPA_v3_Fix_Representation_Collapse/train_jepa.py @@ -0,0 +1,1529 @@ +""" +JEPA v2 — Multi-Step Latent Prediction + int6/LZMA + BigramHash + EMA artifact. + +This submission documents WHY naive single-step JEPA implementations collapse +(loss→0.002) and fixes the problem structurally with multi-step prediction. + +Root cause of collapse in v1 (and most existing JEPA PRs): + 1. EMA momentum=0.996 with <700 steps → target encoder barely diverges from online + encoder → task trivially easy → loss→0 → zero gradient signal + 2. Single-step prediction (z[t]→z[t+1]) is almost redundant with CE objective + 3. Bug: z_target computed from micro_batch[0] but JEPA loss applied to all + micro-batches → JEPA loss computed on mismatched batch pairs (pure noise) + +Fixes in v2: + 1. EMA momentum=0.9 (default) → target diverges in tens of steps + 2. Multi-step prediction at offsets [1,2,4,8] with weights [1.0,0.5,0.25,0.125] + — offset-8 prediction requires long-range planning, can't be trivially solved + 3. z_target computed per micro-batch (correct pairing) + +Additional improvements: + - int6 quantization (range [-31,31] in int8 container) → smaller artifact + - LZMA compression (preset=9) → ~280KB smaller than zlib at same quality + - BigramHash embedding (vocab=2048, dim=512) → ~500KB compressed, + frees attention capacity from learning bigram statistics + - EMA artifact (decay=0.9999) → smoothed checkpoint for serialization, + distinct from JEPA EMA (which is the target encoder during training) + - LeakyReLU(0.5)^2 instead of ReLU^2 (community-validated) + +New env vars (additions to v1): + BIGRAM_VOCAB_SIZE=2048 BigramHash table size (0=disabled) + ARTIFACT_EMA_DECAY=0.9999 EMA decay for the saved checkpoint + QUANT_MAX=31 Quantization range: 31=int6, 127=int8 + +JEPA env vars (updated defaults): + USE_JEPA=1 Enable JEPA (default: 0) + JEPA_LAMBDA=0.12 Loss weight (default: 0.12, matches #1006) + JEPA_EMA_MOMENTUM=0.9 Starting EMA momentum (default: 0.9, was 0.996) + JEPA_PRED_DIM=256 Predictor hidden dim +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import lzma +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 131_072)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 262_144)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 4.0)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # JEPA auxiliary objective hyperparameters. + jepa_lambda = float(os.environ.get("JEPA_LAMBDA", "0.12")) + jepa_ema_momentum = float(os.environ.get("JEPA_EMA_MOMENTUM", "0.9")) + jepa_pred_dim = int(os.environ.get("JEPA_PRED_DIM", "256")) + jepa_warmup_steps = int(os.environ.get("JEPA_WARMUP_STEPS", "100")) + # Span-masked JEPA: sample N non-overlapping target spans per sequence. + jepa_num_spans = int(os.environ.get("JEPA_NUM_SPANS", "4")) + jepa_span_len_mean = int(os.environ.get("JEPA_SPAN_LEN_MEAN", "16")) + jepa_span_len_min = int(os.environ.get("JEPA_SPAN_LEN_MIN", "4")) + + # VICReg-style variance and covariance regularization for JEPA anti-collapse. + jepa_var_weight = float(os.environ.get("JEPA_VAR_WEIGHT", "0.15")) + jepa_cov_weight = float(os.environ.get("JEPA_COV_WEIGHT", "0.02")) + jepa_var_gamma = float(os.environ.get("JEPA_VAR_GAMMA", "1.0")) + jepa_var_eps = float(os.environ.get("JEPA_VAR_EPS", "1e-4")) + + # BigramHash embedding: lookup table for (token[t-1], token[t]) pairs. + # Hashed via Cantor pairing into bigram_vocab_size buckets. + # 0 = disabled. 2048 → ~1.5M params, compresses to ~500KB. + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", "2048")) + + # EMA of model weights for artifact serialization (distinct from JEPA EMA). + # The JEPA EMA is the target encoder during training; this one is for the checkpoint. + artifact_ema_decay = float(os.environ.get("ARTIFACT_EMA_DECAY", "0.9999")) + + # Quantization range. 31 = int6 (values in [-31,31] stored in int8 container). + # int6 gives better compression than int8 at minor quality cost. + quant_max = int(os.environ.get("QUANT_MAX", "31")) + + # MLP activation slope for LeakyReLU^2. 0.0 = standard ReLU^2 (v1 baseline). + # 0.5 = LeakyReLU(0.5)^2, community-validated free improvement. + mlp_leaky_slope = float(os.environ.get("MLP_LEAKY_SLOPE", "0.5")) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.no_grad(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_ce = model(x, y) + batch_loss = batch_ce.detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if rank == 0 and (batch_seq_start // local_batch_seqs) % 25 == 0: + print(f"val_progress:{batch_seq_end}/{seq_end}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# v2: configurable quantization range. Default 31 = int6 (range [-31,31] in int8 +# container). Better compression than int8 (range [-127,127]) at minor quality cost. +_ARGS_QUANT_MAX: int | None = None # set in main() from args.quant_max +_MLP_LEAKY_SLOPE: float = 0.0 # set in main() from args.mlp_leaky_slope; 0 = relu + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + qmax = float(_ARGS_QUANT_MAX if _ARGS_QUANT_MAX is not None else 31) + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / qmax).clamp_min(1.0 / qmax) + q = torch.clamp(torch.round(clipped / scale[:, None]), -qmax, qmax).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / qmax if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -qmax, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + h = self.fc(x) + h = F.leaky_relu(h, negative_slope=_MLP_LEAKY_SLOPE) if _MLP_LEAKY_SLOPE > 0.0 else torch.relu(h) + return self.proj(h.square()) + + +class JEPAPredictor(nn.Module): + """ + Small residual MLP that maps z_context -> z_pred (predicted embedding of next token). + Uses ReLU^2 activation consistent with the rest of the model. + Saved in the artifact and used during both training and eval. + The EMA target encoder is NOT saved — it is only needed to produce training targets. + """ + def __init__(self, dim: int, hidden_dim: int): + super().__init__() + self.norm = RMSNorm() + self.fc = CastedLinear(dim, hidden_dim, bias=False) + self.proj = CastedLinear(hidden_dim, dim, bias=False) + self.proj._zero_init = True # starts as identity (residual connection) + + def forward(self, x: Tensor) -> Tensor: + h = torch.relu(self.fc(self.norm(x))) + return x + self.proj(h.square()) # residual: z_pred starts equal to z_context + + +def sample_block_spans( + seq_len: int, + num_spans: int, + span_len_mean: int, + span_len_min: int = 4, + device: torch.device | None = None, +) -> Tensor: + """ + Sample num_spans non-overlapping contiguous spans within [0, seq_len). + + Span lengths are drawn from Geometric(1/span_len_mean): a distribution where + each integer k >= 1 has probability p*(1-p)^(k-1) with p=1/span_len_mean. + Expected length = span_len_mean; heavy-tailed so most spans are short but + occasionally longer ones appear. Lengths are clamped to [span_len_min, + seq_len // (2 * num_spans)] to avoid consuming the whole sequence. + + Returns a LongTensor (int64) of shape [num_spans, 2] on `device`, where each + row is (start, end_exclusive). Spans are sorted by start and non-overlapping. + """ + max_span_len = max(span_len_min, seq_len // (2 * num_spans)) + p = 1.0 / span_len_mean + spans = [] + for _ in range(num_spans): + # Geometric(p) sample via inverse CDF: floor(log(U)/log(1-p)) + 1 + length = int(math.floor(math.log(random.random()) / math.log(1.0 - p))) + 1 + length = min(max_span_len, max(span_len_min, length)) + start = random.randint(0, seq_len - length) + spans.append((start, start + length)) + # Sort by start and greedily resolve overlaps + spans.sort() + resolved = [] + cursor = 0 + for s, e in spans: + s = max(s, cursor) + if s >= seq_len: + break + e = min(e, seq_len) + if e - s >= span_len_min: + resolved.append((s, e)) + cursor = e + # Duplicate last span if any were dropped (rare edge case with tight seq_len) + while len(resolved) < num_spans and resolved: + resolved.append(resolved[-1]) + return torch.tensor(resolved[:num_spans], dtype=torch.long, device=device) + + +class BigramHashEmbedding(nn.Module): + """ + Lookup table for bigram context (token[t-1], token[t]) hashed via Cantor pairing. + + Why: a causal LM must spend attention capacity learning bigram co-occurrence + statistics (e.g. "New" → "York"). An explicit bigram embedding makes these + patterns free at O(1) lookup cost, leaving attention heads for longer-range + dependencies. + + Implementation: Cantor pairing h(a,b) = (a+b)(a+b+1)/2 + b, then mod + bigram_vocab_size to map into a fixed-size table. The output is summed with + the standard token embedding before the first layer. + + Budget: bigram_vocab_size=2048, dim=512 → 1,048,576 params (~4MB fp32). + After int6+LZMA the bigram table compresses to ~300-500KB because many + rarely-seen pairs share the same bucket (hash collisions act as implicit + parameter sharing) and learned weights for rare bigrams stay near-zero. + """ + + def __init__(self, bigram_vocab_size: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embedding = nn.Embedding(bigram_vocab_size, model_dim) + nn.init.normal_(self.embedding.weight, mean=0.0, std=0.01) + + def forward(self, input_ids: Tensor) -> Tensor: + # input_ids: [B, T] (int64) + # For position 0, the "previous token" is padded with token 0. + prev_ids = torch.cat([input_ids.new_zeros(input_ids.shape[0], 1), input_ids[:, :-1]], dim=1) + a = prev_ids.long() + b = input_ids.long() + s = a + b + h = (s * (s + 1) // 2 + b) % self.bigram_vocab_size + return self.embedding(h) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + jepa_pred_dim: int = 0, # 0 = no predictor (baseline mode) + bigram_vocab_size: int = 0, # 0 = no bigram embedding + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + # JEPA predictor: optional small MLP on top of the encoder output. + # Maps z_context -> z_pred. Used only for JEPA MSE loss, not for CE. + self.jepa_predictor: JEPAPredictor | None = ( + JEPAPredictor(model_dim, jepa_pred_dim) if jepa_pred_dim > 0 else None + ) + # BigramHash embedding: lookup for (prev_token, cur_token) pairs. + # Summed with tok_emb before the first transformer layer. + self.bigram_hash_emb: BigramHashEmbedding | None = ( + BigramHashEmbedding(bigram_vocab_size, model_dim) if bigram_vocab_size > 0 else None + ) + # Learned mask embedding: replaces token embeddings at target span positions + # in the context encoder forward pass. Zero-init so it starts neutral; + # trained by the JEPA loss to encode "this position is unknown, predict it". + self.jepa_mask_emb = nn.Parameter(torch.zeros(model_dim)) + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def encode(self, input_ids: Tensor, jepa_mask: Tensor | None = None) -> Tensor: + """Run the U-Net encoder and return final hidden states [B, T, D]. + + jepa_mask: optional bool [B, T] where True marks target span positions. + At masked positions, the token embedding is replaced with jepa_mask_emb + so the context encoder cannot see the actual token identity. + The bigram contribution is also zeroed at masked positions to prevent + the Cantor hash of (prev_tok, masked_tok) from leaking the token identity. + Pass jepa_mask=None (default) for the CE forward pass and the target encoder. + """ + x = self.tok_emb(input_ids) + if self.bigram_hash_emb is not None: + bigram = self.bigram_hash_emb(input_ids) + if jepa_mask is not None: + bigram = bigram.masked_fill(jepa_mask.unsqueeze(-1), 0.0) + x = x + bigram + if jepa_mask is not None: + mask_vec = self.jepa_mask_emb.to(dtype=x.dtype) + x = torch.where(jepa_mask.unsqueeze(-1), mask_vec, x) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + return self.final_norm(x) # [B, T, D] + + def forward( + self, + input_ids: Tensor, + target_ids: Tensor, + ) -> Tensor: + """ + Returns ce_loss only. The JEPA path (masked encode + predictor) is handled + explicitly in the training loop via encode(jepa_mask=...) + jepa_predictor(), + so torch.compile never sees the target encoder or changing Python scalars. + """ + # --- U-Net encoder (unmasked, for CE) --- + z_context = self.encode(input_ids) # [B, T, D] + + # --- LM head for cross-entropy --- + z_flat = z_context.reshape(-1, z_context.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(z_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(z_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# VICREG HELPERS +# ----------------------------- + +def vicreg_var_loss(z: Tensor, gamma: float, eps: float) -> Tensor: + """Per-feature variance hinge: penalizes when std drops below gamma.""" + assert z.ndim == 2, f"vicreg_var_loss expects 2D input, got shape {z.shape}" + n = z.shape[0] + assert n >= 2, f"vicreg_var_loss requires at least 2 samples, got {n}" + z_c = z - z.mean(dim=0) + var = z_c.pow(2).sum(dim=0) / (n - 1) + std = (var + eps).sqrt() + return (gamma - std).clamp(min=0.0).mean() + + +def vicreg_cov_loss(z: Tensor) -> Tensor: + """Off-diagonal covariance penalty: penalizes inter-feature correlation.""" + assert z.ndim == 2, f"vicreg_cov_loss expects 2D input, got shape {z.shape}" + n, d = z.shape + assert n >= 2, f"vicreg_cov_loss requires at least 2 samples, got {n}" + z_c = z - z.mean(dim=0) + cov = z_c.T @ z_c / (n - 1) + off = cov.pow(2) + off.fill_diagonal_(0.0) + return off.sum() / d + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + global _ARGS_QUANT_MAX, _MLP_LEAKY_SLOPE + _ARGS_QUANT_MAX = args.quant_max + _MLP_LEAKY_SLOPE = args.mlp_leaky_slope + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + jepa_pred_dim=args.jepa_pred_dim, + bigram_vocab_size=args.bigram_vocab_size, + ).to(device).bfloat16() + + # Artifact EMA: smoothed checkpoint for serialization. + # Distinct from JEPA EMA (target encoder). This one is always active and + # its state_dict is what gets quantized and saved at the end of training. + artifact_ema: GPT = copy.deepcopy(base_model).cpu().bfloat16() + for p in artifact_ema.parameters(): + p.requires_grad_(False) + + # JEPA: create EMA target encoder (deep copy, frozen, not saved in artifact). + jepa_target_encoder = copy.deepcopy(base_model).to(device).bfloat16() + for p in jepa_target_encoder.parameters(): + p.requires_grad_(False) + log0(f"jepa:enabled lambda={args.jepa_lambda} ema_momentum={args.jepa_ema_momentum} pred_dim={args.jepa_pred_dim} num_spans={args.jepa_num_spans} span_len_mean={args.jepa_span_len_mean}") + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + # Compile target encoder separately so it benefits from the same kernel fusion + # as the main model. Must happen after base_model modules are cast to fp32. + compiled_target_encoder = torch.compile(jepa_target_encoder, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + # jepa_mask_emb and jepa_predictor are outside base_model.blocks so they + # are not covered by the block_named_params loop above. Add them explicitly. + scalar_params.append(base_model.jepa_mask_emb) + if base_model.jepa_predictor is not None: + for name, p in base_model.jepa_predictor.named_parameters(): + if p.ndim == 2: + matrix_params.append(p) # fc.weight, proj.weight → Muon + else: + scalar_params.append(p) # RMSNorm scale → Adam + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + bigram_params = sum(p.numel() for p in base_model.bigram_hash_emb.parameters()) if base_model.bigram_hash_emb is not None else 0 + log0(f"model_params:{n_params} bigram_params:{bigram_params} quant_max:{args.quant_max} mlp_leaky_slope:{args.mlp_leaky_slope} artifact_ema_decay:{args.artifact_ema_decay}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_ce = model(x, y) + (warmup_ce * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + # JEPA lambda warmup: step-based ramp over jepa_warmup_steps. + # Step-based avoids the bug where a slow first step (torch.compile) + # burns through the entire wall-clock warmup window immediately. + jepa_warmup_frac = min(step / max(args.jepa_warmup_steps, 1), 1.0) + eff_jepa = args.jepa_lambda * jepa_warmup_frac + + train_ce_loss = torch.zeros((), device=device) + train_jepa_loss = torch.zeros((), device=device) + train_jepa_mse_loss = torch.zeros((), device=device) + train_jepa_var_p_loss = torch.zeros((), device=device) + train_jepa_cov_p_loss = torch.zeros((), device=device) + train_jepa_var_t_loss = torch.zeros((), device=device) + train_jepa_cov_t_loss = torch.zeros((), device=device) + # Pre-fetch all micro-batches. + micro_batches = [ + train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + for _ in range(grad_accum_steps) + ] + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = micro_batches[micro_step] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ce_loss = model(x, y) + + # Span-masked JEPA loss (outside compiled model). + # Context encoder sees target spans replaced with jepa_mask_emb; + # target encoder sees the full unmasked sequence. Loss is MSE only + # at masked positions — prediction is genuinely hard because the + # context encoder cannot see the actual token identity. + jepa_loss_micro = ce_loss.new_zeros(()) + jepa_mse_micro = ce_loss.new_zeros(()) + var_loss_p = ce_loss.new_zeros(()) + cov_loss_p = ce_loss.new_zeros(()) + var_loss_t = ce_loss.new_zeros(()) + cov_loss_t = ce_loss.new_zeros(()) + if eff_jepa > 0.0: + # Sample target spans (CPU, one set per micro-batch) + spans = sample_block_spans( + x.shape[1], args.jepa_num_spans, args.jepa_span_len_mean, args.jepa_span_len_min + ) + jepa_mask = torch.zeros(x.shape[:2], dtype=torch.bool, device=x.device) + for s, e in spans.tolist(): + jepa_mask[:, s:e] = True + + # Target encoder: full unmasked sequence, no grad + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + z_target_full = compiled_target_encoder.encode(x).detach() # [B, T, D] + + # Context encoder: masked sequence, gradients flow through + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + z_context = base_model.encode(x, jepa_mask=jepa_mask) # [B, T, D] + z_pred = base_model.jepa_predictor(z_context) # [B, T, D] + + # MSE + VICReg var/cov only at masked positions + z_p = z_pred[jepa_mask] # [N_masked, D] + z_t = z_target_full[jepa_mask] # [N_masked, D] + z_p_f = z_p.float() + z_t_f = z_t.float() + + jepa_mse_micro = F.mse_loss(z_p_f, z_t_f) + + # p-side: gradients flow; t-side: diagnostic only + var_loss_p = vicreg_var_loss(z_p_f, args.jepa_var_gamma, args.jepa_var_eps) + cov_loss_p = vicreg_cov_loss(z_p_f) + with torch.no_grad(): + var_loss_t = vicreg_var_loss(z_t_f, args.jepa_var_gamma, args.jepa_var_eps) + cov_loss_t = vicreg_cov_loss(z_t_f) + + jepa_loss_micro = ( + jepa_mse_micro + + args.jepa_var_weight * var_loss_p + + args.jepa_cov_weight * cov_loss_p + ) + + loss = ce_loss + eff_jepa * jepa_loss_micro + + train_ce_loss += ce_loss.detach() + train_jepa_loss += jepa_loss_micro.detach() + train_jepa_mse_loss += jepa_mse_micro.detach() + train_jepa_var_p_loss += var_loss_p.detach() + train_jepa_cov_p_loss += cov_loss_p.detach() + train_jepa_var_t_loss += var_loss_t.detach() + train_jepa_cov_t_loss += cov_loss_t.detach() + (loss * grad_scale).backward() + train_ce_loss /= grad_accum_steps + train_jepa_loss /= grad_accum_steps + train_jepa_mse_loss /= grad_accum_steps + train_jepa_var_p_loss /= grad_accum_steps + train_jepa_cov_p_loss /= grad_accum_steps + train_jepa_var_t_loss /= grad_accum_steps + train_jepa_cov_t_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + # Artifact EMA: update after each optimizer step (always active). + # Kept on CPU to save GPU memory; model_p is transferred per-param. + with torch.no_grad(): + for ema_p, model_p in zip(artifact_ema.parameters(), base_model.parameters()): + ema_p.data.lerp_(model_p.data.cpu().to(dtype=ema_p.dtype), 1.0 - args.artifact_ema_decay) + + # JEPA: EMA-update target encoder after each optimizer step. + # Momentum rises from jepa_ema_momentum (start) to 0.999 (end) as training progresses. + # Use wallclock fraction when a cap is set, otherwise fall back to step fraction so + # that the schedule works correctly in uncapped / development runs. + if max_wallclock_ms is not None and max_wallclock_ms > 0: + frac_done = min(elapsed_ms / max_wallclock_ms, 1.0) + else: + frac_done = min(step / max(args.iterations, 1), 1.0) + ema_mom = args.jepa_ema_momentum + (0.999 - args.jepa_ema_momentum) * frac_done + with torch.no_grad(): + src_params = [p.data.to(dtype=tgt_p.dtype) for p, tgt_p in + zip(base_model.parameters(), jepa_target_encoder.parameters())] + tgt_params = [p.data for p in jepa_target_encoder.parameters()] + torch._foreach_lerp_(tgt_params, src_params, 1.0 - ema_mom) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + jepa_log = ( + f" jepa_loss:{train_jepa_loss.item():.4f}" + f" jepa_mse:{train_jepa_mse_loss.item():.4f}" + f" jepa_var_p:{train_jepa_var_p_loss.item():.4f}" + f" jepa_cov_p:{train_jepa_cov_p_loss.item():.4f}" + f" jepa_var_t:{train_jepa_var_t_loss.item():.4f}" + f" jepa_cov_t:{train_jepa_cov_t_loss.item():.4f}" + f" jepa_lam_scale:{jepa_warmup_frac:.3f}" + ) + log0( + f"step:{step}/{args.iterations} train_loss:{train_ce_loss.item():.4f}{jepa_log} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + # Serialize: use artifact_ema weights (smoothed checkpoint) instead of raw base_model. + # artifact_ema has been updated every step with decay=0.9999, giving a Polyak average + # of the weight trajectory. This typically improves val_bpb by 0.003-0.005 BPB. + save_state = artifact_ema.state_dict() + + if master_process: + torch.save(save_state, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model (artifact_ema): {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(save_state) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + # v2: LZMA instead of zlib. ~280KB smaller at same quality on typical weight distributions. + quant_blob = lzma.compress(quant_raw, preset=9) + quant_raw_bytes = len(quant_raw) + quant_label = f"int{6 if args.quant_max == 31 else 8}+lzma" + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int6.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {quant_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size {quant_label}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(lzma.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_{quant_label}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_{quant_label}_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()