diff --git a/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/README.md b/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/README.md new file mode 100644 index 0000000000..fd6aeb8bee --- /dev/null +++ b/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/README.md @@ -0,0 +1,215 @@ +# 11L GEPA + Mixed Int6/Int8 Quantization + 7k Steps + Legal Score-First TTT + +**val_bpb = 1.1334** | Pre-TTT: 1.1476 | TTT gain: **−0.0142** | Artifact: 15.70 MB + +> Non-record unlimited-compute submission (trained on 4×A100-40GB, eval 2194s on 4×A100). + +--- + +## Headline Result + +This submission pushes base model quality by training for **7000 steps** (35% more than the standard 5200) using a **GEPA architecture** (SwiGLU + U-Net + BigramHash + EMA + XSA + GPTQ-lite) with **mixed int6/int8 quantization** to stay under the 16MB cap. The longer training brings the pre-TTT baseline down to **1.1476 BPB**, and legal score-first TTT with SGD momentum (10 epochs) yields a further **−0.0142 BPB** improvement to reach **1.1334 BPB**. + +The mixed quantization strategy — int6 per-row with GPTQ-lite clip search for the large QAT-trained attention and MLP weights, int8 per-tensor scalar for the rest — compresses the 27M-parameter model into a 15.70 MB artifact with minimal accuracy loss. + +--- + +## Novel & Creative Contributions + +### 1. Extended Training (7000 Steps) for Deeper Loss Basin + +Training 35% longer than the standard 5200 steps, with warmdown starting at step 3500, gives the model more time to converge. The extra steps particularly help during the warmdown cosine anneal from step 3500 to 7000, allowing gentler learning rate decay. Base model improves from 1.1570 BPB (5200 steps, prior submissions) to **1.1476 BPB** (7000 steps). + +The key enabler is mixed quantization: the int6+int8 scheme compresses the model enough that even with a larger 27M-parameter model trained for 7000 steps, the artifact stays under 16MB. + +### 2. Mixed Int6/Int8 Quantization for Size-Aware Compression + +A dual-scheme quantization strategy: + +- **QAT-trained attention and MLP weights** (the bulk of parameters) use **int6 per-row quantization** with GPTQ-lite clip search (5 candidates per row), minimizing per-row MSE. +- **Smaller tensors** (layer norms, value embeddings, biases, embedding tables) use **int8 per-tensor scalar quantization**, preserving accuracy in the tensors most sensitive to low-precision quantization. + +This preserves accuracy where it matters most while keeping the large tensors compact. The mixed scheme compresses 27M parameters into 15.63 MB of model data (+ 76 KB code = 15.70 MB total). + +### 3. GEPA Architecture (SwiGLU + U-Net + BigramHash + EMA + XSA + GPTQ-lite) + +Combines several proven techniques into a coherent architecture: + +- **ReLU² activation** in 3× MLP (1536 hidden) +- **Cross-sequence attention (XSA)** on the last 4 layers — removes self-value bias via orthogonal projection +- **Exponential moving average** (decay 0.997) applied every step +- **Bigram hash embeddings** (2048 buckets, 128-dim) — cheap bigram context via hash of consecutive token pairs +- **Partial RoPE** (16 of 64 dims) with YARN scaling — concentrates position info in a compact subspace +- **Late QAT** with 5-candidate GPTQ-lite clip search, triggered when LR scale drops below 0.15 +- **Value embeddings** (128d) on layers 9–10 — direct token identity signal in the value stream for deep layers +- **U-Net skip connections** across layer pairs +- **LN depth scaling** (`1/√(layer+1)`) for stable deep training + +### 4. Legal Score-First TTT with SGD Momentum + +Competition-legal test-time training using sliding windows with score-first protocol — every token is scored under `torch.inference_mode()` before any weight update. SGD with momentum (0.9) at lr=0.002 for 10 epochs per 32K-token chunk, freezing the first 2 blocks. Yields **−0.0142 BPB** gain (1.1476 → 1.1334). + +--- + +## Architecture Summary + +| Component | Configuration | +|---|---| +| Layers | 11 | +| Embedding dim | 512 | +| Heads | 8 query, 4 KV (GQA) | +| MLP | 3× expansion (1536 hidden), ReLU² activation | +| Vocab | 1024 (SentencePiece BPE) | +| BigramHash | 2048 buckets, 128-dim embeddings | +| RoPE | Partial: 16/64 dims, YARN scaling (train_seq=1024) | +| Value Embeddings | 128d on layers 9–10, per-layer scale (init 0.1) | +| LN Scale | `1/√(layer+1)` depth scaling | +| XSA | Cross-sequence attention on last 4 layers | +| U-Net skips | Residual connections across layer pairs | +| SmearGate | Learned token-mixing gate on input embeddings | +| Tied Embeddings | Yes | +| Parameters | 27,030,108 total | + +## Training Details + +| Setting | Value | +|---|---| +| Hardware | 4×A100-40GB (NVIDIA) | +| Steps | 7,000 | +| Warmdown | Cosine anneal from step 3,500 to 7,000 | +| Warmup | 20 steps | +| Training wallclock | 3,490s (~58 min) | +| Batch tokens | 786,432 | +| Sequence length | 2,048 | +| Optimizer | Muon (hidden/attn) + Adam (embeddings/scalars) | +| Muon WD | 0.04 | +| Adam WD | 0.04 | +| Decoder LR mult | 2.0 | +| Grad clip | 0.3 | +| EMA | Decay 0.997, every step | +| Late QAT | Enabled at step 6,476 (scale < 0.15) | + +## TTT Protocol (Legal Score-First) + +``` +for each 32K-token chunk: + 1. model.eval() + torch.inference_mode() + → Forward pass on chunk, accumulate NLL ← SCORE (graded) + 2. model.train() + → SGD(lr=0.002, momentum=0.9), 10 epochs ← TRAIN (adaptation) + 3. Advance to next chunk with updated weights +``` + +Every target token is scored exactly once, strictly before any gradient update that could benefit from it. The `torch.inference_mode()` context manager makes gradient leakage during scoring physically impossible. + +| TTT Setting | Value | +|---|---| +| Optimizer | SGD, momentum=0.9 | +| Learning rate | 0.002 | +| Epochs per chunk | 10 | +| Chunk size | 32,768 tokens | +| Stride | 64 | +| Frozen blocks | First 2 (of 11) | +| Trainable params | 22,301,260 / 27,030,108 | +| Eval time | 2,194s (4×A100) | + +## Quantization & Size + +| Component | Bytes | +|---|---| +| Model (mixed int6/int8 + zstd-22) | 15,626,769 | +| Code (train_gpt.py) | 76,429 | +| **Total** | **15,703,198** | +| Limit | 16,000,000 | +| Headroom | 296,802 (1.9%) | + +Mixed quantization breakdown: +- **Int6 per-row** (GPTQ-lite, 5 clip candidates): attention projections, MLP weights — all QAT-trained tensors +- **Int8 per-tensor** (scalar scale): layer norms, value embeddings, biases, embedding tables +- **Payload**: 27,522,422 bytes → **zstd-22**: 15,626,769 bytes (3.89× compression) + +## Training Curve + +| Step | Val BPB | Notes | +|---|---|---| +| 0 | 4.1044 | | +| 500 | 1.4108 | | +| 1000 | 1.3334 | | +| 1500 | 1.3050 | | +| 2000 | 1.2711 | | +| 2500 | 1.2592 | | +| 3000 | 1.2509 | | +| 3500 | 1.2475 | Warmdown begins | +| 4000 | 1.2354 | | +| 4500 | 1.2231 | | +| 5000 | 1.2104 | | +| 5500 | 1.1970 | | +| 6000 | 1.1834 | | +| 6500 | 1.1642 | Late QAT enabled at 6476 | +| **7000** | **1.1476** | Pre-TTT baseline (EMA applied) | +| **TTT** | **1.1334** | −0.0142 from legal score-first TTT (10 epochs) | + +## Comparison to Prior Submissions + +| Submission | BPB | Status | +|---|---|---| +| **This work** | **1.1334** | Non-record | +| Prior: VE128+PartialRoPE+LegalTTT 30ep | 1.1425 | Non-record | +| Prior: VE128+PartialRoPE+LegalTTT 10ep | 1.1451 | Non-record | +| Record SOTA (signalrush) | 1.1228 | Record | + +Key improvements over the prior VE128+PartialRoPE submission (1.1425): +- **Better base model**: 1.1476 vs 1.1609 (−0.0133) from 7000 steps + GEPA architecture +- **Fewer TTT epochs needed**: 10 epochs vs 30 epochs, yet still reaches a better final BPB +- **Faster eval**: 2,194s vs 3,662s (40% faster, partly from fewer TTT epochs and 4-GPU eval) + +## Reproducibility + +```bash +# Environment: Python 3.10+, PyTorch 2.x with CUDA +# From the repo root: +RUN_ID=gep_v27k \ +SEED=42 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +MAX_WALLCLOCK_SECONDS=0 \ +ITERATIONS=7000 \ +WARMDOWN_ITERS=3500 \ +VAL_LOSS_EVERY=500 \ +NUM_LAYERS=11 \ +MODEL_DIM=512 \ +NUM_HEADS=8 \ +NUM_KV_HEADS=4 \ +MLP_HIDDEN=1536 \ +TIE_EMBEDDINGS=1 \ +BIGRAM_BUCKETS=2048 \ +BIGRAM_EMBED_DIM=128 \ +ROPE_DIMS=16 \ +ROPE_TRAIN_SEQ_LEN=1024 \ +LN_SCALE=1 \ +XSA_LAYERS=4 \ +EVAL_STRIDE=64 \ +EMA_ENABLED=1 EMA_DECAY=0.997 \ +LATE_QAT=1 QAT_THRESHOLD=0.15 \ +VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 \ +GRAD_CLIP_NORM=0.3 \ +MIXED_QUANT=1 QUANT_EMBED=1 \ +TTT_ENABLED=1 TTT_OPTIMIZER=sgd \ +TTT_LR=0.002 TTT_EPOCHS=10 \ +TTT_FREEZE_BLOCKS=2 TTT_BATCH_SEQS=32 \ +TTT_CHUNK_TOKENS=32768 TTT_MOMENTUM=0.9 \ +torchrun --standalone --nproc_per_node=4 \ + records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/train_gpt.py +``` + +## Credits + +This submission builds on work from many contributors to the parameter-golf competition: + +- **Muon optimizer** — Baseline (`modded-nanogpt`); Newton-Schulz orthogonal preconditioning +- **BigramHash embeddings** — PR #65 (aquariouseworkman): hash consecutive token pairs for cheap bigram context +- **SmearGate** — PR #65 (aquariouseworkman): per-dim sigmoid gate blending adjacent token embeddings +- **XSA (Exclusive Self Attention)** — PR #187 (Idan3011): removes self-value bias via orthogonal projection; GQA-aware variant in PR #265 (unnir) +- **Value Embeddings** — Per-layer learned embeddings added to the value stream +- **GPTQ-lite clip search** — Per-row optimal clip percentile search for int6 quantization diff --git a/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/submission.json b/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/submission.json new file mode 100644 index 0000000000..8f33db3596 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/submission.json @@ -0,0 +1,17 @@ +{ + "author": "Christopher Lee McClendon", + "github_id": "Christopher-Lee-McClendon", + "name": "11L GEPA + Mixed Int6/Int8 Quantization + 7k Steps + Legal TTT", + "blurb": "GEPA architecture with mixed quantization (int6 per-row for QAT-trained MLP/attn, int8 per-tensor for rest) enables 7000 training steps to push base model quality while staying under 16MB. Combined with legal score-first TTT (SGD, 10 epochs), achieves 1.1334 BPB.", + "date": "2026-03-24T00:00:00Z", + "track": "non_record_16mb", + "val_loss": 1.913751, + "val_bpb": 1.13343416, + "pre_ttt_val_loss": 1.9376, + "pre_ttt_val_bpb": 1.1476, + "bytes_total": 15703198, + "eval_time_seconds": 2194, + "gpu": "4xA100-40GB", + "wallclock_seconds": 3490, + "seed": 42 +} diff --git a/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/train_gpt.py b/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/train_gpt.py new file mode 100644 index 0000000000..f048f262e0 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/train_gpt.py @@ -0,0 +1,1746 @@ +"""train_gpt.py — GEPA (SwiGLU + U-Net + BigramHash + EMA + XSA4 + GPTQ-lite) + Legal TTT.""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# zstd-22 compression with zlib fallback +try: + import zstandard as zstd + USE_ZSTD = True +except ImportError: + import zlib + USE_ZSTD = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", "3500")) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", "11")) # Up from 9 + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", "8")) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) # Unused by Star-ReLU + mlp_hidden = int(os.environ.get("MLP_HIDDEN", "1792")) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # BigramHash config + bigram_buckets = int(os.environ.get("BIGRAM_BUCKETS", "8192")) + bigram_embed_dim = int(os.environ.get("BIGRAM_EMBED_DIM", 128)) + + # Partial RoPE: apply rotary to only first ROPE_DIMS of head_dim (0 = full) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + + # YARN RoPE: train_seq_len for YARN scaling (default 1024 enables YARN at seq_len=2048) + rope_train_seq_len = int(os.environ.get("ROPE_TRAIN_SEQ_LEN", "1024")) + + # LN Scale: scale norm input by 1/sqrt(layer_idx+1) per block + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + + # Optimizer hyperparameters (updated to match #1 team) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + decoder_lr_mult = float(os.environ.get("DECODER_LR_MULT", 2.0)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # V2 improvements (each controllable via env var) + ortho_init = bool(int(os.environ.get("ORTHO_INIT", "1"))) + bigram_xor_hash = bool(int(os.environ.get("BIGRAM_XOR_HASH", "1"))) + + # EMA: exponential moving average, updates every step (priority over SWA) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", "0.997")) + + # SWA config (fallback when EMA disabled) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Late QAT: enable fake int6 quantization when LR scale < qat_threshold + late_qat = bool(int(os.environ.get("LATE_QAT", "1"))) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", "0.15")) + + # Value Embeddings: reinject token identity into attention values at deep layers + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + + # TTT: legal score-first fine-tune on val data after training + xsa_layers = int(os.environ.get("XSA_LAYERS", "4")) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.003)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_optimizer = os.environ.get("TTT_OPTIMIZER", "adamw") # "sgd" or "adamw" + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) # for SGD + ttt_adam_wd = float(os.environ.get("TTT_ADAM_WD", 0.01)) # for AdamW + + +# ----------------------------- +# MUON OPTIMIZER WITH WEIGHT DECAY +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.02): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group["weight_decay"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + # Apply weight decay BEFORE update (standard decoupled WD) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride] + total_windows = len(window_starts) + + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else wlen - stride + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + + scored_prev = x_batch[i, s:wlen] + scored_tgt = y_batch[i, s:wlen] + tb = base_bytes_lut[scored_tgt].to(torch.int16) + tb += (has_leading_space_lut[scored_tgt] & ~is_boundary_token_lut[scored_prev]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + token_count += float(wlen - s) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = loss_sum / token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT: score each chunk with sliding windows, then train on it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"ttt_optimizer={args.ttt_optimizer} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks (GEPA has no depth recurrence, freeze by block index) + n_blocks = len(base_model.blocks) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, n_blocks))) + + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=args.ttt_adam_wd) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- Phase 1: SCORE this chunk's windows --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk's tokens (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + actual_be = my_seq_s + be + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + actual_be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + if args.ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# POST-TRAINING INT6 QUANTIZATION +# ----------------------------- + +INT6_MIN = -32 +INT6_MAX = 31 +INT6_CLIP_PERCENTILE = 99.99984 +INT6_CLIP_Q = INT6_CLIP_PERCENTILE / 100.0 + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear_gate,bigram,skip_gates,ve_shared,ve_layer_scales", + ).split(",") + if pattern +) +INT6_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT6_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT6_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT6_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT6_PER_ROW_SCALE_DTYPE = torch.float16 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT6_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT6_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +# Storage format: INT8_STORAGE=1 uses [-127,127] for better accuracy (larger artifact) +# Default (int6): uses [-32,31] for smaller artifacts but more quantization noise +# MIXED_QUANT=1: int6 per-row for MLP+attn (QAT-trained), int8 per-tensor for rest +INT8_STORAGE = int(os.environ.get("INT8_STORAGE", "0")) +MIXED_QUANT = int(os.environ.get("MIXED_QUANT", "0")) +QUANT_MIN = -127 if INT8_STORAGE else INT6_MIN +QUANT_MAX = 127 if INT8_STORAGE else INT6_MAX + +def _classify_param(name: str) -> str: + """Classify parameter by category for mixed quantization.""" + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_float_tensor_int8_scalar(t: Tensor) -> tuple[Tensor, Tensor]: + """Quantize to int8 [-127,127] range with per-tensor scalar scale.""" + t32 = t.float() + amax = t32.abs().max().item() + scale = torch.tensor(amax / 127.0 if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -127, 127).to(torch.int8) + return q.contiguous(), scale + +def quantize_float_tensor_int6(t: Tensor) -> tuple[Tensor, Tensor]: + """Quantize to int6 [-32,31] or int8 [-127,127] range, stored as int8. + Uses GPTQ-lite: per-row optimal clip percentile search (5 candidates) + to minimize reconstruction MSE. Free improvement over fixed percentile.""" + t32 = t.float() + qmin, qmax = QUANT_MIN, QUANT_MAX + if t32.ndim == 2 and t32.numel() > 0: + # GPTQ-lite: try 5 clip percentiles per-row, pick minimum MSE + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + scale = (row_clip / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(t32 / scale[:, None]), qmin, qmax).to(torch.int8) + recon = q.float() * scale[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q = q.contiguous() + best_s = scale.to(dtype=INT6_PER_ROW_SCALE_DTYPE).contiguous() + best_err = err + return best_q, best_s + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), qmin, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int6(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int6_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int6_payload_bytes"] += tensor_nbytes(t) + continue + + # Keep small float tensors (and tok_emb.weight unless QUANT_EMBED=1) in fp16 + quant_embed = int(os.environ.get("QUANT_EMBED", "0")) + if t.numel() <= INT6_KEEP_FLOAT_MAX_NUMEL or (name == "tok_emb.weight" and not quant_embed): + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int6_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Mixed quantization: int6 per-row for MLP+attn, int8 per-tensor for rest + cat = _classify_param(name) + if MIXED_QUANT and cat not in ("mlp", "attn"): + q, s = quantize_float_tensor_int8_scalar(t) + qmeta[name] = {"scheme": "per_tensor", "quant_type": "int8"} + else: + q, s = quantize_float_tensor_int6(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int6_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + quant_label = "mixed" if MIXED_QUANT else ("int8" if INT8_STORAGE else "int6") + obj: dict[str, object] = { + "__quant_format__": f"{quant_label}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int6(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Class-level flag: set True during late-QAT phase to enable fake int6 STE + _qat_enabled: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + # Fake int6 quantization via straight-through estimator + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + """RoPE with optional partial application and YARN scaling.""" + def __init__(self, dim: int, base: float = 10000.0, rope_dims: int = 0, train_seq_len: int = 1024): + super().__init__() + # rope_dims=0 means full head_dim; otherwise rotate only first rope_dims dims + rope_d = rope_dims if rope_dims > 0 else dim + self.rope_d = rope_d + self.base = base + self.train_seq_len = train_seq_len + inv_freq = 1.0 / (base ** (torch.arange(0, rope_d, 2, dtype=torch.float32) / rope_d)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_d + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + """Apply RoPE; if cos covers fewer dims than x, rotate only those dims.""" + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope = x[..., :rd] + x_pass = x[..., rd:] + half = rd // 2 + x1 = x_rope[..., :half] + x2 = x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_dims: int = 0, rope_train_seq_len: int = 1024): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, rope_dims=rope_dims, train_seq_len=rope_train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + # Add value embeddings to v before attention if provided + if v_embed is not None: + ve_reshaped = v_embed.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v + ve_reshaped + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads)) + if self.use_xsa: + y_xsa = y.transpose(1, 2) + v_xsa = v.transpose(1, 2) + y_xsa = self._xsa_efficient(y_xsa, v_xsa) + y = y_xsa.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + # Star-ReLU implementation. mlp_mult is unused. + hidden = mlp_hidden if mlp_hidden > 0 else int(dim * 3) + self.up_proj = CastedLinear(dim, hidden, bias=False) + self.down_proj = CastedLinear(hidden, dim, bias=False) + self.down_proj._zero_init = True + self.scale = nn.Parameter(torch.ones(hidden, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(hidden, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + x_up = self.up_proj(x) + activated = F.relu(x_up).pow(2) + activated = activated * self.scale.to(dtype=activated.dtype) + self.bias.to(dtype=activated.dtype) + return self.down_proj(activated) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + rope_train_seq_len: int = 1024, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims, rope_train_seq_len=rope_train_seq_len) + self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + # LN Scale: dampen norm inputs by 1/sqrt(layer_idx+1) for deeper layers + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x) * s, v_embed=v_embed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + + +# ----------------------------- +# BIGRAM HASH EMBEDDING +# ----------------------------- + +class BigramHashEmbedding(nn.Module): + """Hash-based bigram embedding with optional XOR hash and learned scale.""" + def __init__(self, num_buckets: int, embed_dim: int, model_dim: int, use_xor_hash: bool = True): + super().__init__() + self.num_buckets = num_buckets + self.use_xor_hash = use_xor_hash + self.embed = nn.Embedding(num_buckets, embed_dim) + self.proj = CastedLinear(embed_dim, model_dim, bias=False) + if use_xor_hash: + nn.init.zeros_(self.embed.weight) # Zero init with learned scale + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + else: + nn.init.normal_(self.embed.weight, std=0.01) + nn.init.zeros_(self.proj.weight) + self.scale = None + + def bigram_hash(self, tokens: Tensor) -> Tensor: + """XOR-based bigram hash with large primes for uniform distribution.""" + t = tokens.to(torch.int32) + mod = self.num_buckets - 1 + out = torch.empty_like(t) + out[..., 0] = mod # Special bucket for first position + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, input_ids: Tensor) -> Tensor: + if self.use_xor_hash: + h = self.embed(self.bigram_hash(input_ids)) + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + else: + bsz, seq_len = input_ids.shape + prev_ids = F.pad(input_ids[:, :-1], (1, 0), value=0) + bigram_hash = (prev_ids * 1009 + input_ids) % self.num_buckets + bigram_emb = self.embed(bigram_hash) + return self.proj(bigram_emb) + + +# ----------------------------- +# SMEAR GATE +# ----------------------------- + +class SmearGate(nn.Module): + """Learned blending of current position with previous position.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + # x: (bsz, seq_len, dim) + gate = torch.sigmoid(self.gate.to(dtype=x.dtype)) + # Shift x to get previous position, pad with zeros + x_prev = F.pad(x[:, :-1], (0, 0, 1, 0)) + return (1 - gate) * x + gate * x_prev + + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers.""" + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, kv_dim, bias=False) if ve_dim != kv_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + bigram_buckets: int = 4096, + bigram_embed_dim: int = 128, + bigram_xor_hash: bool = True, + rope_dims: int = 0, + ln_scale: bool = False, + xsa_last_n: int = 0, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + rope_train_seq_len: int = 1024, + ortho_init: bool = True, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.ortho_init = ortho_init + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_emb = BigramHashEmbedding(bigram_buckets, bigram_embed_dim, model_dim, use_xor_hash=bigram_xor_hash) if bigram_buckets > 0 else None + self.smear_gate = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + mlp_hidden=mlp_hidden, rope_dims=rope_dims, layer_idx=i, ln_scale=ln_scale, + rope_train_seq_len=rope_train_seq_len, + ) + for i in range(num_layers) + ]) + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + # Value Embeddings: reinject token identity into values at deep layers + kv_dim = model_dim // num_heads * num_kv_heads + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif self.ortho_init and module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_emb is not None: + x = x + self.bigram_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear_gate(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0, v_embed=self._get_ve(i, input_ids, ve_cache)) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + skip = skips.pop() + gate = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype)) + scaled_skip = self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skip + x = gate[None, None, :] * x + (1.0 - gate[None, None, :]) * scaled_skip + x = self.blocks[bi](x, x0, v_embed=self._get_ve(bi, input_ids, ve_cache)) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram_emb is not None: + x = x + self.bigram_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear_gate(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0, v_embed=self._get_ve(i, input_ids, ve_cache)) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + skip = skips.pop() + gate = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype)) + scaled_skip = self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skip + x = gate[None, None, :] * x + (1.0 - gate[None, None, :]) * scaled_skip + x = self.blocks[bi](x, x0, v_embed=self._get_ve(bi, input_ids, ve_cache)) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}") + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = False # start with QAT off; late_qat enables it mid-run + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mlp_hidden=args.mlp_hidden, + bigram_buckets=args.bigram_buckets, + bigram_embed_dim=args.bigram_embed_dim, + bigram_xor_hash=args.bigram_xor_hash, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + xsa_last_n=args.xsa_layers, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + rope_train_seq_len=args.rope_train_seq_len, + ortho_init=args.ortho_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Differential LR setup + matrix_params_enc, scalar_params_enc = [], [] + matrix_params_dec, scalar_params_dec = [], [] + num_encoder_layers = base_model.num_encoder_layers + for i, block in enumerate(base_model.blocks): + is_decoder = i >= num_encoder_layers + for name, p in block.named_parameters(): + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + (matrix_params_dec if is_decoder else matrix_params_enc).append(p) + else: + (scalar_params_dec if is_decoder else scalar_params_enc).append(p) + + # Non-block scalar parameters + other_scalar_params = [base_model.smear_gate.gate] + if base_model.bigram_emb is not None: + other_scalar_params.append(base_model.bigram_emb.embed.weight) + if base_model.skip_weights.numel() > 0: + other_scalar_params.append(base_model.skip_weights) + if hasattr(base_model, 'skip_gates') and base_model.skip_gates.numel() > 0: + other_scalar_params.append(base_model.skip_gates) + # Value Embedding parameters + if base_model.ve_shared is not None: + other_scalar_params.extend(list(base_model.ve_shared.parameters())) + other_scalar_params.extend(list(base_model.ve_layer_scales.parameters())) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.AdamW( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + + matrix_lr_dec = args.matrix_lr * args.decoder_lr_mult + optimizer_muon = Muon( + [ + {'params': matrix_params_enc, 'lr': args.matrix_lr, 'base_lr': args.matrix_lr}, + {'params': matrix_params_dec, 'lr': matrix_lr_dec, 'base_lr': matrix_lr_dec}, + ], + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + + scalar_lr_dec = args.scalar_lr * args.decoder_lr_mult + optimizer_scalar = torch.optim.AdamW( + [ + {'params': scalar_params_enc, 'lr': args.scalar_lr, 'base_lr': args.scalar_lr}, + {'params': scalar_params_dec, 'lr': scalar_lr_dec, 'base_lr': scalar_lr_dec}, + {'params': other_scalar_params, 'lr': args.scalar_lr, 'base_lr': args.scalar_lr}, + ], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + + optimizer_bigram_proj = Muon( + [base_model.bigram_emb.proj.weight], + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_bigram_proj.param_groups: + group["base_lr"] = args.matrix_lr + + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar, optimizer_bigram_proj] + if base_model.lm_head is not None: + optimizer_head = torch.optim.AdamW( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} decoder_lr_mult:{args.decoder_lr_mult}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}") + log0(f"rope_dims:{args.rope_dims} rope_train_seq_len:{args.rope_train_seq_len} ln_scale:{args.ln_scale}") + log0(f"muon_wd:{args.muon_wd} adam_wd:{args.adam_wd} ema_enabled:{args.ema_enabled} late_qat:{args.late_qat}") + log0(f"bigram_buckets:{args.bigram_buckets} bigram_embed_dim:{args.bigram_embed_dim} seed:{args.seed}") + if args.ttt_enabled: + log0(f"ttt:enabled optimizer:{args.ttt_optimizer} lr:{args.ttt_lr} epochs:{args.ttt_epochs} " + f"freeze_blocks:{args.ttt_freeze_blocks} chunk_tokens:{args.ttt_chunk_tokens}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # EMA / SWA STATE + # ----------------------------- + + # EMA takes priority; SWA is fallback (mutually exclusive) + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + log0(f"ema:init decay={args.ema_decay}") + + swa_state: dict[str, Tensor] = {} + swa_count = 0 + + def update_swa(): + nonlocal swa_count + with torch.no_grad(): + for name, param in base_model.state_dict().items(): + if name not in swa_state: + swa_state[name] = param.detach().cpu().clone().float() + else: + swa_state[name].add_(param.detach().cpu().float()) + swa_count += 1 + + def get_swa_state() -> dict[str, Tensor]: + return {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) for name, t in swa_state.items()} + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + # Estimate total steps for SWA start + estimated_total_steps = args.iterations + if max_wallclock_ms is not None: + estimated_total_steps = min(args.iterations, int(max_wallclock_ms / 30)) # rough estimate + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Late QAT: enable fake int6 quantization once LR scale drops below threshold + # NOTE: torch.compile constant-folds _qat_enabled at first trace, so we must + # reset dynamo caches to force recompilation with QAT branch active. + if args.late_qat and not CastedLinear._qat_enabled and scale < args.qat_threshold: + CastedLinear._qat_enabled = True + torch._dynamo.reset() # force recompile with QAT enabled + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} (dynamo reset for recompile)") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for group in optimizer_bigram_proj.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + + # EMA update every step (takes priority over SWA) + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + # SWA update (only when EMA disabled) + swa_start_step = int(estimated_total_steps * args.swa_start_frac) + if ema_state is None and step >= swa_start_step and step % args.swa_every == 0: + update_swa() + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + if should_log_train: + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms") + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + # Final SWA update (only if EMA disabled and no SWA yet) + if ema_state is None and swa_count == 0: + update_swa() + + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") + + # Apply EMA or SWA weights (EMA takes priority) + if ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + del avg_state + elif swa_count > 0: + log0(f"swa:applying averaged {swa_count} checkpoints") + base_model.load_state_dict(get_swa_state(), strict=True) + else: + log0("weight_avg:skipped (no EMA or SWA state)") + + # ----------------------------- + # TTT: fine-tune on val data AFTER EMA/SWA, BEFORE quantization + # ----------------------------- + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int6(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + + # Use zstd-22 for compression (or zlib fallback) + if USE_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compression_method = "zstd-22" + else: + import zlib + quant_blob = zlib.compress(quant_raw, level=9) + compression_method = "zlib-9" + + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int6.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int6_payload_bytes"], 1) + q_label = "mixed" if MIXED_QUANT else ("int8" if INT8_STORAGE else "int6") + log0(f"Serialized model {q_label}+{compression_method}: {quant_file_bytes} bytes (payload:{quant_stats['int6_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)") + log0(f"Total submission size {q_label}+{compression_method}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + + # Decompress + if USE_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_raw_disk = dctx.decompress(quant_blob_disk) + else: + import zlib + quant_raw_disk = zlib.decompress(quant_blob_disk) + + quant_state = torch.load(io.BytesIO(quant_raw_disk), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int6(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.ttt_enabled and args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window_ttt stride:{args.eval_stride} " + f"chunk_tokens:{args.ttt_chunk_tokens} optimizer:{args.ttt_optimizer}") + q_val_loss, q_val_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, log0=log0) + elif args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"final_int6_{compression_method}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval:{1000.0*(time.perf_counter()-t_qeval):.0f}ms") + log0(f"final_int6_{compression_method}_roundtrip_exact val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/train_seed42.log b/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/train_seed42.log new file mode 100644 index 0000000000..13e8e79b8c --- /dev/null +++ b/records/track_non_record_16mb/2026-03-24_11L_GEPA_MixedQuant_7kSteps_LegalTTT/train_seed42.log @@ -0,0 +1,1981 @@ +logs/gep_v27k_55229918.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27030108 +world_size:4 grad_accum_steps:2 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 decoder_lr_mult:2.0 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:0.000 +rope_dims:16 rope_train_seq_len:1024 ln_scale:True +muon_wd:0.04 adam_wd:0.04 ema_enabled:True late_qat:True +bigram_buckets:2048 bigram_embed_dim:128 seed:42 +ttt:enabled optimizer:sgd lr:0.002 epochs:10 freeze_blocks:2 chunk_tokens:32768 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +ema:init decay=0.997 +step:0/7000 val_loss:6.9301 val_bpb:4.1044 train_time:0ms step_avg:0.01ms +step:1/7000 train_loss:6.9313 train_time:490ms step_avg:490.27ms +step:2/7000 train_loss:8.4238 train_time:979ms step_avg:489.55ms +step:3/7000 train_loss:7.5556 train_time:1478ms step_avg:492.83ms +step:4/7000 train_loss:8.6123 train_time:1980ms step_avg:494.92ms +step:5/7000 train_loss:9.0143 train_time:2477ms step_avg:495.36ms +step:6/7000 train_loss:8.8303 train_time:2971ms step_avg:495.10ms +step:7/7000 train_loss:8.2794 train_time:3471ms step_avg:495.90ms +step:8/7000 train_loss:7.4327 train_time:3973ms step_avg:496.67ms +step:9/7000 train_loss:6.8659 train_time:4477ms step_avg:497.42ms +step:10/7000 train_loss:6.4393 train_time:4972ms step_avg:497.22ms +step:200/7000 train_loss:2.5808 train_time:100029ms step_avg:500.15ms +step:400/7000 train_loss:2.4553 train_time:200389ms step_avg:500.97ms +step:500/7000 val_loss:2.3821 val_bpb:1.4108 train_time:250213ms step_avg:500.43ms +step:600/7000 train_loss:2.3758 train_time:300013ms step_avg:500.02ms +step:800/7000 train_loss:2.2630 train_time:399630ms step_avg:499.54ms +step:1000/7000 train_loss:2.3020 train_time:499269ms step_avg:499.27ms +step:1000/7000 val_loss:2.2514 val_bpb:1.3334 train_time:499278ms step_avg:499.28ms +step:1200/7000 train_loss:2.2961 train_time:599114ms step_avg:499.26ms +step:1400/7000 train_loss:2.2551 train_time:698793ms step_avg:499.14ms +step:1500/7000 val_loss:2.2034 val_bpb:1.3050 train_time:748645ms step_avg:499.10ms +step:1600/7000 train_loss:2.1487 train_time:798455ms step_avg:499.03ms +step:1800/7000 train_loss:2.1519 train_time:898339ms step_avg:499.08ms +step:2000/7000 train_loss:2.0396 train_time:998108ms step_avg:499.05ms +step:2000/7000 val_loss:2.1463 val_bpb:1.2711 train_time:998116ms step_avg:499.06ms +step:2200/7000 train_loss:2.1582 train_time:1097880ms step_avg:499.04ms +step:2400/7000 train_loss:2.0984 train_time:1197471ms step_avg:498.95ms +step:2500/7000 val_loss:2.1260 val_bpb:1.2592 train_time:1247279ms step_avg:498.91ms +step:2600/7000 train_loss:2.1654 train_time:1297122ms step_avg:498.89ms +step:2800/7000 train_loss:2.1999 train_time:1397127ms step_avg:498.97ms +step:3000/7000 train_loss:2.1277 train_time:1496756ms step_avg:498.92ms +step:3000/7000 val_loss:2.1122 val_bpb:1.2509 train_time:1496764ms step_avg:498.92ms +step:3200/7000 train_loss:2.1651 train_time:1596443ms step_avg:498.89ms +step:3400/7000 train_loss:2.1145 train_time:1696169ms step_avg:498.87ms +step:3500/7000 val_loss:2.1064 val_bpb:1.2475 train_time:1746216ms step_avg:498.92ms +step:3600/7000 train_loss:2.1113 train_time:1796178ms step_avg:498.94ms +step:3800/7000 train_loss:2.0975 train_time:1895885ms step_avg:498.92ms +step:4000/7000 train_loss:2.1540 train_time:1995522ms step_avg:498.88ms +step:4000/7000 val_loss:2.0859 val_bpb:1.2354 train_time:1995531ms step_avg:498.88ms +step:4200/7000 train_loss:2.0917 train_time:2095260ms step_avg:498.87ms +step:4400/7000 train_loss:2.0158 train_time:2194994ms step_avg:498.86ms +step:4500/7000 val_loss:2.0651 val_bpb:1.2231 train_time:2244995ms step_avg:498.89ms +step:4600/7000 train_loss:1.9621 train_time:2294720ms step_avg:498.85ms +step:4800/7000 train_loss:2.2581 train_time:2394137ms step_avg:498.78ms +step:5000/7000 train_loss:2.0617 train_time:2493859ms step_avg:498.77ms +step:5000/7000 val_loss:2.0437 val_bpb:1.2104 train_time:2493868ms step_avg:498.77ms +step:5200/7000 train_loss:2.0638 train_time:2593605ms step_avg:498.77ms +step:5400/7000 train_loss:1.9975 train_time:2693386ms step_avg:498.78ms +step:5500/7000 val_loss:2.0211 val_bpb:1.1970 train_time:2743156ms step_avg:498.76ms +step:5600/7000 train_loss:2.0094 train_time:2792858ms step_avg:498.72ms +step:5800/7000 train_loss:1.9736 train_time:2892326ms step_avg:498.68ms +step:6000/7000 train_loss:1.9671 train_time:2992131ms step_avg:498.69ms +step:6000/7000 val_loss:1.9981 val_bpb:1.1834 train_time:2992140ms step_avg:498.69ms +step:6200/7000 train_loss:2.0387 train_time:3091840ms step_avg:498.68ms +step:6400/7000 train_loss:1.9905 train_time:3191372ms step_avg:498.65ms +late_qat:enabled step:6476 scale:0.1497 +step:6500/7000 val_loss:1.9657 val_bpb:1.1642 train_time:3241122ms step_avg:498.63ms +step:6600/7000 train_loss:1.8821 train_time:3290897ms step_avg:498.62ms +step:6800/7000 train_loss:2.0409 train_time:3390711ms step_avg:498.63ms +step:7000/7000 train_loss:1.8150 train_time:3490351ms step_avg:498.62ms +step:7000/7000 val_loss:1.9376 val_bpb:1.1476 train_time:3490359ms step_avg:498.62ms +peak memory allocated: 21327 MiB reserved: 21512 MiB +ema:applying EMA weights +Serialized model: 107123447 bytes +Code size: 76429 bytes +Total submission size: 107199876 bytes +Serialized model int6+zstd-22: 15626769 bytes (payload:27522422 raw_torch:27588053 payload_ratio:3.89x) +Total submission size int6+zstd-22: 15703198 bytes +final_eval_mode:sliding_window_ttt stride:64 chunk_tokens:32768 optimizer:sgd +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=10 ttt_optimizer=sgd freeze_blocks=2 +ttt_sliding:params unfrozen=22301260 frozen=4728848 + ttt_chunk [1/1893] bpb=1.174449 time=1.3s + ttt_chunk [11/1893] bpb=1.131772 time=12.8s + ttt_chunk [21/1893] bpb=1.136036 time=24.3s + ttt_chunk [31/1893] bpb=1.138824 time=35.8s + ttt_chunk [41/1893] bpb=1.128172 time=47.4s + ttt_chunk [51/1893] bpb=1.125337 time=58.9s + ttt_chunk [61/1893] bpb=1.131113 time=70.5s + ttt_chunk [71/1893] bpb=1.128140 time=82.1s + ttt_chunk [81/1893] bpb=1.128327 time=93.6s + ttt_chunk [91/1893] bpb=1.127725 time=105.2s + ttt_chunk [101/1893] bpb=1.130981 time=116.8s + ttt_chunk [111/1893] bpb=1.132478 time=128.5s + ttt_chunk [121/1893] bpb=1.129591 time=140.1s + ttt_chunk [131/1893] bpb=1.129946 time=151.7s + ttt_chunk [141/1893] bpb=1.129857 time=163.3s + ttt_chunk [151/1893] bpb=1.133145 time=174.8s + ttt_chunk [161/1893] bpb=1.135286 time=186.4s + ttt_chunk [171/1893] bpb=1.136401 time=198.0s + ttt_chunk [181/1893] bpb=1.136577 time=209.6s + ttt_chunk [191/1893] bpb=1.140083 time=221.2s + ttt_chunk [201/1893] bpb=1.140433 time=232.8s + ttt_chunk [211/1893] bpb=1.138257 time=244.4s + ttt_chunk [221/1893] bpb=1.140323 time=255.9s + ttt_chunk [231/1893] bpb=1.139984 time=267.5s + ttt_chunk [241/1893] bpb=1.139897 time=279.1s + ttt_chunk [251/1893] bpb=1.138439 time=290.7s + ttt_chunk [261/1893] bpb=1.136828 time=302.3s + ttt_chunk [271/1893] bpb=1.135533 time=313.8s + ttt_chunk [281/1893] bpb=1.137994 time=325.4s + ttt_chunk [291/1893] bpb=1.138875 time=337.0s + ttt_chunk [301/1893] bpb=1.139468 time=348.6s + ttt_chunk [311/1893] bpb=1.141148 time=360.1s + ttt_chunk [321/1893] bpb=1.142669 time=371.7s + ttt_chunk [331/1893] bpb=1.142534 time=383.3s + ttt_chunk [341/1893] bpb=1.142923 time=394.8s + ttt_chunk [351/1893] bpb=1.144321 time=406.4s + ttt_chunk [361/1893] bpb=1.145631 time=418.0s + ttt_chunk [371/1893] bpb=1.145187 time=429.5s + ttt_chunk [381/1893] bpb=1.145098 time=441.1s + ttt_chunk [391/1893] bpb=1.144739 time=452.7s + ttt_chunk [401/1893] bpb=1.143393 time=464.2s + ttt_chunk [411/1893] bpb=1.142374 time=475.8s + ttt_chunk [421/1893] bpb=1.141834 time=487.4s + ttt_chunk [431/1893] bpb=1.142580 time=498.9s + ttt_chunk [441/1893] bpb=1.142468 time=510.5s + ttt_chunk [451/1893] bpb=1.142226 time=522.1s + ttt_chunk [461/1893] bpb=1.141511 time=533.6s + ttt_chunk [471/1893] bpb=1.141128 time=545.2s + ttt_chunk [481/1893] bpb=1.140896 time=556.7s + ttt_chunk [491/1893] bpb=1.140638 time=568.3s + ttt_chunk [501/1893] bpb=1.140157 time=579.8s + ttt_chunk [511/1893] bpb=1.139646 time=591.4s + ttt_chunk [521/1893] bpb=1.138782 time=603.0s + ttt_chunk [531/1893] bpb=1.138896 time=614.5s + ttt_chunk [541/1893] bpb=1.138847 time=626.1s + ttt_chunk [551/1893] bpb=1.137660 time=637.7s + ttt_chunk [561/1893] bpb=1.138154 time=649.2s + ttt_chunk [571/1893] bpb=1.137454 time=660.8s + ttt_chunk [581/1893] bpb=1.136930 time=672.3s + ttt_chunk [591/1893] bpb=1.136281 time=683.9s + ttt_chunk [601/1893] bpb=1.137008 time=695.5s + ttt_chunk [611/1893] bpb=1.136609 time=707.0s + ttt_chunk [621/1893] bpb=1.136580 time=718.6s + ttt_chunk [631/1893] bpb=1.136996 time=730.2s + ttt_chunk [641/1893] bpb=1.136793 time=741.7s + ttt_chunk [651/1893] bpb=1.136826 time=753.3s + ttt_chunk [661/1893] bpb=1.136755 time=764.9s + ttt_chunk [671/1893] bpb=1.136399 time=776.4s + ttt_chunk [681/1893] bpb=1.136668 time=788.1s + ttt_chunk [691/1893] bpb=1.137390 time=799.9s + ttt_chunk [701/1893] bpb=1.136588 time=811.4s + ttt_chunk [711/1893] bpb=1.137210 time=823.0s + ttt_chunk [721/1893] bpb=1.136866 time=834.5s + ttt_chunk [731/1893] bpb=1.137361 time=846.1s + ttt_chunk [741/1893] bpb=1.137336 time=857.6s + ttt_chunk [751/1893] bpb=1.136976 time=869.2s + ttt_chunk [761/1893] bpb=1.136875 time=880.8s + ttt_chunk [771/1893] bpb=1.136668 time=892.3s + ttt_chunk [781/1893] bpb=1.137227 time=903.9s + ttt_chunk [791/1893] bpb=1.136940 time=915.4s + ttt_chunk [801/1893] bpb=1.137045 time=927.0s + ttt_chunk [811/1893] bpb=1.136616 time=938.5s + ttt_chunk [821/1893] bpb=1.136474 time=950.2s + ttt_chunk [831/1893] bpb=1.136085 time=961.9s + ttt_chunk [841/1893] bpb=1.135570 time=973.7s + ttt_chunk [851/1893] bpb=1.135510 time=985.4s + ttt_chunk [861/1893] bpb=1.135696 time=997.0s + ttt_chunk [871/1893] bpb=1.135789 time=1008.7s + ttt_chunk [881/1893] bpb=1.135866 time=1020.3s + ttt_chunk [891/1893] bpb=1.135748 time=1031.9s + ttt_chunk [901/1893] bpb=1.135774 time=1043.5s + ttt_chunk [911/1893] bpb=1.135824 time=1055.2s + ttt_chunk [921/1893] bpb=1.136241 time=1066.8s + ttt_chunk [931/1893] bpb=1.136100 time=1078.5s + ttt_chunk [941/1893] bpb=1.136016 time=1090.1s + ttt_chunk [951/1893] bpb=1.136056 time=1101.7s + ttt_chunk [961/1893] bpb=1.135829 time=1113.3s + ttt_chunk [971/1893] bpb=1.136600 time=1124.9s + ttt_chunk [981/1893] bpb=1.136761 time=1136.5s + ttt_chunk [991/1893] bpb=1.136700 time=1148.1s + ttt_chunk [1001/1893] bpb=1.136864 time=1159.8s + ttt_chunk [1011/1893] bpb=1.137154 time=1171.4s + ttt_chunk [1021/1893] bpb=1.137333 time=1183.0s + ttt_chunk [1031/1893] bpb=1.137876 time=1194.6s + ttt_chunk [1041/1893] bpb=1.137535 time=1206.2s + ttt_chunk [1051/1893] bpb=1.137250 time=1217.8s + ttt_chunk [1061/1893] bpb=1.137518 time=1229.4s + ttt_chunk [1071/1893] bpb=1.138023 time=1241.1s + ttt_chunk [1081/1893] bpb=1.138054 time=1252.7s + ttt_chunk [1091/1893] bpb=1.138459 time=1264.3s + ttt_chunk [1101/1893] bpb=1.138604 time=1275.9s + ttt_chunk [1111/1893] bpb=1.138362 time=1287.6s + ttt_chunk [1121/1893] bpb=1.138294 time=1299.2s + ttt_chunk [1131/1893] bpb=1.138173 time=1310.8s + ttt_chunk [1141/1893] bpb=1.138040 time=1322.4s + ttt_chunk [1151/1893] bpb=1.138090 time=1334.0s + ttt_chunk [1161/1893] bpb=1.137486 time=1345.7s + ttt_chunk [1171/1893] bpb=1.138066 time=1357.3s + ttt_chunk [1181/1893] bpb=1.137575 time=1368.9s + ttt_chunk [1191/1893] bpb=1.137310 time=1380.5s + ttt_chunk [1201/1893] bpb=1.137891 time=1392.2s + ttt_chunk [1211/1893] bpb=1.137319 time=1403.8s + ttt_chunk [1221/1893] bpb=1.137033 time=1415.4s + ttt_chunk [1231/1893] bpb=1.136922 time=1427.0s + ttt_chunk [1241/1893] bpb=1.136730 time=1438.7s + ttt_chunk [1251/1893] bpb=1.136489 time=1450.3s + ttt_chunk [1261/1893] bpb=1.136446 time=1461.9s + ttt_chunk [1271/1893] bpb=1.136259 time=1473.5s + ttt_chunk [1281/1893] bpb=1.136060 time=1485.1s + ttt_chunk [1291/1893] bpb=1.135912 time=1496.8s + ttt_chunk [1301/1893] bpb=1.135555 time=1508.4s + ttt_chunk [1311/1893] bpb=1.135233 time=1520.0s + ttt_chunk [1321/1893] bpb=1.135061 time=1531.7s + ttt_chunk [1331/1893] bpb=1.134972 time=1543.3s + ttt_chunk [1341/1893] bpb=1.134865 time=1554.9s + ttt_chunk [1351/1893] bpb=1.134818 time=1566.6s + ttt_chunk [1361/1893] bpb=1.135040 time=1578.2s + ttt_chunk [1371/1893] bpb=1.134911 time=1589.8s + ttt_chunk [1381/1893] bpb=1.134864 time=1601.4s + ttt_chunk [1391/1893] bpb=1.134337 time=1613.0s + ttt_chunk [1401/1893] bpb=1.134377 time=1624.6s + ttt_chunk [1411/1893] bpb=1.134402 time=1636.3s + ttt_chunk [1421/1893] bpb=1.134667 time=1647.9s + ttt_chunk [1431/1893] bpb=1.134547 time=1659.5s + ttt_chunk [1441/1893] bpb=1.135215 time=1671.2s + ttt_chunk [1451/1893] bpb=1.135365 time=1682.8s + ttt_chunk [1461/1893] bpb=1.135118 time=1694.4s + ttt_chunk [1471/1893] bpb=1.136046 time=1706.0s + ttt_chunk [1481/1893] bpb=1.135855 time=1717.6s + ttt_chunk [1491/1893] bpb=1.135874 time=1729.3s + ttt_chunk [1501/1893] bpb=1.136058 time=1740.9s + ttt_chunk [1511/1893] bpb=1.136165 time=1752.5s + ttt_chunk [1521/1893] bpb=1.136196 time=1764.1s + ttt_chunk [1531/1893] bpb=1.136048 time=1775.8s + ttt_chunk [1541/1893] bpb=1.136026 time=1787.4s + ttt_chunk [1551/1893] bpb=1.136391 time=1799.0s + ttt_chunk [1561/1893] bpb=1.136546 time=1810.6s + ttt_chunk [1571/1893] bpb=1.136693 time=1822.3s + ttt_chunk [1581/1893] bpb=1.136814 time=1833.9s + ttt_chunk [1591/1893] bpb=1.136764 time=1845.5s + ttt_chunk [1601/1893] bpb=1.136940 time=1857.1s + ttt_chunk [1611/1893] bpb=1.137021 time=1868.8s + ttt_chunk [1621/1893] bpb=1.136896 time=1880.4s + ttt_chunk [1631/1893] bpb=1.137069 time=1892.0s + ttt_chunk [1641/1893] bpb=1.136961 time=1903.7s + ttt_chunk [1651/1893] bpb=1.136906 time=1915.4s + ttt_chunk [1661/1893] bpb=1.136796 time=1927.1s + ttt_chunk [1671/1893] bpb=1.137180 time=1938.6s + ttt_chunk [1681/1893] bpb=1.137454 time=1950.1s + ttt_chunk [1691/1893] bpb=1.137422 time=1961.7s + ttt_chunk [1701/1893] bpb=1.137408 time=1973.2s + ttt_chunk [1711/1893] bpb=1.137268 time=1984.7s + ttt_chunk [1721/1893] bpb=1.137122 time=1996.2s + ttt_chunk [1731/1893] bpb=1.137091 time=2007.8s + ttt_chunk [1741/1893] bpb=1.136882 time=2019.3s + ttt_chunk [1751/1893] bpb=1.136725 time=2030.8s + ttt_chunk [1761/1893] bpb=1.136813 time=2042.4s + ttt_chunk [1771/1893] bpb=1.136744 time=2053.9s + ttt_chunk [1781/1893] bpb=1.136720 time=2065.4s + ttt_chunk [1791/1893] bpb=1.136318 time=2077.0s + ttt_chunk [1801/1893] bpb=1.136328 time=2088.5s + ttt_chunk [1811/1893] bpb=1.136175 time=2100.0s + ttt_chunk [1821/1893] bpb=1.136196 time=2111.6s + ttt_chunk [1831/1893] bpb=1.135816 time=2123.1s + ttt_chunk [1841/1893] bpb=1.135838 time=2134.6s + ttt_chunk [1851/1893] bpb=1.135612 time=2146.1s + ttt_chunk [1861/1893] bpb=1.135142 time=2157.7s + ttt_chunk [1871/1893] bpb=1.134990 time=2169.2s + ttt_chunk [1881/1893] bpb=1.134613 time=2180.7s + ttt_chunk [1891/1893] bpb=1.134454 time=2192.2s + ttt_chunk [1893/1893] bpb=1.134470 time=2193.8s +ttt_sliding:done val_loss=1.913751 val_bpb=1.133434 elapsed=2193.8s +final_int6_zstd-22_roundtrip val_loss:1.9138 val_bpb:1.1334 eval:2194299ms +final_int6_zstd-22_roundtrip_exact val_bpb:1.13343416 + tb += (has_leading_space_lut[scored_tgt] & ~is_boundary_token_lut[scored_prev]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + token_count += float(wlen - s) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = loss_sum / token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT: score each chunk with sliding windows, then train on it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"ttt_optimizer={args.ttt_optimizer} freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks (GEPA has no depth recurrence, freeze by block index) + n_blocks = len(base_model.blocks) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, n_blocks))) + + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + if args.ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=args.ttt_adam_wd) + else: + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- Phase 1: SCORE this chunk's windows --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk's tokens (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + actual_be = my_seq_s + be + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + actual_be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + if args.ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# POST-TRAINING INT6 QUANTIZATION +# ----------------------------- + +INT6_MIN = -32 +INT6_MAX = 31 +INT6_CLIP_PERCENTILE = 99.99984 +INT6_CLIP_Q = INT6_CLIP_PERCENTILE / 100.0 + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear_gate,bigram,skip_gates,ve_shared,ve_layer_scales", + ).split(",") + if pattern +) +INT6_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT6_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT6_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT6_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT6_PER_ROW_SCALE_DTYPE = torch.float16 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT6_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT6_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +# Storage format: INT8_STORAGE=1 uses [-127,127] for better accuracy (larger artifact) +# Default (int6): uses [-32,31] for smaller artifacts but more quantization noise +# MIXED_QUANT=1: int6 per-row for MLP+attn (QAT-trained), int8 per-tensor for rest +INT8_STORAGE = int(os.environ.get("INT8_STORAGE", "0")) +MIXED_QUANT = int(os.environ.get("MIXED_QUANT", "0")) +QUANT_MIN = -127 if INT8_STORAGE else INT6_MIN +QUANT_MAX = 127 if INT8_STORAGE else INT6_MAX + +def _classify_param(name: str) -> str: + """Classify parameter by category for mixed quantization.""" + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_float_tensor_int8_scalar(t: Tensor) -> tuple[Tensor, Tensor]: + """Quantize to int8 [-127,127] range with per-tensor scalar scale.""" + t32 = t.float() + amax = t32.abs().max().item() + scale = torch.tensor(amax / 127.0 if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -127, 127).to(torch.int8) + return q.contiguous(), scale + +def quantize_float_tensor_int6(t: Tensor) -> tuple[Tensor, Tensor]: + """Quantize to int6 [-32,31] or int8 [-127,127] range, stored as int8. + Uses GPTQ-lite: per-row optimal clip percentile search (5 candidates) + to minimize reconstruction MSE. Free improvement over fixed percentile.""" + t32 = t.float() + qmin, qmax = QUANT_MIN, QUANT_MAX + if t32.ndim == 2 and t32.numel() > 0: + # GPTQ-lite: try 5 clip percentiles per-row, pick minimum MSE + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + scale = (row_clip / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(t32 / scale[:, None]), qmin, qmax).to(torch.int8) + recon = q.float() * scale[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q = q.contiguous() + best_s = scale.to(dtype=INT6_PER_ROW_SCALE_DTYPE).contiguous() + best_err = err + return best_q, best_s + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT6_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(qmax) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), qmin, qmax).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int6(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int6_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int6_payload_bytes"] += tensor_nbytes(t) + continue + + # Keep small float tensors (and tok_emb.weight unless QUANT_EMBED=1) in fp16 + quant_embed = int(os.environ.get("QUANT_EMBED", "0")) + if t.numel() <= INT6_KEEP_FLOAT_MAX_NUMEL or (name == "tok_emb.weight" and not quant_embed): + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int6_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + # Mixed quantization: int6 per-row for MLP+attn, int8 per-tensor for rest + cat = _classify_param(name) + if MIXED_QUANT and cat not in ("mlp", "attn"): + q, s = quantize_float_tensor_int8_scalar(t) + qmeta[name] = {"scheme": "per_tensor", "quant_type": "int8"} + else: + q, s = quantize_float_tensor_int6(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int6_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + quant_label = "mixed" if MIXED_QUANT else ("int8" if INT8_STORAGE else "int6") + obj: dict[str, object] = { + "__quant_format__": f"{quant_label}_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int6(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Class-level flag: set True during late-QAT phase to enable fake int6 STE + _qat_enabled: bool = False + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + # Fake int6 quantization via straight-through estimator + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + """RoPE with optional partial application and YARN scaling.""" + def __init__(self, dim: int, base: float = 10000.0, rope_dims: int = 0, train_seq_len: int = 1024): + super().__init__() + # rope_dims=0 means full head_dim; otherwise rotate only first rope_dims dims + rope_d = rope_dims if rope_dims > 0 else dim + self.rope_d = rope_d + self.base = base + self.train_seq_len = train_seq_len + inv_freq = 1.0 / (base ** (torch.arange(0, rope_d, 2, dtype=torch.float32) / rope_d)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_d + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + """Apply RoPE; if cos covers fewer dims than x, rotate only those dims.""" + rd = cos.size(-1) * 2 + if rd < x.size(-1): + x_rope = x[..., :rd] + x_pass = x[..., rd:] + half = rd // 2 + x1 = x_rope[..., :half] + x2 = x_rope[..., half:] + x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rot, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_dims: int = 0, rope_train_seq_len: int = 1024): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, rope_dims=rope_dims, train_seq_len=rope_train_seq_len) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection via GQA-aware reshape (no repeat_interleave).""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + # Add value embeddings to v before attention if provided + if v_embed is not None: + ve_reshaped = v_embed.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v + ve_reshaped + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads)) + if self.use_xsa: + y_xsa = y.transpose(1, 2) + v_xsa = v.transpose(1, 2) + y_xsa = self._xsa_efficient(y_xsa, v_xsa) + y = y_xsa.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + # Star-ReLU implementation. mlp_mult is unused. + hidden = mlp_hidden if mlp_hidden > 0 else int(dim * 3) + self.up_proj = CastedLinear(dim, hidden, bias=False) + self.down_proj = CastedLinear(hidden, dim, bias=False) + self.down_proj._zero_init = True + self.scale = nn.Parameter(torch.ones(hidden, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(hidden, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + x_up = self.up_proj(x) + activated = F.relu(x_up).pow(2) + activated = activated * self.scale.to(dtype=activated.dtype) + self.bias.to(dtype=activated.dtype) + return self.down_proj(activated) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + rope_dims: int = 0, + layer_idx: int = 0, + ln_scale: bool = False, + rope_train_seq_len: int = 1024, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims, rope_train_seq_len=rope_train_seq_len) + self.mlp = MLP(dim, mlp_mult, mlp_hidden=mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + # LN Scale: dampen norm inputs by 1/sqrt(layer_idx+1) for deeper layers + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out = self.attn(self.attn_norm(x) * s, v_embed=v_embed) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + + +# ----------------------------- +# BIGRAM HASH EMBEDDING +# ----------------------------- + +class BigramHashEmbedding(nn.Module): + """Hash-based bigram embedding with XOR hash and learned scale.""" + def __init__(self, num_buckets: int, embed_dim: int, model_dim: int): + super().__init__() + self.num_buckets = num_buckets + self.embed = nn.Embedding(num_buckets, embed_dim) + self.proj = CastedLinear(embed_dim, model_dim, bias=False) + nn.init.zeros_(self.embed.weight) # Zero init (signalrush convention) + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + """XOR-based bigram hash with large primes for uniform distribution.""" + t = tokens.to(torch.int32) + mod = self.num_buckets - 1 + out = torch.empty_like(t) + out[..., 0] = mod # Special bucket for first position + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, input_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(input_ids)) + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +# ----------------------------- +# SMEAR GATE +# ----------------------------- + +class SmearGate(nn.Module): + """Learned blending of current position with previous position.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + # x: (bsz, seq_len, dim) + gate = torch.sigmoid(self.gate.to(dtype=x.dtype)) + # Shift x to get previous position, pad with zeros + x_prev = F.pad(x[:, :-1], (0, 0, 1, 0)) + return (1 - gate) * x + gate * x_prev + + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers.""" + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, kv_dim, bias=False) if ve_dim != kv_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + bigram_buckets: int = 4096, + bigram_embed_dim: int = 128, + rope_dims: int = 0, + ln_scale: bool = False, + xsa_last_n: int = 0, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + rope_train_seq_len: int = 1024, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram_emb = BigramHashEmbedding(bigram_buckets, bigram_embed_dim, model_dim) if bigram_buckets > 0 else None + self.smear_gate = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + Block( + model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + mlp_hidden=mlp_hidden, rope_dims=rope_dims, layer_idx=i, ln_scale=ln_scale, + rope_train_seq_len=rope_train_seq_len, + ) + for i in range(num_layers) + ]) + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + # Value Embeddings: reinject token identity into values at deep layers + kv_dim = model_dim // num_heads * num_kv_heads + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_emb is not None: + x = x + self.bigram_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear_gate(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0, v_embed=self._get_ve(i, input_ids, ve_cache)) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + skip = skips.pop() + gate = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype)) + scaled_skip = self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skip + x = gate[None, None, :] * x + (1.0 - gate[None, None, :]) * scaled_skip + x = self.blocks[bi](x, x0, v_embed=self._get_ve(bi, input_ids, ve_cache)) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + x = self.tok_emb(input_ids) + if self.bigram_emb is not None: + x = x + self.bigram_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear_gate(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0, v_embed=self._get_ve(i, input_ids, ve_cache)) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + skip = skips.pop() + gate = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype)) + scaled_skip = self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skip + x = gate[None, None, :] * x + (1.0 - gate[None, None, :]) * scaled_skip + x = self.blocks[bi](x, x0, v_embed=self._get_ve(bi, input_ids, ve_cache)) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}") + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + CastedLinear._qat_enabled = False # start with QAT off; late_qat enables it mid-run + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mlp_hidden=args.mlp_hidden, + bigram_buckets=args.bigram_buckets, + bigram_embed_dim=args.bigram_embed_dim, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + xsa_last_n=args.xsa_layers, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + rope_train_seq_len=args.rope_train_seq_len, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Differential LR setup + matrix_params_enc, scalar_params_enc = [], [] + matrix_params_dec, scalar_params_dec = [], [] + num_encoder_layers = base_model.num_encoder_layers + for i, block in enumerate(base_model.blocks): + is_decoder = i >= num_encoder_layers + for name, p in block.named_parameters(): + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS): + (matrix_params_dec if is_decoder else matrix_params_enc).append(p) + else: + (scalar_params_dec if is_decoder else scalar_params_enc).append(p) + + # Non-block scalar parameters + other_scalar_params = [base_model.smear_gate.gate] + if base_model.bigram_emb is not None: + other_scalar_params.append(base_model.bigram_emb.embed.weight) + if base_model.skip_weights.numel() > 0: + other_scalar_params.append(base_model.skip_weights) + if hasattr(base_model, 'skip_gates') and base_model.skip_gates.numel() > 0: + other_scalar_params.append(base_model.skip_gates) + # Value Embedding parameters + if base_model.ve_shared is not None: + other_scalar_params.extend(list(base_model.ve_shared.parameters())) + other_scalar_params.extend(list(base_model.ve_layer_scales.parameters())) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.AdamW( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + + matrix_lr_dec = args.matrix_lr * args.decoder_lr_mult + optimizer_muon = Muon( + [ + {'params': matrix_params_enc, 'lr': args.matrix_lr, 'base_lr': args.matrix_lr}, + {'params': matrix_params_dec, 'lr': matrix_lr_dec, 'base_lr': matrix_lr_dec}, + ], + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + + scalar_lr_dec = args.scalar_lr * args.decoder_lr_mult + optimizer_scalar = torch.optim.AdamW( + [ + {'params': scalar_params_enc, 'lr': args.scalar_lr, 'base_lr': args.scalar_lr}, + {'params': scalar_params_dec, 'lr': scalar_lr_dec, 'base_lr': scalar_lr_dec}, + {'params': other_scalar_params, 'lr': args.scalar_lr, 'base_lr': args.scalar_lr}, + ], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + + optimizer_bigram_proj = Muon( + [base_model.bigram_emb.proj.weight], + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_bigram_proj.param_groups: + group["base_lr"] = args.matrix_lr + + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar, optimizer_bigram_proj] + if base_model.lm_head is not None: + optimizer_head = torch.optim.AdamW( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} decoder_lr_mult:{args.decoder_lr_mult}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}") + log0(f"rope_dims:{args.rope_dims} rope_train_seq_len:{args.rope_train_seq_len} ln_scale:{args.ln_scale}") + log0(f"muon_wd:{args.muon_wd} adam_wd:{args.adam_wd} ema_enabled:{args.ema_enabled} late_qat:{args.late_qat}") + log0(f"bigram_buckets:{args.bigram_buckets} bigram_embed_dim:{args.bigram_embed_dim} seed:{args.seed}") + if args.ttt_enabled: + log0(f"ttt:enabled optimizer:{args.ttt_optimizer} lr:{args.ttt_lr} epochs:{args.ttt_epochs} " + f"freeze_blocks:{args.ttt_freeze_blocks} chunk_tokens:{args.ttt_chunk_tokens}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # EMA / SWA STATE + # ----------------------------- + + # EMA takes priority; SWA is fallback (mutually exclusive) + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + log0(f"ema:init decay={args.ema_decay}") + + swa_state: dict[str, Tensor] = {} + swa_count = 0 + + def update_swa(): + nonlocal swa_count + with torch.no_grad(): + for name, param in base_model.state_dict().items(): + if name not in swa_state: + swa_state[name] = param.detach().cpu().clone().float() + else: + swa_state[name].add_(param.detach().cpu().float()) + swa_count += 1 + + def get_swa_state() -> dict[str, Tensor]: + return {name: (t / swa_count).to(dtype=base_model.state_dict()[name].dtype) for name, t in swa_state.items()} + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + # Estimate total steps for SWA start + estimated_total_steps = args.iterations + if max_wallclock_ms is not None: + estimated_total_steps = min(args.iterations, int(max_wallclock_ms / 30)) # rough estimate + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + # Late QAT: enable fake int6 quantization once LR scale drops below threshold + if args.late_qat and not CastedLinear._qat_enabled and scale < args.qat_threshold: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for group in optimizer_bigram_proj.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + + # EMA update every step (takes priority over SWA) + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + # SWA update (only when EMA disabled) + swa_start_step = int(estimated_total_steps * args.swa_start_frac) + if ema_state is None and step >= swa_start_step and step % args.swa_every == 0: + update_swa() + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + if should_log_train: + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms") + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + # Final SWA update (only if EMA disabled and no SWA yet) + if ema_state is None and swa_count == 0: + update_swa() + + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") + + # Apply EMA or SWA weights (EMA takes priority) + if ema_state is not None: + log0("ema:applying EMA weights") + avg_state = {name: t.to(dtype=base_model.state_dict()[name].dtype) for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + del avg_state + elif swa_count > 0: + log0(f"swa:applying averaged {swa_count} checkpoints") + base_model.load_state_dict(get_swa_state(), strict=True) + else: + log0("weight_avg:skipped (no EMA or SWA state)") + + # ----------------------------- + # TTT: fine-tune on val data AFTER EMA/SWA, BEFORE quantization + # ----------------------------- + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int6(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + + # Use zstd-22 for compression (or zlib fallback) + if USE_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compression_method = "zstd-22" + else: + import zlib + quant_blob = zlib.compress(quant_raw, level=9) + compression_method = "zlib-9" + + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int6.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int6_payload_bytes"], 1) + log0(f"Serialized model int6+{compression_method}: {quant_file_bytes} bytes (payload:{quant_stats['int6_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)") + log0(f"Total submission size int6+{compression_method}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + + # Decompress + if USE_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_raw_disk = dctx.decompress(quant_blob_disk) + else: + import zlib + quant_raw_disk = zlib.decompress(quant_blob_disk) + + quant_state = torch.load(io.BytesIO(quant_raw_disk), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int6(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.ttt_enabled and args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window_ttt stride:{args.eval_stride} " + f"chunk_tokens:{args.ttt_chunk_tokens} optimizer:{args.ttt_optimizer}") + q_val_loss, q_val_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, log0=log0) + elif args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"final_int6_{compression_method}_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval:{1000.0*(time.perf_counter()-t_qeval):.0f}ms") + log0(f"final_int6_{compression_method}_roundtrip_exact val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + +if __name__ == "__main__": + main() +==================================================================================================== +Running Python 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0] +Running PyTorch 2.10.0+cu128 +Mon Mar 23 21:30:46 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA A100-PCIE-40GB On | 00000000:17:00.0 Off | 0 | +| N/A 46C P0 50W / 250W | 667MiB / 40960MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA A100-PCIE-40GB On | 00000000:65:00.0 Off | 0 | +| N/A 47C P0 53W / 250W | 667MiB / 40960MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA A100-PCIE-40GB On | 00000000:CA:00.0 Off | 0 | +| N/A 46C P0 51W / 250W | 667MiB / 40960MiB | 3% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA A100-PCIE-40GB On | 00000000:E3:00.0 Off | 0 | +| N/A 47C P0 53W / 250W | 667MiB / 40960MiB | 10% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 150823 C ...ameter_golf/.venv/bin/python3 658MiB | +| 1 N/A N/A 150824 C ...ameter_golf/.venv/bin/python3 658MiB | +| 2 N/A N/A 150825 C ...ameter_golf/.venv/bin/python3 658MiB | +| 3 N/A N/A 150827 C ...ameter_golf/.venv/bin/python3 658MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27030108 +world_size:4 grad_accum_steps:2 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 decoder_lr_mult:2.0 +train_batch_tokens:786432 train_seq_len:2048 iterations:7000 warmup_steps:20 max_wallclock_seconds:0.000 +rope_dims:16 rope_train_seq_len:1024 ln_scale:True +muon_wd:0.04 adam_wd:0.04 ema_enabled:True late_qat:True +bigram_buckets:2048 bigram_embed_dim:128 seed:42 +ttt:enabled optimizer:sgd lr:0.002 epochs:10 freeze_blocks:2 chunk_tokens:32768 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +ema:init decay=0.997 +step:0/7000 val_loss:6.9301 val_bpb:4.1044 train_time:0ms step_avg:0.01ms +step:1/7000 train_loss:6.9313 train_time:490ms step_avg:490.27ms +step:2/7000 train_loss:8.4238 train_time:979ms step_avg:489.55ms +step:3/7000 train_loss:7.5556 train_time:1478ms step_avg:492.83ms +step:4/7000 train_loss:8.6123 train_time:1980ms step_avg:494.92ms +step:5/7000 train_loss:9.0143 train_time:2477ms step_avg:495.36ms +step:6/7000 train_loss:8.8303 train_time:2971ms step_avg:495.10ms +step:7/7000 train_loss:8.2794 train_time:3471ms step_avg:495.90ms +step:8/7000 train_loss:7.4327 train_time:3973ms step_avg:496.67ms +step:9/7000 train_loss:6.8659 train_time:4477ms step_avg:497.42ms +step:10/7000 train_loss:6.4393 train_time:4972ms step_avg:497.22ms +step:200/7000 train_loss:2.5808 train_time:100029ms step_avg:500.15ms +step:400/7000 train_loss:2.4553 train_time:200389ms step_avg:500.97ms +step:500/7000 val_loss:2.3821 val_bpb:1.4108 train_time:250213ms step_avg:500.43ms +step:600/7000 train_loss:2.3758 train_time:300013ms step_avg:500.02ms +step:800/7000 train_loss:2.2630 train_time:399630ms step_avg:499.54ms +step:1000/7000 train_loss:2.3020 train_time:499269ms step_avg:499.27ms +step:1000/7000 val_loss:2.2514 val_bpb:1.3334 train_time:499278ms step_avg:499.28ms +step:1200/7000 train_loss:2.2961 train_time:599114ms step_avg:499.26ms +step:1400/7000 train_loss:2.2551 train_time:698793ms step_avg:499.14ms +step:1500/7000 val_loss:2.2034 val_bpb:1.3050 train_time:748645ms step_avg:499.10ms +step:1600/7000 train_loss:2.1487 train_time:798455ms step_avg:499.03ms +step:1800/7000 train_loss:2.1519 train_time:898339ms step_avg:499.08ms +step:2000/7000 train_loss:2.0396 train_time:998108ms step_avg:499.05ms +step:2000/7000 val_loss:2.1463 val_bpb:1.2711 train_time:998116ms step_avg:499.06ms +step:2200/7000 train_loss:2.1582 train_time:1097880ms step_avg:499.04ms +step:2400/7000 train_loss:2.0984 train_time:1197471ms step_avg:498.95ms +step:2500/7000 val_loss:2.1260 val_bpb:1.2592 train_time:1247279ms step_avg:498.91ms +step:2600/7000 train_loss:2.1654 train_time:1297122ms step_avg:498.89ms +step:2800/7000 train_loss:2.1999 train_time:1397127ms step_avg:498.97ms +step:3000/7000 train_loss:2.1277 train_time:1496756ms step_avg:498.92ms +step:3000/7000 val_loss:2.1122 val_bpb:1.2509 train_time:1496764ms step_avg:498.92ms +step:3200/7000 train_loss:2.1651 train_time:1596443ms step_avg:498.89ms +step:3400/7000 train_loss:2.1145 train_time:1696169ms step_avg:498.87ms +step:3500/7000 val_loss:2.1064 val_bpb:1.2475 train_time:1746216ms step_avg:498.92ms +step:3600/7000 train_loss:2.1113 train_time:1796178ms step_avg:498.94ms +step:3800/7000 train_loss:2.0975 train_time:1895885ms step_avg:498.92ms +step:4000/7000 train_loss:2.1540 train_time:1995522ms step_avg:498.88ms +step:4000/7000 val_loss:2.0859 val_bpb:1.2354 train_time:1995531ms step_avg:498.88ms +step:4200/7000 train_loss:2.0917 train_time:2095260ms step_avg:498.87ms +step:4400/7000 train_loss:2.0158 train_time:2194994ms step_avg:498.86ms +step:4500/7000 val_loss:2.0651 val_bpb:1.2231 train_time:2244995ms step_avg:498.89ms +step:4600/7000 train_loss:1.9621 train_time:2294720ms step_avg:498.85ms +step:4800/7000 train_loss:2.2581 train_time:2394137ms step_avg:498.78ms +step:5000/7000 train_loss:2.0617 train_time:2493859ms step_avg:498.77ms +step:5000/7000 val_loss:2.0437 val_bpb:1.2104 train_time:2493868ms step_avg:498.77ms +step:5200/7000 train_loss:2.0638 train_time:2593605ms step_avg:498.77ms +step:5400/7000 train_loss:1.9975 train_time:2693386ms step_avg:498.78ms +step:5500/7000 val_loss:2.0211 val_bpb:1.1970 train_time:2743156ms step_avg:498.76ms +step:5600/7000 train_loss:2.0094 train_time:2792858ms step_avg:498.72ms +step:5800/7000 train_loss:1.9736 train_time:2892326ms step_avg:498.68ms +step:6000/7000 train_loss:1.9671 train_time:2992131ms step_avg:498.69ms +step:6000/7000 val_loss:1.9981 val_bpb:1.1834 train_time:2992140ms step_avg:498.69ms +step:6200/7000 train_loss:2.0387 train_time:3091840ms step_avg:498.68ms +step:6400/7000 train_loss:1.9905 train_time:3191372ms step_avg:498.65ms +late_qat:enabled step:6476 scale:0.1497 +step:6500/7000 val_loss:1.9657 val_bpb:1.1642 train_time:3241122ms step_avg:498.63ms +step:6600/7000 train_loss:1.8821 train_time:3290897ms step_avg:498.62ms +step:6800/7000 train_loss:2.0409 train_time:3390711ms step_avg:498.63ms +step:7000/7000 train_loss:1.8150 train_time:3490351ms step_avg:498.62ms +step:7000/7000 val_loss:1.9376 val_bpb:1.1476 train_time:3490359ms step_avg:498.62ms +peak memory allocated: 21327 MiB reserved: 21512 MiB +ema:applying EMA weights +Serialized model: 107123447 bytes +Code size: 76429 bytes +Total submission size: 107199876 bytes +Serialized model int6+zstd-22: 15626769 bytes (payload:27522422 raw_torch:27588053 payload_ratio:3.89x) +Total submission size int6+zstd-22: 15703198 bytes +final_eval_mode:sliding_window_ttt stride:64 chunk_tokens:32768 optimizer:sgd +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=10 ttt_optimizer=sgd freeze_blocks=2 +ttt_sliding:params unfrozen=22301260 frozen=4728848 + ttt_chunk [1/1893] bpb=1.174449 time=1.3s + ttt_chunk [11/1893] bpb=1.131772 time=12.8s + ttt_chunk [21/1893] bpb=1.136036 time=24.3s + ttt_chunk [31/1893] bpb=1.138824 time=35.8s + ttt_chunk [41/1893] bpb=1.128172 time=47.4s + ttt_chunk [51/1893] bpb=1.125337 time=58.9s + ttt_chunk [61/1893] bpb=1.131113 time=70.5s + ttt_chunk [71/1893] bpb=1.128140 time=82.1s + ttt_chunk [81/1893] bpb=1.128327 time=93.6s + ttt_chunk [91/1893] bpb=1.127725 time=105.2s + ttt_chunk [101/1893] bpb=1.130981 time=116.8s + ttt_chunk [111/1893] bpb=1.132478 time=128.5s + ttt_chunk [121/1893] bpb=1.129591 time=140.1s + ttt_chunk [131/1893] bpb=1.129946 time=151.7s + ttt_chunk [141/1893] bpb=1.129857 time=163.3s + ttt_chunk [151/1893] bpb=1.133145 time=174.8s + ttt_chunk [161/1893] bpb=1.135286 time=186.4s + ttt_chunk [171/1893] bpb=1.136401 time=198.0s + ttt_chunk [181/1893] bpb=1.136577 time=209.6s + ttt_chunk [191/1893] bpb=1.140083 time=221.2s + ttt_chunk [201/1893] bpb=1.140433 time=232.8s + ttt_chunk [211/1893] bpb=1.138257 time=244.4s + ttt_chunk [221/1893] bpb=1.140323 time=255.9s + ttt_chunk [231/1893] bpb=1.139984 time=267.5s + ttt_chunk [241/1893] bpb=1.139897 time=279.1s + ttt_chunk [251/1893] bpb=1.138439 time=290.7s + ttt_chunk [261/1893] bpb=1.136828 time=302.3s + ttt_chunk [271/1893] bpb=1.135533 time=313.8s + ttt_chunk [281/1893] bpb=1.137994 time=325.4s + ttt_chunk [291/1893] bpb=1.138875 time=337.0s + ttt_chunk [301/1893] bpb=1.139468 time=348.6s + ttt_chunk [311/1893] bpb=1.141148 time=360.1s + ttt_chunk [321/1893] bpb=1.142669 time=371.7s + ttt_chunk [331/1893] bpb=1.142534 time=383.3s + ttt_chunk [341/1893] bpb=1.142923 time=394.8s + ttt_chunk [351/1893] bpb=1.144321 time=406.4s + ttt_chunk [361/1893] bpb=1.145631 time=418.0s + ttt_chunk [371/1893] bpb=1.145187 time=429.5s + ttt_chunk [381/1893] bpb=1.145098 time=441.1s + ttt_chunk [391/1893] bpb=1.144739 time=452.7s + ttt_chunk [401/1893] bpb=1.143393 time=464.2s + ttt_chunk [411/1893] bpb=1.142374 time=475.8s + ttt_chunk [421/1893] bpb=1.141834 time=487.4s + ttt_chunk [431/1893] bpb=1.142580 time=498.9s + ttt_chunk [441/1893] bpb=1.142468 time=510.5s + ttt_chunk [451/1893] bpb=1.142226 time=522.1s + ttt_chunk [461/1893] bpb=1.141511 time=533.6s + ttt_chunk [471/1893] bpb=1.141128 time=545.2s + ttt_chunk [481/1893] bpb=1.140896 time=556.7s + ttt_chunk [491/1893] bpb=1.140638 time=568.3s + ttt_chunk [501/1893] bpb=1.140157 time=579.8s + ttt_chunk [511/1893] bpb=1.139646 time=591.4s + ttt_chunk [521/1893] bpb=1.138782 time=603.0s + ttt_chunk [531/1893] bpb=1.138896 time=614.5s + ttt_chunk [541/1893] bpb=1.138847 time=626.1s + ttt_chunk [551/1893] bpb=1.137660 time=637.7s + ttt_chunk [561/1893] bpb=1.138154 time=649.2s + ttt_chunk [571/1893] bpb=1.137454 time=660.8s + ttt_chunk [581/1893] bpb=1.136930 time=672.3s + ttt_chunk [591/1893] bpb=1.136281 time=683.9s + ttt_chunk [601/1893] bpb=1.137008 time=695.5s + ttt_chunk [611/1893] bpb=1.136609 time=707.0s + ttt_chunk [621/1893] bpb=1.136580 time=718.6s + ttt_chunk [631/1893] bpb=1.136996 time=730.2s + ttt_chunk [641/1893] bpb=1.136793 time=741.7s + ttt_chunk [651/1893] bpb=1.136826 time=753.3s + ttt_chunk [661/1893] bpb=1.136755 time=764.9s + ttt_chunk [671/1893] bpb=1.136399 time=776.4s + ttt_chunk [681/1893] bpb=1.136668 time=788.1s + ttt_chunk [691/1893] bpb=1.137390 time=799.9s + ttt_chunk [701/1893] bpb=1.136588 time=811.4s + ttt_chunk [711/1893] bpb=1.137210 time=823.0s + ttt_chunk [721/1893] bpb=1.136866 time=834.5s + ttt_chunk [731/1893] bpb=1.137361 time=846.1s + ttt_chunk [741/1893] bpb=1.137336 time=857.6s + ttt_chunk [751/1893] bpb=1.136976 time=869.2s + ttt_chunk [761/1893] bpb=1.136875 time=880.8s + ttt_chunk [771/1893] bpb=1.136668 time=892.3s + ttt_chunk [781/1893] bpb=1.137227 time=903.9s + ttt_chunk [791/1893] bpb=1.136940 time=915.4s + ttt_chunk [801/1893] bpb=1.137045 time=927.0s + ttt_chunk [811/1893] bpb=1.136616 time=938.5s + ttt_chunk [821/1893] bpb=1.136474 time=950.2s + ttt_chunk [831/1893] bpb=1.136085 time=961.9s + ttt_chunk [841/1893] bpb=1.135570 time=973.7s + ttt_chunk [851/1893] bpb=1.135510 time=985.4s + ttt_chunk [861/1893] bpb=1.135696 time=997.0s + ttt_chunk [871/1893] bpb=1.135789 time=1008.7s + ttt_chunk [881/1893] bpb=1.135866 time=1020.3s + ttt_chunk [891/1893] bpb=1.135748 time=1031.9s + ttt_chunk [901/1893] bpb=1.135774 time=1043.5s + ttt_chunk [911/1893] bpb=1.135824 time=1055.2s + ttt_chunk [921/1893] bpb=1.136241 time=1066.8s + ttt_chunk [931/1893] bpb=1.136100 time=1078.5s + ttt_chunk [941/1893] bpb=1.136016 time=1090.1s + ttt_chunk [951/1893] bpb=1.136056 time=1101.7s + ttt_chunk [961/1893] bpb=1.135829 time=1113.3s + ttt_chunk [971/1893] bpb=1.136600 time=1124.9s + ttt_chunk [981/1893] bpb=1.136761 time=1136.5s + ttt_chunk [991/1893] bpb=1.136700 time=1148.1s + ttt_chunk [1001/1893] bpb=1.136864 time=1159.8s + ttt_chunk [1011/1893] bpb=1.137154 time=1171.4s + ttt_chunk [1021/1893] bpb=1.137333 time=1183.0s + ttt_chunk [1031/1893] bpb=1.137876 time=1194.6s + ttt_chunk [1041/1893] bpb=1.137535 time=1206.2s + ttt_chunk [1051/1893] bpb=1.137250 time=1217.8s + ttt_chunk [1061/1893] bpb=1.137518 time=1229.4s + ttt_chunk [1071/1893] bpb=1.138023 time=1241.1s + ttt_chunk [1081/1893] bpb=1.138054 time=1252.7s + ttt_chunk [1091/1893] bpb=1.138459 time=1264.3s + ttt_chunk [1101/1893] bpb=1.138604 time=1275.9s + ttt_chunk [1111/1893] bpb=1.138362 time=1287.6s + ttt_chunk [1121/1893] bpb=1.138294 time=1299.2s + ttt_chunk [1131/1893] bpb=1.138173 time=1310.8s + ttt_chunk [1141/1893] bpb=1.138040 time=1322.4s + ttt_chunk [1151/1893] bpb=1.138090 time=1334.0s + ttt_chunk [1161/1893] bpb=1.137486 time=1345.7s + ttt_chunk [1171/1893] bpb=1.138066 time=1357.3s + ttt_chunk [1181/1893] bpb=1.137575 time=1368.9s + ttt_chunk [1191/1893] bpb=1.137310 time=1380.5s + ttt_chunk [1201/1893] bpb=1.137891 time=1392.2s + ttt_chunk [1211/1893] bpb=1.137319 time=1403.8s + ttt_chunk [1221/1893] bpb=1.137033 time=1415.4s + ttt_chunk [1231/1893] bpb=1.136922 time=1427.0s + ttt_chunk [1241/1893] bpb=1.136730 time=1438.7s + ttt_chunk [1251/1893] bpb=1.136489 time=1450.3s + ttt_chunk [1261/1893] bpb=1.136446 time=1461.9s + ttt_chunk [1271/1893] bpb=1.136259 time=1473.5s + ttt_chunk [1281/1893] bpb=1.136060 time=1485.1s + ttt_chunk [1291/1893] bpb=1.135912 time=1496.8s + ttt_chunk [1301/1893] bpb=1.135555 time=1508.4s + ttt_chunk [1311/1893] bpb=1.135233 time=1520.0s + ttt_chunk [1321/1893] bpb=1.135061 time=1531.7s + ttt_chunk [1331/1893] bpb=1.134972 time=1543.3s + ttt_chunk [1341/1893] bpb=1.134865 time=1554.9s + ttt_chunk [1351/1893] bpb=1.134818 time=1566.6s + ttt_chunk [1361/1893] bpb=1.135040 time=1578.2s + ttt_chunk [1371/1893] bpb=1.134911 time=1589.8s + ttt_chunk [1381/1893] bpb=1.134864 time=1601.4s + ttt_chunk [1391/1893] bpb=1.134337 time=1613.0s + ttt_chunk [1401/1893] bpb=1.134377 time=1624.6s + ttt_chunk [1411/1893] bpb=1.134402 time=1636.3s + ttt_chunk [1421/1893] bpb=1.134667 time=1647.9s + ttt_chunk [1431/1893] bpb=1.134547 time=1659.5s + ttt_chunk [1441/1893] bpb=1.135215 time=1671.2s + ttt_chunk [1451/1893] bpb=1.135365 time=1682.8s + ttt_chunk [1461/1893] bpb=1.135118 time=1694.4s + ttt_chunk [1471/1893] bpb=1.136046 time=1706.0s + ttt_chunk [1481/1893] bpb=1.135855 time=1717.6s + ttt_chunk [1491/1893] bpb=1.135874 time=1729.3s + ttt_chunk [1501/1893] bpb=1.136058 time=1740.9s + ttt_chunk [1511/1893] bpb=1.136165 time=1752.5s + ttt_chunk [1521/1893] bpb=1.136196 time=1764.1s + ttt_chunk [1531/1893] bpb=1.136048 time=1775.8s + ttt_chunk [1541/1893] bpb=1.136026 time=1787.4s + ttt_chunk [1551/1893] bpb=1.136391 time=1799.0s + ttt_chunk [1561/1893] bpb=1.136546 time=1810.6s + ttt_chunk [1571/1893] bpb=1.136693 time=1822.3s + ttt_chunk [1581/1893] bpb=1.136814 time=1833.9s + ttt_chunk [1591/1893] bpb=1.136764 time=1845.5s + ttt_chunk [1601/1893] bpb=1.136940 time=1857.1s + ttt_chunk [1611/1893] bpb=1.137021 time=1868.8s + ttt_chunk [1621/1893] bpb=1.136896 time=1880.4s + ttt_chunk [1631/1893] bpb=1.137069 time=1892.0s + ttt_chunk [1641/1893] bpb=1.136961 time=1903.7s + ttt_chunk [1651/1893] bpb=1.136906 time=1915.4s + ttt_chunk [1661/1893] bpb=1.136796 time=1927.1s + ttt_chunk [1671/1893] bpb=1.137180 time=1938.6s + ttt_chunk [1681/1893] bpb=1.137454 time=1950.1s + ttt_chunk [1691/1893] bpb=1.137422 time=1961.7s + ttt_chunk [1701/1893] bpb=1.137408 time=1973.2s + ttt_chunk [1711/1893] bpb=1.137268 time=1984.7s + ttt_chunk [1721/1893] bpb=1.137122 time=1996.2s + ttt_chunk [1731/1893] bpb=1.137091 time=2007.8s + ttt_chunk [1741/1893] bpb=1.136882 time=2019.3s + ttt_chunk [1751/1893] bpb=1.136725 time=2030.8s + ttt_chunk [1761/1893] bpb=1.136813 time=2042.4s + ttt_chunk [1771/1893] bpb=1.136744 time=2053.9s + ttt_chunk [1781/1893] bpb=1.136720 time=2065.4s + ttt_chunk [1791/1893] bpb=1.136318 time=2077.0s + ttt_chunk [1801/1893] bpb=1.136328 time=2088.5s + ttt_chunk [1811/1893] bpb=1.136175 time=2100.0s + ttt_chunk [1821/1893] bpb=1.136196 time=2111.6s + ttt_chunk [1831/1893] bpb=1.135816 time=2123.1s + ttt_chunk [1841/1893] bpb=1.135838 time=2134.6s + ttt_chunk [1851/1893] bpb=1.135612 time=2146.1s + ttt_chunk [1861/1893] bpb=1.135142 time=2157.7s + ttt_chunk [1871/1893] bpb=1.134990 time=2169.2s + ttt_chunk [1881/1893] bpb=1.134613 time=2180.7s + ttt_chunk [1891/1893] bpb=1.134454 time=2192.2s + ttt_chunk [1893/1893] bpb=1.134470 time=2193.8s +ttt_sliding:done val_loss=1.913751 val_bpb=1.133434 elapsed=2193.8s +final_int6_zstd-22_roundtrip val_loss:1.9138 val_bpb:1.1334 eval:2194299ms +final_int6_zstd-22_roundtrip_exact val_bpb:1.13343416