From 8dc11e2f070dc8bfdd2a494590b65aa338626d7e Mon Sep 17 00:00:00 2001 From: Alexandr Azizyan Date: Thu, 26 Mar 2026 17:57:01 +0400 Subject: [PATCH 01/10] feat: add modified training script with recurrence stabilization techniques --- .../train_gpt.py | 1320 +++++++++++++++++ 1 file changed, 1320 insertions(+) create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py new file mode 100644 index 0000000000..7a244120d8 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py @@ -0,0 +1,1320 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + + def get(self, v: int) -> tuple[Tensor, Tensor]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + return ag, mg + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() From 43f74389dec53518cdc170df318f37c7af45e7e3 Mon Sep 17 00:00:00 2001 From: Alexandr Azizyan Date: Thu, 26 Mar 2026 17:58:03 +0400 Subject: [PATCH 02/10] feat: add screening experiment scripts and logs (7 runs, 2000 steps) --- .../logs/s1_Ap.txt | 1400 ++++++++++++++++ .../logs/s1_Bp.txt | 1412 +++++++++++++++++ .../logs/s1_C.txt | 1412 +++++++++++++++++ .../logs/s1_Cp.txt | 1412 +++++++++++++++++ .../logs/s1_D.txt | 1412 +++++++++++++++++ .../logs/s1_E.txt | 1412 +++++++++++++++++ .../logs/s1_F.txt | 1412 +++++++++++++++++ .../scripts/run_screening.sh | 66 + 8 files changed, 9938 insertions(+) create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Ap.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Bp.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_C.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Cp.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_D.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_E.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_F.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_screening.sh diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Ap.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Ap.txt new file mode 100644 index 0000000000..2d72ccf631 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Ap.txt @@ -0,0 +1,1400 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return (self.timestep_scale.attn_gamma[v], + self.timestep_scale.mlp_gamma[v]) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 13:13:20 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 34C P0 104W / 700W | 1184MiB / 81559MiB | 3% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:15222848 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:2000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +recurrence:disabled num_layers:8 +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/2000 val_loss:6.9369 val_bpb:4.1084 train_time:0ms step_avg:0.03ms +step:1/2000 train_loss:6.9369 train_time:309ms step_avg:308.95ms +step:2/2000 train_loss:16.9406 train_time:676ms step_avg:338.05ms +step:3/2000 train_loss:8.7636 train_time:994ms step_avg:331.34ms +step:4/2000 train_loss:6.5160 train_time:1377ms step_avg:344.25ms +step:5/2000 train_loss:6.5786 train_time:1709ms step_avg:341.88ms +step:6/2000 train_loss:6.6115 train_time:2063ms step_avg:343.79ms +step:7/2000 train_loss:6.2776 train_time:2418ms step_avg:345.42ms +step:8/2000 train_loss:6.1292 train_time:2799ms step_avg:349.93ms +step:9/2000 train_loss:6.0395 train_time:3188ms step_avg:354.19ms +step:10/2000 train_loss:5.9571 train_time:3516ms step_avg:351.61ms +step:200/2000 train_loss:2.7626 train_time:71390ms step_avg:356.95ms +step:400/2000 train_loss:2.3840 train_time:143891ms step_avg:359.73ms +step:500/2000 val_loss:2.4916 val_bpb:1.4757 train_time:179330ms step_avg:358.66ms +step:600/2000 train_loss:2.4989 train_time:215992ms step_avg:359.99ms +step:800/2000 train_loss:2.3534 train_time:287017ms step_avg:358.77ms +step:1000/2000 train_loss:2.3616 train_time:358788ms step_avg:358.79ms +step:1000/2000 val_loss:2.3313 val_bpb:1.3808 train_time:358792ms step_avg:358.79ms +step:1200/2000 train_loss:2.2821 train_time:430276ms step_avg:358.56ms +step:1400/2000 train_loss:2.3092 train_time:501281ms step_avg:358.06ms +step:1500/2000 val_loss:2.2491 val_bpb:1.3321 train_time:537489ms step_avg:358.33ms +step:1600/2000 train_loss:2.1959 train_time:573472ms step_avg:358.42ms +step:1800/2000 train_loss:2.2346 train_time:645004ms step_avg:358.34ms +step:2000/2000 train_loss:2.2507 train_time:715765ms step_avg:357.88ms +step:2000/2000 val_loss:2.2007 val_bpb:1.3034 train_time:715769ms step_avg:357.88ms +peak memory allocated: 8570 MiB reserved: 9042 MiB +Serialized model: 59873165 bytes +Code size: 56522 bytes +Total submission size: 59929687 bytes +Serialized model int8+zlib: 13336026 bytes (payload:15329536 raw_torch:15369657 payload_ratio:3.90x) +Total submission size int8+zlib: 13392548 bytes +final_int8_zlib_roundtrip val_loss:2.2035 val_bpb:1.3050 eval_time:8873ms +final_int8_zlib_roundtrip_exact val_loss:2.20345363 val_bpb:1.30500936 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Bp.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Bp.txt new file mode 100644 index 0000000000..061fc9812f --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Bp.txt @@ -0,0 +1,1412 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return (self.timestep_scale.attn_gamma[v], + self.timestep_scale.mlp_gamma[v]) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 13:27:43 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 43C P0 114W / 700W | 1184MiB / 81559MiB | 9% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:7872544 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:2000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +recurrence:enabled prelude:0 shared:4 loops:2 coda:0 effective_layers:8 +peri_norm:False birkhoff_mix:False +timestep_scale:disabled +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/2000 val_loss:6.9365 val_bpb:4.1082 train_time:0ms step_avg:0.02ms +step:1/2000 train_loss:6.9365 train_time:290ms step_avg:289.55ms +step:2/2000 train_loss:15.3054 train_time:671ms step_avg:335.38ms +step:3/2000 train_loss:7.4873 train_time:1001ms step_avg:333.59ms +step:4/2000 train_loss:7.0969 train_time:1369ms step_avg:342.33ms +step:5/2000 train_loss:7.3936 train_time:1701ms step_avg:340.15ms +step:6/2000 train_loss:7.1367 train_time:2065ms step_avg:344.18ms +step:7/2000 train_loss:6.5868 train_time:2417ms step_avg:345.22ms +step:8/2000 train_loss:6.2915 train_time:2778ms step_avg:347.27ms +step:9/2000 train_loss:6.1044 train_time:3177ms step_avg:352.99ms +step:10/2000 train_loss:5.9683 train_time:3575ms step_avg:357.51ms +step:200/2000 train_loss:2.7934 train_time:72774ms step_avg:363.87ms +step:200 eff_mlp_scale:[v0:1.3749 v1:1.4795 v2:1.5208 v3:1.4656 v4:1.3749 v5:1.4795 v6:1.5208 v7:1.4656] eff_attn_scale:[v0:0.6971 v1:1.6206 v2:1.7776 v3:2.0465 v4:0.6971 v5:1.6206 v6:1.7776 v7:2.0465] +step:400/2000 train_loss:2.4532 train_time:146554ms step_avg:366.39ms +step:400 eff_mlp_scale:[v0:1.1923 v1:1.3538 v2:1.4721 v3:1.5603 v4:1.1923 v5:1.3538 v6:1.4721 v7:1.5603] eff_attn_scale:[v0:0.3291 v1:1.7381 v2:2.0542 v3:2.3685 v4:0.3291 v5:1.7381 v6:2.0542 v7:2.3685] +step:500/2000 val_loss:2.5658 val_bpb:1.5196 train_time:182468ms step_avg:364.94ms +step:600/2000 train_loss:2.5780 train_time:217821ms step_avg:363.03ms +step:600 eff_mlp_scale:[v0:1.1070 v1:1.2131 v2:1.3211 v3:1.4886 v4:1.1070 v5:1.2131 v6:1.3211 v7:1.4886] eff_attn_scale:[v0:0.1698 v1:1.9605 v2:2.3374 v3:2.6891 v4:0.1698 v5:1.9605 v6:2.3374 v7:2.6891] +step:800/2000 train_loss:2.4384 train_time:290721ms step_avg:363.40ms +step:800 eff_mlp_scale:[v0:1.0569 v1:1.1042 v2:1.1859 v3:1.4216 v4:1.0569 v5:1.1042 v6:1.1859 v7:1.4216] eff_attn_scale:[v0:0.1226 v1:2.1585 v2:2.5694 v3:2.9469 v4:0.1226 v5:2.1585 v6:2.5694 v7:2.9469] +step:1000/2000 train_loss:2.4488 train_time:363607ms step_avg:363.61ms +step:1000 eff_mlp_scale:[v0:1.0251 v1:1.0387 v2:1.0883 v3:1.3961 v4:1.0251 v5:1.0387 v6:1.0883 v7:1.3961] eff_attn_scale:[v0:0.1115 v1:2.3099 v2:2.7290 v3:3.1397 v4:0.1115 v5:2.3099 v6:2.7290 v7:3.1397] +step:1000/2000 val_loss:2.4201 val_bpb:1.4333 train_time:363616ms step_avg:363.62ms +step:1200/2000 train_loss:2.3727 train_time:436817ms step_avg:364.01ms +step:1200 eff_mlp_scale:[v0:1.0161 v1:1.0003 v2:1.0270 v3:1.4071 v4:1.0161 v5:1.0003 v6:1.0270 v7:1.4071] eff_attn_scale:[v0:0.1094 v1:2.4014 v2:2.8088 v3:3.2696 v4:0.1094 v5:2.4014 v6:2.8088 v7:3.2696] +step:1400/2000 train_loss:2.4037 train_time:509735ms step_avg:364.10ms +step:1400 eff_mlp_scale:[v0:1.0114 v1:0.9864 v2:0.9882 v3:1.4488 v4:1.0114 v5:0.9864 v6:0.9882 v7:1.4488] eff_attn_scale:[v0:0.1098 v1:2.4445 v2:2.8440 v3:3.3495 v4:0.1098 v5:2.4445 v6:2.8440 v7:3.3495] +step:1500/2000 val_loss:2.3448 val_bpb:1.3887 train_time:545811ms step_avg:363.87ms +step:1600/2000 train_loss:2.2905 train_time:581639ms step_avg:363.52ms +step:1600 eff_mlp_scale:[v0:1.0107 v1:0.9799 v2:0.9673 v3:1.4912 v4:1.0107 v5:0.9799 v6:0.9673 v7:1.4912] eff_attn_scale:[v0:0.1118 v1:2.4638 v2:2.8518 v3:3.3944 v4:0.1118 v5:2.4638 v6:2.8518 v7:3.3944] +step:1800/2000 train_loss:2.3306 train_time:652871ms step_avg:362.71ms +step:1800 eff_mlp_scale:[v0:1.0141 v1:0.9842 v2:0.9611 v3:1.5221 v4:1.0141 v5:0.9842 v6:0.9611 v7:1.5221] eff_attn_scale:[v0:0.1126 v1:2.4638 v2:2.8444 v3:3.4095 v4:0.1126 v5:2.4638 v6:2.8444 v7:3.4095] +step:2000/2000 train_loss:2.3453 train_time:726725ms step_avg:363.36ms +step:2000 eff_mlp_scale:[v0:1.0167 v1:0.9887 v2:0.9613 v3:1.5336 v4:1.0167 v5:0.9887 v6:0.9613 v7:1.5336] eff_attn_scale:[v0:0.1122 v1:2.4573 v2:2.8348 v3:3.4125 v4:0.1122 v5:2.4573 v6:2.8348 v7:3.4125] +step:2000/2000 val_loss:2.2986 val_bpb:1.3613 train_time:726733ms step_avg:363.37ms +peak memory allocated: 8392 MiB reserved: 8768 MiB +Serialized model: 30457828 bytes +Code size: 56522 bytes +Total submission size: 30514350 bytes +Serialized model int8+zlib: 6889862 bytes (payload:7923840 raw_torch:7944285 payload_ratio:3.84x) +Total submission size int8+zlib: 6946384 bytes +final_int8_zlib_roundtrip val_loss:2.3025 val_bpb:1.3637 eval_time:8771ms +final_int8_zlib_roundtrip_exact val_loss:2.30249951 val_bpb:1.36366990 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_C.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_C.txt new file mode 100644 index 0000000000..64fab506ad --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_C.txt @@ -0,0 +1,1412 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return (self.timestep_scale.attn_gamma[v], + self.timestep_scale.mlp_gamma[v]) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 13:41:41 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 44C P0 111W / 700W | 1184MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:7870496 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:2000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +recurrence:enabled prelude:0 shared:4 loops:2 coda:0 effective_layers:8 +peri_norm:True birkhoff_mix:True +timestep_scale:disabled +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/2000 val_loss:6.9365 val_bpb:4.1082 train_time:0ms step_avg:0.02ms +step:1/2000 train_loss:6.9365 train_time:295ms step_avg:295.20ms +step:2/2000 train_loss:9.4113 train_time:671ms step_avg:335.45ms +step:3/2000 train_loss:7.6620 train_time:1023ms step_avg:341.11ms +step:4/2000 train_loss:9.1263 train_time:1382ms step_avg:345.62ms +step:5/2000 train_loss:8.7319 train_time:1745ms step_avg:349.00ms +step:6/2000 train_loss:7.7896 train_time:2094ms step_avg:348.93ms +step:7/2000 train_loss:6.9395 train_time:2432ms step_avg:347.37ms +step:8/2000 train_loss:6.2371 train_time:2795ms step_avg:349.35ms +step:9/2000 train_loss:5.6578 train_time:3143ms step_avg:349.22ms +step:10/2000 train_loss:5.3392 train_time:3484ms step_avg:348.39ms +step:200/2000 train_loss:2.7839 train_time:69186ms step_avg:345.93ms +step:200 shared0_alpha:mean=0.441,std=0.070 shared1_alpha:mean=0.469,std=0.048 shared2_alpha:mean=0.520,std=0.048 shared3_alpha:mean=0.584,std=0.054 eff_mlp_scale:[v0:1.2092 v1:1.2996 v2:1.4409 v3:1.8575 v4:1.2092 v5:1.2996 v6:1.4409 v7:1.8575] eff_attn_scale:[v0:0.5533 v1:0.4801 v2:0.4578 v3:0.6056 v4:0.5533 v5:0.4801 v6:0.4578 v7:0.6056] +step:400/2000 train_loss:2.4491 train_time:138896ms step_avg:347.24ms +step:400 shared0_alpha:mean=0.418,std=0.097 shared1_alpha:mean=0.464,std=0.053 shared2_alpha:mean=0.541,std=0.052 shared3_alpha:mean=0.624,std=0.058 eff_mlp_scale:[v0:1.5703 v1:1.5681 v2:1.6243 v3:2.2638 v4:1.5703 v5:1.5681 v6:1.6243 v7:2.2638] eff_attn_scale:[v0:0.2862 v1:0.2282 v2:0.2243 v3:0.3956 v4:0.2862 v5:0.2282 v6:0.2243 v7:0.3956] +step:500/2000 val_loss:2.5583 val_bpb:1.5152 train_time:174277ms step_avg:348.55ms +step:600/2000 train_loss:2.5712 train_time:208113ms step_avg:346.85ms +step:600 shared0_alpha:mean=0.400,std=0.110 shared1_alpha:mean=0.463,std=0.055 shared2_alpha:mean=0.555,std=0.054 shared3_alpha:mean=0.656,std=0.059 eff_mlp_scale:[v0:1.8208 v1:1.7482 v2:1.7596 v3:2.6260 v4:1.8208 v5:1.7482 v6:1.7596 v7:2.6260] eff_attn_scale:[v0:0.1592 v1:0.1400 v2:0.1394 v3:0.2973 v4:0.1592 v5:0.1400 v6:0.1394 v7:0.2973] +step:800/2000 train_loss:2.4318 train_time:277814ms step_avg:347.27ms +step:800 shared0_alpha:mean=0.390,std=0.116 shared1_alpha:mean=0.463,std=0.054 shared2_alpha:mean=0.559,std=0.056 shared3_alpha:mean=0.675,std=0.062 eff_mlp_scale:[v0:2.0144 v1:1.8876 v2:1.8803 v3:2.9446 v4:2.0144 v5:1.8876 v6:1.8803 v7:2.9446] eff_attn_scale:[v0:0.1112 v1:0.1048 v2:0.1077 v3:0.2454 v4:0.1112 v5:0.1048 v6:0.1077 v7:0.2454] +step:1000/2000 train_loss:2.4426 train_time:350106ms step_avg:350.11ms +step:1000 shared0_alpha:mean=0.384,std=0.118 shared1_alpha:mean=0.462,std=0.054 shared2_alpha:mean=0.561,std=0.057 shared3_alpha:mean=0.686,std=0.063 eff_mlp_scale:[v0:2.1762 v1:2.0127 v2:1.9935 v3:3.2101 v4:2.1762 v5:2.0127 v6:1.9935 v7:3.2101] eff_attn_scale:[v0:0.0922 v1:0.0931 v2:0.0958 v3:0.2220 v4:0.0922 v5:0.0931 v6:0.0958 v7:0.2220] +step:1000/2000 val_loss:2.4135 val_bpb:1.4294 train_time:350114ms step_avg:350.11ms +step:1200/2000 train_loss:2.3686 train_time:421961ms step_avg:351.63ms +step:1200 shared0_alpha:mean=0.379,std=0.118 shared1_alpha:mean=0.462,std=0.054 shared2_alpha:mean=0.561,std=0.058 shared3_alpha:mean=0.693,std=0.064 eff_mlp_scale:[v0:2.2918 v1:2.1146 v2:2.0897 v3:3.4150 v4:2.2918 v5:2.1146 v6:2.0897 v7:3.4150] eff_attn_scale:[v0:0.0856 v1:0.0888 v2:0.0924 v3:0.2146 v4:0.0856 v5:0.0888 v6:0.0924 v7:0.2146] +step:1400/2000 train_loss:2.3984 train_time:493356ms step_avg:352.40ms +step:1400 shared0_alpha:mean=0.376,std=0.118 shared1_alpha:mean=0.463,std=0.054 shared2_alpha:mean=0.562,std=0.058 shared3_alpha:mean=0.698,std=0.064 eff_mlp_scale:[v0:2.3832 v1:2.1902 v2:2.1642 v3:3.5697 v4:2.3832 v5:2.1902 v6:2.1642 v7:3.5697] eff_attn_scale:[v0:0.0824 v1:0.0874 v2:0.0920 v3:0.2160 v4:0.0824 v5:0.0874 v6:0.0920 v7:0.2160] +step:1500/2000 val_loss:2.3382 val_bpb:1.3848 train_time:529158ms step_avg:352.77ms +step:1600/2000 train_loss:2.2849 train_time:564525ms step_avg:352.83ms +step:1600 shared0_alpha:mean=0.373,std=0.117 shared1_alpha:mean=0.463,std=0.054 shared2_alpha:mean=0.561,std=0.058 shared3_alpha:mean=0.700,std=0.064 eff_mlp_scale:[v0:2.4401 v1:2.2435 v2:2.2208 v3:3.6756 v4:2.4401 v5:2.2435 v6:2.2208 v7:3.6756] eff_attn_scale:[v0:0.0813 v1:0.0872 v2:0.0934 v3:0.2197 v4:0.0813 v5:0.0872 v6:0.0934 v7:0.2197] +step:1800/2000 train_loss:2.3237 train_time:634815ms step_avg:352.68ms +step:1800 shared0_alpha:mean=0.372,std=0.116 shared1_alpha:mean=0.464,std=0.053 shared2_alpha:mean=0.561,std=0.059 shared3_alpha:mean=0.703,std=0.064 eff_mlp_scale:[v0:2.4762 v1:2.2771 v2:2.2578 v3:3.7424 v4:2.4762 v5:2.2771 v6:2.2578 v7:3.7424] eff_attn_scale:[v0:0.0812 v1:0.0872 v2:0.0939 v3:0.2223 v4:0.0812 v5:0.0872 v6:0.0939 v7:0.2223] +step:2000/2000 train_loss:2.3380 train_time:705681ms step_avg:352.84ms +step:2000 shared0_alpha:mean=0.372,std=0.116 shared1_alpha:mean=0.464,std=0.053 shared2_alpha:mean=0.562,std=0.059 shared3_alpha:mean=0.703,std=0.064 eff_mlp_scale:[v0:2.4867 v1:2.2894 v2:2.2707 v3:3.7655 v4:2.4867 v5:2.2894 v6:2.2707 v7:3.7655] eff_attn_scale:[v0:0.0806 v1:0.0867 v2:0.0938 v3:0.2242 v4:0.0806 v5:0.0867 v6:0.0938 v7:0.2242] +step:2000/2000 val_loss:2.2907 val_bpb:1.3567 train_time:705691ms step_avg:352.85ms +peak memory allocated: 7944 MiB reserved: 8000 MiB +Serialized model: 30449700 bytes +Code size: 56522 bytes +Total submission size: 30506222 bytes +Serialized model int8+zlib: 6892475 bytes (payload:7915648 raw_torch:7936093 payload_ratio:3.84x) +Total submission size int8+zlib: 6948997 bytes +final_int8_zlib_roundtrip val_loss:2.2942 val_bpb:1.3587 eval_time:8614ms +final_int8_zlib_roundtrip_exact val_loss:2.29415207 val_bpb:1.35872608 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Cp.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Cp.txt new file mode 100644 index 0000000000..7141c4f6a1 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_Cp.txt @@ -0,0 +1,1412 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return (self.timestep_scale.attn_gamma[v], + self.timestep_scale.mlp_gamma[v]) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 13:55:18 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 44C P0 114W / 700W | 1184MiB / 81559MiB | 5% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:7870496 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:2000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +recurrence:enabled prelude:0 shared:4 loops:2 coda:0 effective_layers:8 +peri_norm:False birkhoff_mix:True +timestep_scale:disabled +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/2000 val_loss:6.9365 val_bpb:4.1082 train_time:0ms step_avg:0.02ms +step:1/2000 train_loss:6.9365 train_time:347ms step_avg:346.70ms +step:2/2000 train_loss:18.2700 train_time:680ms step_avg:339.85ms +step:3/2000 train_loss:10.1220 train_time:1057ms step_avg:352.46ms +step:4/2000 train_loss:6.5829 train_time:1383ms step_avg:345.85ms +step:5/2000 train_loss:6.2033 train_time:1766ms step_avg:353.14ms +step:6/2000 train_loss:6.3199 train_time:2071ms step_avg:345.16ms +step:7/2000 train_loss:6.1400 train_time:2463ms step_avg:351.90ms +step:8/2000 train_loss:6.1059 train_time:2800ms step_avg:350.00ms +step:9/2000 train_loss:6.0560 train_time:3192ms step_avg:354.67ms +step:10/2000 train_loss:5.9825 train_time:3566ms step_avg:356.55ms +step:200/2000 train_loss:2.8731 train_time:69351ms step_avg:346.76ms +step:200 shared0_alpha:mean=0.505,std=0.057 shared1_alpha:mean=0.478,std=0.049 shared2_alpha:mean=0.494,std=0.047 shared3_alpha:mean=0.531,std=0.049 eff_mlp_scale:[v0:0.9650 v1:0.9506 v2:0.9618 v3:0.9589 v4:0.9650 v5:0.9506 v6:0.9618 v7:0.9589] eff_attn_scale:[v0:0.6818 v1:0.7948 v2:0.7206 v3:1.0069 v4:0.6818 v5:0.7948 v6:0.7206 v7:1.0069] +step:400/2000 train_loss:2.4855 train_time:140452ms step_avg:351.13ms +step:400 shared0_alpha:mean=0.502,std=0.082 shared1_alpha:mean=0.461,std=0.054 shared2_alpha:mean=0.475,std=0.054 shared3_alpha:mean=0.523,std=0.056 eff_mlp_scale:[v0:0.7822 v1:0.7088 v2:0.6497 v3:0.6039 v4:0.7822 v5:0.7088 v6:0.6497 v7:0.6039] eff_attn_scale:[v0:0.3659 v1:0.5980 v2:0.5223 v3:0.7817 v4:0.3659 v5:0.5980 v6:0.5223 v7:0.7817] +step:500/2000 val_loss:2.5889 val_bpb:1.5333 train_time:176384ms step_avg:352.77ms +step:600/2000 train_loss:2.5995 train_time:211512ms step_avg:352.52ms +step:600 shared0_alpha:mean=0.501,std=0.097 shared1_alpha:mean=0.464,std=0.057 shared2_alpha:mean=0.467,std=0.055 shared3_alpha:mean=0.517,std=0.062 eff_mlp_scale:[v0:0.5964 v1:0.4822 v2:0.3997 v3:0.4283 v4:0.5964 v5:0.4822 v6:0.3997 v7:0.4283] eff_attn_scale:[v0:0.2057 v1:0.5266 v2:0.4490 v3:0.6647 v4:0.2057 v5:0.5266 v6:0.4490 v7:0.6647] +step:800/2000 train_loss:2.4550 train_time:280269ms step_avg:350.34ms +step:800 shared0_alpha:mean=0.499,std=0.103 shared1_alpha:mean=0.474,std=0.063 shared2_alpha:mean=0.461,std=0.055 shared3_alpha:mean=0.509,std=0.066 eff_mlp_scale:[v0:0.4717 v1:0.3474 v2:0.2764 v3:0.3463 v4:0.4717 v5:0.3474 v6:0.2764 v7:0.3463] eff_attn_scale:[v0:0.1494 v1:0.4972 v2:0.4206 v3:0.5973 v4:0.1494 v5:0.4972 v6:0.4206 v7:0.5973] +step:1000/2000 train_loss:2.4634 train_time:346011ms step_avg:346.01ms +step:1000 shared0_alpha:mean=0.497,std=0.105 shared1_alpha:mean=0.483,std=0.066 shared2_alpha:mean=0.457,std=0.054 shared3_alpha:mean=0.501,std=0.068 eff_mlp_scale:[v0:0.3997 v1:0.2834 v2:0.2289 v3:0.3158 v4:0.3997 v5:0.2834 v6:0.2289 v7:0.3158] eff_attn_scale:[v0:0.1248 v1:0.4907 v2:0.4184 v3:0.5679 v4:0.1248 v5:0.4907 v6:0.4184 v7:0.5679] +step:1000/2000 val_loss:2.4357 val_bpb:1.4425 train_time:346023ms step_avg:346.02ms +step:1200/2000 train_loss:2.3851 train_time:415148ms step_avg:345.96ms +step:1200 shared0_alpha:mean=0.492,std=0.107 shared1_alpha:mean=0.490,std=0.069 shared2_alpha:mean=0.453,std=0.054 shared3_alpha:mean=0.493,std=0.068 eff_mlp_scale:[v0:0.3644 v1:0.2597 v2:0.2132 v3:0.3067 v4:0.3644 v5:0.2597 v6:0.2132 v7:0.3067] eff_attn_scale:[v0:0.1151 v1:0.4911 v2:0.4271 v3:0.5622 v4:0.1151 v5:0.4911 v6:0.4271 v7:0.5622] +step:1400/2000 train_loss:2.4133 train_time:486448ms step_avg:347.46ms +step:1400 shared0_alpha:mean=0.486,std=0.107 shared1_alpha:mean=0.494,std=0.070 shared2_alpha:mean=0.450,std=0.053 shared3_alpha:mean=0.487,std=0.069 eff_mlp_scale:[v0:0.3472 v1:0.2516 v2:0.2123 v3:0.3056 v4:0.3472 v5:0.2516 v6:0.2123 v7:0.3056] eff_attn_scale:[v0:0.1105 v1:0.4947 v2:0.4372 v3:0.5642 v4:0.1105 v5:0.4947 v6:0.4372 v7:0.5642] +step:1500/2000 val_loss:2.3536 val_bpb:1.3939 train_time:520882ms step_avg:347.25ms +step:1600/2000 train_loss:2.2980 train_time:555585ms step_avg:347.24ms +step:1600 shared0_alpha:mean=0.480,std=0.108 shared1_alpha:mean=0.496,std=0.070 shared2_alpha:mean=0.448,std=0.052 shared3_alpha:mean=0.482,std=0.069 eff_mlp_scale:[v0:0.3407 v1:0.2506 v2:0.2140 v3:0.3073 v4:0.3407 v5:0.2506 v6:0.2140 v7:0.3073] eff_attn_scale:[v0:0.1080 v1:0.4969 v2:0.4438 v3:0.5650 v4:0.1080 v5:0.4969 v6:0.4438 v7:0.5650] +step:1800/2000 train_loss:2.3376 train_time:625284ms step_avg:347.38ms +step:1800 shared0_alpha:mean=0.476,std=0.107 shared1_alpha:mean=0.496,std=0.069 shared2_alpha:mean=0.446,std=0.052 shared3_alpha:mean=0.480,std=0.068 eff_mlp_scale:[v0:0.3387 v1:0.2520 v2:0.2175 v3:0.3093 v4:0.3387 v5:0.2520 v6:0.2175 v7:0.3093] eff_attn_scale:[v0:0.1080 v1:0.4949 v2:0.4465 v3:0.5652 v4:0.1080 v5:0.4949 v6:0.4465 v7:0.5652] +step:2000/2000 train_loss:2.3489 train_time:693883ms step_avg:346.94ms +step:2000 shared0_alpha:mean=0.475,std=0.107 shared1_alpha:mean=0.496,std=0.069 shared2_alpha:mean=0.445,std=0.052 shared3_alpha:mean=0.479,std=0.068 eff_mlp_scale:[v0:0.3389 v1:0.2536 v2:0.2190 v3:0.3100 v4:0.3389 v5:0.2536 v6:0.2190 v7:0.3100] eff_attn_scale:[v0:0.1074 v1:0.4913 v2:0.4451 v3:0.5635 v4:0.1074 v5:0.4913 v6:0.4451 v7:0.5635] +step:2000/2000 val_loss:2.3023 val_bpb:1.3636 train_time:693892ms step_avg:346.95ms +peak memory allocated: 8394 MiB reserved: 8752 MiB +Serialized model: 30449636 bytes +Code size: 56522 bytes +Total submission size: 30506158 bytes +Serialized model int8+zlib: 6834075 bytes (payload:7915648 raw_torch:7936093 payload_ratio:3.84x) +Total submission size int8+zlib: 6890597 bytes +final_int8_zlib_roundtrip val_loss:2.3064 val_bpb:1.3660 eval_time:8793ms +final_int8_zlib_roundtrip_exact val_loss:2.30641545 val_bpb:1.36598914 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_D.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_D.txt new file mode 100644 index 0000000000..13a49456e2 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_D.txt @@ -0,0 +1,1412 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return (self.timestep_scale.attn_gamma[v], + self.timestep_scale.mlp_gamma[v]) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 14:08:40 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 44C P0 121W / 700W | 1184MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:7878688 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:2000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +recurrence:enabled prelude:0 shared:4 loops:2 coda:0 effective_layers:8 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:8192 +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/2000 val_loss:6.9365 val_bpb:4.1082 train_time:0ms step_avg:0.02ms +step:1/2000 train_loss:6.9365 train_time:291ms step_avg:291.48ms +step:2/2000 train_loss:9.4113 train_time:649ms step_avg:324.46ms +step:3/2000 train_loss:7.6728 train_time:985ms step_avg:328.42ms +step:4/2000 train_loss:9.1962 train_time:1347ms step_avg:336.85ms +step:5/2000 train_loss:8.7846 train_time:1661ms step_avg:332.25ms +step:6/2000 train_loss:7.8362 train_time:1994ms step_avg:332.37ms +step:7/2000 train_loss:7.0229 train_time:2351ms step_avg:335.89ms +step:8/2000 train_loss:6.2804 train_time:2669ms step_avg:333.61ms +step:9/2000 train_loss:5.6827 train_time:3009ms step_avg:334.36ms +step:10/2000 train_loss:5.3594 train_time:3357ms step_avg:335.65ms +step:200/2000 train_loss:2.7647 train_time:70768ms step_avg:353.84ms +step:200 shared0_alpha:mean=0.451,std=0.059 shared1_alpha:mean=0.467,std=0.043 shared2_alpha:mean=0.511,std=0.043 shared3_alpha:mean=0.567,std=0.042 eff_mlp_scale:[v0:25.1685 v1:29.1578 v2:31.1889 v3:36.7026 v4:29.6679 v5:31.2195 v6:35.2844 v7:55.6458] eff_attn_scale:[v0:11.0568 v1:10.5124 v2:9.0145 v3:8.0936 v4:10.4333 v5:9.2082 v6:9.4401 v7:13.8619] +step:400/2000 train_loss:2.4320 train_time:143194ms step_avg:357.99ms +step:400 shared0_alpha:mean=0.448,std=0.071 shared1_alpha:mean=0.460,std=0.050 shared2_alpha:mean=0.527,std=0.046 shared3_alpha:mean=0.603,std=0.044 eff_mlp_scale:[v0:33.0219 v1:39.2284 v2:36.9032 v3:41.1890 v4:40.0863 v5:38.0648 v6:39.4308 v7:68.7217] eff_attn_scale:[v0:5.8223 v1:5.6057 v2:5.1954 v3:4.7406 v4:5.4286 v5:4.6981 v6:4.5720 v7:9.0038] +step:500/2000 val_loss:2.5502 val_bpb:1.5104 train_time:178063ms step_avg:356.13ms +step:600/2000 train_loss:2.5637 train_time:214822ms step_avg:358.04ms +step:600 shared0_alpha:mean=0.452,std=0.076 shared1_alpha:mean=0.457,std=0.053 shared2_alpha:mean=0.536,std=0.047 shared3_alpha:mean=0.632,std=0.047 eff_mlp_scale:[v0:38.5629 v1:45.6997 v2:40.7924 v3:45.3334 v4:47.1324 v5:42.3342 v6:42.5507 v7:82.4683] eff_attn_scale:[v0:3.2159 v1:3.3784 v2:3.2837 v3:3.2140 v4:3.0989 v5:2.8378 v6:2.7142 v7:6.2666] +step:800/2000 train_loss:2.4250 train_time:285434ms step_avg:356.79ms +step:800 shared0_alpha:mean=0.456,std=0.078 shared1_alpha:mean=0.454,std=0.054 shared2_alpha:mean=0.537,std=0.048 shared3_alpha:mean=0.650,std=0.050 eff_mlp_scale:[v0:42.6849 v1:50.6140 v2:43.8185 v3:49.3427 v4:52.2748 v5:45.4417 v6:45.6367 v7:94.0352] eff_attn_scale:[v0:2.1032 v1:2.3460 v2:2.3893 v3:2.4899 v4:2.1334 v5:2.0353 v6:1.8992 v7:4.7556] +step:1000/2000 train_loss:2.4386 train_time:356052ms step_avg:356.05ms +step:1000 shared0_alpha:mean=0.457,std=0.080 shared1_alpha:mean=0.452,std=0.055 shared2_alpha:mean=0.538,std=0.049 shared3_alpha:mean=0.662,std=0.051 eff_mlp_scale:[v0:46.1172 v1:54.6434 v2:46.5893 v3:53.5488 v4:56.6695 v5:48.3384 v6:48.2733 v7:104.3655] eff_attn_scale:[v0:1.6404 v1:1.9416 v2:2.0487 v3:2.1097 v4:1.7533 v5:1.7041 v6:1.6097 v7:4.1180] +step:1000/2000 val_loss:2.4105 val_bpb:1.4276 train_time:356062ms step_avg:356.06ms +step:1200/2000 train_loss:2.3643 train_time:426775ms step_avg:355.65ms +step:1200 shared0_alpha:mean=0.457,std=0.081 shared1_alpha:mean=0.451,std=0.054 shared2_alpha:mean=0.537,std=0.050 shared3_alpha:mean=0.668,std=0.053 eff_mlp_scale:[v0:48.5537 v1:57.1165 v2:48.9482 v3:57.1080 v4:59.7892 v5:50.4660 v6:50.6758 v7:112.5027] eff_attn_scale:[v0:1.4341 v1:1.7408 v2:1.8800 v3:1.9912 v4:1.5899 v5:1.5747 v6:1.4733 v7:3.9235] +step:1400/2000 train_loss:2.3950 train_time:494246ms step_avg:353.03ms +step:1400 shared0_alpha:mean=0.456,std=0.083 shared1_alpha:mean=0.450,std=0.054 shared2_alpha:mean=0.537,std=0.050 shared3_alpha:mean=0.673,std=0.054 eff_mlp_scale:[v0:50.1999 v1:59.6011 v2:50.9493 v3:59.7882 v4:62.2889 v5:52.8010 v6:52.9088 v7:118.3924] eff_attn_scale:[v0:1.3356 v1:1.6715 v2:1.8183 v3:1.9511 v4:1.5178 v5:1.5213 v6:1.4362 v7:3.9022] +step:1500/2000 val_loss:2.3366 val_bpb:1.3839 train_time:528947ms step_avg:352.63ms +step:1600/2000 train_loss:2.2818 train_time:564663ms step_avg:352.91ms +step:1600 shared0_alpha:mean=0.456,std=0.082 shared1_alpha:mean=0.450,std=0.054 shared2_alpha:mean=0.536,std=0.050 shared3_alpha:mean=0.675,std=0.054 eff_mlp_scale:[v0:51.5643 v1:60.8835 v2:52.0868 v3:61.9811 v4:64.0395 v5:53.9833 v6:54.0749 v7:122.1392] eff_attn_scale:[v0:1.2953 v1:1.6306 v2:1.7994 v3:1.9374 v4:1.5001 v5:1.5012 v6:1.4251 v7:3.9731] +step:1800/2000 train_loss:2.3240 train_time:639169ms step_avg:355.09ms +step:1800 shared0_alpha:mean=0.455,std=0.083 shared1_alpha:mean=0.451,std=0.054 shared2_alpha:mean=0.536,std=0.050 shared3_alpha:mean=0.677,std=0.054 eff_mlp_scale:[v0:52.1812 v1:61.6255 v2:53.2409 v3:63.4873 v4:64.8057 v5:54.6413 v6:55.2576 v7:125.1164] eff_attn_scale:[v0:1.2688 v1:1.6148 v2:1.7941 v3:1.9257 v4:1.4843 v5:1.4921 v6:1.4379 v7:4.0094] +step:2000/2000 train_loss:2.3373 train_time:713473ms step_avg:356.74ms +step:2000 shared0_alpha:mean=0.455,std=0.083 shared1_alpha:mean=0.451,std=0.054 shared2_alpha:mean=0.536,std=0.050 shared3_alpha:mean=0.677,std=0.054 eff_mlp_scale:[v0:52.3886 v1:61.9223 v2:53.5504 v3:63.9752 v4:65.0633 v5:54.9045 v6:55.5788 v7:126.0779] eff_attn_scale:[v0:1.2709 v1:1.6105 v2:1.8114 v3:1.9359 v4:1.4867 v5:1.4941 v6:1.4412 v7:4.0514] +step:2000/2000 val_loss:2.2903 val_bpb:1.3565 train_time:713482ms step_avg:356.74ms +peak memory allocated: 7944 MiB reserved: 8054 MiB +Serialized model: 30466726 bytes +Code size: 56522 bytes +Total submission size: 30523248 bytes +Serialized model int8+zlib: 6911191 bytes (payload:7932032 raw_torch:7953175 payload_ratio:3.84x) +Total submission size int8+zlib: 6967713 bytes +final_int8_zlib_roundtrip val_loss:2.2936 val_bpb:1.3584 eval_time:8654ms +final_int8_zlib_roundtrip_exact val_loss:2.29360682 val_bpb:1.35840315 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_E.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_E.txt new file mode 100644 index 0000000000..cc22a70c94 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_E.txt @@ -0,0 +1,1412 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return (self.timestep_scale.attn_gamma[v], + self.timestep_scale.mlp_gamma[v]) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 14:22:31 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 45C P0 113W / 700W | 1184MiB / 81559MiB | 8% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:9715240 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:2000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +recurrence:enabled prelude:1 shared:3 loops:2 coda:1 effective_layers:8 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:8192 +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/2000 val_loss:6.9352 val_bpb:4.1074 train_time:0ms step_avg:0.01ms +step:1/2000 train_loss:6.9352 train_time:296ms step_avg:295.64ms +step:2/2000 train_loss:9.4553 train_time:648ms step_avg:324.20ms +step:3/2000 train_loss:7.2992 train_time:1010ms step_avg:336.55ms +step:4/2000 train_loss:8.8532 train_time:1352ms step_avg:337.99ms +step:5/2000 train_loss:8.1834 train_time:1682ms step_avg:336.49ms +step:6/2000 train_loss:7.1723 train_time:2051ms step_avg:341.88ms +step:7/2000 train_loss:6.3969 train_time:2369ms step_avg:338.43ms +step:8/2000 train_loss:5.9703 train_time:2750ms step_avg:343.77ms +step:9/2000 train_loss:5.4559 train_time:3138ms step_avg:348.69ms +step:10/2000 train_loss:5.1654 train_time:3453ms step_avg:345.32ms +step:200/2000 train_loss:2.7169 train_time:72061ms step_avg:360.31ms +step:200 shared0_alpha:mean=0.479,std=0.048 shared1_alpha:mean=0.500,std=0.038 shared2_alpha:mean=0.523,std=0.040 eff_mlp_scale:[v0:36.4650 v1:27.4559 v2:30.2393 v3:31.4992 v4:31.2329 v5:33.4304 v6:34.2923 v7:58.1287] eff_attn_scale:[v0:15.2916 v1:8.9235 v2:11.1147 v3:9.3044 v4:9.8357 v5:10.8177 v6:9.3044 v7:15.8753] +step:400/2000 train_loss:2.4141 train_time:143285ms step_avg:358.21ms +step:400 shared0_alpha:mean=0.490,std=0.051 shared1_alpha:mean=0.518,std=0.040 shared2_alpha:mean=0.545,std=0.045 eff_mlp_scale:[v0:46.2288 v1:35.3427 v2:38.5586 v3:38.1938 v4:41.0377 v5:40.0807 v6:37.0414 v7:72.7120] eff_attn_scale:[v0:6.4834 v1:4.9410 v2:5.8552 v3:5.0616 v4:5.5340 v5:5.4494 v6:4.6702 v7:9.2671] +step:500/2000 val_loss:2.5289 val_bpb:1.4978 train_time:178887ms step_avg:357.77ms +step:600/2000 train_loss:2.5428 train_time:213869ms step_avg:356.45ms +step:600 shared0_alpha:mean=0.497,std=0.053 shared1_alpha:mean=0.527,std=0.042 shared2_alpha:mean=0.559,std=0.048 eff_mlp_scale:[v0:52.3816 v1:40.3661 v2:44.0226 v3:42.8452 v4:47.9347 v5:43.8444 v6:38.7484 v7:86.7040] eff_attn_scale:[v0:3.0892 v1:2.9213 v2:3.4563 v3:3.2685 v4:3.3095 v5:3.1421 v6:2.8944 v7:6.0615] +step:800/2000 train_loss:2.4040 train_time:285182ms step_avg:356.48ms +step:800 shared0_alpha:mean=0.501,std=0.055 shared1_alpha:mean=0.531,std=0.044 shared2_alpha:mean=0.565,std=0.050 eff_mlp_scale:[v0:58.2959 v1:44.8684 v2:48.0435 v3:46.9211 v4:52.6304 v5:46.7500 v6:41.1000 v7:98.2646] eff_attn_scale:[v0:1.9333 v1:2.1156 v2:2.4856 v3:2.4671 v4:2.3894 v5:2.2302 v6:2.1217 v7:4.6103] +step:1000/2000 train_loss:2.4171 train_time:355579ms step_avg:355.58ms +step:1000 shared0_alpha:mean=0.503,std=0.056 shared1_alpha:mean=0.533,std=0.046 shared2_alpha:mean=0.569,std=0.052 eff_mlp_scale:[v0:63.1870 v1:48.3672 v2:51.4418 v3:50.2091 v4:57.0183 v5:49.9176 v6:43.6601 v7:107.7019] eff_attn_scale:[v0:1.4890 v1:1.7390 v2:2.0680 v3:2.1136 v4:1.9896 v5:1.8323 v6:1.8343 v7:3.9133] +step:1000/2000 val_loss:2.3898 val_bpb:1.4153 train_time:355589ms step_avg:355.59ms +step:1200/2000 train_loss:2.3423 train_time:424599ms step_avg:353.83ms +step:1200 shared0_alpha:mean=0.504,std=0.058 shared1_alpha:mean=0.535,std=0.046 shared2_alpha:mean=0.571,std=0.053 eff_mlp_scale:[v0:66.6423 v1:51.1894 v2:53.9289 v3:52.6932 v4:60.6989 v5:52.3657 v6:45.9664 v7:115.6749] eff_attn_scale:[v0:1.2909 v1:1.6083 v2:1.8758 v3:1.9875 v4:1.8481 v5:1.6882 v6:1.7244 v7:3.7206] +step:1400/2000 train_loss:2.3742 train_time:497840ms step_avg:355.60ms +step:1400 shared0_alpha:mean=0.504,std=0.059 shared1_alpha:mean=0.535,std=0.047 shared2_alpha:mean=0.572,std=0.054 eff_mlp_scale:[v0:68.8389 v1:53.4264 v2:55.9491 v3:55.1292 v4:63.3662 v5:54.7502 v6:48.4294 v7:121.9842] eff_attn_scale:[v0:1.2221 v1:1.5335 v2:1.8114 v3:1.9478 v4:1.7742 v5:1.6409 v6:1.7080 v7:3.7654] +step:1500/2000 val_loss:2.3121 val_bpb:1.3693 train_time:535969ms step_avg:357.31ms +step:1600/2000 train_loss:2.2579 train_time:572966ms step_avg:358.10ms +step:1600 shared0_alpha:mean=0.505,std=0.059 shared1_alpha:mean=0.536,std=0.048 shared2_alpha:mean=0.573,std=0.054 eff_mlp_scale:[v0:70.6981 v1:54.7869 v2:57.7289 v3:56.4746 v4:65.3228 v5:56.1027 v6:49.8535 v7:125.3435] eff_attn_scale:[v0:1.1746 v1:1.5009 v2:1.7788 v3:1.9207 v4:1.7601 v5:1.6235 v6:1.7170 v7:3.8943] +step:1800/2000 train_loss:2.2974 train_time:644736ms step_avg:358.19ms +step:1800 shared0_alpha:mean=0.506,std=0.059 shared1_alpha:mean=0.536,std=0.048 shared2_alpha:mean=0.574,std=0.054 eff_mlp_scale:[v0:71.3269 v1:56.0250 v2:58.5451 v3:57.8082 v4:66.7167 v5:57.3082 v6:51.0771 v7:127.9043] eff_attn_scale:[v0:1.1514 v1:1.5082 v2:1.7806 v3:1.9188 v4:1.7687 v5:1.6316 v6:1.7357 v7:3.9647] +step:2000/2000 train_loss:2.3128 train_time:715663ms step_avg:357.83ms +step:2000 shared0_alpha:mean=0.506,std=0.059 shared1_alpha:mean=0.536,std=0.048 shared2_alpha:mean=0.574,std=0.054 eff_mlp_scale:[v0:71.6224 v1:56.4033 v2:58.8769 v3:58.1747 v4:67.1673 v5:57.6330 v6:51.4009 v7:128.9102] eff_attn_scale:[v0:1.1474 v1:1.5013 v2:1.7832 v3:1.9260 v4:1.7686 v5:1.6340 v6:1.7495 v7:4.0299] +step:2000/2000 val_loss:2.2641 val_bpb:1.3409 train_time:715670ms step_avg:357.83ms +peak memory allocated: 8036 MiB reserved: 8052 MiB +Serialized model: 37816688 bytes +Code size: 56522 bytes +Total submission size: 37873210 bytes +Serialized model int8+zlib: 8505722 bytes (payload:9779360 raw_torch:9805479 payload_ratio:3.86x) +Total submission size int8+zlib: 8562244 bytes +final_int8_zlib_roundtrip val_loss:2.2672 val_bpb:1.3428 eval_time:8678ms +final_int8_zlib_roundtrip_exact val_loss:2.26721900 val_bpb:1.34277480 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_F.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_F.txt new file mode 100644 index 0000000000..034cef4d14 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s1_F.txt @@ -0,0 +1,1412 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return (self.timestep_scale.attn_gamma[v], + self.timestep_scale.mlp_gamma[v]) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 14:36:13 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 565.57.01 Driver Version: 565.57.01 CUDA Version: 12.7 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 44C P0 125W / 700W | 1184MiB / 81559MiB | 1% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:7878688 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:2000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +recurrence:enabled prelude:1 shared:2 loops:3 coda:1 effective_layers:8 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:8192 +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/2000 val_loss:6.9365 val_bpb:4.1082 train_time:0ms step_avg:0.02ms +step:1/2000 train_loss:6.9365 train_time:298ms step_avg:298.19ms +step:2/2000 train_loss:9.4082 train_time:676ms step_avg:338.23ms +step:3/2000 train_loss:7.7045 train_time:1007ms step_avg:335.65ms +step:4/2000 train_loss:9.4127 train_time:1381ms step_avg:345.31ms +step:5/2000 train_loss:9.0550 train_time:1699ms step_avg:339.81ms +step:6/2000 train_loss:7.9648 train_time:2087ms step_avg:347.88ms +step:7/2000 train_loss:7.0051 train_time:2479ms step_avg:354.19ms +step:8/2000 train_loss:6.2667 train_time:2813ms step_avg:351.66ms +step:9/2000 train_loss:5.7315 train_time:3190ms step_avg:354.50ms +step:10/2000 train_loss:5.3680 train_time:3583ms step_avg:358.29ms +step:200/2000 train_loss:2.7398 train_time:68736ms step_avg:343.68ms +step:200 shared0_alpha:mean=0.496,std=0.043 shared1_alpha:mean=0.508,std=0.037 eff_mlp_scale:[v0:40.2916 v1:26.6734 v2:30.2281 v3:28.5377 v4:31.2756 v5:30.1152 v6:31.4253 v7:55.6959] eff_attn_scale:[v0:15.0683 v1:8.2236 v2:9.7563 v3:7.8548 v4:8.3789 v5:8.8874 v6:9.6032 v7:15.5578] +step:400/2000 train_loss:2.4364 train_time:139693ms step_avg:349.23ms +step:400 shared0_alpha:mean=0.508,std=0.047 shared1_alpha:mean=0.523,std=0.043 eff_mlp_scale:[v0:51.4735 v1:33.5732 v2:38.1169 v3:38.6254 v4:37.6384 v5:36.8327 v6:32.2159 v7:70.9991] eff_attn_scale:[v0:6.4987 v1:4.6038 v2:5.1594 v3:4.4997 v4:4.5495 v5:4.9159 v6:4.6257 v7:9.2896] +step:500/2000 val_loss:2.5537 val_bpb:1.5124 train_time:175331ms step_avg:350.66ms +step:600/2000 train_loss:2.5673 train_time:211587ms step_avg:352.65ms +step:600 shared0_alpha:mean=0.511,std=0.052 shared1_alpha:mean=0.528,std=0.047 eff_mlp_scale:[v0:58.5086 v1:39.0374 v2:44.0241 v3:46.2797 v4:43.3519 v5:42.0403 v6:33.4381 v7:86.0336] eff_attn_scale:[v0:3.0470 v1:2.8121 v2:3.0556 v3:2.8695 v4:2.8899 v5:3.0608 v6:2.7243 v7:5.9753] +step:800/2000 train_loss:2.4278 train_time:283386ms step_avg:354.23ms +step:800 shared0_alpha:mean=0.511,std=0.056 shared1_alpha:mean=0.527,std=0.052 eff_mlp_scale:[v0:65.2341 v1:43.6184 v2:48.6029 v3:52.1930 v4:48.2532 v5:46.2281 v6:34.6164 v7:98.3803] eff_attn_scale:[v0:1.9327 v1:2.0897 v2:2.2566 v3:2.1683 v4:2.1808 v5:2.2626 v6:1.9688 v7:4.4771] +step:1000/2000 train_loss:2.4461 train_time:355607ms step_avg:355.61ms +step:1000 shared0_alpha:mean=0.510,std=0.060 shared1_alpha:mean=0.525,std=0.056 eff_mlp_scale:[v0:71.1827 v1:47.7607 v2:52.5514 v3:57.3128 v4:52.1890 v5:50.2949 v6:36.6048 v7:108.8261] eff_attn_scale:[v0:1.5093 v1:1.7716 v2:1.9197 v3:1.8708 v4:1.8921 v5:1.9133 v6:1.6642 v7:3.7793] +step:1000/2000 val_loss:2.4166 val_bpb:1.4313 train_time:355616ms step_avg:355.62ms +step:1200/2000 train_loss:2.3705 train_time:428372ms step_avg:356.98ms +step:1200 shared0_alpha:mean=0.511,std=0.062 shared1_alpha:mean=0.524,std=0.058 eff_mlp_scale:[v0:75.1625 v1:50.9809 v2:55.7047 v3:60.8547 v4:55.7047 v5:53.6005 v6:38.5073 v7:117.0229] eff_attn_scale:[v0:1.3183 v1:1.6361 v2:1.7786 v3:1.7308 v4:1.7654 v5:1.7984 v6:1.5596 v7:3.5325] +step:1400/2000 train_loss:2.4013 train_time:500785ms step_avg:357.70ms +step:1400 shared0_alpha:mean=0.511,std=0.064 shared1_alpha:mean=0.523,std=0.060 eff_mlp_scale:[v0:78.2572 v1:53.7877 v2:58.3106 v3:63.3040 v4:58.3106 v5:56.6840 v6:40.2803 v7:123.4736] eff_attn_scale:[v0:1.2333 v1:1.5842 v2:1.7134 v3:1.6836 v4:1.7134 v5:1.7366 v6:1.5172 v7:3.5208] +step:1500/2000 val_loss:2.3434 val_bpb:1.3879 train_time:537683ms step_avg:358.46ms +step:1600/2000 train_loss:2.2890 train_time:574449ms step_avg:359.03ms +step:1600 shared0_alpha:mean=0.511,std=0.064 shared1_alpha:mean=0.523,std=0.060 eff_mlp_scale:[v0:79.9718 v1:55.7932 v2:60.1865 v3:65.5147 v4:59.7956 v5:58.7519 v6:41.4271 v7:127.5284] eff_attn_scale:[v0:1.1815 v1:1.5613 v2:1.6850 v3:1.6738 v4:1.7111 v5:1.7333 v6:1.5086 v7:3.5712] +step:1800/2000 train_loss:2.3302 train_time:648573ms step_avg:360.32ms +step:1800 shared0_alpha:mean=0.511,std=0.065 shared1_alpha:mean=0.523,std=0.061 eff_mlp_scale:[v0:81.2547 v1:56.6805 v2:61.2037 v3:66.5567 v4:61.2037 v5:59.6863 v6:42.3259 v7:130.2350] eff_attn_scale:[v0:1.1639 v1:1.5604 v2:1.7016 v3:1.7004 v4:1.7280 v5:1.7604 v6:1.5301 v7:3.5923] +step:2000/2000 train_loss:2.3441 train_time:720478ms step_avg:360.24ms +step:2000 shared0_alpha:mean=0.511,std=0.065 shared1_alpha:mean=0.523,std=0.061 eff_mlp_scale:[v0:81.5544 v1:57.0820 v2:61.6221 v3:67.0281 v4:61.6221 v5:60.1091 v6:42.6153 v7:131.3367] eff_attn_scale:[v0:1.1623 v1:1.5501 v2:1.7035 v3:1.6965 v4:1.7299 v5:1.7564 v6:1.5318 v7:3.6421] +step:2000/2000 val_loss:2.2967 val_bpb:1.3603 train_time:720488ms step_avg:360.24ms +peak memory allocated: 8065 MiB reserved: 8178 MiB +Serialized model: 30466726 bytes +Code size: 56522 bytes +Total submission size: 30523248 bytes +Serialized model int8+zlib: 6906433 bytes (payload:7932032 raw_torch:7953175 payload_ratio:3.84x) +Total submission size int8+zlib: 6962955 bytes +final_int8_zlib_roundtrip val_loss:2.3000 val_bpb:1.3622 eval_time:8592ms +final_int8_zlib_roundtrip_exact val_loss:2.30000746 val_bpb:1.36219397 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_screening.sh b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_screening.sh new file mode 100644 index 0000000000..35ba9f4e8e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_screening.sh @@ -0,0 +1,66 @@ +#!/bin/bash +set -uo pipefail + +SCRIPT="train_gpt.py" +NGPU=${NGPU:-1} +COMMON="SEED=1337 ITERATIONS=2000 VAL_LOSS_EVERY=500 MAX_WALLCLOCK_SECONDS=0 TRAIN_LOG_EVERY=200" +DATA="DATA_PATH=${DATA_PATH:-./data/datasets/fineweb10B_sp1024} TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model VOCAB_SIZE=1024" + +FAILS=0 +SUMMARY="" + +run_experiment() { + local name="$1"; shift + echo "" + echo "=== $name ===" + if "$@"; then + SUMMARY="${SUMMARY} PASS $name"$'\n' + else + SUMMARY="${SUMMARY} FAIL $name (exit $?)"$'\n' + FAILS=$((FAILS + 1)) + fi +} + +# --- Baselines (leaky_relu² matched) --- run first for early signal --- + +run_experiment "Run A': 8L standard (leaky_relu² baseline)" \ + env $COMMON $DATA RUN_ID=s1_Ap NUM_LAYERS=8 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +run_experiment "Run B': 4x2 bare recurrence (leaky_relu² baseline)" \ + env $COMMON $DATA RUN_ID=s1_Bp NUM_LAYERS=8 NUM_SHARED=4 NUM_LOOPS=2 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +# --- Experimental runs --- + +run_experiment "Run C: 4x2 recurrence + peri-norm + birkhoff mix" \ + env $COMMON $DATA RUN_ID=s1_C NUM_LAYERS=8 NUM_SHARED=4 NUM_LOOPS=2 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +run_experiment "Run C': 4x2 recurrence + birkhoff only (no peri-norm)" \ + env $COMMON $DATA RUN_ID=s1_Cp NUM_LAYERS=8 NUM_SHARED=4 NUM_LOOPS=2 \ + USE_PERI_NORM=0 USE_BIRKHOFF_MIX=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +run_experiment "Run D: 4x2 recurrence + peri-norm + birkhoff + timestep scaling" \ + env $COMMON $DATA RUN_ID=s1_D NUM_LAYERS=8 NUM_SHARED=4 NUM_LOOPS=2 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +run_experiment "Run E: 1 prelude + 3x2 shared + 1 coda + all fixes" \ + env $COMMON $DATA RUN_ID=s1_E NUM_LAYERS=8 NUM_PRELUDE=1 NUM_SHARED=3 NUM_LOOPS=2 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +run_experiment "Run F: 1 prelude + 2x3 shared + 1 coda + all fixes (3 loops!)" \ + env $COMMON $DATA RUN_ID=s1_F NUM_LAYERS=8 NUM_PRELUDE=1 NUM_SHARED=2 NUM_LOOPS=3 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +echo "" +echo "===============================" +echo " SCREENING SUMMARY" +echo "===============================" +echo "$SUMMARY" +echo "$FAILS run(s) failed." From 5603ee0aa3cacf380b62507b3c8ec6ddf57978ad Mon Sep 17 00:00:00 2001 From: Alexandr Azizyan Date: Thu, 26 Mar 2026 17:58:23 +0400 Subject: [PATCH 03/10] feat: add full-scale experiment scripts and logs (5 runs, 600s 8xH100) --- .../logs/s2_G.txt | 1688 +++++++++++++++++ .../logs/s2_H.txt | 1643 ++++++++++++++++ .../logs/s2_I.txt | 1643 ++++++++++++++++ .../logs/s2_J.txt | 1586 ++++++++++++++++ .../logs/s2_K.txt | 1587 ++++++++++++++++ .../scripts/run_fullscale.sh | 64 + 6 files changed, 8211 insertions(+) create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_G.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_H.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_I.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_J.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_K.txt create mode 100755 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale.sh diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_G.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_G.txt new file mode 100644 index 0000000000..0d90dddabb --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_G.txt @@ -0,0 +1,1688 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + + def get(self, v: int) -> tuple[Tensor, Tensor]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + return ag, mg + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 18:43:05 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 33C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 31C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 30C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 31C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:9707048 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:3 loops:2 coda:1 effective_layers:8 +peri_norm:True birkhoff_mix:True +timestep_scale:disabled +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9352 val_bpb:4.1074 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9353 train_time:22ms step_avg:22.10ms +step:2/20000 train_loss:9.4937 train_time:58ms step_avg:28.77ms +step:3/20000 train_loss:7.3142 train_time:91ms step_avg:30.24ms +step:4/20000 train_loss:8.9487 train_time:126ms step_avg:31.42ms +step:5/20000 train_loss:8.0620 train_time:160ms step_avg:31.98ms +step:6/20000 train_loss:7.6401 train_time:195ms step_avg:32.44ms +step:7/20000 train_loss:6.3141 train_time:229ms step_avg:32.69ms +step:8/20000 train_loss:5.9700 train_time:264ms step_avg:32.96ms +step:9/20000 train_loss:5.3965 train_time:297ms step_avg:33.04ms +step:10/20000 train_loss:5.1330 train_time:333ms step_avg:33.28ms +step:200/20000 train_loss:2.8022 train_time:6957ms step_avg:34.78ms +step:200 shared0_alpha:mean=0.491,std=0.052 shared1_alpha:mean=0.508,std=0.042 shared2_alpha:mean=0.534,std=0.044 eff_mlp_scale:[v0:1.4084 v1:1.2608 v2:1.3505 v3:1.4336 v4:1.2608 v5:1.3505 v6:1.4336 v7:1.9251] eff_attn_scale:[v0:0.7424 v1:0.4957 v2:0.5418 v3:0.4623 v4:0.4957 v5:0.5418 v6:0.4623 v7:0.7527] +step:200/20000 val_loss:2.7773 val_bpb:1.6449 train_time:7003ms step_avg:35.01ms +step:400/20000 train_loss:2.3827 train_time:13913ms step_avg:34.78ms +step:400 shared0_alpha:mean=0.518,std=0.054 shared1_alpha:mean=0.532,std=0.044 shared2_alpha:mean=0.558,std=0.047 eff_mlp_scale:[v0:1.7554 v1:1.5701 v2:1.6003 v3:1.5795 v4:1.5701 v5:1.6003 v6:1.5795 v7:2.4462] eff_attn_scale:[v0:0.3185 v1:0.2576 v2:0.2570 v3:0.2151 v4:0.2576 v5:0.2570 v6:0.2151 v7:0.4801] +step:400/20000 val_loss:2.5899 val_bpb:1.5339 train_time:13919ms step_avg:34.80ms +step:600/20000 train_loss:2.6007 train_time:20826ms step_avg:34.71ms +step:600 shared0_alpha:mean=0.533,std=0.053 shared1_alpha:mean=0.541,std=0.045 shared2_alpha:mean=0.572,std=0.047 eff_mlp_scale:[v0:1.9835 v1:1.7994 v2:1.7788 v3:1.6967 v4:1.7994 v5:1.7788 v6:1.6967 v7:2.9304] eff_attn_scale:[v0:0.1547 v1:0.1584 v2:0.1518 v3:0.1364 v4:0.1584 v5:0.1518 v6:0.1364 v7:0.3441] +step:600/20000 val_loss:2.4992 val_bpb:1.4802 train_time:20835ms step_avg:34.72ms +step:800/20000 train_loss:2.3559 train_time:27767ms step_avg:34.71ms +step:800 shared0_alpha:mean=0.541,std=0.053 shared1_alpha:mean=0.545,std=0.046 shared2_alpha:mean=0.578,std=0.049 eff_mlp_scale:[v0:2.1978 v1:2.0035 v2:1.9362 v3:1.8210 v4:2.0035 v5:1.9362 v6:1.8210 v7:3.3051] eff_attn_scale:[v0:0.1028 v1:0.1169 v2:0.1136 v3:0.1081 v4:0.1169 v5:0.1136 v6:0.1081 v7:0.2905] +step:800/20000 val_loss:2.4465 val_bpb:1.4490 train_time:27772ms step_avg:34.71ms +step:1000/20000 train_loss:2.4368 train_time:34683ms step_avg:34.68ms +step:1000 shared0_alpha:mean=0.545,std=0.054 shared1_alpha:mean=0.547,std=0.048 shared2_alpha:mean=0.580,std=0.050 eff_mlp_scale:[v0:2.3986 v1:2.1927 v2:2.0915 v3:1.9546 v4:2.1927 v5:2.0915 v6:1.9546 v7:3.6234] eff_attn_scale:[v0:0.0846 v1:0.1006 v2:0.0984 v3:0.0953 v4:0.1006 v5:0.0984 v6:0.0953 v7:0.2664] +step:1000/20000 val_loss:2.4048 val_bpb:1.4242 train_time:34688ms step_avg:34.69ms +step:1200/20000 train_loss:2.4609 train_time:41588ms step_avg:34.66ms +step:1200 shared0_alpha:mean=0.548,std=0.055 shared1_alpha:mean=0.548,std=0.048 shared2_alpha:mean=0.580,std=0.051 eff_mlp_scale:[v0:2.5822 v1:2.3689 v2:2.2370 v3:2.0841 v4:2.3689 v5:2.2370 v6:2.0841 v7:3.9041] eff_attn_scale:[v0:0.0724 v1:0.0906 v2:0.0885 v3:0.0870 v4:0.0906 v5:0.0885 v6:0.0870 v7:0.2534] +step:1200/20000 val_loss:2.3778 val_bpb:1.4083 train_time:41592ms step_avg:34.66ms +step:1400/20000 train_loss:2.5070 train_time:48498ms step_avg:34.64ms +step:1400 shared0_alpha:mean=0.548,std=0.055 shared1_alpha:mean=0.548,std=0.049 shared2_alpha:mean=0.580,std=0.052 eff_mlp_scale:[v0:2.7593 v1:2.5298 v2:2.3766 v3:2.2070 v4:2.5298 v5:2.3766 v6:2.2070 v7:4.1589] eff_attn_scale:[v0:0.0679 v1:0.0850 v2:0.0830 v3:0.0831 v4:0.0850 v5:0.0830 v6:0.0831 v7:0.2448] +step:1400/20000 val_loss:2.3588 val_bpb:1.3970 train_time:48505ms step_avg:34.65ms +step:1600/20000 train_loss:2.1775 train_time:55408ms step_avg:34.63ms +step:1600 shared0_alpha:mean=0.548,std=0.055 shared1_alpha:mean=0.549,std=0.051 shared2_alpha:mean=0.579,std=0.053 eff_mlp_scale:[v0:2.9237 v1:2.6860 v2:2.5072 v3:2.3156 v4:2.6860 v5:2.5072 v6:2.3156 v7:4.3990] eff_attn_scale:[v0:0.0649 v1:0.0802 v2:0.0806 v3:0.0806 v4:0.0802 v5:0.0806 v6:0.0806 v7:0.2380] +step:1600/20000 val_loss:2.3481 val_bpb:1.3907 train_time:55412ms step_avg:34.63ms +step:1800/20000 train_loss:2.2843 train_time:62321ms step_avg:34.62ms +step:1800 shared0_alpha:mean=0.549,std=0.056 shared1_alpha:mean=0.549,std=0.051 shared2_alpha:mean=0.578,std=0.053 eff_mlp_scale:[v0:3.0765 v1:2.8330 v2:2.6350 v3:2.4312 v4:2.8330 v5:2.6350 v6:2.4312 v7:4.6170] eff_attn_scale:[v0:0.0620 v1:0.0774 v2:0.0783 v3:0.0785 v4:0.0774 v5:0.0783 v6:0.0785 v7:0.2297] +step:1800/20000 val_loss:2.3335 val_bpb:1.3820 train_time:62326ms step_avg:34.63ms +step:2000/20000 train_loss:2.3342 train_time:69240ms step_avg:34.62ms +step:2000 shared0_alpha:mean=0.549,std=0.055 shared1_alpha:mean=0.550,std=0.053 shared2_alpha:mean=0.576,std=0.054 eff_mlp_scale:[v0:3.2369 v1:2.9837 v2:2.7575 v3:2.5377 v4:2.9837 v5:2.7575 v6:2.5377 v7:4.8270] eff_attn_scale:[v0:0.0598 v1:0.0756 v2:0.0767 v3:0.0771 v4:0.0756 v5:0.0767 v6:0.0771 v7:0.2229] +step:2000/20000 val_loss:2.3197 val_bpb:1.3738 train_time:69242ms step_avg:34.62ms +step:2200/20000 train_loss:2.1637 train_time:76157ms step_avg:34.62ms +step:2200 shared0_alpha:mean=0.549,std=0.055 shared1_alpha:mean=0.552,std=0.053 shared2_alpha:mean=0.575,std=0.055 eff_mlp_scale:[v0:3.3799 v1:3.1209 v2:2.8755 v3:2.6468 v4:3.1209 v5:2.8755 v6:2.6468 v7:5.0210] eff_attn_scale:[v0:0.0583 v1:0.0747 v2:0.0766 v3:0.0780 v4:0.0747 v5:0.0766 v6:0.0780 v7:0.2220] +step:2200/20000 val_loss:2.3124 val_bpb:1.3695 train_time:76160ms step_avg:34.62ms +step:2400/20000 train_loss:2.2869 train_time:83073ms step_avg:34.61ms +step:2400 shared0_alpha:mean=0.548,std=0.056 shared1_alpha:mean=0.552,std=0.054 shared2_alpha:mean=0.572,std=0.055 eff_mlp_scale:[v0:3.5190 v1:3.2580 v2:2.9902 v3:2.7516 v4:3.2580 v5:2.9902 v6:2.7516 v7:5.2163] eff_attn_scale:[v0:0.0575 v1:0.0737 v2:0.0749 v3:0.0773 v4:0.0737 v5:0.0749 v6:0.0773 v7:0.2222] +step:2400/20000 val_loss:2.3035 val_bpb:1.3643 train_time:83077ms step_avg:34.62ms +step:2600/20000 train_loss:2.5088 train_time:89991ms step_avg:34.61ms +step:2600 shared0_alpha:mean=0.549,std=0.057 shared1_alpha:mean=0.554,std=0.055 shared2_alpha:mean=0.571,std=0.056 eff_mlp_scale:[v0:3.6587 v1:3.3925 v2:3.1036 v3:2.8535 v4:3.3925 v5:3.1036 v6:2.8535 v7:5.3891] eff_attn_scale:[v0:0.0600 v1:0.0754 v2:0.0779 v3:0.0807 v4:0.0754 v5:0.0779 v6:0.0807 v7:0.2240] +step:2600/20000 val_loss:2.3250 val_bpb:1.3770 train_time:89993ms step_avg:34.61ms +step:2800/20000 train_loss:2.3293 train_time:96895ms step_avg:34.61ms +step:2800 shared0_alpha:mean=0.548,std=0.057 shared1_alpha:mean=0.555,std=0.055 shared2_alpha:mean=0.569,std=0.056 eff_mlp_scale:[v0:3.8020 v1:3.5227 v2:3.2193 v3:2.9508 v4:3.5227 v5:3.2193 v6:2.9508 v7:5.5719] eff_attn_scale:[v0:0.0580 v1:0.0742 v2:0.0770 v3:0.0793 v4:0.0742 v5:0.0770 v6:0.0793 v7:0.2237] +step:2800/20000 val_loss:2.2927 val_bpb:1.3579 train_time:96899ms step_avg:34.61ms +step:3000/20000 train_loss:2.3146 train_time:103818ms step_avg:34.61ms +step:3000 shared0_alpha:mean=0.549,std=0.058 shared1_alpha:mean=0.556,std=0.056 shared2_alpha:mean=0.568,std=0.057 eff_mlp_scale:[v0:3.9287 v1:3.6407 v2:3.3236 v3:3.0458 v4:3.6407 v5:3.3236 v6:3.0458 v7:5.7380] eff_attn_scale:[v0:0.0565 v1:0.0729 v2:0.0764 v3:0.0795 v4:0.0729 v5:0.0764 v6:0.0795 v7:0.2250] +step:3000/20000 val_loss:2.2852 val_bpb:1.3534 train_time:103820ms step_avg:34.61ms +step:3200/20000 train_loss:2.2808 train_time:110718ms step_avg:34.60ms +step:3200 shared0_alpha:mean=0.549,std=0.058 shared1_alpha:mean=0.558,std=0.057 shared2_alpha:mean=0.567,std=0.057 eff_mlp_scale:[v0:4.0574 v1:3.7598 v2:3.4281 v3:3.1385 v4:3.7598 v5:3.4281 v6:3.1385 v7:5.8908] eff_attn_scale:[v0:0.0558 v1:0.0734 v2:0.0771 v3:0.0791 v4:0.0734 v5:0.0771 v6:0.0791 v7:0.2251] +step:3200/20000 val_loss:2.2805 val_bpb:1.3507 train_time:110720ms step_avg:34.60ms +step:3400/20000 train_loss:2.2502 train_time:117628ms step_avg:34.60ms +step:3400 shared0_alpha:mean=0.548,std=0.059 shared1_alpha:mean=0.558,std=0.058 shared2_alpha:mean=0.565,std=0.058 eff_mlp_scale:[v0:4.1810 v1:3.8761 v2:3.5315 v3:3.2298 v4:3.8761 v5:3.5315 v6:3.2298 v7:6.0423] eff_attn_scale:[v0:0.0563 v1:0.0730 v2:0.0780 v3:0.0814 v4:0.0730 v5:0.0780 v6:0.0814 v7:0.2280] +step:3400/20000 val_loss:2.2789 val_bpb:1.3497 train_time:117631ms step_avg:34.60ms +step:3600/20000 train_loss:2.2185 train_time:124547ms step_avg:34.60ms +step:3600 shared0_alpha:mean=0.549,std=0.059 shared1_alpha:mean=0.560,std=0.058 shared2_alpha:mean=0.563,std=0.058 eff_mlp_scale:[v0:4.2993 v1:3.9852 v2:3.6294 v3:3.3170 v4:3.9852 v5:3.6294 v6:3.3170 v7:6.1840] eff_attn_scale:[v0:0.0554 v1:0.0733 v2:0.0791 v3:0.0828 v4:0.0733 v5:0.0791 v6:0.0828 v7:0.2299] +step:3600/20000 val_loss:2.2709 val_bpb:1.3450 train_time:124550ms step_avg:34.60ms +step:3800/20000 train_loss:2.3161 train_time:131473ms step_avg:34.60ms +step:3800 shared0_alpha:mean=0.549,std=0.059 shared1_alpha:mean=0.562,std=0.059 shared2_alpha:mean=0.562,std=0.058 eff_mlp_scale:[v0:4.4141 v1:4.0914 v2:3.7205 v3:3.4000 v4:4.0914 v5:3.7205 v6:3.4000 v7:6.3234] eff_attn_scale:[v0:0.0552 v1:0.0751 v2:0.0809 v3:0.0838 v4:0.0751 v5:0.0809 v6:0.0838 v7:0.2323] +step:3800/20000 val_loss:2.2682 val_bpb:1.3434 train_time:131477ms step_avg:34.60ms +step:4000/20000 train_loss:2.2547 train_time:138386ms step_avg:34.60ms +step:4000 shared0_alpha:mean=0.549,std=0.060 shared1_alpha:mean=0.563,std=0.059 shared2_alpha:mean=0.560,std=0.058 eff_mlp_scale:[v0:4.5179 v1:4.1917 v2:3.8160 v3:3.4805 v4:4.1917 v5:3.8160 v6:3.4805 v7:6.4612] eff_attn_scale:[v0:0.0546 v1:0.0752 v2:0.0814 v3:0.0857 v4:0.0752 v5:0.0814 v6:0.0857 v7:0.2331] +step:4000/20000 val_loss:2.2633 val_bpb:1.3405 train_time:138389ms step_avg:34.60ms +step:4200/20000 train_loss:2.2648 train_time:145353ms step_avg:34.61ms +step:4200 shared0_alpha:mean=0.548,std=0.060 shared1_alpha:mean=0.564,std=0.060 shared2_alpha:mean=0.558,std=0.059 eff_mlp_scale:[v0:4.6153 v1:4.2876 v2:3.9019 v3:3.5570 v4:4.2876 v5:3.9019 v6:3.5570 v7:6.5924] eff_attn_scale:[v0:0.0544 v1:0.0772 v2:0.0829 v3:0.0881 v4:0.0772 v5:0.0829 v6:0.0881 v7:0.2358] +step:4200/20000 val_loss:2.2602 val_bpb:1.3386 train_time:145362ms step_avg:34.61ms +step:4400/20000 train_loss:2.2079 train_time:152280ms step_avg:34.61ms +step:4400 shared0_alpha:mean=0.548,std=0.060 shared1_alpha:mean=0.566,std=0.060 shared2_alpha:mean=0.557,std=0.059 eff_mlp_scale:[v0:4.7197 v1:4.3831 v2:3.9860 v3:3.6300 v4:4.3831 v5:3.9860 v6:3.6300 v7:6.7170] eff_attn_scale:[v0:0.0545 v1:0.0798 v2:0.0859 v3:0.0916 v4:0.0798 v5:0.0859 v6:0.0916 v7:0.2384] +step:4400/20000 val_loss:2.2612 val_bpb:1.3392 train_time:152287ms step_avg:34.61ms +step:4600/20000 train_loss:2.0684 train_time:159190ms step_avg:34.61ms +step:4600 shared0_alpha:mean=0.548,std=0.060 shared1_alpha:mean=0.567,std=0.061 shared2_alpha:mean=0.555,std=0.059 eff_mlp_scale:[v0:4.8144 v1:4.4714 v2:4.0692 v3:3.7098 v4:4.4714 v5:4.0692 v6:3.7098 v7:6.8445] eff_attn_scale:[v0:0.0547 v1:0.0831 v2:0.0908 v3:0.0956 v4:0.0831 v5:0.0908 v6:0.0956 v7:0.2422] +step:4600/20000 val_loss:2.2585 val_bpb:1.3376 train_time:159197ms step_avg:34.61ms +step:4800/20000 train_loss:2.3598 train_time:166112ms step_avg:34.61ms +step:4800 shared0_alpha:mean=0.547,std=0.061 shared1_alpha:mean=0.568,std=0.061 shared2_alpha:mean=0.554,std=0.060 eff_mlp_scale:[v0:4.8925 v1:4.5538 v2:4.1453 v3:3.7736 v4:4.5538 v5:4.1453 v6:3.7736 v7:6.9643] eff_attn_scale:[v0:0.0541 v1:0.0866 v2:0.0936 v3:0.0987 v4:0.0866 v5:0.0936 v6:0.0987 v7:0.2438] +step:4800/20000 val_loss:2.2539 val_bpb:1.3349 train_time:166119ms step_avg:34.61ms +step:5000/20000 train_loss:2.1338 train_time:173033ms step_avg:34.61ms +step:5000 shared0_alpha:mean=0.546,std=0.060 shared1_alpha:mean=0.569,std=0.062 shared2_alpha:mean=0.552,std=0.060 eff_mlp_scale:[v0:4.9721 v1:4.6239 v2:4.2142 v3:3.8348 v4:4.6239 v5:4.2142 v6:3.8348 v7:7.0714] eff_attn_scale:[v0:0.0535 v1:0.0888 v2:0.0955 v3:0.1019 v4:0.0888 v5:0.0955 v6:0.1019 v7:0.2460] +step:5000/20000 val_loss:2.2489 val_bpb:1.3319 train_time:173038ms step_avg:34.61ms +step:5200/20000 train_loss:2.2686 train_time:179938ms step_avg:34.60ms +step:5200 shared0_alpha:mean=0.544,std=0.061 shared1_alpha:mean=0.570,std=0.062 shared2_alpha:mean=0.551,std=0.060 eff_mlp_scale:[v0:5.0491 v1:4.6923 v2:4.2822 v3:3.8973 v4:4.6923 v5:4.2822 v6:3.8973 v7:7.1794] eff_attn_scale:[v0:0.0546 v1:0.0935 v2:0.1018 v3:0.1094 v4:0.0935 v5:0.1018 v6:0.1094 v7:0.2498] +step:5200/20000 val_loss:2.2508 val_bpb:1.3330 train_time:179941ms step_avg:34.60ms +step:5400/20000 train_loss:2.2818 train_time:186837ms step_avg:34.60ms +step:5400 shared0_alpha:mean=0.543,std=0.062 shared1_alpha:mean=0.571,std=0.063 shared2_alpha:mean=0.550,std=0.060 eff_mlp_scale:[v0:5.1234 v1:4.7550 v2:4.3455 v3:3.9584 v4:4.7550 v5:4.3455 v6:3.9584 v7:7.2892] eff_attn_scale:[v0:0.0542 v1:0.0957 v2:0.1056 v3:0.1133 v4:0.0957 v5:0.1056 v6:0.1133 v7:0.2516] +step:5400/20000 val_loss:2.2448 val_bpb:1.3295 train_time:186840ms step_avg:34.60ms +step:5600/20000 train_loss:2.2856 train_time:193799ms step_avg:34.61ms +step:5600 shared0_alpha:mean=0.542,std=0.062 shared1_alpha:mean=0.572,std=0.063 shared2_alpha:mean=0.549,std=0.060 eff_mlp_scale:[v0:5.1908 v1:4.8169 v2:4.4046 v3:4.0119 v4:4.8169 v5:4.4046 v6:4.0119 v7:7.3997] eff_attn_scale:[v0:0.0535 v1:0.1005 v2:0.1101 v3:0.1180 v4:0.1005 v5:0.1101 v6:0.1180 v7:0.2539] +step:5600/20000 val_loss:2.2457 val_bpb:1.3301 train_time:193804ms step_avg:34.61ms +step:5800/20000 train_loss:2.2476 train_time:200738ms step_avg:34.61ms +step:5800 shared0_alpha:mean=0.541,std=0.062 shared1_alpha:mean=0.573,std=0.064 shared2_alpha:mean=0.548,std=0.060 eff_mlp_scale:[v0:5.2639 v1:4.8770 v2:4.4600 v3:4.0628 v4:4.8770 v5:4.4600 v6:4.0628 v7:7.5122] eff_attn_scale:[v0:0.0549 v1:0.1042 v2:0.1141 v3:0.1227 v4:0.1042 v5:0.1141 v6:0.1227 v7:0.2583] +step:5800/20000 val_loss:2.2445 val_bpb:1.3293 train_time:200743ms step_avg:34.61ms +step:6000/20000 train_loss:2.3098 train_time:207669ms step_avg:34.61ms +step:6000 shared0_alpha:mean=0.540,std=0.062 shared1_alpha:mean=0.574,std=0.063 shared2_alpha:mean=0.547,std=0.060 eff_mlp_scale:[v0:5.3307 v1:4.9271 v2:4.5129 v3:4.1197 v4:4.9271 v5:4.5129 v6:4.1197 v7:7.6164] eff_attn_scale:[v0:0.0543 v1:0.1056 v2:0.1162 v3:0.1268 v4:0.1056 v5:0.1162 v6:0.1268 v7:0.2607] +step:6000/20000 val_loss:2.2398 val_bpb:1.3265 train_time:207674ms step_avg:34.61ms +step:6200/20000 train_loss:2.1860 train_time:214594ms step_avg:34.61ms +step:6200 shared0_alpha:mean=0.538,std=0.063 shared1_alpha:mean=0.575,std=0.064 shared2_alpha:mean=0.546,std=0.061 eff_mlp_scale:[v0:5.3893 v1:4.9819 v2:4.5645 v3:4.1654 v4:4.9819 v5:4.5645 v6:4.1654 v7:7.7404] eff_attn_scale:[v0:0.0543 v1:0.1067 v2:0.1167 v3:0.1283 v4:0.1067 v5:0.1167 v6:0.1283 v7:0.2556] +step:6200/20000 val_loss:2.2397 val_bpb:1.3265 train_time:214597ms step_avg:34.61ms +step:6400/20000 train_loss:2.2620 train_time:221493ms step_avg:34.61ms +step:6400 shared0_alpha:mean=0.538,std=0.064 shared1_alpha:mean=0.576,std=0.065 shared2_alpha:mean=0.545,std=0.061 eff_mlp_scale:[v0:5.4549 v1:5.0344 v2:4.6225 v3:4.2237 v4:5.0344 v5:4.6225 v6:4.2237 v7:7.8420] eff_attn_scale:[v0:0.0541 v1:0.1090 v2:0.1196 v3:0.1325 v4:0.1090 v5:0.1196 v6:0.1325 v7:0.2621] +step:6400/20000 val_loss:2.2372 val_bpb:1.3250 train_time:221498ms step_avg:34.61ms +step:6600/20000 train_loss:2.2288 train_time:228396ms step_avg:34.61ms +step:6600 shared0_alpha:mean=0.537,std=0.064 shared1_alpha:mean=0.576,std=0.065 shared2_alpha:mean=0.544,std=0.061 eff_mlp_scale:[v0:5.5168 v1:5.0877 v2:4.6782 v3:4.2755 v4:5.0877 v5:4.6782 v6:4.2755 v7:7.9403] eff_attn_scale:[v0:0.0542 v1:0.1096 v2:0.1202 v3:0.1328 v4:0.1096 v5:0.1202 v6:0.1328 v7:0.2606] +step:6600/20000 val_loss:2.2347 val_bpb:1.3235 train_time:228399ms step_avg:34.61ms +step:6800/20000 train_loss:2.2917 train_time:235303ms step_avg:34.60ms +step:6800 shared0_alpha:mean=0.535,std=0.064 shared1_alpha:mean=0.577,std=0.065 shared2_alpha:mean=0.544,std=0.062 eff_mlp_scale:[v0:5.5737 v1:5.1390 v2:4.7359 v3:4.3262 v4:5.1390 v5:4.7359 v6:4.3262 v7:8.0384] eff_attn_scale:[v0:0.0543 v1:0.1094 v2:0.1194 v3:0.1351 v4:0.1094 v5:0.1194 v6:0.1351 v7:0.2616] +step:6800/20000 val_loss:2.2349 val_bpb:1.3236 train_time:235305ms step_avg:34.60ms +step:7000/20000 train_loss:2.3219 train_time:242191ms step_avg:34.60ms +step:7000 shared0_alpha:mean=0.535,std=0.064 shared1_alpha:mean=0.578,std=0.066 shared2_alpha:mean=0.542,std=0.061 eff_mlp_scale:[v0:5.6303 v1:5.1920 v2:4.7878 v3:4.3830 v4:5.1920 v5:4.7878 v6:4.3830 v7:8.1406] eff_attn_scale:[v0:0.0545 v1:0.1117 v2:0.1220 v3:0.1379 v4:0.1117 v5:0.1220 v6:0.1379 v7:0.2641] +step:7000/20000 val_loss:2.2305 val_bpb:1.3210 train_time:242196ms step_avg:34.60ms +step:7200/20000 train_loss:2.2988 train_time:249093ms step_avg:34.60ms +step:7200 shared0_alpha:mean=0.533,std=0.064 shared1_alpha:mean=0.578,std=0.066 shared2_alpha:mean=0.541,std=0.061 eff_mlp_scale:[v0:5.6851 v1:5.2440 v2:4.8448 v3:4.4309 v4:5.2440 v5:4.8448 v6:4.4309 v7:8.2399] eff_attn_scale:[v0:0.0539 v1:0.1107 v2:0.1238 v3:0.1407 v4:0.1107 v5:0.1238 v6:0.1407 v7:0.2656] +step:7200/20000 val_loss:2.2306 val_bpb:1.3211 train_time:249095ms step_avg:34.60ms +step:7400/20000 train_loss:2.2177 train_time:255999ms step_avg:34.59ms +step:7400 shared0_alpha:mean=0.533,std=0.065 shared1_alpha:mean=0.579,std=0.066 shared2_alpha:mean=0.540,std=0.061 eff_mlp_scale:[v0:5.7430 v1:5.2904 v2:4.9020 v3:4.4788 v4:5.2904 v5:4.9020 v6:4.4788 v7:8.3367] eff_attn_scale:[v0:0.0541 v1:0.1110 v2:0.1231 v3:0.1422 v4:0.1110 v5:0.1231 v6:0.1422 v7:0.2669] +step:7400/20000 val_loss:2.2288 val_bpb:1.3200 train_time:256004ms step_avg:34.60ms +step:7600/20000 train_loss:2.0980 train_time:262889ms step_avg:34.59ms +step:7600 shared0_alpha:mean=0.532,std=0.065 shared1_alpha:mean=0.580,std=0.066 shared2_alpha:mean=0.539,std=0.062 eff_mlp_scale:[v0:5.7964 v1:5.3423 v2:4.9552 v3:4.5253 v4:5.3423 v5:4.9552 v6:4.5253 v7:8.4350] eff_attn_scale:[v0:0.0545 v1:0.1127 v2:0.1247 v3:0.1454 v4:0.1127 v5:0.1247 v6:0.1454 v7:0.2695] +step:7600/20000 val_loss:2.2265 val_bpb:1.3186 train_time:262893ms step_avg:34.59ms +step:7800/20000 train_loss:2.2415 train_time:269816ms step_avg:34.59ms +step:7800 shared0_alpha:mean=0.531,std=0.065 shared1_alpha:mean=0.581,std=0.066 shared2_alpha:mean=0.538,std=0.062 eff_mlp_scale:[v0:5.8558 v1:5.3951 v2:5.0059 v3:4.5747 v4:5.3951 v5:5.0059 v6:4.5747 v7:8.5330] eff_attn_scale:[v0:0.0537 v1:0.1123 v2:0.1240 v3:0.1474 v4:0.1123 v5:0.1240 v6:0.1474 v7:0.2714] +step:7800/20000 val_loss:2.2247 val_bpb:1.3176 train_time:269827ms step_avg:34.59ms +step:8000/20000 train_loss:2.2113 train_time:276726ms step_avg:34.59ms +step:8000 shared0_alpha:mean=0.531,std=0.065 shared1_alpha:mean=0.582,std=0.067 shared2_alpha:mean=0.538,std=0.062 eff_mlp_scale:[v0:5.9098 v1:5.4464 v2:5.0608 v3:4.6229 v4:5.4464 v5:5.0608 v6:4.6229 v7:8.6223] eff_attn_scale:[v0:0.0537 v1:0.1129 v2:0.1259 v3:0.1485 v4:0.1129 v5:0.1259 v6:0.1485 v7:0.2729] +step:8000/20000 val_loss:2.2215 val_bpb:1.3157 train_time:276730ms step_avg:34.59ms +step:8200/20000 train_loss:2.2767 train_time:283629ms step_avg:34.59ms +step:8200 shared0_alpha:mean=0.529,std=0.065 shared1_alpha:mean=0.583,std=0.067 shared2_alpha:mean=0.537,std=0.062 eff_mlp_scale:[v0:5.9602 v1:5.4968 v2:5.1166 v3:4.6674 v4:5.4968 v5:5.1166 v6:4.6674 v7:8.7195] eff_attn_scale:[v0:0.0537 v1:0.1132 v2:0.1267 v3:0.1498 v4:0.1132 v5:0.1267 v6:0.1498 v7:0.2753] +step:8200/20000 val_loss:2.2209 val_bpb:1.3153 train_time:283632ms step_avg:34.59ms +step:8400/20000 train_loss:2.2299 train_time:290626ms step_avg:34.60ms +step:8400 shared0_alpha:mean=0.529,std=0.066 shared1_alpha:mean=0.583,std=0.067 shared2_alpha:mean=0.536,std=0.062 eff_mlp_scale:[v0:6.0147 v1:5.5436 v2:5.1712 v3:4.7140 v4:5.5436 v5:5.1712 v6:4.7140 v7:8.8149] eff_attn_scale:[v0:0.0544 v1:0.1130 v2:0.1273 v3:0.1521 v4:0.1130 v5:0.1273 v6:0.1521 v7:0.2782] +step:8400/20000 val_loss:2.2209 val_bpb:1.3153 train_time:290634ms step_avg:34.60ms +step:8600/20000 train_loss:2.2381 train_time:297504ms step_avg:34.59ms +step:8600 shared0_alpha:mean=0.528,std=0.065 shared1_alpha:mean=0.584,std=0.067 shared2_alpha:mean=0.534,std=0.062 eff_mlp_scale:[v0:6.0733 v1:5.5949 v2:5.2215 v3:4.7595 v4:5.5949 v5:5.2215 v6:4.7595 v7:8.9087] eff_attn_scale:[v0:0.0545 v1:0.1141 v2:0.1295 v3:0.1541 v4:0.1141 v5:0.1295 v6:0.1541 v7:0.2807] +step:8600/20000 val_loss:2.2192 val_bpb:1.3143 train_time:297511ms step_avg:34.59ms +step:8800/20000 train_loss:2.2067 train_time:304423ms step_avg:34.59ms +step:8800 shared0_alpha:mean=0.527,std=0.065 shared1_alpha:mean=0.584,std=0.068 shared2_alpha:mean=0.534,std=0.062 eff_mlp_scale:[v0:6.1265 v1:5.6422 v2:5.2751 v3:4.8048 v4:5.6422 v5:5.2751 v6:4.8048 v7:8.9969] eff_attn_scale:[v0:0.0541 v1:0.1148 v2:0.1281 v3:0.1563 v4:0.1148 v5:0.1281 v6:0.1563 v7:0.2812] +step:8800/20000 val_loss:2.2175 val_bpb:1.3133 train_time:304431ms step_avg:34.59ms +step:9000/20000 train_loss:2.1269 train_time:311360ms step_avg:34.60ms +step:9000 shared0_alpha:mean=0.527,std=0.066 shared1_alpha:mean=0.585,std=0.068 shared2_alpha:mean=0.533,std=0.062 eff_mlp_scale:[v0:6.1770 v1:5.6893 v2:5.3312 v3:4.8488 v4:5.6893 v5:5.3312 v6:4.8488 v7:9.0858] eff_attn_scale:[v0:0.0542 v1:0.1150 v2:0.1290 v3:0.1591 v4:0.1150 v5:0.1290 v6:0.1591 v7:0.2842] +step:9000/20000 val_loss:2.2172 val_bpb:1.3132 train_time:311367ms step_avg:34.60ms +step:9200/20000 train_loss:2.1851 train_time:318266ms step_avg:34.59ms +step:9200 shared0_alpha:mean=0.526,std=0.066 shared1_alpha:mean=0.587,std=0.068 shared2_alpha:mean=0.532,std=0.062 eff_mlp_scale:[v0:6.2288 v1:5.7381 v2:5.3828 v3:4.8920 v4:5.7381 v5:5.3828 v6:4.8920 v7:9.1780] eff_attn_scale:[v0:0.0545 v1:0.1162 v2:0.1305 v3:0.1615 v4:0.1162 v5:0.1305 v6:0.1615 v7:0.2855] +step:9200/20000 val_loss:2.2161 val_bpb:1.3125 train_time:318270ms step_avg:34.59ms +step:9400/20000 train_loss:2.2447 train_time:325143ms step_avg:34.59ms +step:9400 shared0_alpha:mean=0.525,std=0.065 shared1_alpha:mean=0.587,std=0.068 shared2_alpha:mean=0.532,std=0.062 eff_mlp_scale:[v0:6.2792 v1:5.7868 v2:5.4383 v3:4.9344 v4:5.7868 v5:5.4383 v6:4.9344 v7:9.2644] eff_attn_scale:[v0:0.0533 v1:0.1146 v2:0.1288 v3:0.1617 v4:0.1146 v5:0.1288 v6:0.1617 v7:0.2861] +step:9400/20000 val_loss:2.2144 val_bpb:1.3115 train_time:325147ms step_avg:34.59ms +step:9600/20000 train_loss:2.2493 train_time:332033ms step_avg:34.59ms +step:9600 shared0_alpha:mean=0.524,std=0.065 shared1_alpha:mean=0.588,std=0.068 shared2_alpha:mean=0.530,std=0.062 eff_mlp_scale:[v0:6.3297 v1:5.8339 v2:5.4897 v3:4.9775 v4:5.8339 v5:5.4897 v6:4.9775 v7:9.3525] eff_attn_scale:[v0:0.0532 v1:0.1153 v2:0.1294 v3:0.1633 v4:0.1153 v5:0.1294 v6:0.1633 v7:0.2861] +step:9600/20000 val_loss:2.2135 val_bpb:1.3110 train_time:332035ms step_avg:34.59ms +step:9800/20000 train_loss:2.1816 train_time:338923ms step_avg:34.58ms +step:9800 shared0_alpha:mean=0.524,std=0.066 shared1_alpha:mean=0.589,std=0.068 shared2_alpha:mean=0.530,std=0.062 eff_mlp_scale:[v0:6.3835 v1:5.8791 v2:5.5466 v3:5.0202 v4:5.8791 v5:5.5466 v6:5.0202 v7:9.4385] eff_attn_scale:[v0:0.0547 v1:0.1163 v2:0.1303 v3:0.1650 v4:0.1163 v5:0.1303 v6:0.1650 v7:0.2883] +step:9800/20000 val_loss:2.2148 val_bpb:1.3117 train_time:338926ms step_avg:34.58ms +step:10000/20000 train_loss:2.2208 train_time:345810ms step_avg:34.58ms +step:10000 shared0_alpha:mean=0.523,std=0.066 shared1_alpha:mean=0.590,std=0.068 shared2_alpha:mean=0.529,std=0.063 eff_mlp_scale:[v0:6.4385 v1:5.9260 v2:5.6036 v3:5.0616 v4:5.9260 v5:5.6036 v6:5.0616 v7:9.5277] eff_attn_scale:[v0:0.0538 v1:0.1169 v2:0.1316 v3:0.1673 v4:0.1169 v5:0.1316 v6:0.1673 v7:0.2904] +step:10000/20000 val_loss:2.2134 val_bpb:1.3109 train_time:345814ms step_avg:34.58ms +step:10200/20000 train_loss:2.1735 train_time:352702ms step_avg:34.58ms +step:10200 shared0_alpha:mean=0.523,std=0.066 shared1_alpha:mean=0.591,std=0.069 shared2_alpha:mean=0.528,std=0.062 eff_mlp_scale:[v0:6.4867 v1:5.9755 v2:5.6487 v3:5.1053 v4:5.9755 v5:5.6487 v6:5.1053 v7:9.6089] eff_attn_scale:[v0:0.0537 v1:0.1169 v2:0.1308 v3:0.1671 v4:0.1169 v5:0.1308 v6:0.1671 v7:0.2918] +step:10200/20000 val_loss:2.2108 val_bpb:1.3094 train_time:352705ms step_avg:34.58ms +step:10400/20000 train_loss:2.2074 train_time:359588ms step_avg:34.58ms +step:10400 shared0_alpha:mean=0.522,std=0.066 shared1_alpha:mean=0.592,std=0.069 shared2_alpha:mean=0.527,std=0.062 eff_mlp_scale:[v0:6.5364 v1:6.0211 v2:5.7044 v3:5.1454 v4:6.0211 v5:5.7044 v6:5.1454 v7:9.6992] eff_attn_scale:[v0:0.0538 v1:0.1155 v2:0.1310 v3:0.1688 v4:0.1155 v5:0.1310 v6:0.1688 v7:0.2930] +step:10400/20000 val_loss:2.2098 val_bpb:1.3087 train_time:359591ms step_avg:34.58ms +step:10600/20000 train_loss:2.0804 train_time:366486ms step_avg:34.57ms +step:10600 shared0_alpha:mean=0.522,std=0.066 shared1_alpha:mean=0.593,std=0.069 shared2_alpha:mean=0.526,std=0.063 eff_mlp_scale:[v0:6.5870 v1:6.0643 v2:5.7561 v3:5.1875 v4:6.0643 v5:5.7561 v6:5.1875 v7:9.7829] eff_attn_scale:[v0:0.0546 v1:0.1161 v2:0.1311 v3:0.1705 v4:0.1161 v5:0.1311 v6:0.1705 v7:0.2934] +step:10600/20000 val_loss:2.2106 val_bpb:1.3092 train_time:366491ms step_avg:34.57ms +step:10800/20000 train_loss:2.2835 train_time:373399ms step_avg:34.57ms +step:10800 shared0_alpha:mean=0.521,std=0.066 shared1_alpha:mean=0.593,std=0.069 shared2_alpha:mean=0.525,std=0.063 eff_mlp_scale:[v0:6.6342 v1:6.1112 v2:5.8096 v3:5.2310 v4:6.1112 v5:5.8096 v6:5.2310 v7:9.8698] eff_attn_scale:[v0:0.0544 v1:0.1157 v2:0.1306 v3:0.1725 v4:0.1157 v5:0.1306 v6:0.1725 v7:0.2948] +step:10800/20000 val_loss:2.2081 val_bpb:1.3078 train_time:373404ms step_avg:34.57ms +step:11000/20000 train_loss:2.2171 train_time:380320ms step_avg:34.57ms +step:11000 shared0_alpha:mean=0.520,std=0.066 shared1_alpha:mean=0.594,std=0.069 shared2_alpha:mean=0.525,std=0.063 eff_mlp_scale:[v0:6.6789 v1:6.1560 v2:5.8645 v3:5.2685 v4:6.1560 v5:5.8645 v6:5.2685 v7:9.9543] eff_attn_scale:[v0:0.0543 v1:0.1168 v2:0.1319 v3:0.1735 v4:0.1168 v5:0.1319 v6:0.1735 v7:0.2954] +step:11000/20000 val_loss:2.2070 val_bpb:1.3071 train_time:380321ms step_avg:34.57ms +step:11200/20000 train_loss:2.1733 train_time:387218ms step_avg:34.57ms +step:11200 shared0_alpha:mean=0.520,std=0.066 shared1_alpha:mean=0.595,std=0.070 shared2_alpha:mean=0.523,std=0.063 eff_mlp_scale:[v0:6.7240 v1:6.2017 v2:5.9167 v3:5.3122 v4:6.2017 v5:5.9167 v6:5.3122 v7:10.0384] eff_attn_scale:[v0:0.0545 v1:0.1168 v2:0.1323 v3:0.1750 v4:0.1168 v5:0.1323 v6:0.1750 v7:0.2985] +step:11200/20000 val_loss:2.2067 val_bpb:1.3069 train_time:387220ms step_avg:34.57ms +step:11400/20000 train_loss:2.1591 train_time:394149ms step_avg:34.57ms +step:11400 shared0_alpha:mean=0.520,std=0.066 shared1_alpha:mean=0.596,std=0.070 shared2_alpha:mean=0.523,std=0.063 eff_mlp_scale:[v0:6.7761 v1:6.2496 v2:5.9733 v3:5.3501 v4:6.2496 v5:5.9733 v6:5.3501 v7:10.1188] eff_attn_scale:[v0:0.0545 v1:0.1185 v2:0.1342 v3:0.1771 v4:0.1185 v5:0.1342 v6:0.1771 v7:0.2980] +step:11400/20000 val_loss:2.2071 val_bpb:1.3071 train_time:394154ms step_avg:34.57ms +step:11600/20000 train_loss:2.1683 train_time:401067ms step_avg:34.57ms +step:11600 shared0_alpha:mean=0.519,std=0.066 shared1_alpha:mean=0.597,std=0.070 shared2_alpha:mean=0.522,std=0.063 eff_mlp_scale:[v0:6.8276 v1:6.2938 v2:6.0254 v3:5.3947 v4:6.2938 v5:6.0254 v6:5.3947 v7:10.2032] eff_attn_scale:[v0:0.0546 v1:0.1172 v2:0.1338 v3:0.1781 v4:0.1172 v5:0.1338 v6:0.1781 v7:0.2995] +step:11600/20000 val_loss:2.2062 val_bpb:1.3067 train_time:401071ms step_avg:34.58ms +step:11800/20000 train_loss:2.1937 train_time:407976ms step_avg:34.57ms +step:11800 shared0_alpha:mean=0.518,std=0.066 shared1_alpha:mean=0.599,std=0.070 shared2_alpha:mean=0.521,std=0.063 eff_mlp_scale:[v0:6.8773 v1:6.3452 v2:6.0765 v3:5.4375 v4:6.3452 v5:6.0765 v6:5.4375 v7:10.2869] eff_attn_scale:[v0:0.0548 v1:0.1178 v2:0.1333 v3:0.1796 v4:0.1178 v5:0.1333 v6:0.1796 v7:0.3009] +step:11800/20000 val_loss:2.2036 val_bpb:1.3051 train_time:407979ms step_avg:34.57ms +step:12000/20000 train_loss:2.1722 train_time:414932ms step_avg:34.58ms +step:12000 shared0_alpha:mean=0.517,std=0.066 shared1_alpha:mean=0.600,std=0.070 shared2_alpha:mean=0.520,std=0.063 eff_mlp_scale:[v0:6.9210 v1:6.3858 v2:6.1318 v3:5.4755 v4:6.3858 v5:6.1318 v6:5.4755 v7:10.3730] eff_attn_scale:[v0:0.0546 v1:0.1175 v2:0.1348 v3:0.1812 v4:0.1175 v5:0.1348 v6:0.1812 v7:0.3018] +step:12000/20000 val_loss:2.2031 val_bpb:1.3048 train_time:414936ms step_avg:34.58ms +step:12200/20000 train_loss:2.3154 train_time:421869ms step_avg:34.58ms +step:12200 shared0_alpha:mean=0.517,std=0.066 shared1_alpha:mean=0.600,std=0.070 shared2_alpha:mean=0.519,std=0.063 eff_mlp_scale:[v0:6.9705 v1:6.4277 v2:6.1829 v3:5.5259 v4:6.4277 v5:6.1829 v6:5.5259 v7:10.4528] eff_attn_scale:[v0:0.0549 v1:0.1181 v2:0.1340 v3:0.1824 v4:0.1181 v5:0.1340 v6:0.1824 v7:0.3032] +step:12200/20000 val_loss:2.2024 val_bpb:1.3044 train_time:421872ms step_avg:34.58ms +step:12400/20000 train_loss:1.9592 train_time:428879ms step_avg:34.59ms +step:12400 shared0_alpha:mean=0.517,std=0.066 shared1_alpha:mean=0.601,std=0.071 shared2_alpha:mean=0.518,std=0.063 eff_mlp_scale:[v0:7.0207 v1:6.4722 v2:6.2376 v3:5.5636 v4:6.4722 v5:6.2376 v6:5.5636 v7:10.5328] eff_attn_scale:[v0:0.0555 v1:0.1191 v2:0.1342 v3:0.1842 v4:0.1191 v5:0.1342 v6:0.1842 v7:0.3019] +step:12400/20000 val_loss:2.2025 val_bpb:1.3044 train_time:428886ms step_avg:34.59ms +step:12600/20000 train_loss:2.1922 train_time:435808ms step_avg:34.59ms +step:12600 shared0_alpha:mean=0.516,std=0.066 shared1_alpha:mean=0.602,std=0.071 shared2_alpha:mean=0.518,std=0.063 eff_mlp_scale:[v0:7.0701 v1:6.5189 v2:6.2872 v3:5.6032 v4:6.5189 v5:6.2872 v6:5.6032 v7:10.6174] eff_attn_scale:[v0:0.0552 v1:0.1193 v2:0.1363 v3:0.1864 v4:0.1193 v5:0.1363 v6:0.1864 v7:0.3045] +step:12600/20000 val_loss:2.2039 val_bpb:1.3053 train_time:435815ms step_avg:34.59ms +step:12800/20000 train_loss:2.2109 train_time:442728ms step_avg:34.59ms +step:12800 shared0_alpha:mean=0.515,std=0.067 shared1_alpha:mean=0.604,std=0.071 shared2_alpha:mean=0.517,std=0.064 eff_mlp_scale:[v0:7.1159 v1:6.5671 v2:6.3429 v3:5.6439 v4:6.5671 v5:6.3429 v6:5.6439 v7:10.6963] eff_attn_scale:[v0:0.0549 v1:0.1199 v2:0.1365 v3:0.1871 v4:0.1199 v5:0.1365 v6:0.1871 v7:0.3058] +step:12800/20000 val_loss:2.2016 val_bpb:1.3039 train_time:442735ms step_avg:34.59ms +step:13000/20000 train_loss:2.2977 train_time:449628ms step_avg:34.59ms +step:13000 shared0_alpha:mean=0.514,std=0.066 shared1_alpha:mean=0.604,std=0.071 shared2_alpha:mean=0.516,std=0.063 eff_mlp_scale:[v0:7.1658 v1:6.6100 v2:6.3969 v3:5.6857 v4:6.6100 v5:6.3969 v6:5.6857 v7:10.7749] eff_attn_scale:[v0:0.0552 v1:0.1189 v2:0.1358 v3:0.1890 v4:0.1189 v5:0.1358 v6:0.1890 v7:0.3053] +step:13000/20000 val_loss:2.2029 val_bpb:1.3047 train_time:449634ms step_avg:34.59ms +step:13200/20000 train_loss:2.3070 train_time:456542ms step_avg:34.59ms +step:13200 shared0_alpha:mean=0.514,std=0.067 shared1_alpha:mean=0.606,std=0.071 shared2_alpha:mean=0.516,std=0.063 eff_mlp_scale:[v0:7.2096 v1:6.6534 v2:6.4500 v3:5.7254 v4:6.6534 v5:6.4500 v6:5.7254 v7:10.8539] eff_attn_scale:[v0:0.0549 v1:0.1198 v2:0.1366 v3:0.1897 v4:0.1198 v5:0.1366 v6:0.1897 v7:0.3055] +step:13200/20000 val_loss:2.1985 val_bpb:1.3021 train_time:456546ms step_avg:34.59ms +step:13400/20000 train_loss:2.1811 train_time:463453ms step_avg:34.59ms +step:13400 shared0_alpha:mean=0.514,std=0.067 shared1_alpha:mean=0.607,std=0.072 shared2_alpha:mean=0.515,std=0.064 eff_mlp_scale:[v0:7.2505 v1:6.6995 v2:6.5038 v3:5.7678 v4:6.6995 v5:6.5038 v6:5.7678 v7:10.9355] eff_attn_scale:[v0:0.0554 v1:0.1194 v2:0.1385 v3:0.1923 v4:0.1194 v5:0.1385 v6:0.1923 v7:0.3074] +step:13400/20000 val_loss:2.1990 val_bpb:1.3024 train_time:463458ms step_avg:34.59ms +step:13600/20000 train_loss:2.0562 train_time:470369ms step_avg:34.59ms +step:13600 shared0_alpha:mean=0.514,std=0.066 shared1_alpha:mean=0.608,std=0.072 shared2_alpha:mean=0.514,std=0.063 eff_mlp_scale:[v0:7.2974 v1:6.7400 v2:6.5570 v3:5.8110 v4:6.7400 v5:6.5570 v6:5.8110 v7:11.0135] eff_attn_scale:[v0:0.0542 v1:0.1218 v2:0.1383 v3:0.1938 v4:0.1218 v5:0.1383 v6:0.1938 v7:0.3097] +step:13600/20000 val_loss:2.2004 val_bpb:1.3032 train_time:470373ms step_avg:34.59ms +step:13800/20000 train_loss:2.1285 train_time:477270ms step_avg:34.58ms +step:13800 shared0_alpha:mean=0.513,std=0.066 shared1_alpha:mean=0.609,std=0.072 shared2_alpha:mean=0.513,std=0.064 eff_mlp_scale:[v0:7.3464 v1:6.7849 v2:6.6091 v3:5.8470 v4:6.7849 v5:6.6091 v6:5.8470 v7:11.0878] eff_attn_scale:[v0:0.0543 v1:0.1221 v2:0.1385 v3:0.1949 v4:0.1221 v5:0.1385 v6:0.1949 v7:0.3092] +step:13800/20000 val_loss:2.1969 val_bpb:1.3012 train_time:477272ms step_avg:34.58ms +step:14000/20000 train_loss:2.1969 train_time:484182ms step_avg:34.58ms +step:14000 shared0_alpha:mean=0.513,std=0.067 shared1_alpha:mean=0.610,std=0.072 shared2_alpha:mean=0.512,std=0.064 eff_mlp_scale:[v0:7.3934 v1:6.8282 v2:6.6631 v3:5.8917 v4:6.8282 v5:6.6631 v6:5.8917 v7:11.1619] eff_attn_scale:[v0:0.0555 v1:0.1216 v2:0.1397 v3:0.1973 v4:0.1216 v5:0.1397 v6:0.1973 v7:0.3084] +step:14000/20000 val_loss:2.1975 val_bpb:1.3015 train_time:484185ms step_avg:34.58ms +step:14200/20000 train_loss:2.2875 train_time:491072ms step_avg:34.58ms +step:14200 shared0_alpha:mean=0.513,std=0.067 shared1_alpha:mean=0.611,std=0.072 shared2_alpha:mean=0.511,std=0.063 eff_mlp_scale:[v0:7.4448 v1:6.8683 v2:6.7203 v3:5.9345 v4:6.8683 v5:6.7203 v6:5.9345 v7:11.2435] eff_attn_scale:[v0:0.0553 v1:0.1225 v2:0.1393 v3:0.1990 v4:0.1225 v5:0.1393 v6:0.1990 v7:0.3101] +step:14200/20000 val_loss:2.1976 val_bpb:1.3015 train_time:491076ms step_avg:34.58ms +step:14400/20000 train_loss:2.1775 train_time:497972ms step_avg:34.58ms +step:14400 shared0_alpha:mean=0.512,std=0.067 shared1_alpha:mean=0.612,std=0.072 shared2_alpha:mean=0.511,std=0.063 eff_mlp_scale:[v0:7.4888 v1:6.9123 v2:6.7717 v3:5.9716 v4:6.9123 v5:6.7717 v6:5.9716 v7:11.3167] eff_attn_scale:[v0:0.0552 v1:0.1227 v2:0.1393 v3:0.1989 v4:0.1227 v5:0.1393 v6:0.1989 v7:0.3105] +step:14400/20000 val_loss:2.1943 val_bpb:1.2996 train_time:497974ms step_avg:34.58ms +step:14600/20000 train_loss:2.2366 train_time:504920ms step_avg:34.58ms +step:14600 shared0_alpha:mean=0.511,std=0.067 shared1_alpha:mean=0.613,std=0.073 shared2_alpha:mean=0.510,std=0.064 eff_mlp_scale:[v0:7.5359 v1:6.9568 v2:6.8264 v3:6.0118 v4:6.9568 v5:6.8264 v6:6.0118 v7:11.3905] eff_attn_scale:[v0:0.0551 v1:0.1218 v2:0.1395 v3:0.2004 v4:0.1218 v5:0.1395 v6:0.2004 v7:0.3106] +step:14600/20000 val_loss:2.1942 val_bpb:1.2995 train_time:504927ms step_avg:34.58ms +step:14800/20000 train_loss:2.0248 train_time:511848ms step_avg:34.58ms +step:14800 shared0_alpha:mean=0.511,std=0.066 shared1_alpha:mean=0.614,std=0.073 shared2_alpha:mean=0.510,std=0.064 eff_mlp_scale:[v0:7.5770 v1:7.0013 v2:6.8793 v3:6.0525 v4:7.0013 v5:6.8793 v6:6.0525 v7:11.4662] eff_attn_scale:[v0:0.0548 v1:0.1216 v2:0.1396 v3:0.2011 v4:0.1216 v5:0.1396 v6:0.2011 v7:0.3101] +step:14800/20000 val_loss:2.1949 val_bpb:1.3000 train_time:511856ms step_avg:34.58ms +step:15000/20000 train_loss:2.1352 train_time:518760ms step_avg:34.58ms +step:15000 shared0_alpha:mean=0.510,std=0.067 shared1_alpha:mean=0.615,std=0.073 shared2_alpha:mean=0.508,std=0.064 eff_mlp_scale:[v0:7.6210 v1:7.0444 v2:6.9292 v3:6.0954 v4:7.0444 v5:6.9292 v6:6.0954 v7:11.5410] eff_attn_scale:[v0:0.0547 v1:0.1213 v2:0.1408 v3:0.2016 v4:0.1213 v5:0.1408 v6:0.2016 v7:0.3121] +step:15000/20000 val_loss:2.1928 val_bpb:1.2987 train_time:518763ms step_avg:34.58ms +step:15200/20000 train_loss:2.2552 train_time:525664ms step_avg:34.58ms +step:15200 shared0_alpha:mean=0.510,std=0.066 shared1_alpha:mean=0.617,std=0.073 shared2_alpha:mean=0.508,std=0.064 eff_mlp_scale:[v0:7.6759 v1:7.0868 v2:6.9870 v3:6.1337 v4:7.0868 v5:6.9870 v6:6.1337 v7:11.6174] eff_attn_scale:[v0:0.0554 v1:0.1210 v2:0.1403 v3:0.2033 v4:0.1210 v5:0.1403 v6:0.2033 v7:0.3150] +step:15200/20000 val_loss:2.1935 val_bpb:1.2991 train_time:525666ms step_avg:34.58ms +step:15400/20000 train_loss:2.1518 train_time:532624ms step_avg:34.59ms +step:15400 shared0_alpha:mean=0.509,std=0.067 shared1_alpha:mean=0.618,std=0.073 shared2_alpha:mean=0.507,std=0.064 eff_mlp_scale:[v0:7.7175 v1:7.1304 v2:7.0385 v3:6.1778 v4:7.1304 v5:7.0385 v6:6.1778 v7:11.6876] eff_attn_scale:[v0:0.0550 v1:0.1235 v2:0.1415 v3:0.2044 v4:0.1235 v5:0.1415 v6:0.2044 v7:0.3151] +step:15400/20000 val_loss:2.1929 val_bpb:1.2987 train_time:532630ms step_avg:34.59ms +step:15600/20000 train_loss:2.1737 train_time:539537ms step_avg:34.59ms +step:15600 shared0_alpha:mean=0.509,std=0.066 shared1_alpha:mean=0.619,std=0.074 shared2_alpha:mean=0.506,std=0.064 eff_mlp_scale:[v0:7.7592 v1:7.1700 v2:7.0920 v3:6.2179 v4:7.1700 v5:7.0920 v6:6.2179 v7:11.7613] eff_attn_scale:[v0:0.0550 v1:0.1227 v2:0.1403 v3:0.2064 v4:0.1227 v5:0.1403 v6:0.2064 v7:0.3138] +step:15600/20000 val_loss:2.1909 val_bpb:1.2976 train_time:539539ms step_avg:34.59ms +step:15800/20000 train_loss:2.0250 train_time:546489ms step_avg:34.59ms +step:15800 shared0_alpha:mean=0.508,std=0.067 shared1_alpha:mean=0.619,std=0.074 shared2_alpha:mean=0.506,std=0.063 eff_mlp_scale:[v0:7.8031 v1:7.2150 v2:7.1462 v3:6.2596 v4:7.2150 v5:7.1462 v6:6.2596 v7:11.8317] eff_attn_scale:[v0:0.0555 v1:0.1242 v2:0.1417 v3:0.2074 v4:0.1242 v5:0.1417 v6:0.2074 v7:0.3188] +step:15800/20000 val_loss:2.1942 val_bpb:1.2995 train_time:546492ms step_avg:34.59ms +step:16000/20000 train_loss:2.2350 train_time:553401ms step_avg:34.59ms +step:16000 shared0_alpha:mean=0.508,std=0.067 shared1_alpha:mean=0.621,std=0.074 shared2_alpha:mean=0.505,std=0.064 eff_mlp_scale:[v0:7.8523 v1:7.2552 v2:7.1978 v3:6.3010 v4:7.2552 v5:7.1978 v6:6.3010 v7:11.8975] eff_attn_scale:[v0:0.0552 v1:0.1227 v2:0.1416 v3:0.2074 v4:0.1227 v5:0.1416 v6:0.2074 v7:0.3202] +step:16000/20000 val_loss:2.1896 val_bpb:1.2968 train_time:553403ms step_avg:34.59ms +step:16200/20000 train_loss:2.1077 train_time:560332ms step_avg:34.59ms +step:16200 shared0_alpha:mean=0.507,std=0.066 shared1_alpha:mean=0.622,std=0.074 shared2_alpha:mean=0.503,std=0.064 eff_mlp_scale:[v0:7.8979 v1:7.2992 v2:7.2493 v3:6.3390 v4:7.2992 v5:7.2493 v6:6.3390 v7:11.9682] eff_attn_scale:[v0:0.0555 v1:0.1244 v2:0.1426 v3:0.2098 v4:0.1244 v5:0.1426 v6:0.2098 v7:0.3235] +step:16200/20000 val_loss:2.1927 val_bpb:1.2986 train_time:560335ms step_avg:34.59ms +step:16400/20000 train_loss:2.1231 train_time:567290ms step_avg:34.59ms +step:16400 shared0_alpha:mean=0.506,std=0.067 shared1_alpha:mean=0.622,std=0.075 shared2_alpha:mean=0.503,std=0.064 eff_mlp_scale:[v0:7.9329 v1:7.3337 v2:7.2931 v3:6.3724 v4:7.3337 v5:7.2931 v6:6.3724 v7:12.0308] eff_attn_scale:[v0:0.0558 v1:0.1249 v2:0.1449 v3:0.2134 v4:0.1249 v5:0.1449 v6:0.2134 v7:0.3265] +step:16400/20000 val_loss:2.1865 val_bpb:1.2950 train_time:567293ms step_avg:34.59ms +step:16600/20000 train_loss:2.0605 train_time:574260ms step_avg:34.59ms +step:16600 shared0_alpha:mean=0.505,std=0.067 shared1_alpha:mean=0.622,std=0.074 shared2_alpha:mean=0.502,std=0.064 eff_mlp_scale:[v0:7.9586 v1:7.3620 v2:7.3311 v3:6.4033 v4:7.3620 v5:7.3311 v6:6.4033 v7:12.0931] eff_attn_scale:[v0:0.0544 v1:0.1273 v2:0.1462 v3:0.2184 v4:0.1273 v5:0.1462 v6:0.2184 v7:0.3363] +step:16600/20000 val_loss:2.1731 val_bpb:1.2871 train_time:574268ms step_avg:34.59ms +step:16800/20000 train_loss:2.2734 train_time:581176ms step_avg:34.59ms +step:16800 shared0_alpha:mean=0.505,std=0.067 shared1_alpha:mean=0.622,std=0.075 shared2_alpha:mean=0.501,std=0.064 eff_mlp_scale:[v0:7.9809 v1:7.3799 v2:7.3543 v3:6.4164 v4:7.3799 v5:7.3543 v6:6.4164 v7:12.1448] eff_attn_scale:[v0:0.0546 v1:0.1281 v2:0.1482 v3:0.2230 v4:0.1281 v5:0.1482 v6:0.2230 v7:0.3440] +step:16800/20000 val_loss:2.1663 val_bpb:1.2830 train_time:581183ms step_avg:34.59ms +step:17000/20000 train_loss:2.1810 train_time:588080ms step_avg:34.59ms +step:17000 shared0_alpha:mean=0.505,std=0.066 shared1_alpha:mean=0.621,std=0.074 shared2_alpha:mean=0.501,std=0.063 eff_mlp_scale:[v0:7.9922 v1:7.3949 v2:7.3709 v3:6.4275 v4:7.3949 v5:7.3709 v6:6.4275 v7:12.1843] eff_attn_scale:[v0:0.0541 v1:0.1289 v2:0.1493 v3:0.2263 v4:0.1289 v5:0.1493 v6:0.2263 v7:0.3473] +step:17000/20000 val_loss:2.1584 val_bpb:1.2783 train_time:588086ms step_avg:34.59ms +step:17200/20000 train_loss:2.1683 train_time:595009ms step_avg:34.59ms +step:17200 shared0_alpha:mean=0.504,std=0.066 shared1_alpha:mean=0.621,std=0.074 shared2_alpha:mean=0.500,std=0.063 eff_mlp_scale:[v0:7.9966 v1:7.4033 v2:7.3783 v3:6.4359 v4:7.4033 v5:7.3783 v6:6.4359 v7:12.2067] eff_attn_scale:[v0:0.0536 v1:0.1284 v2:0.1490 v3:0.2281 v4:0.1284 v5:0.1490 v6:0.2281 v7:0.3491] +step:17200/20000 val_loss:2.1513 val_bpb:1.2741 train_time:595014ms step_avg:34.59ms +step:17346/20000 val_loss:2.1469 val_bpb:1.2715 train_time:600005ms step_avg:34.59ms +stopping_early: wallclock_cap train_time:600005ms step:17346/20000 +peak memory allocated: 7969 MiB reserved: 8166 MiB +Serialized model: 37799662 bytes +Code size: 57024 bytes +Total submission size: 37856686 bytes +Serialized model int8+zlib: 8999912 bytes (payload:9762976 raw_torch:9788397 payload_ratio:3.87x) +Total submission size int8+zlib: 9056936 bytes +final_int8_zlib_roundtrip val_loss:2.1707 val_bpb:1.2856 eval_time:1087ms +final_int8_zlib_roundtrip_exact val_loss:2.17073266 val_bpb:1.28563015 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_H.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_H.txt new file mode 100644 index 0000000000..dd19fe1ef2 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_H.txt @@ -0,0 +1,1643 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + + def get(self, v: int) -> tuple[Tensor, Tensor]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + return ag, mg + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 18:56:19 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 41C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 39C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 41C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 40C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 33C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11543600 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:2 coda:1 effective_layers:10 +peri_norm:True birkhoff_mix:True +timestep_scale:disabled +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9377 train_time:27ms step_avg:26.61ms +step:2/20000 train_loss:9.6277 train_time:72ms step_avg:35.83ms +step:3/20000 train_loss:7.3691 train_time:116ms step_avg:38.67ms +step:4/20000 train_loss:9.0640 train_time:157ms step_avg:39.15ms +step:5/20000 train_loss:8.4299 train_time:197ms step_avg:39.47ms +step:6/20000 train_loss:8.0867 train_time:238ms step_avg:39.67ms +step:7/20000 train_loss:6.8166 train_time:279ms step_avg:39.86ms +step:8/20000 train_loss:6.1507 train_time:320ms step_avg:40.01ms +step:9/20000 train_loss:5.5943 train_time:361ms step_avg:40.15ms +step:10/20000 train_loss:5.3333 train_time:403ms step_avg:40.27ms +step:200/20000 train_loss:2.8022 train_time:8322ms step_avg:41.61ms +step:200 shared0_alpha:mean=0.464,std=0.056 shared1_alpha:mean=0.485,std=0.044 shared2_alpha:mean=0.495,std=0.043 shared3_alpha:mean=0.526,std=0.045 eff_mlp_scale:[v0:1.2901 v1:1.2417 v2:1.2140 v3:1.3137 v4:1.4145 v5:1.2417 v6:1.2140 v7:1.3137 v8:1.4145 v9:1.9314] eff_attn_scale:[v0:0.7182 v1:0.5813 v2:0.5307 v3:0.5101 v4:0.4794 v5:0.5813 v6:0.5307 v7:0.5101 v8:0.4794 v9:0.7829] +step:200/20000 val_loss:2.7816 val_bpb:1.6474 train_time:8474ms step_avg:42.37ms +step:400/20000 train_loss:2.3679 train_time:16827ms step_avg:42.07ms +step:400 shared0_alpha:mean=0.495,std=0.058 shared1_alpha:mean=0.516,std=0.044 shared2_alpha:mean=0.528,std=0.042 shared3_alpha:mean=0.563,std=0.046 eff_mlp_scale:[v0:1.6072 v1:1.5561 v2:1.4940 v3:1.5592 v4:1.5642 v5:1.5561 v6:1.4940 v7:1.5592 v8:1.5642 v9:2.4459] eff_attn_scale:[v0:0.3301 v1:0.2863 v2:0.2525 v3:0.2372 v4:0.2347 v5:0.2863 v6:0.2525 v7:0.2372 v8:0.2347 v9:0.4861] +step:400/20000 val_loss:2.5732 val_bpb:1.5240 train_time:16840ms step_avg:42.10ms +step:600/20000 train_loss:2.5848 train_time:25206ms step_avg:42.01ms +step:600 shared0_alpha:mean=0.511,std=0.057 shared1_alpha:mean=0.532,std=0.045 shared2_alpha:mean=0.548,std=0.041 shared3_alpha:mean=0.587,std=0.046 eff_mlp_scale:[v0:1.8455 v1:1.7688 v2:1.6821 v3:1.7205 v4:1.6832 v5:1.7688 v6:1.6821 v7:1.7205 v8:1.6832 v9:2.9292] eff_attn_scale:[v0:0.1628 v1:0.1614 v2:0.1492 v3:0.1418 v4:0.1549 v5:0.1614 v6:0.1492 v7:0.1418 v8:0.1549 v9:0.3379] +step:600/20000 val_loss:2.4843 val_bpb:1.4713 train_time:25220ms step_avg:42.03ms +step:800/20000 train_loss:2.3457 train_time:33593ms step_avg:41.99ms +step:800 shared0_alpha:mean=0.520,std=0.056 shared1_alpha:mean=0.540,std=0.046 shared2_alpha:mean=0.559,std=0.042 shared3_alpha:mean=0.600,std=0.046 eff_mlp_scale:[v0:2.0676 v1:1.9528 v2:1.8465 v3:1.8647 v4:1.8096 v5:1.9528 v6:1.8465 v7:1.8647 v8:1.8096 v9:3.3125] eff_attn_scale:[v0:0.1095 v1:0.1158 v2:0.1099 v3:0.1065 v4:0.1235 v5:0.1158 v6:0.1099 v7:0.1065 v8:0.1235 v9:0.2745] +step:800/20000 val_loss:2.4307 val_bpb:1.4396 train_time:33606ms step_avg:42.01ms +step:1000/20000 train_loss:2.4212 train_time:41997ms step_avg:42.00ms +step:1000 shared0_alpha:mean=0.522,std=0.056 shared1_alpha:mean=0.544,std=0.047 shared2_alpha:mean=0.566,std=0.042 shared3_alpha:mean=0.609,std=0.047 eff_mlp_scale:[v0:2.2749 v1:2.1287 v2:2.0046 v3:2.0106 v4:1.9387 v5:2.1287 v6:2.0046 v7:2.0106 v8:1.9387 v9:3.6303] eff_attn_scale:[v0:0.0886 v1:0.0975 v2:0.0942 v3:0.0926 v4:0.1088 v5:0.0975 v6:0.0942 v7:0.0926 v8:0.1088 v9:0.2447] +step:1000/20000 val_loss:2.3880 val_bpb:1.4143 train_time:42012ms step_avg:42.01ms +step:1200/20000 train_loss:2.4419 train_time:50405ms step_avg:42.00ms +step:1200 shared0_alpha:mean=0.523,std=0.056 shared1_alpha:mean=0.548,std=0.048 shared2_alpha:mean=0.571,std=0.043 shared3_alpha:mean=0.616,std=0.048 eff_mlp_scale:[v0:2.4766 v1:2.2909 v2:2.1511 v3:2.1496 v4:2.0741 v5:2.2909 v6:2.1511 v7:2.1496 v8:2.0741 v9:3.9160] eff_attn_scale:[v0:0.0795 v1:0.0865 v2:0.0868 v3:0.0848 v4:0.1006 v5:0.0865 v6:0.0868 v7:0.0848 v8:0.1006 v9:0.2293] +step:1200/20000 val_loss:2.3601 val_bpb:1.3978 train_time:50419ms step_avg:42.02ms +step:1400/20000 train_loss:2.4937 train_time:58820ms step_avg:42.01ms +step:1400 shared0_alpha:mean=0.522,std=0.056 shared1_alpha:mean=0.551,std=0.048 shared2_alpha:mean=0.576,std=0.044 shared3_alpha:mean=0.620,std=0.048 eff_mlp_scale:[v0:2.6656 v1:2.4508 v2:2.2937 v3:2.2831 v4:2.2025 v5:2.4508 v6:2.2937 v7:2.2831 v8:2.2025 v9:4.1706] eff_attn_scale:[v0:0.0740 v1:0.0806 v2:0.0808 v3:0.0807 v4:0.0944 v5:0.0806 v6:0.0808 v7:0.0807 v8:0.0944 v9:0.2233] +step:1400/20000 val_loss:2.3418 val_bpb:1.3869 train_time:58833ms step_avg:42.02ms +step:1600/20000 train_loss:2.1668 train_time:67225ms step_avg:42.02ms +step:1600 shared0_alpha:mean=0.522,std=0.057 shared1_alpha:mean=0.552,std=0.049 shared2_alpha:mean=0.579,std=0.044 shared3_alpha:mean=0.623,std=0.049 eff_mlp_scale:[v0:2.8483 v1:2.6024 v2:2.4274 v3:2.4089 v4:2.3277 v5:2.6024 v6:2.4274 v7:2.4089 v8:2.3277 v9:4.4020] eff_attn_scale:[v0:0.0700 v1:0.0776 v2:0.0783 v3:0.0777 v4:0.0918 v5:0.0776 v6:0.0783 v7:0.0777 v8:0.0918 v9:0.2222] +step:1600/20000 val_loss:2.3314 val_bpb:1.3808 train_time:67239ms step_avg:42.02ms +step:1800/20000 train_loss:2.2623 train_time:75627ms step_avg:42.02ms +step:1800 shared0_alpha:mean=0.520,std=0.057 shared1_alpha:mean=0.553,std=0.050 shared2_alpha:mean=0.582,std=0.045 shared3_alpha:mean=0.626,std=0.049 eff_mlp_scale:[v0:3.0185 v1:2.7397 v2:2.5600 v3:2.5353 v4:2.4539 v5:2.7397 v6:2.5600 v7:2.5353 v8:2.4539 v9:4.6232] eff_attn_scale:[v0:0.0671 v1:0.0743 v2:0.0770 v3:0.0755 v4:0.0886 v5:0.0743 v6:0.0770 v7:0.0755 v8:0.0886 v9:0.2183] +step:1800/20000 val_loss:2.3142 val_bpb:1.3706 train_time:75641ms step_avg:42.02ms +step:2000/20000 train_loss:2.3133 train_time:84026ms step_avg:42.01ms +step:2000 shared0_alpha:mean=0.520,std=0.058 shared1_alpha:mean=0.554,std=0.050 shared2_alpha:mean=0.584,std=0.047 shared3_alpha:mean=0.629,std=0.050 eff_mlp_scale:[v0:3.1911 v1:2.8836 v2:2.6869 v3:2.6575 v4:2.5679 v5:2.8836 v6:2.6869 v7:2.6575 v8:2.5679 v9:4.8318] eff_attn_scale:[v0:0.0645 v1:0.0725 v2:0.0736 v3:0.0738 v4:0.0869 v5:0.0725 v6:0.0736 v7:0.0738 v8:0.0869 v9:0.2159] +step:2000/20000 val_loss:2.2996 val_bpb:1.3619 train_time:84040ms step_avg:42.02ms +step:2200/20000 train_loss:2.1418 train_time:92426ms step_avg:42.01ms +step:2200 shared0_alpha:mean=0.518,std=0.059 shared1_alpha:mean=0.555,std=0.051 shared2_alpha:mean=0.586,std=0.047 shared3_alpha:mean=0.631,std=0.050 eff_mlp_scale:[v0:3.3532 v1:3.0230 v2:2.8109 v3:2.7763 v4:2.6898 v5:3.0230 v6:2.8109 v7:2.7763 v8:2.6898 v9:5.0332] eff_attn_scale:[v0:0.0634 v1:0.0716 v2:0.0739 v3:0.0742 v4:0.0868 v5:0.0716 v6:0.0739 v7:0.0742 v8:0.0868 v9:0.2164] +step:2200/20000 val_loss:2.2937 val_bpb:1.3585 train_time:92440ms step_avg:42.02ms +step:2400/20000 train_loss:2.2714 train_time:100828ms step_avg:42.01ms +step:2400 shared0_alpha:mean=0.517,std=0.059 shared1_alpha:mean=0.556,std=0.051 shared2_alpha:mean=0.588,std=0.047 shared3_alpha:mean=0.632,std=0.050 eff_mlp_scale:[v0:3.5114 v1:3.1575 v2:2.9261 v3:2.8918 v4:2.8026 v5:3.1575 v6:2.9261 v7:2.8918 v8:2.8026 v9:5.2329] eff_attn_scale:[v0:0.0621 v1:0.0707 v2:0.0723 v3:0.0729 v4:0.0871 v5:0.0707 v6:0.0723 v7:0.0729 v8:0.0871 v9:0.2167] +step:2400/20000 val_loss:2.2840 val_bpb:1.3527 train_time:100841ms step_avg:42.02ms +step:2600/20000 train_loss:2.4844 train_time:109221ms step_avg:42.01ms +step:2600 shared0_alpha:mean=0.517,std=0.060 shared1_alpha:mean=0.556,std=0.052 shared2_alpha:mean=0.591,std=0.049 shared3_alpha:mean=0.634,std=0.051 eff_mlp_scale:[v0:3.6705 v1:3.2827 v2:3.0487 v3:3.0051 v4:2.9184 v5:3.2827 v6:3.0487 v7:3.0051 v8:2.9184 v9:5.4064] eff_attn_scale:[v0:0.0655 v1:0.0723 v2:0.0752 v3:0.0762 v4:0.0867 v5:0.0723 v6:0.0752 v7:0.0762 v8:0.0867 v9:0.2197] +step:2600/20000 val_loss:2.2994 val_bpb:1.3618 train_time:109235ms step_avg:42.01ms +step:2800/20000 train_loss:2.3100 train_time:117611ms step_avg:42.00ms +step:2800 shared0_alpha:mean=0.516,std=0.060 shared1_alpha:mean=0.557,std=0.053 shared2_alpha:mean=0.593,std=0.049 shared3_alpha:mean=0.635,std=0.052 eff_mlp_scale:[v0:3.8320 v1:3.4193 v2:3.1671 v3:3.1209 v4:3.0296 v5:3.4193 v6:3.1671 v7:3.1209 v8:3.0296 v9:5.5935] eff_attn_scale:[v0:0.0624 v1:0.0699 v2:0.0727 v3:0.0746 v4:0.0866 v5:0.0699 v6:0.0727 v7:0.0746 v8:0.0866 v9:0.2180] +step:2800/20000 val_loss:2.2714 val_bpb:1.3453 train_time:117625ms step_avg:42.01ms +step:3000/20000 train_loss:2.2891 train_time:126011ms step_avg:42.00ms +step:3000 shared0_alpha:mean=0.515,std=0.061 shared1_alpha:mean=0.557,std=0.053 shared2_alpha:mean=0.594,std=0.049 shared3_alpha:mean=0.636,std=0.052 eff_mlp_scale:[v0:3.9761 v1:3.5337 v2:3.2754 v3:3.2304 v4:3.1390 v5:3.5337 v6:3.2754 v7:3.2304 v8:3.1390 v9:5.7566] eff_attn_scale:[v0:0.0604 v1:0.0697 v2:0.0724 v3:0.0740 v4:0.0860 v5:0.0697 v6:0.0724 v7:0.0740 v8:0.0860 v9:0.2185] +step:3000/20000 val_loss:2.2637 val_bpb:1.3407 train_time:126024ms step_avg:42.01ms +step:3200/20000 train_loss:2.2588 train_time:134400ms step_avg:42.00ms +step:3200 shared0_alpha:mean=0.514,std=0.061 shared1_alpha:mean=0.557,std=0.053 shared2_alpha:mean=0.596,std=0.050 shared3_alpha:mean=0.637,std=0.053 eff_mlp_scale:[v0:4.1234 v1:3.6527 v2:3.3889 v3:3.3368 v4:3.2460 v5:3.6527 v6:3.3889 v7:3.3368 v8:3.2460 v9:5.9190] eff_attn_scale:[v0:0.0608 v1:0.0687 v2:0.0724 v3:0.0734 v4:0.0867 v5:0.0687 v6:0.0724 v7:0.0734 v8:0.0867 v9:0.2197] +step:3200/20000 val_loss:2.2590 val_bpb:1.3379 train_time:134413ms step_avg:42.00ms +step:3400/20000 train_loss:2.2262 train_time:142797ms step_avg:42.00ms +step:3400 shared0_alpha:mean=0.514,std=0.061 shared1_alpha:mean=0.558,std=0.054 shared2_alpha:mean=0.598,std=0.051 shared3_alpha:mean=0.639,std=0.053 eff_mlp_scale:[v0:4.2675 v1:3.7679 v2:3.4946 v3:3.4441 v4:3.3489 v5:3.7679 v6:3.4946 v7:3.4441 v8:3.3489 v9:6.0748] eff_attn_scale:[v0:0.0599 v1:0.0696 v2:0.0722 v3:0.0741 v4:0.0868 v5:0.0696 v6:0.0722 v7:0.0741 v8:0.0868 v9:0.2211] +step:3400/20000 val_loss:2.2560 val_bpb:1.3362 train_time:142810ms step_avg:42.00ms +step:3600/20000 train_loss:2.1921 train_time:151267ms step_avg:42.02ms +step:3600 shared0_alpha:mean=0.513,std=0.062 shared1_alpha:mean=0.558,std=0.054 shared2_alpha:mean=0.600,std=0.051 shared3_alpha:mean=0.640,std=0.053 eff_mlp_scale:[v0:4.4002 v1:3.8864 v2:3.6015 v3:3.5464 v4:3.4553 v5:3.8864 v6:3.6015 v7:3.5464 v8:3.4553 v9:6.2243] eff_attn_scale:[v0:0.0599 v1:0.0703 v2:0.0729 v3:0.0756 v4:0.0880 v5:0.0703 v6:0.0729 v7:0.0756 v8:0.0880 v9:0.2244] +step:3600/20000 val_loss:2.2498 val_bpb:1.3324 train_time:151280ms step_avg:42.02ms +step:3800/20000 train_loss:2.2971 train_time:159655ms step_avg:42.01ms +step:3800 shared0_alpha:mean=0.512,std=0.064 shared1_alpha:mean=0.559,std=0.055 shared2_alpha:mean=0.602,std=0.052 shared3_alpha:mean=0.640,std=0.053 eff_mlp_scale:[v0:4.5359 v1:3.9977 v2:3.7034 v3:3.6453 v4:3.5530 v5:3.9977 v6:3.7034 v7:3.6453 v8:3.5530 v9:6.3718] eff_attn_scale:[v0:0.0604 v1:0.0707 v2:0.0733 v3:0.0767 v4:0.0892 v5:0.0707 v6:0.0733 v7:0.0767 v8:0.0892 v9:0.2257] +step:3800/20000 val_loss:2.2453 val_bpb:1.3298 train_time:159668ms step_avg:42.02ms +step:4000/20000 train_loss:2.2309 train_time:168042ms step_avg:42.01ms +step:4000 shared0_alpha:mean=0.511,std=0.063 shared1_alpha:mean=0.558,std=0.055 shared2_alpha:mean=0.602,std=0.052 shared3_alpha:mean=0.641,std=0.054 eff_mlp_scale:[v0:4.6753 v1:4.1069 v2:3.8022 v3:3.7480 v4:3.6533 v5:4.1069 v6:3.8022 v7:3.7480 v8:3.6533 v9:6.5172] eff_attn_scale:[v0:0.0597 v1:0.0702 v2:0.0725 v3:0.0763 v4:0.0896 v5:0.0702 v6:0.0725 v7:0.0763 v8:0.0896 v9:0.2268] +step:4000/20000 val_loss:2.2412 val_bpb:1.3273 train_time:168056ms step_avg:42.01ms +step:4200/20000 train_loss:2.2431 train_time:176495ms step_avg:42.02ms +step:4200 shared0_alpha:mean=0.510,std=0.064 shared1_alpha:mean=0.559,std=0.056 shared2_alpha:mean=0.604,std=0.053 shared3_alpha:mean=0.642,std=0.053 eff_mlp_scale:[v0:4.8021 v1:4.2121 v2:3.8976 v3:3.8383 v4:3.7557 v5:4.2121 v6:3.8976 v7:3.8383 v8:3.7557 v9:6.6528] eff_attn_scale:[v0:0.0594 v1:0.0701 v2:0.0734 v3:0.0767 v4:0.0891 v5:0.0701 v6:0.0734 v7:0.0767 v8:0.0891 v9:0.2294] +step:4200/20000 val_loss:2.2369 val_bpb:1.3248 train_time:176507ms step_avg:42.03ms +step:4400/20000 train_loss:2.1889 train_time:184872ms step_avg:42.02ms +step:4400 shared0_alpha:mean=0.510,std=0.064 shared1_alpha:mean=0.559,std=0.057 shared2_alpha:mean=0.605,std=0.053 shared3_alpha:mean=0.643,std=0.055 eff_mlp_scale:[v0:4.9285 v1:4.3136 v2:3.9935 v3:3.9332 v4:3.8478 v5:4.3136 v6:3.9935 v7:3.9332 v8:3.8478 v9:6.7883] eff_attn_scale:[v0:0.0594 v1:0.0705 v2:0.0746 v3:0.0769 v4:0.0917 v5:0.0705 v6:0.0746 v7:0.0769 v8:0.0917 v9:0.2315] +step:4400/20000 val_loss:2.2387 val_bpb:1.3259 train_time:184885ms step_avg:42.02ms +step:4600/20000 train_loss:2.0442 train_time:193252ms step_avg:42.01ms +step:4600 shared0_alpha:mean=0.511,std=0.065 shared1_alpha:mean=0.559,std=0.057 shared2_alpha:mean=0.607,std=0.053 shared3_alpha:mean=0.643,std=0.054 eff_mlp_scale:[v0:5.0652 v1:4.4227 v2:4.0956 v3:4.0341 v4:3.9476 v5:4.4227 v6:4.0956 v7:4.0341 v8:3.9476 v9:6.9208] eff_attn_scale:[v0:0.0603 v1:0.0711 v2:0.0748 v3:0.0790 v4:0.0922 v5:0.0711 v6:0.0748 v7:0.0790 v8:0.0922 v9:0.2353] +step:4600/20000 val_loss:2.2352 val_bpb:1.3238 train_time:193267ms step_avg:42.01ms +step:4800/20000 train_loss:2.3370 train_time:201635ms step_avg:42.01ms +step:4800 shared0_alpha:mean=0.510,std=0.065 shared1_alpha:mean=0.560,std=0.057 shared2_alpha:mean=0.607,std=0.054 shared3_alpha:mean=0.644,std=0.055 eff_mlp_scale:[v0:5.1917 v1:4.5209 v2:4.1799 v3:4.1216 v4:4.0450 v5:4.5209 v6:4.1799 v7:4.1216 v8:4.0450 v9:7.0452] eff_attn_scale:[v0:0.0591 v1:0.0716 v2:0.0759 v3:0.0805 v4:0.0940 v5:0.0716 v6:0.0759 v7:0.0805 v8:0.0940 v9:0.2371] +step:4800/20000 val_loss:2.2308 val_bpb:1.3212 train_time:201648ms step_avg:42.01ms +step:5000/20000 train_loss:2.1105 train_time:210019ms step_avg:42.00ms +step:5000 shared0_alpha:mean=0.509,std=0.066 shared1_alpha:mean=0.559,std=0.058 shared2_alpha:mean=0.608,std=0.055 shared3_alpha:mean=0.644,std=0.055 eff_mlp_scale:[v0:5.3103 v1:4.6188 v2:4.2735 v3:4.2090 v4:4.1337 v5:4.6188 v6:4.2735 v7:4.2090 v8:4.1337 v9:7.1662] eff_attn_scale:[v0:0.0602 v1:0.0723 v2:0.0762 v3:0.0813 v4:0.0945 v5:0.0723 v6:0.0762 v7:0.0813 v8:0.0945 v9:0.2400] +step:5000/20000 val_loss:2.2261 val_bpb:1.3184 train_time:210033ms step_avg:42.01ms +step:5200/20000 train_loss:2.2490 train_time:218399ms step_avg:42.00ms +step:5200 shared0_alpha:mean=0.509,std=0.067 shared1_alpha:mean=0.560,std=0.058 shared2_alpha:mean=0.609,std=0.054 shared3_alpha:mean=0.645,std=0.055 eff_mlp_scale:[v0:5.4376 v1:4.7135 v2:4.3587 v3:4.3000 v4:4.2222 v5:4.7135 v6:4.3587 v7:4.3000 v8:4.2222 v9:7.2883] eff_attn_scale:[v0:0.0586 v1:0.0733 v2:0.0782 v3:0.0837 v4:0.0968 v5:0.0733 v6:0.0782 v7:0.0837 v8:0.0968 v9:0.2422] +step:5200/20000 val_loss:2.2270 val_bpb:1.3189 train_time:218412ms step_avg:42.00ms +step:5400/20000 train_loss:2.2614 train_time:226776ms step_avg:42.00ms +step:5400 shared0_alpha:mean=0.508,std=0.067 shared1_alpha:mean=0.560,std=0.059 shared2_alpha:mean=0.610,std=0.055 shared3_alpha:mean=0.645,std=0.055 eff_mlp_scale:[v0:5.5550 v1:4.8093 v2:4.4453 v3:4.3808 v4:4.3111 v5:4.8093 v6:4.4453 v7:4.3808 v8:4.3111 v9:7.4064] eff_attn_scale:[v0:0.0608 v1:0.0738 v2:0.0777 v3:0.0843 v4:0.0976 v5:0.0738 v6:0.0777 v7:0.0843 v8:0.0976 v9:0.2451] +step:5400/20000 val_loss:2.2218 val_bpb:1.3159 train_time:226789ms step_avg:42.00ms +step:5600/20000 train_loss:2.2560 train_time:235150ms step_avg:41.99ms +step:5600 shared0_alpha:mean=0.508,std=0.067 shared1_alpha:mean=0.560,std=0.059 shared2_alpha:mean=0.611,std=0.055 shared3_alpha:mean=0.646,std=0.055 eff_mlp_scale:[v0:5.6646 v1:4.8966 v2:4.5245 v3:4.4625 v4:4.3961 v5:4.8966 v6:4.5245 v7:4.4625 v8:4.3961 v9:7.5149] eff_attn_scale:[v0:0.0595 v1:0.0751 v2:0.0812 v3:0.0864 v4:0.0997 v5:0.0751 v6:0.0812 v7:0.0864 v8:0.0997 v9:0.2483] +step:5600/20000 val_loss:2.2216 val_bpb:1.3158 train_time:235164ms step_avg:41.99ms +step:5800/20000 train_loss:2.2172 train_time:243525ms step_avg:41.99ms +step:5800 shared0_alpha:mean=0.509,std=0.067 shared1_alpha:mean=0.561,std=0.060 shared2_alpha:mean=0.612,std=0.055 shared3_alpha:mean=0.646,std=0.056 eff_mlp_scale:[v0:5.7927 v1:4.9872 v2:4.6062 v3:4.5431 v4:4.4706 v5:4.9872 v6:4.6062 v7:4.5431 v8:4.4706 v9:7.6395] eff_attn_scale:[v0:0.0598 v1:0.0771 v2:0.0820 v3:0.0894 v4:0.1015 v5:0.0771 v6:0.0820 v7:0.0894 v8:0.1015 v9:0.2510] +step:5800/20000 val_loss:2.2206 val_bpb:1.3152 train_time:243537ms step_avg:41.99ms +step:6000/20000 train_loss:2.2889 train_time:251895ms step_avg:41.98ms +step:6000 shared0_alpha:mean=0.508,std=0.068 shared1_alpha:mean=0.561,std=0.060 shared2_alpha:mean=0.612,std=0.055 shared3_alpha:mean=0.646,std=0.056 eff_mlp_scale:[v0:5.9073 v1:5.0805 v2:4.6856 v3:4.6189 v4:4.5503 v5:5.0805 v6:4.6856 v7:4.6189 v8:4.5503 v9:7.7475] eff_attn_scale:[v0:0.0598 v1:0.0782 v2:0.0835 v3:0.0903 v4:0.1033 v5:0.0782 v6:0.0835 v7:0.0903 v8:0.1033 v9:0.2547] +step:6000/20000 val_loss:2.2160 val_bpb:1.3125 train_time:251907ms step_avg:41.98ms +step:6200/20000 train_loss:2.1628 train_time:260278ms step_avg:41.98ms +step:6200 shared0_alpha:mean=0.507,std=0.068 shared1_alpha:mean=0.560,std=0.061 shared2_alpha:mean=0.613,std=0.055 shared3_alpha:mean=0.647,std=0.056 eff_mlp_scale:[v0:6.0118 v1:5.1582 v2:4.7615 v3:4.6907 v4:4.6259 v5:5.1582 v6:4.7615 v7:4.6907 v8:4.6259 v9:7.8498] eff_attn_scale:[v0:0.0605 v1:0.0792 v2:0.0858 v3:0.0924 v4:0.1044 v5:0.0792 v6:0.0858 v7:0.0924 v8:0.1044 v9:0.2550] +step:6200/20000 val_loss:2.2158 val_bpb:1.3123 train_time:260292ms step_avg:41.98ms +step:6400/20000 train_loss:2.2367 train_time:268651ms step_avg:41.98ms +step:6400 shared0_alpha:mean=0.507,std=0.068 shared1_alpha:mean=0.561,std=0.060 shared2_alpha:mean=0.614,std=0.056 shared3_alpha:mean=0.648,std=0.056 eff_mlp_scale:[v0:6.1236 v1:5.2414 v2:4.8321 v3:4.7699 v4:4.7049 v5:5.2414 v6:4.8321 v7:4.7699 v8:4.7049 v9:7.9627] eff_attn_scale:[v0:0.0589 v1:0.0810 v2:0.0870 v3:0.0937 v4:0.1087 v5:0.0810 v6:0.0870 v7:0.0937 v8:0.1087 v9:0.2610] +step:6400/20000 val_loss:2.2136 val_bpb:1.3110 train_time:268664ms step_avg:41.98ms +step:6600/20000 train_loss:2.2030 train_time:277028ms step_avg:41.97ms +step:6600 shared0_alpha:mean=0.505,std=0.069 shared1_alpha:mean=0.560,std=0.061 shared2_alpha:mean=0.614,std=0.056 shared3_alpha:mean=0.648,std=0.056 eff_mlp_scale:[v0:6.2243 v1:5.3100 v2:4.8974 v3:4.8315 v4:4.7805 v5:5.3100 v6:4.8974 v7:4.8315 v8:4.7805 v9:8.0691] eff_attn_scale:[v0:0.0601 v1:0.0823 v2:0.0884 v3:0.0961 v4:0.1096 v5:0.0823 v6:0.0884 v7:0.0961 v8:0.1096 v9:0.2635] +step:6600/20000 val_loss:2.2090 val_bpb:1.3083 train_time:277041ms step_avg:41.98ms +step:6800/20000 train_loss:2.2649 train_time:285402ms step_avg:41.97ms +step:6800 shared0_alpha:mean=0.504,std=0.069 shared1_alpha:mean=0.560,std=0.061 shared2_alpha:mean=0.615,std=0.056 shared3_alpha:mean=0.649,std=0.056 eff_mlp_scale:[v0:6.3201 v1:5.3770 v2:4.9601 v3:4.9003 v4:4.8488 v5:5.3770 v6:4.9601 v7:4.9003 v8:4.8488 v9:8.1726] eff_attn_scale:[v0:0.0594 v1:0.0842 v2:0.0888 v3:0.0975 v4:0.1111 v5:0.0842 v6:0.0888 v7:0.0975 v8:0.1111 v9:0.2659] +step:6800/20000 val_loss:2.2075 val_bpb:1.3074 train_time:285415ms step_avg:41.97ms +step:7000/20000 train_loss:2.2982 train_time:293777ms step_avg:41.97ms +step:7000 shared0_alpha:mean=0.503,std=0.068 shared1_alpha:mean=0.561,std=0.061 shared2_alpha:mean=0.616,std=0.057 shared3_alpha:mean=0.649,std=0.057 eff_mlp_scale:[v0:6.4226 v1:5.4445 v2:5.0240 v3:4.9649 v4:4.9226 v5:5.4445 v6:5.0240 v7:4.9649 v8:4.9226 v9:8.2793] eff_attn_scale:[v0:0.0595 v1:0.0871 v2:0.0940 v3:0.1026 v4:0.1146 v5:0.0871 v6:0.0940 v7:0.1026 v8:0.1146 v9:0.2688] +step:7000/20000 val_loss:2.2059 val_bpb:1.3065 train_time:293790ms step_avg:41.97ms +step:7200/20000 train_loss:2.2722 train_time:302148ms step_avg:41.97ms +step:7200 shared0_alpha:mean=0.501,std=0.069 shared1_alpha:mean=0.560,std=0.062 shared2_alpha:mean=0.616,std=0.057 shared3_alpha:mean=0.649,std=0.056 eff_mlp_scale:[v0:6.5156 v1:5.5060 v2:5.0792 v3:5.0270 v4:4.9866 v5:5.5060 v6:5.0792 v7:5.0270 v8:4.9866 v9:8.3818] eff_attn_scale:[v0:0.0589 v1:0.0897 v2:0.0962 v3:0.1035 v4:0.1172 v5:0.0897 v6:0.0962 v7:0.1035 v8:0.1172 v9:0.2723] +step:7200/20000 val_loss:2.2044 val_bpb:1.3056 train_time:302161ms step_avg:41.97ms +step:7400/20000 train_loss:2.1897 train_time:310519ms step_avg:41.96ms +step:7400 shared0_alpha:mean=0.499,std=0.069 shared1_alpha:mean=0.560,std=0.062 shared2_alpha:mean=0.617,std=0.057 shared3_alpha:mean=0.650,std=0.057 eff_mlp_scale:[v0:6.6133 v1:5.5623 v2:5.1312 v3:5.0851 v4:5.0530 v5:5.5623 v6:5.1312 v7:5.0851 v8:5.0530 v9:8.4850] eff_attn_scale:[v0:0.0595 v1:0.0926 v2:0.0975 v3:0.1056 v4:0.1187 v5:0.0926 v6:0.0975 v7:0.1056 v8:0.1187 v9:0.2768] +step:7400/20000 val_loss:2.2029 val_bpb:1.3047 train_time:310532ms step_avg:41.96ms +step:7600/20000 train_loss:2.0779 train_time:318885ms step_avg:41.96ms +step:7600 shared0_alpha:mean=0.496,std=0.070 shared1_alpha:mean=0.561,std=0.062 shared2_alpha:mean=0.618,std=0.057 shared3_alpha:mean=0.651,std=0.057 eff_mlp_scale:[v0:6.6993 v1:5.6149 v2:5.1843 v3:5.1408 v4:5.1199 v5:5.6149 v6:5.1843 v7:5.1408 v8:5.1199 v9:8.5855] eff_attn_scale:[v0:0.0586 v1:0.0960 v2:0.1001 v3:0.1092 v4:0.1217 v5:0.0960 v6:0.1001 v7:0.1092 v8:0.1217 v9:0.2814] +step:7600/20000 val_loss:2.2018 val_bpb:1.3041 train_time:318898ms step_avg:41.96ms +step:7800/20000 train_loss:2.2223 train_time:327261ms step_avg:41.96ms +step:7800 shared0_alpha:mean=0.495,std=0.070 shared1_alpha:mean=0.561,std=0.061 shared2_alpha:mean=0.618,std=0.057 shared3_alpha:mean=0.651,std=0.057 eff_mlp_scale:[v0:6.7876 v1:5.6698 v2:5.2348 v3:5.1983 v4:5.1846 v5:5.6698 v6:5.2348 v7:5.1983 v8:5.1846 v9:8.6880] eff_attn_scale:[v0:0.0595 v1:0.0978 v2:0.1039 v3:0.1108 v4:0.1240 v5:0.0978 v6:0.1039 v7:0.1108 v8:0.1240 v9:0.2847] +step:7800/20000 val_loss:2.1988 val_bpb:1.3023 train_time:327274ms step_avg:41.96ms +step:8000/20000 train_loss:2.1814 train_time:335631ms step_avg:41.95ms +step:8000 shared0_alpha:mean=0.493,std=0.070 shared1_alpha:mean=0.562,std=0.062 shared2_alpha:mean=0.620,std=0.057 shared3_alpha:mean=0.652,std=0.057 eff_mlp_scale:[v0:6.8739 v1:5.7163 v2:5.2860 v3:5.2538 v4:5.2461 v5:5.7163 v6:5.2860 v7:5.2538 v8:5.2461 v9:8.7894] eff_attn_scale:[v0:0.0591 v1:0.1017 v2:0.1061 v3:0.1138 v4:0.1249 v5:0.1017 v6:0.1061 v7:0.1138 v8:0.1249 v9:0.2864] +step:8000/20000 val_loss:2.1955 val_bpb:1.3003 train_time:335644ms step_avg:41.96ms +step:8200/20000 train_loss:2.2580 train_time:343997ms step_avg:41.95ms +step:8200 shared0_alpha:mean=0.491,std=0.070 shared1_alpha:mean=0.562,std=0.062 shared2_alpha:mean=0.620,std=0.057 shared3_alpha:mean=0.653,std=0.057 eff_mlp_scale:[v0:6.9539 v1:5.7618 v2:5.3337 v3:5.3094 v4:5.3077 v5:5.7618 v6:5.3337 v7:5.3094 v8:5.3077 v9:8.8871] eff_attn_scale:[v0:0.0592 v1:0.1040 v2:0.1073 v3:0.1143 v4:0.1272 v5:0.1040 v6:0.1073 v7:0.1143 v8:0.1272 v9:0.2890] +step:8200/20000 val_loss:2.1954 val_bpb:1.3003 train_time:344010ms step_avg:41.95ms +step:8400/20000 train_loss:2.2084 train_time:352439ms step_avg:41.96ms +step:8400 shared0_alpha:mean=0.489,std=0.070 shared1_alpha:mean=0.562,std=0.061 shared2_alpha:mean=0.621,std=0.058 shared3_alpha:mean=0.653,std=0.058 eff_mlp_scale:[v0:7.0353 v1:5.8055 v2:5.3788 v3:5.3625 v4:5.3698 v5:5.8055 v6:5.3788 v7:5.3625 v8:5.3698 v9:8.9822] eff_attn_scale:[v0:0.0596 v1:0.1062 v2:0.1096 v3:0.1166 v4:0.1296 v5:0.1062 v6:0.1096 v7:0.1166 v8:0.1296 v9:0.2923] +step:8400/20000 val_loss:2.1947 val_bpb:1.2998 train_time:352450ms step_avg:41.96ms +step:8600/20000 train_loss:2.2118 train_time:360811ms step_avg:41.95ms +step:8600 shared0_alpha:mean=0.487,std=0.070 shared1_alpha:mean=0.563,std=0.062 shared2_alpha:mean=0.622,std=0.058 shared3_alpha:mean=0.654,std=0.057 eff_mlp_scale:[v0:7.1148 v1:5.8528 v2:5.4268 v3:5.4150 v4:5.4292 v5:5.8528 v6:5.4268 v7:5.4150 v8:5.4292 v9:9.0862] eff_attn_scale:[v0:0.0594 v1:0.1083 v2:0.1117 v3:0.1183 v4:0.1308 v5:0.1083 v6:0.1117 v7:0.1183 v8:0.1308 v9:0.2948] +step:8600/20000 val_loss:2.1920 val_bpb:1.2982 train_time:360824ms step_avg:41.96ms +step:8800/20000 train_loss:2.1823 train_time:369182ms step_avg:41.95ms +step:8800 shared0_alpha:mean=0.485,std=0.070 shared1_alpha:mean=0.563,std=0.062 shared2_alpha:mean=0.623,std=0.058 shared3_alpha:mean=0.654,std=0.058 eff_mlp_scale:[v0:7.1893 v1:5.8949 v2:5.4759 v3:5.4693 v4:5.4860 v5:5.8949 v6:5.4759 v7:5.4693 v8:5.4860 v9:9.1816] eff_attn_scale:[v0:0.0591 v1:0.1107 v2:0.1127 v3:0.1208 v4:0.1328 v5:0.1107 v6:0.1127 v7:0.1208 v8:0.1328 v9:0.2974] +step:8800/20000 val_loss:2.1910 val_bpb:1.2976 train_time:369195ms step_avg:41.95ms +step:9000/20000 train_loss:2.0963 train_time:377558ms step_avg:41.95ms +step:9000 shared0_alpha:mean=0.484,std=0.070 shared1_alpha:mean=0.564,std=0.062 shared2_alpha:mean=0.624,std=0.058 shared3_alpha:mean=0.655,std=0.058 eff_mlp_scale:[v0:7.2749 v1:5.9348 v2:5.5212 v3:5.5264 v4:5.5455 v5:5.9348 v6:5.5212 v7:5.5264 v8:5.5455 v9:9.2821] eff_attn_scale:[v0:0.0586 v1:0.1131 v2:0.1166 v3:0.1246 v4:0.1373 v5:0.1131 v6:0.1166 v7:0.1246 v8:0.1373 v9:0.3014] +step:9000/20000 val_loss:2.1921 val_bpb:1.2983 train_time:377572ms step_avg:41.95ms +step:9200/20000 train_loss:2.1541 train_time:385929ms step_avg:41.95ms +step:9200 shared0_alpha:mean=0.482,std=0.070 shared1_alpha:mean=0.564,std=0.062 shared2_alpha:mean=0.625,std=0.058 shared3_alpha:mean=0.655,std=0.058 eff_mlp_scale:[v0:7.3440 v1:5.9785 v2:5.5700 v3:5.5833 v4:5.6104 v5:5.9785 v6:5.5700 v7:5.5833 v8:5.6104 v9:9.3886] eff_attn_scale:[v0:0.0598 v1:0.1133 v2:0.1172 v3:0.1235 v4:0.1371 v5:0.1133 v6:0.1172 v7:0.1235 v8:0.1371 v9:0.3030] +step:9200/20000 val_loss:2.1896 val_bpb:1.2968 train_time:385943ms step_avg:41.95ms +step:9400/20000 train_loss:2.2160 train_time:394298ms step_avg:41.95ms +step:9400 shared0_alpha:mean=0.480,std=0.070 shared1_alpha:mean=0.564,std=0.062 shared2_alpha:mean=0.626,std=0.058 shared3_alpha:mean=0.656,std=0.058 eff_mlp_scale:[v0:7.4182 v1:6.0193 v2:5.6155 v3:5.6357 v4:5.6687 v5:6.0193 v6:5.6155 v7:5.6357 v8:5.6687 v9:9.4867] eff_attn_scale:[v0:0.0595 v1:0.1126 v2:0.1170 v3:0.1251 v4:0.1380 v5:0.1126 v6:0.1170 v7:0.1251 v8:0.1380 v9:0.3048] +step:9400/20000 val_loss:2.1879 val_bpb:1.2958 train_time:394311ms step_avg:41.95ms +step:9600/20000 train_loss:2.2207 train_time:402661ms step_avg:41.94ms +step:9600 shared0_alpha:mean=0.478,std=0.070 shared1_alpha:mean=0.564,std=0.062 shared2_alpha:mean=0.627,std=0.058 shared3_alpha:mean=0.656,std=0.058 eff_mlp_scale:[v0:7.4922 v1:6.0602 v2:5.6604 v3:5.6878 v4:5.7286 v5:6.0602 v6:5.6604 v7:5.6878 v8:5.7286 v9:9.5825] eff_attn_scale:[v0:0.0587 v1:0.1141 v2:0.1177 v3:0.1250 v4:0.1397 v5:0.1141 v6:0.1177 v7:0.1250 v8:0.1397 v9:0.3076] +step:9600/20000 val_loss:2.1867 val_bpb:1.2951 train_time:402674ms step_avg:41.95ms +step:9800/20000 train_loss:2.1468 train_time:411027ms step_avg:41.94ms +step:9800 shared0_alpha:mean=0.476,std=0.070 shared1_alpha:mean=0.564,std=0.062 shared2_alpha:mean=0.628,std=0.059 shared3_alpha:mean=0.657,std=0.059 eff_mlp_scale:[v0:7.5671 v1:6.0978 v2:5.7078 v3:5.7425 v4:5.7824 v5:6.0978 v6:5.7078 v7:5.7425 v8:5.7824 v9:9.6742] eff_attn_scale:[v0:0.0591 v1:0.1154 v2:0.1175 v3:0.1262 v4:0.1403 v5:0.1154 v6:0.1175 v7:0.1262 v8:0.1403 v9:0.3096] +step:9800/20000 val_loss:2.1868 val_bpb:1.2951 train_time:411041ms step_avg:41.94ms +step:10000/20000 train_loss:2.1912 train_time:419386ms step_avg:41.94ms +step:10000 shared0_alpha:mean=0.475,std=0.070 shared1_alpha:mean=0.564,std=0.062 shared2_alpha:mean=0.628,std=0.059 shared3_alpha:mean=0.657,std=0.058 eff_mlp_scale:[v0:7.6370 v1:6.1381 v2:5.7506 v3:5.7974 v4:5.8460 v5:6.1381 v6:5.7506 v7:5.7974 v8:5.8460 v9:9.7759] eff_attn_scale:[v0:0.0592 v1:0.1175 v2:0.1209 v3:0.1279 v4:0.1435 v5:0.1175 v6:0.1209 v7:0.1279 v8:0.1435 v9:0.3117] +step:10000/20000 val_loss:2.1866 val_bpb:1.2950 train_time:419409ms step_avg:41.94ms +step:10200/20000 train_loss:2.1453 train_time:427776ms step_avg:41.94ms +step:10200 shared0_alpha:mean=0.473,std=0.070 shared1_alpha:mean=0.564,std=0.062 shared2_alpha:mean=0.630,std=0.060 shared3_alpha:mean=0.657,std=0.058 eff_mlp_scale:[v0:7.7106 v1:6.1787 v2:5.7959 v3:5.8499 v4:5.9034 v5:6.1787 v6:5.7959 v7:5.8499 v8:5.9034 v9:9.8742] eff_attn_scale:[v0:0.0592 v1:0.1170 v2:0.1215 v3:0.1289 v4:0.1426 v5:0.1170 v6:0.1215 v7:0.1289 v8:0.1426 v9:0.3122] +step:10200/20000 val_loss:2.1834 val_bpb:1.2931 train_time:427790ms step_avg:41.94ms +step:10400/20000 train_loss:2.1746 train_time:436148ms step_avg:41.94ms +step:10400 shared0_alpha:mean=0.472,std=0.070 shared1_alpha:mean=0.565,std=0.062 shared2_alpha:mean=0.631,std=0.059 shared3_alpha:mean=0.658,std=0.058 eff_mlp_scale:[v0:7.7811 v1:6.2200 v2:5.8444 v3:5.9039 v4:5.9657 v5:6.2200 v6:5.8444 v7:5.9039 v8:5.9657 v9:9.9736] eff_attn_scale:[v0:0.0591 v1:0.1173 v2:0.1215 v3:0.1295 v4:0.1433 v5:0.1173 v6:0.1215 v7:0.1295 v8:0.1433 v9:0.3137] +step:10400/20000 val_loss:2.1825 val_bpb:1.2926 train_time:436162ms step_avg:41.94ms +step:10600/20000 train_loss:2.0497 train_time:444526ms step_avg:41.94ms +step:10600 shared0_alpha:mean=0.470,std=0.070 shared1_alpha:mean=0.566,std=0.062 shared2_alpha:mean=0.632,std=0.060 shared3_alpha:mean=0.658,std=0.058 eff_mlp_scale:[v0:7.8515 v1:6.2579 v2:5.8857 v3:5.9586 v4:6.0226 v5:6.2579 v6:5.8857 v7:5.9586 v8:6.0226 v9:10.0782] eff_attn_scale:[v0:0.0594 v1:0.1188 v2:0.1230 v3:0.1311 v4:0.1459 v5:0.1188 v6:0.1230 v7:0.1311 v8:0.1459 v9:0.3167] +step:10600/20000 val_loss:2.1832 val_bpb:1.2930 train_time:444540ms step_avg:41.94ms +step:10800/20000 train_loss:2.2623 train_time:452896ms step_avg:41.93ms +step:10800 shared0_alpha:mean=0.469,std=0.070 shared1_alpha:mean=0.566,std=0.062 shared2_alpha:mean=0.633,std=0.060 shared3_alpha:mean=0.658,std=0.059 eff_mlp_scale:[v0:7.9226 v1:6.2948 v2:5.9265 v3:6.0132 v4:6.0886 v5:6.2948 v6:5.9265 v7:6.0132 v8:6.0886 v9:10.1759] eff_attn_scale:[v0:0.0595 v1:0.1184 v2:0.1228 v3:0.1307 v4:0.1465 v5:0.1184 v6:0.1228 v7:0.1307 v8:0.1465 v9:0.3170] +step:10800/20000 val_loss:2.1815 val_bpb:1.2920 train_time:452909ms step_avg:41.94ms +step:11000/20000 train_loss:2.1919 train_time:461264ms step_avg:41.93ms +step:11000 shared0_alpha:mean=0.468,std=0.070 shared1_alpha:mean=0.566,std=0.062 shared2_alpha:mean=0.634,std=0.060 shared3_alpha:mean=0.659,std=0.059 eff_mlp_scale:[v0:7.9871 v1:6.3340 v2:5.9798 v3:6.0676 v4:6.1386 v5:6.3340 v6:5.9798 v7:6.0676 v8:6.1386 v9:10.2796] eff_attn_scale:[v0:0.0597 v1:0.1194 v2:0.1246 v3:0.1318 v4:0.1471 v5:0.1194 v6:0.1246 v7:0.1318 v8:0.1471 v9:0.3184] +step:11000/20000 val_loss:2.1791 val_bpb:1.2906 train_time:461278ms step_avg:41.93ms +step:11200/20000 train_loss:2.1438 train_time:469638ms step_avg:41.93ms +step:11200 shared0_alpha:mean=0.467,std=0.070 shared1_alpha:mean=0.566,std=0.062 shared2_alpha:mean=0.634,std=0.060 shared3_alpha:mean=0.659,std=0.059 eff_mlp_scale:[v0:8.0598 v1:6.3739 v2:6.0242 v3:6.1239 v4:6.1992 v5:6.3739 v6:6.0242 v7:6.1239 v8:6.1992 v9:10.3695] eff_attn_scale:[v0:0.0604 v1:0.1186 v2:0.1254 v3:0.1334 v4:0.1488 v5:0.1186 v6:0.1254 v7:0.1334 v8:0.1488 v9:0.3200] +step:11200/20000 val_loss:2.1794 val_bpb:1.2908 train_time:469651ms step_avg:41.93ms +step:11400/20000 train_loss:2.1344 train_time:478013ms step_avg:41.93ms +step:11400 shared0_alpha:mean=0.465,std=0.069 shared1_alpha:mean=0.566,std=0.062 shared2_alpha:mean=0.635,std=0.061 shared3_alpha:mean=0.659,std=0.058 eff_mlp_scale:[v0:8.1316 v1:6.4157 v2:6.0697 v3:6.1732 v4:6.2626 v5:6.4157 v6:6.0697 v7:6.1732 v8:6.2626 v9:10.4635] eff_attn_scale:[v0:0.0595 v1:0.1188 v2:0.1261 v3:0.1337 v4:0.1497 v5:0.1188 v6:0.1261 v7:0.1337 v8:0.1497 v9:0.3199] +step:11400/20000 val_loss:2.1785 val_bpb:1.2902 train_time:478026ms step_avg:41.93ms +step:11600/20000 train_loss:2.1357 train_time:486379ms step_avg:41.93ms +step:11600 shared0_alpha:mean=0.464,std=0.069 shared1_alpha:mean=0.567,std=0.062 shared2_alpha:mean=0.637,std=0.060 shared3_alpha:mean=0.660,std=0.059 eff_mlp_scale:[v0:8.2048 v1:6.4531 v2:6.1188 v3:6.2304 v4:6.3193 v5:6.4531 v6:6.1188 v7:6.2304 v8:6.3193 v9:10.5578] eff_attn_scale:[v0:0.0593 v1:0.1195 v2:0.1253 v3:0.1354 v4:0.1511 v5:0.1195 v6:0.1253 v7:0.1354 v8:0.1511 v9:0.3207] +step:11600/20000 val_loss:2.1774 val_bpb:1.2896 train_time:486392ms step_avg:41.93ms +step:11800/20000 train_loss:2.1693 train_time:494754ms step_avg:41.93ms +step:11800 shared0_alpha:mean=0.462,std=0.069 shared1_alpha:mean=0.567,std=0.062 shared2_alpha:mean=0.638,std=0.060 shared3_alpha:mean=0.660,std=0.059 eff_mlp_scale:[v0:8.2729 v1:6.4919 v2:6.1583 v3:6.2895 v4:6.3825 v5:6.4919 v6:6.1583 v7:6.2895 v8:6.3825 v9:10.6508] eff_attn_scale:[v0:0.0598 v1:0.1204 v2:0.1273 v3:0.1352 v4:0.1522 v5:0.1204 v6:0.1273 v7:0.1352 v8:0.1522 v9:0.3212] +step:11800/20000 val_loss:2.1761 val_bpb:1.2888 train_time:494767ms step_avg:41.93ms +step:12000/20000 train_loss:2.1488 train_time:503123ms step_avg:41.93ms +step:12000 shared0_alpha:mean=0.461,std=0.069 shared1_alpha:mean=0.567,std=0.062 shared2_alpha:mean=0.638,std=0.061 shared3_alpha:mean=0.660,std=0.059 eff_mlp_scale:[v0:8.3406 v1:6.5328 v2:6.2073 v3:6.3434 v4:6.4363 v5:6.5328 v6:6.2073 v7:6.3434 v8:6.4363 v9:10.7430] eff_attn_scale:[v0:0.0594 v1:0.1191 v2:0.1279 v3:0.1361 v4:0.1523 v5:0.1191 v6:0.1279 v7:0.1361 v8:0.1523 v9:0.3227] +step:12000/20000 val_loss:2.1741 val_bpb:1.2876 train_time:503136ms step_avg:41.93ms +step:12200/20000 train_loss:2.2909 train_time:511489ms step_avg:41.93ms +step:12200 shared0_alpha:mean=0.459,std=0.069 shared1_alpha:mean=0.567,std=0.062 shared2_alpha:mean=0.639,std=0.061 shared3_alpha:mean=0.660,std=0.059 eff_mlp_scale:[v0:8.4102 v1:6.5709 v2:6.2533 v3:6.4008 v4:6.4968 v5:6.5709 v6:6.2533 v7:6.4008 v8:6.4968 v9:10.8363] eff_attn_scale:[v0:0.0599 v1:0.1210 v2:0.1273 v3:0.1365 v4:0.1550 v5:0.1210 v6:0.1273 v7:0.1365 v8:0.1550 v9:0.3253] +step:12200/20000 val_loss:2.1745 val_bpb:1.2878 train_time:511503ms step_avg:41.93ms +step:12400/20000 train_loss:1.9349 train_time:519918ms step_avg:41.93ms +step:12400 shared0_alpha:mean=0.458,std=0.069 shared1_alpha:mean=0.568,std=0.062 shared2_alpha:mean=0.641,std=0.061 shared3_alpha:mean=0.661,std=0.059 eff_mlp_scale:[v0:8.4814 v1:6.6057 v2:6.3046 v3:6.4565 v4:6.5567 v5:6.6057 v6:6.3046 v7:6.4565 v8:6.5567 v9:10.9320] eff_attn_scale:[v0:0.0593 v1:0.1213 v2:0.1286 v3:0.1375 v4:0.1550 v5:0.1213 v6:0.1286 v7:0.1375 v8:0.1550 v9:0.3254] +step:12400/20000 val_loss:2.1738 val_bpb:1.2874 train_time:519929ms step_avg:41.93ms +step:12600/20000 train_loss:2.1675 train_time:528283ms step_avg:41.93ms +step:12600 shared0_alpha:mean=0.457,std=0.069 shared1_alpha:mean=0.569,std=0.063 shared2_alpha:mean=0.641,std=0.062 shared3_alpha:mean=0.661,std=0.059 eff_mlp_scale:[v0:8.5543 v1:6.6524 v2:6.3536 v3:6.5142 v4:6.6156 v5:6.6524 v6:6.3536 v7:6.5142 v8:6.6156 v9:11.0299] eff_attn_scale:[v0:0.0608 v1:0.1205 v2:0.1294 v3:0.1389 v4:0.1575 v5:0.1205 v6:0.1294 v7:0.1389 v8:0.1575 v9:0.3262] +step:12600/20000 val_loss:2.1762 val_bpb:1.2888 train_time:528295ms step_avg:41.93ms +step:12800/20000 train_loss:2.1867 train_time:536653ms step_avg:41.93ms +step:12800 shared0_alpha:mean=0.456,std=0.069 shared1_alpha:mean=0.569,std=0.062 shared2_alpha:mean=0.642,std=0.062 shared3_alpha:mean=0.661,std=0.060 eff_mlp_scale:[v0:8.6217 v1:6.6955 v2:6.4021 v3:6.5754 v4:6.6836 v5:6.6955 v6:6.4021 v7:6.5754 v8:6.6836 v9:11.1198] eff_attn_scale:[v0:0.0599 v1:0.1222 v2:0.1304 v3:0.1384 v4:0.1585 v5:0.1222 v6:0.1304 v7:0.1384 v8:0.1585 v9:0.3281] +step:12800/20000 val_loss:2.1743 val_bpb:1.2878 train_time:536666ms step_avg:41.93ms +step:13000/20000 train_loss:2.2697 train_time:545032ms step_avg:41.93ms +step:13000 shared0_alpha:mean=0.455,std=0.069 shared1_alpha:mean=0.568,std=0.063 shared2_alpha:mean=0.643,std=0.062 shared3_alpha:mean=0.661,std=0.060 eff_mlp_scale:[v0:8.6899 v1:6.7319 v2:6.4475 v3:6.6330 v4:6.7414 v5:6.7319 v6:6.4475 v7:6.6330 v8:6.7414 v9:11.2236] eff_attn_scale:[v0:0.0597 v1:0.1206 v2:0.1297 v3:0.1385 v4:0.1575 v5:0.1206 v6:0.1297 v7:0.1385 v8:0.1575 v9:0.3280] +step:13000/20000 val_loss:2.1754 val_bpb:1.2884 train_time:545045ms step_avg:41.93ms +step:13200/20000 train_loss:2.2744 train_time:553400ms step_avg:41.92ms +step:13200 shared0_alpha:mean=0.453,std=0.069 shared1_alpha:mean=0.569,std=0.062 shared2_alpha:mean=0.643,std=0.062 shared3_alpha:mean=0.661,std=0.059 eff_mlp_scale:[v0:8.7544 v1:6.7727 v2:6.4919 v3:6.6930 v4:6.8007 v5:6.7727 v6:6.4919 v7:6.6930 v8:6.8007 v9:11.3116] eff_attn_scale:[v0:0.0599 v1:0.1220 v2:0.1310 v3:0.1393 v4:0.1597 v5:0.1220 v6:0.1310 v7:0.1393 v8:0.1597 v9:0.3284] +step:13200/20000 val_loss:2.1676 val_bpb:1.2838 train_time:553411ms step_avg:41.93ms +step:13400/20000 train_loss:2.1446 train_time:561768ms step_avg:41.92ms +step:13400 shared0_alpha:mean=0.452,std=0.069 shared1_alpha:mean=0.569,std=0.062 shared2_alpha:mean=0.644,std=0.062 shared3_alpha:mean=0.661,std=0.060 eff_mlp_scale:[v0:8.8156 v1:6.8022 v2:6.5327 v3:6.7391 v4:6.8518 v5:6.8022 v6:6.5327 v7:6.7391 v8:6.8518 v9:11.4112] eff_attn_scale:[v0:0.0592 v1:0.1231 v2:0.1325 v3:0.1429 v4:0.1640 v5:0.1231 v6:0.1325 v7:0.1429 v8:0.1640 v9:0.3343] +step:13400/20000 val_loss:2.1599 val_bpb:1.2792 train_time:561782ms step_avg:41.92ms +step:13600/20000 train_loss:2.0140 train_time:570143ms step_avg:41.92ms +step:13600 shared0_alpha:mean=0.451,std=0.068 shared1_alpha:mean=0.569,std=0.062 shared2_alpha:mean=0.644,std=0.062 shared3_alpha:mean=0.660,std=0.060 eff_mlp_scale:[v0:8.8602 v1:6.8277 v2:6.5641 v3:6.7764 v4:6.8902 v5:6.8277 v6:6.5641 v7:6.7764 v8:6.8902 v9:11.4985] eff_attn_scale:[v0:0.0593 v1:0.1252 v2:0.1354 v3:0.1454 v4:0.1681 v5:0.1252 v6:0.1354 v7:0.1454 v8:0.1681 v9:0.3395] +step:13600/20000 val_loss:2.1537 val_bpb:1.2756 train_time:570156ms step_avg:41.92ms +step:13800/20000 train_loss:2.0812 train_time:578515ms step_avg:41.92ms +step:13800 shared0_alpha:mean=0.450,std=0.069 shared1_alpha:mean=0.568,std=0.062 shared2_alpha:mean=0.643,std=0.062 shared3_alpha:mean=0.660,std=0.060 eff_mlp_scale:[v0:8.8918 v1:6.8480 v2:6.5838 v3:6.8064 v4:6.9193 v5:6.8480 v6:6.5838 v7:6.8064 v8:6.9193 v9:11.5699] eff_attn_scale:[v0:0.0587 v1:0.1258 v2:0.1376 v3:0.1478 v4:0.1695 v5:0.1258 v6:0.1376 v7:0.1478 v8:0.1695 v9:0.3444] +step:13800/20000 val_loss:2.1424 val_bpb:1.2689 train_time:578528ms step_avg:41.92ms +step:14000/20000 train_loss:2.1365 train_time:586886ms step_avg:41.92ms +step:14000 shared0_alpha:mean=0.450,std=0.069 shared1_alpha:mean=0.567,std=0.062 shared2_alpha:mean=0.642,std=0.062 shared3_alpha:mean=0.659,std=0.059 eff_mlp_scale:[v0:8.9111 v1:6.8614 v2:6.5992 v3:6.8242 v4:6.9400 v5:6.8614 v6:6.5992 v7:6.8242 v8:6.9400 v9:11.6202] eff_attn_scale:[v0:0.0588 v1:0.1266 v2:0.1384 v3:0.1491 v4:0.1716 v5:0.1266 v6:0.1384 v7:0.1491 v8:0.1716 v9:0.3485] +step:14000/20000 val_loss:2.1348 val_bpb:1.2644 train_time:586900ms step_avg:41.92ms +step:14200/20000 train_loss:2.2163 train_time:595261ms step_avg:41.92ms +step:14200 shared0_alpha:mean=0.450,std=0.068 shared1_alpha:mean=0.567,std=0.062 shared2_alpha:mean=0.642,std=0.062 shared3_alpha:mean=0.659,std=0.059 eff_mlp_scale:[v0:8.9182 v1:6.8684 v2:6.6030 v3:6.8352 v4:6.9498 v5:6.8684 v6:6.6030 v7:6.8352 v8:6.9498 v9:11.6491] eff_attn_scale:[v0:0.0579 v1:0.1264 v2:0.1381 v3:0.1494 v4:0.1718 v5:0.1264 v6:0.1381 v7:0.1494 v8:0.1718 v9:0.3497] +step:14200/20000 val_loss:2.1271 val_bpb:1.2598 train_time:595274ms step_avg:41.92ms +step:14313/20000 val_loss:2.1238 val_bpb:1.2578 train_time:600022ms step_avg:41.92ms +stopping_early: wallclock_cap train_time:600022ms step:14313/20000 +peak memory allocated: 9923 MiB reserved: 10012 MiB +Serialized model: 45149624 bytes +Code size: 57024 bytes +Total submission size: 45206648 bytes +Serialized model int8+zlib: 10705962 bytes (payload:11610304 raw_torch:11640701 payload_ratio:3.89x) +Total submission size int8+zlib: 10762986 bytes +final_int8_zlib_roundtrip val_loss:2.1450 val_bpb:1.2704 eval_time:1341ms +final_int8_zlib_roundtrip_exact val_loss:2.14499442 val_bpb:1.27038652 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_I.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_I.txt new file mode 100644 index 0000000000..f2c4a690cf --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_I.txt @@ -0,0 +1,1643 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + + def get(self, v: int) -> tuple[Tensor, Tensor]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + return ag, mg + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 19:09:19 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 41C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 34C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 39C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 40C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 39C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 33C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11553840 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:2 coda:1 effective_layers:10 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:10240 +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9377 train_time:29ms step_avg:28.60ms +step:2/20000 train_loss:9.6278 train_time:73ms step_avg:36.40ms +step:3/20000 train_loss:7.3726 train_time:114ms step_avg:38.15ms +step:4/20000 train_loss:9.1107 train_time:155ms step_avg:38.82ms +step:5/20000 train_loss:8.5363 train_time:196ms step_avg:39.15ms +step:6/20000 train_loss:8.2122 train_time:238ms step_avg:39.71ms +step:7/20000 train_loss:6.8775 train_time:279ms step_avg:39.89ms +step:8/20000 train_loss:6.1732 train_time:320ms step_avg:40.05ms +step:9/20000 train_loss:5.6236 train_time:361ms step_avg:40.13ms +step:10/20000 train_loss:5.3649 train_time:402ms step_avg:40.19ms +step:200/20000 train_loss:2.7735 train_time:8324ms step_avg:41.62ms +step:200 shared0_alpha:mean=0.460,std=0.049 shared1_alpha:mean=0.484,std=0.041 shared2_alpha:mean=0.493,std=0.039 shared3_alpha:mean=0.519,std=0.042 eff_mlp_scale:[v0:32.4338 v1:26.8783 v2:27.3330 v3:28.8289 v4:29.8140 v5:30.3279 v6:28.7419 v7:31.4901 v8:33.3305 v9:58.4229] eff_attn_scale:[v0:14.2957 v1:10.8281 v2:10.0086 v3:10.3372 v4:9.4171 v5:11.1277 v6:10.0903 v7:10.0116 v8:9.4171 v9:16.9952] +step:200/20000 val_loss:2.7557 val_bpb:1.6321 train_time:8371ms step_avg:41.86ms +step:400/20000 train_loss:2.3574 train_time:16729ms step_avg:41.82ms +step:400 shared0_alpha:mean=0.473,std=0.051 shared1_alpha:mean=0.512,std=0.042 shared2_alpha:mean=0.520,std=0.041 shared3_alpha:mean=0.550,std=0.045 eff_mlp_scale:[v0:41.1913 v1:34.9090 v2:36.0754 v3:37.7209 v4:36.1858 v5:40.2285 v6:36.5585 v7:37.7209 v8:36.3480 v9:73.1532] eff_attn_scale:[v0:6.3461 v1:5.6364 v2:5.3996 v3:5.5119 v4:5.5663 v5:5.8408 v6:5.3424 v7:4.8812 v8:4.9660 v9:9.8379] +step:400/20000 val_loss:2.5671 val_bpb:1.5204 train_time:16742ms step_avg:41.85ms +step:600/20000 train_loss:2.5794 train_time:25108ms step_avg:41.85ms +step:600 shared0_alpha:mean=0.480,std=0.052 shared1_alpha:mean=0.527,std=0.043 shared2_alpha:mean=0.536,std=0.043 shared3_alpha:mean=0.571,std=0.047 eff_mlp_scale:[v0:47.2879 v1:39.6627 v2:41.4816 v3:43.7143 v4:40.3192 v5:46.2435 v6:41.4816 v7:41.4502 v8:38.4712 v9:87.4475] eff_attn_scale:[v0:3.0079 v1:3.1515 v2:3.1426 v3:3.2746 v4:3.7014 v5:3.2938 v6:3.1015 v7:2.7844 v8:3.1017 v9:6.3364] +step:600/20000 val_loss:2.4839 val_bpb:1.4711 train_time:25122ms step_avg:41.87ms +step:800/20000 train_loss:2.3462 train_time:33504ms step_avg:41.88ms +step:800 shared0_alpha:mean=0.482,std=0.054 shared1_alpha:mean=0.536,std=0.044 shared2_alpha:mean=0.545,std=0.045 shared3_alpha:mean=0.583,std=0.048 eff_mlp_scale:[v0:52.6537 v1:43.9116 v2:46.0793 v3:48.0110 v4:43.6501 v5:50.6100 v6:44.9993 v7:44.0402 v8:40.8787 v9:99.0947] eff_attn_scale:[v0:1.9047 v1:2.1517 v2:2.2376 v3:2.3691 v4:2.8504 v5:2.2633 v6:2.2042 v7:1.9557 v8:2.2804 v9:4.7538] +step:800/20000 val_loss:2.4270 val_bpb:1.4374 train_time:33517ms step_avg:41.90ms +step:1000/20000 train_loss:2.4205 train_time:41907ms step_avg:41.91ms +step:1000 shared0_alpha:mean=0.483,std=0.056 shared1_alpha:mean=0.543,std=0.045 shared2_alpha:mean=0.551,std=0.046 shared3_alpha:mean=0.592,std=0.049 eff_mlp_scale:[v0:57.9590 v1:47.2381 v2:49.4761 v3:51.6986 v4:47.1052 v5:54.3720 v6:48.3601 v7:46.6775 v8:43.3582 v9:108.8084] eff_attn_scale:[v0:1.4433 v1:1.7031 v2:1.7943 v3:1.9042 v4:2.3928 v5:1.7651 v6:1.7870 v7:1.5682 v8:1.9034 v9:3.9250] +step:1000/20000 val_loss:2.3867 val_bpb:1.4136 train_time:41921ms step_avg:41.92ms +step:1200/20000 train_loss:2.4444 train_time:50309ms step_avg:41.92ms +step:1200 shared0_alpha:mean=0.483,std=0.057 shared1_alpha:mean=0.547,std=0.046 shared2_alpha:mean=0.556,std=0.047 shared3_alpha:mean=0.598,std=0.052 eff_mlp_scale:[v0:62.2344 v1:50.3959 v2:52.3744 v3:54.9416 v4:49.7827 v5:57.9354 v6:50.8453 v7:49.2185 v8:45.7561 v9:117.2400] eff_attn_scale:[v0:1.1627 v1:1.4220 v2:1.5000 v3:1.6212 v4:2.0798 v5:1.4719 v6:1.4802 v7:1.3264 v8:1.6168 v9:3.3694] +step:1200/20000 val_loss:2.3599 val_bpb:1.3977 train_time:50322ms step_avg:41.94ms +step:1400/20000 train_loss:2.4909 train_time:58716ms step_avg:41.94ms +step:1400 shared0_alpha:mean=0.481,std=0.059 shared1_alpha:mean=0.550,std=0.048 shared2_alpha:mean=0.559,std=0.048 shared3_alpha:mean=0.602,std=0.053 eff_mlp_scale:[v0:66.4150 v1:53.2391 v2:55.6326 v3:57.8170 v4:51.9946 v5:60.9608 v6:53.6737 v7:51.5665 v8:48.2540 v9:125.0665] eff_attn_scale:[v0:1.0000 v1:1.2543 v2:1.3355 v3:1.4485 v4:1.8808 v5:1.2895 v6:1.3046 v7:1.1758 v8:1.4613 v9:3.0484] +step:1400/20000 val_loss:2.3404 val_bpb:1.3861 train_time:58729ms step_avg:41.95ms +step:1600/20000 train_loss:2.1624 train_time:67124ms step_avg:41.95ms +step:1600 shared0_alpha:mean=0.481,std=0.060 shared1_alpha:mean=0.552,std=0.048 shared2_alpha:mean=0.561,std=0.049 shared3_alpha:mean=0.605,std=0.054 eff_mlp_scale:[v0:70.7997 v1:55.4995 v2:58.2563 v3:60.4920 v4:54.1498 v5:63.7830 v6:55.8622 v7:53.3285 v8:50.3364 v9:131.1975] eff_attn_scale:[v0:0.8811 v1:1.1471 v2:1.1987 v3:1.3324 v4:1.7461 v5:1.1752 v6:1.1753 v7:1.0811 v8:1.3551 v9:2.8033] +step:1600/20000 val_loss:2.3293 val_bpb:1.3795 train_time:67139ms step_avg:41.96ms +step:1800/20000 train_loss:2.2635 train_time:75530ms step_avg:41.96ms +step:1800 shared0_alpha:mean=0.479,std=0.062 shared1_alpha:mean=0.554,std=0.049 shared2_alpha:mean=0.563,std=0.050 shared3_alpha:mean=0.608,std=0.055 eff_mlp_scale:[v0:74.4344 v1:58.2499 v2:60.6688 v3:63.2486 v4:56.3654 v5:66.2698 v6:58.2258 v7:55.5452 v8:52.4781 v9:137.3730] eff_attn_scale:[v0:0.7757 v1:1.0553 v2:1.1078 v3:1.2292 v4:1.6345 v5:1.0499 v6:1.0797 v7:0.9878 v8:1.2734 v9:2.5991] +step:1800/20000 val_loss:2.3153 val_bpb:1.3713 train_time:75543ms step_avg:41.97ms +step:2000/20000 train_loss:2.3186 train_time:83943ms step_avg:41.97ms +step:2000 shared0_alpha:mean=0.478,std=0.063 shared1_alpha:mean=0.556,std=0.049 shared2_alpha:mean=0.565,std=0.051 shared3_alpha:mean=0.610,std=0.056 eff_mlp_scale:[v0:77.5654 v1:60.5581 v2:63.3222 v3:65.5362 v4:58.5639 v5:68.7185 v6:60.4251 v7:57.2927 v8:54.6068 v9:142.6277] eff_attn_scale:[v0:0.7184 v1:0.9642 v2:1.0584 v3:1.1363 v4:1.5347 v5:0.9746 v6:1.0200 v7:0.9144 v8:1.1971 v9:2.4273] +step:2000/20000 val_loss:2.2997 val_bpb:1.3620 train_time:83956ms step_avg:41.98ms +step:2200/20000 train_loss:2.1454 train_time:92346ms step_avg:41.98ms +step:2200 shared0_alpha:mean=0.477,std=0.064 shared1_alpha:mean=0.557,std=0.051 shared2_alpha:mean=0.567,std=0.053 shared3_alpha:mean=0.612,std=0.057 eff_mlp_scale:[v0:81.4522 v1:62.7125 v2:65.6707 v3:67.9144 v4:60.7625 v5:70.9870 v6:62.7239 v7:59.5299 v8:56.7385 v9:148.5993] eff_attn_scale:[v0:0.6644 v1:0.9258 v2:0.9892 v3:1.1022 v4:1.4634 v5:0.9207 v6:0.9626 v7:0.8785 v8:1.1490 v9:2.3272] +step:2200/20000 val_loss:2.2933 val_bpb:1.3582 train_time:92359ms step_avg:41.98ms +step:2400/20000 train_loss:2.2678 train_time:100747ms step_avg:41.98ms +step:2400 shared0_alpha:mean=0.475,std=0.065 shared1_alpha:mean=0.558,std=0.052 shared2_alpha:mean=0.568,std=0.053 shared3_alpha:mean=0.613,std=0.057 eff_mlp_scale:[v0:85.1303 v1:65.0459 v2:67.4894 v3:70.0819 v4:63.0337 v5:73.4532 v6:64.4994 v7:61.1623 v8:58.5313 v9:153.8977] eff_attn_scale:[v0:0.6118 v1:0.8757 v2:0.9308 v3:1.0434 v4:1.4194 v5:0.8757 v6:0.8948 v7:0.8358 v8:1.1060 v9:2.2107] +step:2400/20000 val_loss:2.2829 val_bpb:1.3520 train_time:100760ms step_avg:41.98ms +step:2600/20000 train_loss:2.4745 train_time:109142ms step_avg:41.98ms +step:2600 shared0_alpha:mean=0.474,std=0.066 shared1_alpha:mean=0.560,std=0.052 shared2_alpha:mean=0.569,std=0.054 shared3_alpha:mean=0.614,std=0.058 eff_mlp_scale:[v0:87.7270 v1:66.8418 v2:70.3395 v3:72.4739 v4:64.7138 v5:75.8139 v6:66.8660 v7:62.9833 v8:60.5655 v9:158.7239] eff_attn_scale:[v0:0.5847 v1:0.8530 v2:0.9109 v3:1.0152 v4:1.3466 v5:0.8287 v6:0.8552 v7:0.8101 v8:1.0416 v9:2.1540] +step:2600/20000 val_loss:2.2939 val_bpb:1.3586 train_time:109156ms step_avg:41.98ms +step:2800/20000 train_loss:2.3015 train_time:117532ms step_avg:41.98ms +step:2800 shared0_alpha:mean=0.472,std=0.066 shared1_alpha:mean=0.561,std=0.053 shared2_alpha:mean=0.571,std=0.055 shared3_alpha:mean=0.615,std=0.059 eff_mlp_scale:[v0:90.9391 v1:68.5997 v2:72.0679 v3:74.7095 v4:66.4964 v5:77.6858 v6:68.5524 v7:64.6609 v8:62.7086 v9:163.5023] eff_attn_scale:[v0:0.5367 v1:0.8110 v2:0.8773 v3:0.9663 v4:1.3259 v5:0.7827 v6:0.8325 v7:0.7631 v8:1.0389 v9:2.0614] +step:2800/20000 val_loss:2.2709 val_bpb:1.3449 train_time:117546ms step_avg:41.98ms +step:3000/20000 train_loss:2.2926 train_time:125924ms step_avg:41.97ms +step:3000 shared0_alpha:mean=0.471,std=0.068 shared1_alpha:mean=0.563,std=0.053 shared2_alpha:mean=0.572,std=0.055 shared3_alpha:mean=0.616,std=0.058 eff_mlp_scale:[v0:94.1667 v1:70.8351 v2:73.9319 v3:76.6717 v4:68.6385 v5:79.5745 v6:70.3689 v7:66.4784 v8:64.3752 v9:168.4960] eff_attn_scale:[v0:0.5138 v1:0.7725 v2:0.8430 v3:0.9486 v4:1.2902 v5:0.7633 v6:0.8087 v7:0.7481 v8:1.0041 v9:1.9954] +step:3000/20000 val_loss:2.2634 val_bpb:1.3405 train_time:125937ms step_avg:41.98ms +step:3200/20000 train_loss:2.2584 train_time:134319ms step_avg:41.97ms +step:3200 shared0_alpha:mean=0.469,std=0.068 shared1_alpha:mean=0.564,std=0.054 shared2_alpha:mean=0.574,std=0.056 shared3_alpha:mean=0.618,std=0.059 eff_mlp_scale:[v0:97.4713 v1:72.5887 v2:76.3698 v3:79.0680 v4:70.3904 v5:81.4297 v6:72.3027 v7:68.2860 v8:66.5038 v9:172.3404] eff_attn_scale:[v0:0.4985 v1:0.7607 v2:0.8160 v3:0.9272 v4:1.2746 v5:0.7425 v6:0.7776 v7:0.7330 v8:0.9938 v9:1.9378] +step:3200/20000 val_loss:2.2592 val_bpb:1.3380 train_time:134332ms step_avg:41.98ms +step:3400/20000 train_loss:2.2244 train_time:142713ms step_avg:41.97ms +step:3400 shared0_alpha:mean=0.467,std=0.069 shared1_alpha:mean=0.565,std=0.055 shared2_alpha:mean=0.574,std=0.057 shared3_alpha:mean=0.618,std=0.060 eff_mlp_scale:[v0:100.6388 v1:74.4788 v2:78.1266 v3:81.5614 v4:72.5989 v5:83.9065 v6:74.4715 v7:70.6258 v8:68.2255 v9:176.3723] eff_attn_scale:[v0:0.4764 v1:0.7349 v2:0.8046 v3:0.9009 v4:1.2423 v5:0.7081 v6:0.7620 v7:0.7188 v8:0.9718 v9:1.8838] +step:3400/20000 val_loss:2.2583 val_bpb:1.3375 train_time:142726ms step_avg:41.98ms +step:3600/20000 train_loss:2.1922 train_time:151107ms step_avg:41.97ms +step:3600 shared0_alpha:mean=0.466,std=0.070 shared1_alpha:mean=0.566,std=0.055 shared2_alpha:mean=0.576,std=0.058 shared3_alpha:mean=0.619,std=0.060 eff_mlp_scale:[v0:103.2556 v1:76.2513 v2:80.0093 v3:83.4053 v4:73.8694 v5:85.7827 v6:75.8470 v7:71.8852 v8:70.3308 v9:180.9048] eff_attn_scale:[v0:0.4535 v1:0.7304 v2:0.7801 v3:0.8946 v4:1.2290 v5:0.6992 v6:0.7472 v7:0.7080 v8:0.9590 v9:1.8606] +step:3600/20000 val_loss:2.2495 val_bpb:1.3323 train_time:151120ms step_avg:41.98ms +step:3800/20000 train_loss:2.2950 train_time:159488ms step_avg:41.97ms +step:3800 shared0_alpha:mean=0.464,std=0.071 shared1_alpha:mean=0.567,std=0.055 shared2_alpha:mean=0.576,std=0.059 shared3_alpha:mean=0.619,std=0.061 eff_mlp_scale:[v0:105.8888 v1:78.0440 v2:82.3195 v3:85.8617 v4:76.2038 v5:87.6791 v6:78.1100 v7:73.7291 v8:72.1695 v9:184.7102] eff_attn_scale:[v0:0.4448 v1:0.7099 v2:0.7638 v3:0.8734 v4:1.1981 v5:0.6923 v6:0.7358 v7:0.6940 v8:0.9519 v9:1.8450] +step:3800/20000 val_loss:2.2457 val_bpb:1.3300 train_time:159501ms step_avg:41.97ms +step:4000/20000 train_loss:2.2279 train_time:167881ms step_avg:41.97ms +step:4000 shared0_alpha:mean=0.464,std=0.071 shared1_alpha:mean=0.568,std=0.056 shared2_alpha:mean=0.577,std=0.059 shared3_alpha:mean=0.619,std=0.060 eff_mlp_scale:[v0:108.6963 v1:79.8753 v2:84.2284 v3:87.9159 v4:77.5745 v5:89.6161 v6:79.9697 v7:75.6266 v8:73.9453 v9:189.0424] eff_attn_scale:[v0:0.4271 v1:0.6889 v2:0.7515 v3:0.8607 v4:1.1868 v5:0.6715 v6:0.7190 v7:0.6773 v8:0.9255 v9:1.7928] +step:4000/20000 val_loss:2.2407 val_bpb:1.3271 train_time:167893ms step_avg:41.97ms +step:4200/20000 train_loss:2.2384 train_time:176328ms step_avg:41.98ms +step:4200 shared0_alpha:mean=0.462,std=0.072 shared1_alpha:mean=0.568,std=0.056 shared2_alpha:mean=0.578,std=0.060 shared3_alpha:mean=0.620,std=0.061 eff_mlp_scale:[v0:111.8727 v1:81.7018 v2:86.0845 v3:89.8335 v4:79.4021 v5:91.5454 v6:81.7803 v7:77.4097 v8:75.7303 v9:192.6684] eff_attn_scale:[v0:0.4143 v1:0.6758 v2:0.7305 v3:0.8495 v4:1.1672 v5:0.6502 v6:0.7031 v7:0.6721 v8:0.9197 v9:1.7639] +step:4200/20000 val_loss:2.2364 val_bpb:1.3245 train_time:176339ms step_avg:41.99ms +step:4400/20000 train_loss:2.1874 train_time:184703ms step_avg:41.98ms +step:4400 shared0_alpha:mean=0.460,std=0.074 shared1_alpha:mean=0.570,std=0.056 shared2_alpha:mean=0.580,std=0.060 shared3_alpha:mean=0.620,std=0.061 eff_mlp_scale:[v0:114.5383 v1:83.5632 v2:88.0133 v3:91.9787 v4:81.2585 v5:93.5112 v6:83.6610 v7:79.3922 v8:78.0081 v9:196.3895] eff_attn_scale:[v0:0.3973 v1:0.6725 v2:0.7118 v3:0.8297 v4:1.1723 v5:0.6510 v6:0.6848 v7:0.6675 v8:0.9292 v9:1.7499] +step:4400/20000 val_loss:2.2383 val_bpb:1.3257 train_time:184716ms step_avg:41.98ms +step:4600/20000 train_loss:2.0470 train_time:193090ms step_avg:41.98ms +step:4600 shared0_alpha:mean=0.459,std=0.074 shared1_alpha:mean=0.571,std=0.057 shared2_alpha:mean=0.581,std=0.061 shared3_alpha:mean=0.620,std=0.062 eff_mlp_scale:[v0:117.4054 v1:85.3231 v2:89.9985 v3:93.8894 v4:82.9846 v5:95.3611 v6:85.1073 v7:80.6862 v8:79.2339 v9:200.3763] eff_attn_scale:[v0:0.3932 v1:0.6511 v2:0.7141 v3:0.8264 v4:1.1470 v5:0.6219 v6:0.6827 v7:0.6555 v8:0.9122 v9:1.7280] +step:4600/20000 val_loss:2.2329 val_bpb:1.3224 train_time:193106ms step_avg:41.98ms +step:4800/20000 train_loss:2.3302 train_time:201463ms step_avg:41.97ms +step:4800 shared0_alpha:mean=0.458,std=0.074 shared1_alpha:mean=0.571,std=0.057 shared2_alpha:mean=0.582,std=0.062 shared3_alpha:mean=0.621,std=0.062 eff_mlp_scale:[v0:119.8251 v1:87.0823 v2:91.8598 v3:96.0602 v4:84.8375 v5:97.2082 v6:86.9211 v7:82.6910 v8:81.5198 v9:203.1792] eff_attn_scale:[v0:0.3781 v1:0.6620 v2:0.7114 v3:0.8173 v4:1.1384 v5:0.6280 v6:0.6714 v7:0.6566 v8:0.9044 v9:1.7103] +step:4800/20000 val_loss:2.2288 val_bpb:1.3200 train_time:201476ms step_avg:41.97ms +step:5000/20000 train_loss:2.1016 train_time:209844ms step_avg:41.97ms +step:5000 shared0_alpha:mean=0.457,std=0.075 shared1_alpha:mean=0.572,std=0.057 shared2_alpha:mean=0.582,std=0.062 shared3_alpha:mean=0.621,std=0.063 eff_mlp_scale:[v0:123.1582 v1:88.4754 v2:93.8121 v3:98.0072 v4:86.1107 v5:98.7038 v6:88.3231 v7:84.0062 v8:83.2403 v9:206.8199] eff_attn_scale:[v0:0.3699 v1:0.6475 v2:0.7042 v3:0.8202 v4:1.1353 v5:0.6263 v6:0.6686 v7:0.6471 v8:0.9008 v9:1.6727] +step:5000/20000 val_loss:2.2233 val_bpb:1.3168 train_time:209857ms step_avg:41.97ms +step:5200/20000 train_loss:2.2431 train_time:218217ms step_avg:41.96ms +step:5200 shared0_alpha:mean=0.456,std=0.076 shared1_alpha:mean=0.573,std=0.058 shared2_alpha:mean=0.583,std=0.063 shared3_alpha:mean=0.621,std=0.062 eff_mlp_scale:[v0:125.7796 v1:90.4029 v2:95.8543 v3:100.7359 v4:88.1433 v5:100.7346 v6:90.3048 v7:86.0558 v8:85.2375 v9:210.6802] eff_attn_scale:[v0:0.3683 v1:0.6386 v2:0.7117 v3:0.8161 v4:1.1245 v5:0.6137 v6:0.6666 v7:0.6474 v8:0.8964 v9:1.6585] +step:5200/20000 val_loss:2.2249 val_bpb:1.3177 train_time:218230ms step_avg:41.97ms +step:5400/20000 train_loss:2.2559 train_time:226598ms step_avg:41.96ms +step:5400 shared0_alpha:mean=0.455,std=0.077 shared1_alpha:mean=0.574,std=0.058 shared2_alpha:mean=0.584,std=0.063 shared3_alpha:mean=0.622,std=0.063 eff_mlp_scale:[v0:128.4542 v1:92.2871 v2:97.3259 v3:102.2514 v4:89.6175 v5:102.7150 v6:92.2303 v7:87.9362 v8:87.1690 v9:213.5931] eff_attn_scale:[v0:0.3607 v1:0.6357 v2:0.6933 v3:0.8133 v4:1.1264 v5:0.6068 v6:0.6533 v7:0.6498 v8:0.9022 v9:1.6594] +step:5400/20000 val_loss:2.2196 val_bpb:1.3146 train_time:226611ms step_avg:41.97ms +step:5600/20000 train_loss:2.2544 train_time:234969ms step_avg:41.96ms +step:5600 shared0_alpha:mean=0.454,std=0.077 shared1_alpha:mean=0.575,std=0.058 shared2_alpha:mean=0.586,std=0.064 shared3_alpha:mean=0.622,std=0.064 eff_mlp_scale:[v0:131.2335 v1:93.3816 v2:99.3258 v3:104.2037 v4:91.5731 v5:104.3985 v6:94.1793 v7:89.7596 v8:89.0982 v9:217.2390] eff_attn_scale:[v0:0.3527 v1:0.6350 v2:0.6894 v3:0.8047 v4:1.1250 v5:0.5979 v6:0.6584 v7:0.6456 v8:0.9021 v9:1.6453] +step:5600/20000 val_loss:2.2199 val_bpb:1.3148 train_time:234982ms step_avg:41.96ms +step:5800/20000 train_loss:2.2195 train_time:243353ms step_avg:41.96ms +step:5800 shared0_alpha:mean=0.453,std=0.078 shared1_alpha:mean=0.575,std=0.059 shared2_alpha:mean=0.587,std=0.064 shared3_alpha:mean=0.622,std=0.064 eff_mlp_scale:[v0:133.8926 v1:95.1000 v2:101.3424 v3:106.0152 v4:93.4923 v5:106.1950 v6:95.6257 v7:90.9444 v8:90.4925 v9:221.0409] eff_attn_scale:[v0:0.3416 v1:0.6335 v2:0.6837 v3:0.8083 v4:1.1228 v5:0.6043 v6:0.6487 v7:0.6403 v8:0.8951 v9:1.6122] +step:5800/20000 val_loss:2.2170 val_bpb:1.3131 train_time:243366ms step_avg:41.96ms +step:6000/20000 train_loss:2.2895 train_time:251726ms step_avg:41.95ms +step:6000 shared0_alpha:mean=0.452,std=0.078 shared1_alpha:mean=0.575,std=0.059 shared2_alpha:mean=0.586,std=0.065 shared3_alpha:mean=0.621,std=0.064 eff_mlp_scale:[v0:136.6514 v1:97.0565 v2:102.7869 v3:108.7354 v4:94.8564 v5:107.7220 v6:97.5426 v7:92.9767 v8:92.3336 v9:222.7039] eff_attn_scale:[v0:0.3372 v1:0.6204 v2:0.6771 v3:0.8025 v4:1.1208 v5:0.5878 v6:0.6423 v7:0.6267 v8:0.8977 v9:1.6507] +step:6000/20000 val_loss:2.2137 val_bpb:1.3111 train_time:251739ms step_avg:41.96ms +step:6200/20000 train_loss:2.1663 train_time:260108ms step_avg:41.95ms +step:6200 shared0_alpha:mean=0.451,std=0.079 shared1_alpha:mean=0.577,std=0.059 shared2_alpha:mean=0.588,std=0.064 shared3_alpha:mean=0.622,std=0.065 eff_mlp_scale:[v0:139.2796 v1:98.3637 v2:104.9485 v3:111.0157 v4:96.3790 v5:109.6514 v6:99.1180 v7:95.0804 v8:94.3392 v9:226.4932] eff_attn_scale:[v0:0.3339 v1:0.6266 v2:0.6778 v3:0.8038 v4:1.1213 v5:0.5897 v6:0.6343 v7:0.6367 v8:0.8886 v9:1.5627] +step:6200/20000 val_loss:2.2149 val_bpb:1.3118 train_time:260121ms step_avg:41.96ms +step:6400/20000 train_loss:2.2362 train_time:268481ms step_avg:41.95ms +step:6400 shared0_alpha:mean=0.451,std=0.079 shared1_alpha:mean=0.578,std=0.060 shared2_alpha:mean=0.589,std=0.064 shared3_alpha:mean=0.622,std=0.065 eff_mlp_scale:[v0:142.1748 v1:100.2448 v2:106.8843 v3:112.4559 v4:98.4252 v5:111.6239 v6:101.0056 v7:96.3908 v8:96.3639 v9:230.0691] eff_attn_scale:[v0:0.3275 v1:0.6163 v2:0.6746 v3:0.7938 v4:1.1108 v5:0.5877 v6:0.6398 v7:0.6314 v8:0.8887 v9:1.5966] +step:6400/20000 val_loss:2.2108 val_bpb:1.3094 train_time:268495ms step_avg:41.95ms +step:6600/20000 train_loss:2.2007 train_time:276863ms step_avg:41.95ms +step:6600 shared0_alpha:mean=0.450,std=0.080 shared1_alpha:mean=0.579,std=0.061 shared2_alpha:mean=0.590,std=0.065 shared3_alpha:mean=0.622,std=0.064 eff_mlp_scale:[v0:144.9055 v1:101.7217 v2:108.9018 v3:114.7212 v4:99.7096 v5:113.2064 v6:102.4324 v7:98.4870 v8:97.6323 v9:233.8402] eff_attn_scale:[v0:0.3225 v1:0.6182 v2:0.6853 v3:0.7940 v4:1.1192 v5:0.5856 v6:0.6372 v7:0.6361 v8:0.9060 v9:1.5772] +step:6600/20000 val_loss:2.2066 val_bpb:1.3069 train_time:276876ms step_avg:41.95ms +step:6800/20000 train_loss:2.2652 train_time:285244ms step_avg:41.95ms +step:6800 shared0_alpha:mean=0.449,std=0.080 shared1_alpha:mean=0.580,std=0.060 shared2_alpha:mean=0.590,std=0.066 shared3_alpha:mean=0.622,std=0.065 eff_mlp_scale:[v0:147.4638 v1:103.5015 v2:110.5698 v3:116.8049 v4:101.4442 v5:115.0629 v6:104.5784 v7:100.4304 v8:99.8673 v9:235.8078] eff_attn_scale:[v0:0.3154 v1:0.6084 v2:0.6645 v3:0.8014 v4:1.1305 v5:0.5721 v6:0.6213 v7:0.6338 v8:0.9065 v9:1.5636] +step:6800/20000 val_loss:2.2063 val_bpb:1.3067 train_time:285257ms step_avg:41.95ms +step:7000/20000 train_loss:2.2983 train_time:293623ms step_avg:41.95ms +step:7000 shared0_alpha:mean=0.447,std=0.081 shared1_alpha:mean=0.581,std=0.060 shared2_alpha:mean=0.591,std=0.066 shared3_alpha:mean=0.622,std=0.065 eff_mlp_scale:[v0:149.9792 v1:104.9802 v2:112.5973 v3:118.3586 v4:103.4970 v5:116.6446 v6:106.5555 v7:102.3940 v8:101.9048 v9:239.6004] eff_attn_scale:[v0:0.3089 v1:0.6178 v2:0.6717 v3:0.8121 v4:1.1295 v5:0.5772 v6:0.6327 v7:0.6342 v8:0.9046 v9:1.5582] +step:7000/20000 val_loss:2.2038 val_bpb:1.3052 train_time:293636ms step_avg:41.95ms +step:7200/20000 train_loss:2.2695 train_time:302000ms step_avg:41.94ms +step:7200 shared0_alpha:mean=0.447,std=0.081 shared1_alpha:mean=0.581,std=0.061 shared2_alpha:mean=0.592,std=0.066 shared3_alpha:mean=0.622,std=0.065 eff_mlp_scale:[v0:152.7840 v1:106.1073 v2:113.9995 v3:120.6499 v4:104.9006 v5:117.8350 v6:107.9121 v7:103.9702 v8:103.8302 v9:243.3291] eff_attn_scale:[v0:0.3085 v1:0.6089 v2:0.6770 v3:0.8133 v4:1.1278 v5:0.5724 v6:0.6333 v7:0.6525 v8:0.9087 v9:1.5478] +step:7200/20000 val_loss:2.2053 val_bpb:1.3061 train_time:302013ms step_avg:41.95ms +step:7400/20000 train_loss:2.1883 train_time:310378ms step_avg:41.94ms +step:7400 shared0_alpha:mean=0.446,std=0.081 shared1_alpha:mean=0.582,std=0.062 shared2_alpha:mean=0.592,std=0.066 shared3_alpha:mean=0.622,std=0.066 eff_mlp_scale:[v0:154.9475 v1:108.1855 v2:115.9881 v3:122.4672 v4:106.7991 v5:120.0183 v6:109.2964 v7:105.6139 v8:105.7204 v9:245.1474] eff_attn_scale:[v0:0.2980 v1:0.6014 v2:0.6711 v3:0.8051 v4:1.1237 v5:0.5693 v6:0.6148 v7:0.6287 v8:0.9054 v9:1.5596] +step:7400/20000 val_loss:2.2011 val_bpb:1.3036 train_time:310391ms step_avg:41.94ms +step:7600/20000 train_loss:2.0699 train_time:318748ms step_avg:41.94ms +step:7600 shared0_alpha:mean=0.444,std=0.082 shared1_alpha:mean=0.583,std=0.061 shared2_alpha:mean=0.593,std=0.067 shared3_alpha:mean=0.622,std=0.066 eff_mlp_scale:[v0:157.7380 v1:109.5229 v2:118.2014 v3:124.4992 v4:108.3247 v5:121.4399 v6:111.4470 v7:107.5220 v8:107.7803 v9:248.8188] eff_attn_scale:[v0:0.3015 v1:0.6015 v2:0.6636 v3:0.8072 v4:1.1394 v5:0.5772 v6:0.6205 v7:0.6349 v8:0.9190 v9:1.5569] +step:7600/20000 val_loss:2.1994 val_bpb:1.3026 train_time:318761ms step_avg:41.94ms +step:7800/20000 train_loss:2.2236 train_time:327130ms step_avg:41.94ms +step:7800 shared0_alpha:mean=0.444,std=0.082 shared1_alpha:mean=0.584,std=0.062 shared2_alpha:mean=0.594,std=0.067 shared3_alpha:mean=0.622,std=0.066 eff_mlp_scale:[v0:160.6460 v1:110.8519 v2:120.4601 v3:126.9106 v4:110.4620 v5:123.4228 v6:113.0734 v7:109.1888 v8:109.3629 v9:252.8788] eff_attn_scale:[v0:0.2943 v1:0.6005 v2:0.6670 v3:0.8239 v4:1.1389 v5:0.5725 v6:0.6240 v7:0.6464 v8:0.9240 v9:1.5660] +step:7800/20000 val_loss:2.1971 val_bpb:1.3012 train_time:327142ms step_avg:41.94ms +step:8000/20000 train_loss:2.1847 train_time:335498ms step_avg:41.94ms +step:8000 shared0_alpha:mean=0.443,std=0.083 shared1_alpha:mean=0.584,std=0.062 shared2_alpha:mean=0.595,std=0.067 shared3_alpha:mean=0.622,std=0.066 eff_mlp_scale:[v0:163.4999 v1:112.7380 v2:122.0465 v3:129.2139 v4:111.9447 v5:125.3923 v6:115.1707 v7:111.3316 v8:111.9447 v9:254.5000] eff_attn_scale:[v0:0.2938 v1:0.6067 v2:0.6662 v3:0.8295 v4:1.1500 v5:0.5666 v6:0.6273 v7:0.6462 v8:0.9297 v9:1.5429] +step:8000/20000 val_loss:2.1942 val_bpb:1.2995 train_time:335510ms step_avg:41.94ms +step:8200/20000 train_loss:2.2537 train_time:343874ms step_avg:41.94ms +step:8200 shared0_alpha:mean=0.443,std=0.083 shared1_alpha:mean=0.585,std=0.062 shared2_alpha:mean=0.595,std=0.068 shared3_alpha:mean=0.622,std=0.066 eff_mlp_scale:[v0:166.2311 v1:114.0382 v2:124.1352 v3:130.8591 v4:113.4892 v5:126.7735 v6:116.6293 v7:112.8296 v8:113.4892 v9:258.3551] eff_attn_scale:[v0:0.2895 v1:0.6162 v2:0.6702 v3:0.8182 v4:1.1583 v5:0.5716 v6:0.6269 v7:0.6491 v8:0.9374 v9:1.5568] +step:8200/20000 val_loss:2.1942 val_bpb:1.2995 train_time:343886ms step_avg:41.94ms +step:8400/20000 train_loss:2.2048 train_time:352310ms step_avg:41.94ms +step:8400 shared0_alpha:mean=0.442,std=0.084 shared1_alpha:mean=0.586,std=0.063 shared2_alpha:mean=0.596,std=0.068 shared3_alpha:mean=0.621,std=0.066 eff_mlp_scale:[v0:168.4852 v1:115.3193 v2:125.6406 v3:132.3745 v4:115.6404 v5:128.1325 v6:118.6605 v7:114.8027 v8:115.6404 v9:259.9239] eff_attn_scale:[v0:0.2868 v1:0.6171 v2:0.6775 v3:0.8300 v4:1.1688 v5:0.5727 v6:0.6210 v7:0.6466 v8:0.9459 v9:1.5416] +step:8400/20000 val_loss:2.1934 val_bpb:1.2991 train_time:352323ms step_avg:41.94ms +step:8600/20000 train_loss:2.2054 train_time:360678ms step_avg:41.94ms +step:8600 shared0_alpha:mean=0.441,std=0.083 shared1_alpha:mean=0.587,std=0.063 shared2_alpha:mean=0.597,std=0.069 shared3_alpha:mean=0.621,std=0.066 eff_mlp_scale:[v0:171.4950 v1:116.6192 v2:127.6812 v3:134.7460 v4:116.9804 v5:130.0977 v6:120.0672 v7:116.4253 v8:117.5483 v9:263.5822] eff_attn_scale:[v0:0.2839 v1:0.6051 v2:0.6627 v3:0.8246 v4:1.1713 v5:0.5651 v6:0.6197 v7:0.6469 v8:0.9425 v9:1.5392] +step:8600/20000 val_loss:2.1905 val_bpb:1.2973 train_time:360691ms step_avg:41.94ms +step:8800/20000 train_loss:2.1802 train_time:369055ms step_avg:41.94ms +step:8800 shared0_alpha:mean=0.440,std=0.085 shared1_alpha:mean=0.588,std=0.062 shared2_alpha:mean=0.598,std=0.069 shared3_alpha:mean=0.621,std=0.066 eff_mlp_scale:[v0:174.3474 v1:118.6889 v2:129.3479 v3:136.3904 v4:118.6107 v5:131.6797 v6:122.2603 v7:118.5227 v8:119.1837 v9:265.1349] eff_attn_scale:[v0:0.2830 v1:0.6026 v2:0.6718 v3:0.8478 v4:1.1911 v5:0.5627 v6:0.6159 v7:0.6589 v8:0.9661 v9:1.5692] +step:8800/20000 val_loss:2.1904 val_bpb:1.2973 train_time:369069ms step_avg:41.94ms +step:9000/20000 train_loss:2.0941 train_time:377426ms step_avg:41.94ms +step:9000 shared0_alpha:mean=0.440,std=0.085 shared1_alpha:mean=0.589,std=0.063 shared2_alpha:mean=0.598,std=0.069 shared3_alpha:mean=0.621,std=0.067 eff_mlp_scale:[v0:177.4487 v1:119.8720 v2:131.4997 v3:138.7691 v4:120.7066 v5:132.9273 v6:124.3595 v7:120.7471 v8:121.2841 v9:269.0922] eff_attn_scale:[v0:0.2795 v1:0.6009 v2:0.6624 v3:0.8489 v4:1.2087 v5:0.5608 v6:0.6194 v7:0.6633 v8:0.9714 v9:1.5581] +step:9000/20000 val_loss:2.1896 val_bpb:1.2968 train_time:377438ms step_avg:41.94ms +step:9200/20000 train_loss:2.1581 train_time:385799ms step_avg:41.93ms +step:9200 shared0_alpha:mean=0.439,std=0.085 shared1_alpha:mean=0.589,std=0.064 shared2_alpha:mean=0.599,std=0.069 shared3_alpha:mean=0.621,std=0.066 eff_mlp_scale:[v0:179.6053 v1:121.9023 v2:133.2203 v3:141.0824 v4:122.2778 v5:135.0486 v6:126.0192 v7:122.3118 v8:123.4423 v9:270.8719] eff_attn_scale:[v0:0.2759 v1:0.6093 v2:0.6736 v3:0.8503 v4:1.2270 v5:0.5687 v6:0.6218 v7:0.6655 v8:0.9816 v9:1.5471] +step:9200/20000 val_loss:2.1884 val_bpb:1.2961 train_time:385813ms step_avg:41.94ms +step:9400/20000 train_loss:2.2162 train_time:394173ms step_avg:41.93ms +step:9400 shared0_alpha:mean=0.438,std=0.085 shared1_alpha:mean=0.591,std=0.064 shared2_alpha:mean=0.599,std=0.069 shared3_alpha:mean=0.621,std=0.066 eff_mlp_scale:[v0:182.2212 v1:123.2062 v2:135.3925 v3:143.5050 v4:123.6642 v5:136.4283 v6:127.5349 v7:123.9639 v8:124.8364 v9:274.3891] eff_attn_scale:[v0:0.2729 v1:0.6089 v2:0.6704 v3:0.8654 v4:1.2214 v5:0.5646 v6:0.6272 v7:0.6736 v8:0.9882 v9:1.5588] +step:9400/20000 val_loss:2.1865 val_bpb:1.2949 train_time:394186ms step_avg:41.93ms +step:9600/20000 train_loss:2.2211 train_time:402548ms step_avg:41.93ms +step:9600 shared0_alpha:mean=0.437,std=0.085 shared1_alpha:mean=0.591,std=0.064 shared2_alpha:mean=0.600,std=0.070 shared3_alpha:mean=0.620,std=0.067 eff_mlp_scale:[v0:184.4105 v1:124.5875 v2:136.8696 v3:145.1870 v4:125.8572 v5:138.4978 v6:128.9616 v7:126.1159 v8:127.0390 v9:275.9173] eff_attn_scale:[v0:0.2700 v1:0.6013 v2:0.6613 v3:0.8573 v4:1.2390 v5:0.5615 v6:0.6227 v7:0.6673 v8:0.9979 v9:1.5360] +step:9600/20000 val_loss:2.1865 val_bpb:1.2950 train_time:402561ms step_avg:41.93ms +step:9800/20000 train_loss:2.1535 train_time:410918ms step_avg:41.93ms +step:9800 shared0_alpha:mean=0.438,std=0.086 shared1_alpha:mean=0.593,std=0.064 shared2_alpha:mean=0.601,std=0.070 shared3_alpha:mean=0.621,std=0.067 eff_mlp_scale:[v0:188.5993 v1:126.5279 v2:139.2723 v3:147.0432 v4:127.3870 v5:139.9106 v6:131.2964 v7:127.8097 v8:129.1728 v9:279.5072] eff_attn_scale:[v0:0.2653 v1:0.5916 v2:0.6657 v3:0.8726 v4:1.2435 v5:0.5636 v6:0.6228 v7:0.6756 v8:1.0026 v9:1.5500] +step:9800/20000 val_loss:2.1881 val_bpb:1.2959 train_time:410932ms step_avg:41.93ms +step:10000/20000 train_loss:2.1883 train_time:419295ms step_avg:41.93ms +step:10000 shared0_alpha:mean=0.437,std=0.087 shared1_alpha:mean=0.593,std=0.064 shared2_alpha:mean=0.601,std=0.070 shared3_alpha:mean=0.619,std=0.067 eff_mlp_scale:[v0:190.6008 v1:127.9005 v2:141.4933 v3:149.3115 v4:129.4203 v5:141.9757 v6:132.8431 v7:129.3200 v8:130.6187 v9:283.5366] eff_attn_scale:[v0:0.2655 v1:0.6072 v2:0.6826 v3:0.8902 v4:1.2874 v5:0.5670 v6:0.6261 v7:0.6913 v8:1.0414 v9:1.5503] +step:10000/20000 val_loss:2.1851 val_bpb:1.2941 train_time:419308ms step_avg:41.93ms +step:10200/20000 train_loss:2.1412 train_time:427668ms step_avg:41.93ms +step:10200 shared0_alpha:mean=0.436,std=0.087 shared1_alpha:mean=0.594,std=0.065 shared2_alpha:mean=0.602,std=0.071 shared3_alpha:mean=0.619,std=0.067 eff_mlp_scale:[v0:193.8341 v1:129.8628 v2:143.1271 v3:151.7479 v4:130.9831 v5:143.4030 v6:134.4150 v7:131.5988 v8:132.7939 v9:285.0893] eff_attn_scale:[v0:0.2664 v1:0.6084 v2:0.6780 v3:0.8900 v4:1.2855 v5:0.5681 v6:0.6261 v7:0.6828 v8:1.0352 v9:1.5378] +step:10200/20000 val_loss:2.1814 val_bpb:1.2920 train_time:427681ms step_avg:41.93ms +step:10400/20000 train_loss:2.1755 train_time:436044ms step_avg:41.93ms +step:10400 shared0_alpha:mean=0.436,std=0.087 shared1_alpha:mean=0.594,std=0.065 shared2_alpha:mean=0.603,std=0.071 shared3_alpha:mean=0.619,std=0.067 eff_mlp_scale:[v0:195.9251 v1:131.2278 v2:145.3228 v3:153.7719 v4:132.4698 v5:145.4648 v6:136.5533 v7:133.4384 v8:134.2927 v9:286.7764] eff_attn_scale:[v0:0.2630 v1:0.6124 v2:0.6882 v3:0.8876 v4:1.3003 v5:0.5678 v6:0.6233 v7:0.6930 v8:1.0529 v9:1.5495] +step:10400/20000 val_loss:2.1811 val_bpb:1.2918 train_time:436057ms step_avg:41.93ms +step:10600/20000 train_loss:2.0524 train_time:444420ms step_avg:41.93ms +step:10600 shared0_alpha:mean=0.435,std=0.088 shared1_alpha:mean=0.595,std=0.066 shared2_alpha:mean=0.603,std=0.071 shared3_alpha:mean=0.618,std=0.067 eff_mlp_scale:[v0:199.0776 v1:132.5853 v2:146.9832 v3:156.0367 v4:134.6098 v5:146.9020 v6:138.1516 v7:135.5728 v8:136.4454 v9:290.2270] eff_attn_scale:[v0:0.2623 v1:0.6137 v2:0.6832 v3:0.9047 v4:1.3264 v5:0.5731 v6:0.6270 v7:0.7037 v8:1.0658 v9:1.5493] +step:10600/20000 val_loss:2.1819 val_bpb:1.2922 train_time:444433ms step_avg:41.93ms +step:10800/20000 train_loss:2.2563 train_time:452797ms step_avg:41.93ms +step:10800 shared0_alpha:mean=0.435,std=0.088 shared1_alpha:mean=0.596,std=0.066 shared2_alpha:mean=0.604,std=0.071 shared3_alpha:mean=0.618,std=0.067 eff_mlp_scale:[v0:202.1908 v1:133.8393 v2:148.5420 v3:157.9083 v4:136.2487 v5:148.2239 v6:140.2896 v7:137.2836 v8:138.7147 v9:293.7970] eff_attn_scale:[v0:0.2592 v1:0.6173 v2:0.6848 v3:0.9091 v4:1.3417 v5:0.5645 v6:0.6242 v7:0.6985 v8:1.0734 v9:1.5514] +step:10800/20000 val_loss:2.1804 val_bpb:1.2914 train_time:452810ms step_avg:41.93ms +step:11000/20000 train_loss:2.1897 train_time:461171ms step_avg:41.92ms +step:11000 shared0_alpha:mean=0.434,std=0.088 shared1_alpha:mean=0.596,std=0.066 shared2_alpha:mean=0.604,std=0.071 shared3_alpha:mean=0.618,std=0.068 eff_mlp_scale:[v0:205.3773 v1:135.9211 v2:150.9607 v3:160.1435 v4:137.6326 v5:149.7649 v6:142.0054 v7:138.7478 v8:140.1125 v9:295.2768] eff_attn_scale:[v0:0.2569 v1:0.6203 v2:0.6910 v3:0.9157 v4:1.3568 v5:0.5676 v6:0.6306 v7:0.7133 v8:1.0925 v9:1.5525] +step:11000/20000 val_loss:2.1781 val_bpb:1.2900 train_time:461184ms step_avg:41.93ms +step:11200/20000 train_loss:2.1423 train_time:469538ms step_avg:41.92ms +step:11200 shared0_alpha:mean=0.433,std=0.089 shared1_alpha:mean=0.597,std=0.066 shared2_alpha:mean=0.605,std=0.071 shared3_alpha:mean=0.617,std=0.068 eff_mlp_scale:[v0:208.4357 v1:137.3971 v2:153.2656 v3:162.6382 v4:140.0990 v5:151.9600 v6:143.6060 v7:140.4305 v8:141.9753 v9:296.7932] eff_attn_scale:[v0:0.2562 v1:0.6204 v2:0.6923 v3:0.9292 v4:1.3700 v5:0.5718 v6:0.6357 v7:0.7163 v8:1.0995 v9:1.5507] +step:11200/20000 val_loss:2.1779 val_bpb:1.2899 train_time:469552ms step_avg:41.92ms +step:11400/20000 train_loss:2.1270 train_time:477908ms step_avg:41.92ms +step:11400 shared0_alpha:mean=0.433,std=0.090 shared1_alpha:mean=0.598,std=0.066 shared2_alpha:mean=0.605,std=0.072 shared3_alpha:mean=0.617,std=0.068 eff_mlp_scale:[v0:209.9052 v1:138.7832 v2:155.0351 v3:164.6645 v4:141.5198 v5:153.4254 v6:145.9535 v7:142.9288 v8:144.0357 v9:300.2969] eff_attn_scale:[v0:0.2553 v1:0.6234 v2:0.7037 v3:0.9325 v4:1.3964 v5:0.5748 v6:0.6377 v7:0.7274 v8:1.1172 v9:1.5367] +step:11400/20000 val_loss:2.1785 val_bpb:1.2902 train_time:477921ms step_avg:41.92ms +step:11600/20000 train_loss:2.1360 train_time:486292ms step_avg:41.92ms +step:11600 shared0_alpha:mean=0.432,std=0.090 shared1_alpha:mean=0.598,std=0.066 shared2_alpha:mean=0.606,std=0.072 shared3_alpha:mean=0.617,std=0.068 eff_mlp_scale:[v0:212.7506 v1:140.6871 v2:157.3776 v3:167.0755 v4:143.0041 v5:154.7558 v6:147.5824 v7:144.5335 v8:146.1679 v9:301.6261] eff_attn_scale:[v0:0.2528 v1:0.6205 v2:0.6993 v3:0.9456 v4:1.3951 v5:0.5681 v6:0.6333 v7:0.7274 v8:1.1220 v9:1.5469] +step:11600/20000 val_loss:2.1761 val_bpb:1.2888 train_time:486309ms step_avg:41.92ms +step:11800/20000 train_loss:2.1652 train_time:494719ms step_avg:41.93ms +step:11800 shared0_alpha:mean=0.432,std=0.090 shared1_alpha:mean=0.600,std=0.067 shared2_alpha:mean=0.607,std=0.072 shared3_alpha:mean=0.616,std=0.068 eff_mlp_scale:[v0:215.6969 v1:142.2938 v2:159.0862 v3:168.9564 v4:145.3677 v5:157.1027 v6:149.8829 v7:146.9186 v8:147.9180 v9:305.1757] eff_attn_scale:[v0:0.2500 v1:0.6234 v2:0.6988 v3:0.9563 v4:1.4141 v5:0.5786 v6:0.6372 v7:0.7405 v8:1.1444 v9:1.5507] +step:11800/20000 val_loss:2.1747 val_bpb:1.2880 train_time:494732ms step_avg:41.93ms +step:12000/20000 train_loss:2.1471 train_time:503094ms step_avg:41.92ms +step:12000 shared0_alpha:mean=0.432,std=0.091 shared1_alpha:mean=0.600,std=0.067 shared2_alpha:mean=0.607,std=0.073 shared3_alpha:mean=0.615,std=0.068 eff_mlp_scale:[v0:218.6501 v1:143.5129 v2:160.8276 v3:170.7656 v4:146.7324 v5:159.0278 v6:150.9000 v7:147.9072 v8:149.2954 v9:306.8359] eff_attn_scale:[v0:0.2492 v1:0.6292 v2:0.7046 v3:0.9730 v4:1.4264 v5:0.5724 v6:0.6430 v7:0.7469 v8:1.1387 v9:1.5190] +step:12000/20000 val_loss:2.1739 val_bpb:1.2875 train_time:503107ms step_avg:41.93ms +step:12200/20000 train_loss:2.2911 train_time:511461ms step_avg:41.92ms +step:12200 shared0_alpha:mean=0.432,std=0.091 shared1_alpha:mean=0.601,std=0.067 shared2_alpha:mean=0.607,std=0.073 shared3_alpha:mean=0.615,std=0.068 eff_mlp_scale:[v0:220.2967 v1:145.4833 v2:163.1192 v3:173.3579 v4:148.5036 v5:160.4213 v6:152.4665 v7:150.3338 v8:151.7319 v9:308.1884] eff_attn_scale:[v0:0.2527 v1:0.6230 v2:0.7048 v3:0.9810 v4:1.4532 v5:0.5704 v6:0.6387 v7:0.7542 v8:1.1517 v9:1.5373] +step:12200/20000 val_loss:2.1737 val_bpb:1.2874 train_time:511474ms step_avg:41.92ms +step:12400/20000 train_loss:1.9332 train_time:519892ms step_avg:41.93ms +step:12400 shared0_alpha:mean=0.431,std=0.092 shared1_alpha:mean=0.601,std=0.067 shared2_alpha:mean=0.608,std=0.073 shared3_alpha:mean=0.614,std=0.069 eff_mlp_scale:[v0:223.4694 v1:146.9867 v2:165.0745 v3:176.1563 v4:150.6810 v5:162.0120 v6:154.3380 v7:152.2591 v8:153.2789 v9:311.9421] eff_attn_scale:[v0:0.2465 v1:0.6349 v2:0.7225 v3:0.9867 v4:1.4718 v5:0.5813 v6:0.6511 v7:0.7625 v8:1.1738 v9:1.5320] +step:12400/20000 val_loss:2.1739 val_bpb:1.2875 train_time:519904ms step_avg:41.93ms +step:12600/20000 train_loss:2.1676 train_time:528270ms step_avg:41.93ms +step:12600 shared0_alpha:mean=0.431,std=0.092 shared1_alpha:mean=0.602,std=0.067 shared2_alpha:mean=0.609,std=0.073 shared3_alpha:mean=0.614,std=0.069 eff_mlp_scale:[v0:226.6108 v1:149.1907 v2:167.6873 v3:177.3695 v4:152.4049 v5:163.6497 v6:156.1926 v7:153.9952 v8:155.6754 v9:313.4867] eff_attn_scale:[v0:0.2459 v1:0.6337 v2:0.7215 v3:0.9996 v4:1.4914 v5:0.5723 v6:0.6462 v7:0.7597 v8:1.1968 v9:1.5527] +step:12600/20000 val_loss:2.1748 val_bpb:1.2880 train_time:528283ms step_avg:41.93ms +step:12800/20000 train_loss:2.1854 train_time:536637ms step_avg:41.92ms +step:12800 shared0_alpha:mean=0.430,std=0.092 shared1_alpha:mean=0.603,std=0.067 shared2_alpha:mean=0.610,std=0.073 shared3_alpha:mean=0.613,std=0.069 eff_mlp_scale:[v0:229.9439 v1:150.6834 v2:169.4674 v3:180.0966 v4:154.0392 v5:165.2230 v6:158.5780 v7:156.5455 v8:157.9889 v9:314.8180] eff_attn_scale:[v0:0.2433 v1:0.6289 v2:0.7146 v3:0.9915 v4:1.4973 v5:0.5802 v6:0.6484 v7:0.7661 v8:1.2138 v9:1.5374] +step:12800/20000 val_loss:2.1732 val_bpb:1.2871 train_time:536650ms step_avg:41.93ms +step:13000/20000 train_loss:2.2741 train_time:545021ms step_avg:41.92ms +step:13000 shared0_alpha:mean=0.430,std=0.092 shared1_alpha:mean=0.603,std=0.067 shared2_alpha:mean=0.610,std=0.074 shared3_alpha:mean=0.613,std=0.069 eff_mlp_scale:[v0:233.3033 v1:152.2239 v2:171.3557 v3:182.5361 v4:155.6426 v5:167.5128 v6:160.3889 v7:158.1515 v8:159.6165 v9:318.3016] eff_attn_scale:[v0:0.2426 v1:0.6335 v2:0.7221 v3:1.0170 v4:1.5092 v5:0.5804 v6:0.6512 v7:0.7831 v8:1.2185 v9:1.5572] +step:13000/20000 val_loss:2.1736 val_bpb:1.2874 train_time:545034ms step_avg:41.93ms +step:13200/20000 train_loss:2.2730 train_time:553391ms step_avg:41.92ms +step:13200 shared0_alpha:mean=0.429,std=0.093 shared1_alpha:mean=0.604,std=0.068 shared2_alpha:mean=0.610,std=0.074 shared3_alpha:mean=0.612,std=0.069 eff_mlp_scale:[v0:234.7893 v1:153.5466 v2:172.9665 v3:183.5978 v4:157.0371 v5:168.9013 v6:161.9407 v7:159.7721 v8:161.0296 v9:319.7378] eff_attn_scale:[v0:0.2394 v1:0.6437 v2:0.7330 v3:1.0360 v4:1.5258 v5:0.5859 v6:0.6570 v7:0.7922 v8:1.2281 v9:1.5425] +step:13200/20000 val_loss:2.1665 val_bpb:1.2831 train_time:553404ms step_avg:41.92ms +step:13400/20000 train_loss:2.1402 train_time:561768ms step_avg:41.92ms +step:13400 shared0_alpha:mean=0.429,std=0.093 shared1_alpha:mean=0.604,std=0.068 shared2_alpha:mean=0.610,std=0.074 shared3_alpha:mean=0.611,std=0.069 eff_mlp_scale:[v0:237.8495 v1:154.9440 v2:175.5435 v3:185.1381 v4:158.8945 v5:170.3713 v6:163.7481 v7:161.1126 v8:163.5876 v9:323.5530] eff_attn_scale:[v0:0.2397 v1:0.6401 v2:0.7363 v3:1.0415 v4:1.5591 v5:0.5827 v6:0.6560 v7:0.7965 v8:1.2510 v9:1.5891] +step:13400/20000 val_loss:2.1596 val_bpb:1.2790 train_time:561782ms step_avg:41.92ms +step:13600/20000 train_loss:2.0052 train_time:570145ms step_avg:41.92ms +step:13600 shared0_alpha:mean=0.428,std=0.092 shared1_alpha:mean=0.604,std=0.068 shared2_alpha:mean=0.609,std=0.074 shared3_alpha:mean=0.610,std=0.069 eff_mlp_scale:[v0:238.8668 v1:155.6532 v2:176.1908 v3:187.4647 v4:159.6987 v5:171.1512 v6:164.3519 v7:162.6114 v8:164.4155 v9:325.3213] eff_attn_scale:[v0:0.2401 v1:0.6453 v2:0.7529 v3:1.0614 v4:1.5800 v5:0.5915 v6:0.6667 v7:0.8038 v8:1.2691 v9:1.5988] +step:13600/20000 val_loss:2.1549 val_bpb:1.2762 train_time:570158ms step_avg:41.92ms +step:13800/20000 train_loss:2.0860 train_time:578517ms step_avg:41.92ms +step:13800 shared0_alpha:mean=0.427,std=0.093 shared1_alpha:mean=0.603,std=0.068 shared2_alpha:mean=0.608,std=0.074 shared3_alpha:mean=0.609,std=0.069 eff_mlp_scale:[v0:239.3895 v1:156.7439 v2:176.7584 v3:188.1676 v4:160.4834 v5:171.6076 v6:164.8814 v7:163.2212 v8:165.2234 v9:328.9007] eff_attn_scale:[v0:0.2391 v1:0.6389 v2:0.7390 v3:1.0654 v4:1.6034 v5:0.5815 v6:0.6624 v7:0.8107 v8:1.2892 v9:1.5964] +step:13800/20000 val_loss:2.1425 val_bpb:1.2689 train_time:578530ms step_avg:41.92ms +step:14000/20000 train_loss:2.1419 train_time:586885ms step_avg:41.92ms +step:14000 shared0_alpha:mean=0.426,std=0.093 shared1_alpha:mean=0.602,std=0.068 shared2_alpha:mean=0.608,std=0.074 shared3_alpha:mean=0.609,std=0.069 eff_mlp_scale:[v0:239.8123 v1:157.1295 v2:177.0803 v3:188.6180 v4:161.6570 v5:172.0297 v6:165.1816 v7:163.6119 v8:165.7324 v9:330.1108] eff_attn_scale:[v0:0.2399 v1:0.6445 v2:0.7383 v3:1.0651 v4:1.6075 v5:0.5867 v6:0.6613 v7:0.8105 v8:1.2988 v9:1.6073] +step:14000/20000 val_loss:2.1349 val_bpb:1.2644 train_time:586898ms step_avg:41.92ms +step:14200/20000 train_loss:2.2152 train_time:595264ms step_avg:41.92ms +step:14200 shared0_alpha:mean=0.426,std=0.092 shared1_alpha:mean=0.602,std=0.068 shared2_alpha:mean=0.607,std=0.074 shared3_alpha:mean=0.608,std=0.069 eff_mlp_scale:[v0:239.7280 v1:157.2351 v2:177.1942 v3:188.8905 v4:161.8565 v5:172.1453 v6:165.2878 v7:163.8482 v8:165.9369 v9:331.0954] eff_attn_scale:[v0:0.2400 v1:0.6392 v2:0.7392 v3:1.0679 v4:1.6048 v5:0.5815 v6:0.6621 v7:0.8074 v8:1.2967 v9:1.6061] +step:14200/20000 val_loss:2.1273 val_bpb:1.2599 train_time:595278ms step_avg:41.92ms +step:14313/20000 val_loss:2.1240 val_bpb:1.2580 train_time:600024ms step_avg:41.92ms +stopping_early: wallclock_cap train_time:600024ms step:14313/20000 +peak memory allocated: 9924 MiB reserved: 10120 MiB +Serialized model: 45170746 bytes +Code size: 57024 bytes +Total submission size: 45227770 bytes +Serialized model int8+zlib: 10717849 bytes (payload:11630784 raw_torch:11661879 payload_ratio:3.88x) +Total submission size int8+zlib: 10774873 bytes +final_int8_zlib_roundtrip val_loss:2.1389 val_bpb:1.2668 eval_time:1332ms +final_int8_zlib_roundtrip_exact val_loss:2.13886987 val_bpb:1.26675921 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_J.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_J.txt new file mode 100644 index 0000000000..aab193b830 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_J.txt @@ -0,0 +1,1586 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + + def get(self, v: int) -> tuple[Tensor, Tensor]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + return ag, mg + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 19:22:42 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 41C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 34C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 39C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 41C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 39C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11543600 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:3 coda:1 effective_layers:14 +peri_norm:True birkhoff_mix:True +timestep_scale:disabled +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9377 train_time:31ms step_avg:31.39ms +step:2/20000 train_loss:9.6216 train_time:87ms step_avg:43.29ms +step:3/20000 train_loss:7.3784 train_time:145ms step_avg:48.48ms +step:4/20000 train_loss:9.1001 train_time:201ms step_avg:50.20ms +step:5/20000 train_loss:8.5016 train_time:260ms step_avg:51.95ms +step:6/20000 train_loss:8.2035 train_time:316ms step_avg:52.60ms +step:7/20000 train_loss:6.8744 train_time:373ms step_avg:53.29ms +step:8/20000 train_loss:6.1664 train_time:431ms step_avg:53.87ms +step:9/20000 train_loss:5.6426 train_time:487ms step_avg:54.13ms +step:10/20000 train_loss:5.4025 train_time:544ms step_avg:54.39ms +step:200/20000 train_loss:2.7835 train_time:11391ms step_avg:56.96ms +step:200 shared0_alpha:mean=0.445,std=0.054 shared1_alpha:mean=0.474,std=0.046 shared2_alpha:mean=0.490,std=0.043 shared3_alpha:mean=0.521,std=0.046 eff_mlp_scale:[v0:1.2272 v1:1.2578 v2:1.2182 v3:1.2957 v4:1.3907 v5:1.2578 v6:1.2182 v7:1.2957 v8:1.3907 v9:1.2578 v10:1.2182 v11:1.2957 v12:1.3907 v13:1.9321] eff_attn_scale:[v0:0.6601 v1:0.5806 v2:0.5296 v3:0.5195 v4:0.4854 v5:0.5806 v6:0.5296 v7:0.5195 v8:0.4854 v9:0.5806 v10:0.5296 v11:0.5195 v12:0.4854 v13:0.7676] +step:200/20000 val_loss:2.7730 val_bpb:1.6423 train_time:11447ms step_avg:57.24ms +step:400/20000 train_loss:2.3659 train_time:22892ms step_avg:57.23ms +step:400 shared0_alpha:mean=0.475,std=0.055 shared1_alpha:mean=0.504,std=0.045 shared2_alpha:mean=0.524,std=0.042 shared3_alpha:mean=0.560,std=0.049 eff_mlp_scale:[v0:1.4832 v1:1.5793 v2:1.4919 v3:1.5252 v4:1.5281 v5:1.5793 v6:1.4919 v7:1.5252 v8:1.5281 v9:1.5793 v10:1.4919 v11:1.5252 v12:1.5281 v13:2.4540] eff_attn_scale:[v0:0.3504 v1:0.2926 v2:0.2516 v3:0.2395 v4:0.2326 v5:0.2926 v6:0.2516 v7:0.2395 v8:0.2326 v9:0.2926 v10:0.2516 v11:0.2395 v12:0.2326 v13:0.4918] +step:400/20000 val_loss:2.5667 val_bpb:1.5202 train_time:22916ms step_avg:57.29ms +step:600/20000 train_loss:2.5795 train_time:34387ms step_avg:57.31ms +step:600 shared0_alpha:mean=0.493,std=0.053 shared1_alpha:mean=0.515,std=0.046 shared2_alpha:mean=0.544,std=0.042 shared3_alpha:mean=0.587,std=0.048 eff_mlp_scale:[v0:1.7160 v1:1.7823 v2:1.6664 v3:1.6622 v4:1.6349 v5:1.7823 v6:1.6664 v7:1.6622 v8:1.6349 v9:1.7823 v10:1.6664 v11:1.6622 v12:1.6349 v13:2.9567] eff_attn_scale:[v0:0.1718 v1:0.1651 v2:0.1432 v3:0.1412 v4:0.1458 v5:0.1651 v6:0.1432 v7:0.1412 v8:0.1458 v9:0.1651 v10:0.1432 v11:0.1412 v12:0.1458 v13:0.3489] +step:600/20000 val_loss:2.4826 val_bpb:1.4704 train_time:34411ms step_avg:57.35ms +step:800/20000 train_loss:2.3365 train_time:45899ms step_avg:57.37ms +step:800 shared0_alpha:mean=0.501,std=0.053 shared1_alpha:mean=0.521,std=0.048 shared2_alpha:mean=0.557,std=0.042 shared3_alpha:mean=0.604,std=0.048 eff_mlp_scale:[v0:1.9629 v1:1.9632 v2:1.8147 v3:1.7813 v4:1.7491 v5:1.9632 v6:1.8147 v7:1.7813 v8:1.7491 v9:1.9632 v10:1.8147 v11:1.7813 v12:1.7491 v13:3.3591] eff_attn_scale:[v0:0.1187 v1:0.1200 v2:0.1067 v3:0.1048 v4:0.1135 v5:0.1200 v6:0.1067 v7:0.1048 v8:0.1135 v9:0.1200 v10:0.1067 v11:0.1048 v12:0.1135 v13:0.2837] +step:800/20000 val_loss:2.4216 val_bpb:1.4342 train_time:45922ms step_avg:57.40ms +step:1000/20000 train_loss:2.4180 train_time:57412ms step_avg:57.41ms +step:1000 shared0_alpha:mean=0.505,std=0.054 shared1_alpha:mean=0.525,std=0.049 shared2_alpha:mean=0.565,std=0.043 shared3_alpha:mean=0.615,std=0.048 eff_mlp_scale:[v0:2.1871 v1:2.1357 v2:1.9597 v3:1.8977 v4:1.8674 v5:2.1357 v6:1.9597 v7:1.8977 v8:1.8674 v9:2.1357 v10:1.9597 v11:1.8977 v12:1.8674 v13:3.7013] eff_attn_scale:[v0:0.1006 v1:0.1000 v2:0.0910 v3:0.0904 v4:0.1001 v5:0.1000 v6:0.0910 v7:0.0904 v8:0.1001 v9:0.1000 v10:0.0910 v11:0.0904 v12:0.1001 v13:0.2512] +step:1000/20000 val_loss:2.3816 val_bpb:1.4105 train_time:57435ms step_avg:57.44ms +step:1200/20000 train_loss:2.4351 train_time:68924ms step_avg:57.44ms +step:1200 shared0_alpha:mean=0.506,std=0.054 shared1_alpha:mean=0.528,std=0.050 shared2_alpha:mean=0.572,std=0.044 shared3_alpha:mean=0.624,std=0.048 eff_mlp_scale:[v0:2.4218 v1:2.2936 v2:2.0907 v3:2.0166 v4:1.9839 v5:2.2936 v6:2.0907 v7:2.0166 v8:1.9839 v9:2.2936 v10:2.0907 v11:2.0166 v12:1.9839 v13:3.9956] eff_attn_scale:[v0:0.0882 v1:0.0887 v2:0.0823 v3:0.0815 v4:0.0914 v5:0.0887 v6:0.0823 v7:0.0815 v8:0.0914 v9:0.0887 v10:0.0823 v11:0.0815 v12:0.0914 v13:0.2331] +step:1200/20000 val_loss:2.3529 val_bpb:1.3935 train_time:68947ms step_avg:57.46ms +step:1400/20000 train_loss:2.4828 train_time:80427ms step_avg:57.45ms +step:1400 shared0_alpha:mean=0.507,std=0.054 shared1_alpha:mean=0.529,std=0.051 shared2_alpha:mean=0.577,std=0.044 shared3_alpha:mean=0.630,std=0.048 eff_mlp_scale:[v0:2.6433 v1:2.4430 v2:2.2191 v3:2.1255 v4:2.1032 v5:2.4430 v6:2.2191 v7:2.1255 v8:2.1032 v9:2.4430 v10:2.2191 v11:2.1255 v12:2.1032 v13:4.2657] eff_attn_scale:[v0:0.0810 v1:0.0821 v2:0.0770 v3:0.0766 v4:0.0865 v5:0.0821 v6:0.0770 v7:0.0766 v8:0.0865 v9:0.0821 v10:0.0770 v11:0.0766 v12:0.0865 v13:0.2271] +step:1400/20000 val_loss:2.3335 val_bpb:1.3820 train_time:80450ms step_avg:57.46ms +step:1600/20000 train_loss:2.1522 train_time:91929ms step_avg:57.46ms +step:1600 shared0_alpha:mean=0.507,std=0.053 shared1_alpha:mean=0.530,std=0.052 shared2_alpha:mean=0.582,std=0.046 shared3_alpha:mean=0.636,std=0.048 eff_mlp_scale:[v0:2.8451 v1:2.5923 v2:2.3427 v3:2.2329 v4:2.2144 v5:2.5923 v6:2.3427 v7:2.2329 v8:2.2144 v9:2.5923 v10:2.3427 v11:2.2329 v12:2.2144 v13:4.5034] eff_attn_scale:[v0:0.0778 v1:0.0779 v2:0.0748 v3:0.0745 v4:0.0830 v5:0.0779 v6:0.0748 v7:0.0745 v8:0.0830 v9:0.0779 v10:0.0748 v11:0.0745 v12:0.0830 v13:0.2234] +step:1600/20000 val_loss:2.3205 val_bpb:1.3743 train_time:91952ms step_avg:57.47ms +step:1800/20000 train_loss:2.2544 train_time:103419ms step_avg:57.45ms +step:1800 shared0_alpha:mean=0.506,std=0.054 shared1_alpha:mean=0.530,std=0.053 shared2_alpha:mean=0.586,std=0.047 shared3_alpha:mean=0.641,std=0.048 eff_mlp_scale:[v0:3.0402 v1:2.7345 v2:2.4604 v3:2.3383 v4:2.3313 v5:2.7345 v6:2.4604 v7:2.3383 v8:2.3313 v9:2.7345 v10:2.4604 v11:2.3383 v12:2.3313 v13:4.7378] eff_attn_scale:[v0:0.0741 v1:0.0754 v2:0.0727 v3:0.0728 v4:0.0816 v5:0.0754 v6:0.0727 v7:0.0728 v8:0.0816 v9:0.0754 v10:0.0727 v11:0.0728 v12:0.0816 v13:0.2206] +step:1800/20000 val_loss:2.3072 val_bpb:1.3665 train_time:103443ms step_avg:57.47ms +step:2000/20000 train_loss:2.3106 train_time:114907ms step_avg:57.45ms +step:2000 shared0_alpha:mean=0.507,std=0.054 shared1_alpha:mean=0.531,std=0.054 shared2_alpha:mean=0.591,std=0.047 shared3_alpha:mean=0.645,std=0.048 eff_mlp_scale:[v0:3.2303 v1:2.8718 v2:2.5796 v3:2.4458 v4:2.4412 v5:2.8718 v6:2.5796 v7:2.4458 v8:2.4412 v9:2.8718 v10:2.5796 v11:2.4458 v12:2.4412 v13:4.9627] eff_attn_scale:[v0:0.0727 v1:0.0725 v2:0.0714 v3:0.0710 v4:0.0798 v5:0.0725 v6:0.0714 v7:0.0710 v8:0.0798 v9:0.0725 v10:0.0714 v11:0.0710 v12:0.0798 v13:0.2181] +step:2000/20000 val_loss:2.2919 val_bpb:1.3574 train_time:114931ms step_avg:57.47ms +step:2200/20000 train_loss:2.1358 train_time:126394ms step_avg:57.45ms +step:2200 shared0_alpha:mean=0.506,std=0.054 shared1_alpha:mean=0.530,std=0.055 shared2_alpha:mean=0.594,std=0.048 shared3_alpha:mean=0.648,std=0.048 eff_mlp_scale:[v0:3.4249 v1:3.0052 v2:2.6858 v3:2.5473 v4:2.5500 v5:3.0052 v6:2.6858 v7:2.5473 v8:2.5500 v9:3.0052 v10:2.6858 v11:2.5473 v12:2.5500 v13:5.1766] eff_attn_scale:[v0:0.0715 v1:0.0729 v2:0.0719 v3:0.0715 v4:0.0798 v5:0.0729 v6:0.0719 v7:0.0715 v8:0.0798 v9:0.0729 v10:0.0719 v11:0.0715 v12:0.0798 v13:0.2170] +step:2200/20000 val_loss:2.2850 val_bpb:1.3533 train_time:126416ms step_avg:57.46ms +step:2400/20000 train_loss:2.2598 train_time:137884ms step_avg:57.45ms +step:2400 shared0_alpha:mean=0.506,std=0.055 shared1_alpha:mean=0.531,std=0.055 shared2_alpha:mean=0.597,std=0.049 shared3_alpha:mean=0.651,std=0.048 eff_mlp_scale:[v0:3.6017 v1:3.1359 v2:2.7893 v3:2.6439 v4:2.6551 v5:3.1359 v6:2.7893 v7:2.6439 v8:2.6551 v9:3.1359 v10:2.7893 v11:2.6439 v12:2.6551 v13:5.3813] eff_attn_scale:[v0:0.0694 v1:0.0712 v2:0.0711 v3:0.0716 v4:0.0792 v5:0.0712 v6:0.0711 v7:0.0716 v8:0.0792 v9:0.0712 v10:0.0711 v11:0.0716 v12:0.0792 v13:0.2173] +step:2400/20000 val_loss:2.2740 val_bpb:1.3468 train_time:137908ms step_avg:57.46ms +step:2600/20000 train_loss:2.4758 train_time:149368ms step_avg:57.45ms +step:2600 shared0_alpha:mean=0.506,std=0.056 shared1_alpha:mean=0.530,std=0.056 shared2_alpha:mean=0.601,std=0.050 shared3_alpha:mean=0.654,std=0.048 eff_mlp_scale:[v0:3.7750 v1:3.2609 v2:2.8953 v3:2.7390 v4:2.7574 v5:3.2609 v6:2.8953 v7:2.7390 v8:2.7574 v9:3.2609 v10:2.8953 v11:2.7390 v12:2.7574 v13:5.5683] eff_attn_scale:[v0:0.0720 v1:0.0732 v2:0.0731 v3:0.0734 v4:0.0813 v5:0.0732 v6:0.0731 v7:0.0734 v8:0.0813 v9:0.0732 v10:0.0731 v11:0.0734 v12:0.0813 v13:0.2207] +step:2600/20000 val_loss:2.2912 val_bpb:1.3570 train_time:149390ms step_avg:57.46ms +step:2800/20000 train_loss:2.2965 train_time:160849ms step_avg:57.45ms +step:2800 shared0_alpha:mean=0.506,std=0.055 shared1_alpha:mean=0.530,std=0.057 shared2_alpha:mean=0.603,std=0.051 shared3_alpha:mean=0.656,std=0.048 eff_mlp_scale:[v0:3.9576 v1:3.3860 v2:2.9914 v3:2.8310 v4:2.8586 v5:3.3860 v6:2.9914 v7:2.8310 v8:2.8586 v9:3.3860 v10:2.9914 v11:2.8310 v12:2.8586 v13:5.7610] eff_attn_scale:[v0:0.0694 v1:0.0709 v2:0.0704 v3:0.0717 v4:0.0786 v5:0.0709 v6:0.0704 v7:0.0717 v8:0.0786 v9:0.0709 v10:0.0704 v11:0.0717 v12:0.0786 v13:0.2184] +step:2800/20000 val_loss:2.2620 val_bpb:1.3397 train_time:160871ms step_avg:57.45ms +step:3000/20000 train_loss:2.2842 train_time:172319ms step_avg:57.44ms +step:3000 shared0_alpha:mean=0.506,std=0.056 shared1_alpha:mean=0.530,std=0.057 shared2_alpha:mean=0.606,std=0.051 shared3_alpha:mean=0.658,std=0.048 eff_mlp_scale:[v0:4.1204 v1:3.5013 v2:3.0896 v3:2.9218 v4:2.9550 v5:3.5013 v6:3.0896 v7:2.9218 v8:2.9550 v9:3.5013 v10:3.0896 v11:2.9218 v12:2.9550 v13:5.9427] eff_attn_scale:[v0:0.0667 v1:0.0697 v2:0.0701 v3:0.0712 v4:0.0796 v5:0.0697 v6:0.0701 v7:0.0712 v8:0.0796 v9:0.0697 v10:0.0701 v11:0.0712 v12:0.0796 v13:0.2191] +step:3000/20000 val_loss:2.2538 val_bpb:1.3348 train_time:172341ms step_avg:57.45ms +step:3200/20000 train_loss:2.2476 train_time:183786ms step_avg:57.43ms +step:3200 shared0_alpha:mean=0.506,std=0.056 shared1_alpha:mean=0.530,std=0.058 shared2_alpha:mean=0.609,std=0.051 shared3_alpha:mean=0.660,std=0.048 eff_mlp_scale:[v0:4.2886 v1:3.6180 v2:3.1834 v3:3.0081 v4:3.0508 v5:3.6180 v6:3.1834 v7:3.0081 v8:3.0508 v9:3.6180 v10:3.1834 v11:3.0081 v12:3.0508 v13:6.1149] eff_attn_scale:[v0:0.0675 v1:0.0701 v2:0.0703 v3:0.0726 v4:0.0805 v5:0.0701 v6:0.0703 v7:0.0726 v8:0.0805 v9:0.0701 v10:0.0703 v11:0.0726 v12:0.0805 v13:0.2211] +step:3200/20000 val_loss:2.2495 val_bpb:1.3323 train_time:183809ms step_avg:57.44ms +step:3400/20000 train_loss:2.2180 train_time:195257ms step_avg:57.43ms +step:3400 shared0_alpha:mean=0.505,std=0.056 shared1_alpha:mean=0.529,std=0.059 shared2_alpha:mean=0.612,std=0.052 shared3_alpha:mean=0.662,std=0.048 eff_mlp_scale:[v0:4.4506 v1:3.7345 v2:3.2819 v3:3.0991 v4:3.1431 v5:3.7345 v6:3.2819 v7:3.0991 v8:3.1431 v9:3.7345 v10:3.2819 v11:3.0991 v12:3.1431 v13:6.2812] eff_attn_scale:[v0:0.0669 v1:0.0698 v2:0.0701 v3:0.0736 v4:0.0813 v5:0.0698 v6:0.0701 v7:0.0736 v8:0.0813 v9:0.0698 v10:0.0701 v11:0.0736 v12:0.0813 v13:0.2224] +step:3400/20000 val_loss:2.2459 val_bpb:1.3301 train_time:195279ms step_avg:57.44ms +step:3600/20000 train_loss:2.1855 train_time:206729ms step_avg:57.42ms +step:3600 shared0_alpha:mean=0.505,std=0.057 shared1_alpha:mean=0.529,std=0.059 shared2_alpha:mean=0.614,std=0.053 shared3_alpha:mean=0.663,std=0.048 eff_mlp_scale:[v0:4.6075 v1:3.8462 v2:3.3638 v3:3.1821 v4:3.2349 v5:3.8462 v6:3.3638 v7:3.1821 v8:3.2349 v9:3.8462 v10:3.3638 v11:3.1821 v12:3.2349 v13:6.4444] eff_attn_scale:[v0:0.0664 v1:0.0708 v2:0.0705 v3:0.0735 v4:0.0813 v5:0.0708 v6:0.0705 v7:0.0735 v8:0.0813 v9:0.0708 v10:0.0705 v11:0.0735 v12:0.0813 v13:0.2252] +step:3600/20000 val_loss:2.2387 val_bpb:1.3259 train_time:206752ms step_avg:57.43ms +step:3800/20000 train_loss:2.2841 train_time:218189ms step_avg:57.42ms +step:3800 shared0_alpha:mean=0.505,std=0.057 shared1_alpha:mean=0.529,std=0.060 shared2_alpha:mean=0.617,std=0.053 shared3_alpha:mean=0.665,std=0.047 eff_mlp_scale:[v0:4.7502 v1:3.9513 v2:3.4541 v3:3.2621 v4:3.3293 v5:3.9513 v6:3.4541 v7:3.2621 v8:3.3293 v9:3.9513 v10:3.4541 v11:3.2621 v12:3.3293 v13:6.5976] eff_attn_scale:[v0:0.0667 v1:0.0712 v2:0.0714 v3:0.0748 v4:0.0816 v5:0.0712 v6:0.0714 v7:0.0748 v8:0.0816 v9:0.0712 v10:0.0714 v11:0.0748 v12:0.0816 v13:0.2277] +step:3800/20000 val_loss:2.2360 val_bpb:1.3243 train_time:218212ms step_avg:57.42ms +step:4000/20000 train_loss:2.2193 train_time:229655ms step_avg:57.41ms +step:4000 shared0_alpha:mean=0.505,std=0.058 shared1_alpha:mean=0.528,std=0.061 shared2_alpha:mean=0.619,std=0.053 shared3_alpha:mean=0.666,std=0.047 eff_mlp_scale:[v0:4.9008 v1:4.0603 v2:3.5353 v3:3.3409 v4:3.4159 v5:4.0603 v6:3.5353 v7:3.3409 v8:3.4159 v9:4.0603 v10:3.5353 v11:3.3409 v12:3.4159 v13:6.7532] eff_attn_scale:[v0:0.0654 v1:0.0707 v2:0.0713 v3:0.0752 v4:0.0825 v5:0.0707 v6:0.0713 v7:0.0752 v8:0.0825 v9:0.0707 v10:0.0713 v11:0.0752 v12:0.0825 v13:0.2289] +step:4000/20000 val_loss:2.2306 val_bpb:1.3211 train_time:229680ms step_avg:57.42ms +step:4200/20000 train_loss:2.2331 train_time:241182ms step_avg:57.42ms +step:4200 shared0_alpha:mean=0.504,std=0.058 shared1_alpha:mean=0.528,std=0.061 shared2_alpha:mean=0.621,std=0.053 shared3_alpha:mean=0.668,std=0.048 eff_mlp_scale:[v0:5.0474 v1:4.1642 v2:3.6144 v3:3.4149 v4:3.5012 v5:4.1642 v6:3.6144 v7:3.4149 v8:3.5012 v9:4.1642 v10:3.6144 v11:3.4149 v12:3.5012 v13:6.8981] eff_attn_scale:[v0:0.0650 v1:0.0709 v2:0.0722 v3:0.0760 v4:0.0839 v5:0.0709 v6:0.0722 v7:0.0760 v8:0.0839 v9:0.0709 v10:0.0722 v11:0.0760 v12:0.0839 v13:0.2334] +step:4200/20000 val_loss:2.2265 val_bpb:1.3187 train_time:241204ms step_avg:57.43ms +step:4400/20000 train_loss:2.1750 train_time:252637ms step_avg:57.42ms +step:4400 shared0_alpha:mean=0.504,std=0.059 shared1_alpha:mean=0.528,std=0.062 shared2_alpha:mean=0.623,std=0.053 shared3_alpha:mean=0.670,std=0.047 eff_mlp_scale:[v0:5.1907 v1:4.2607 v2:3.6949 v3:3.4868 v4:3.5843 v5:4.2607 v6:3.6949 v7:3.4868 v8:3.5843 v9:4.2607 v10:3.6949 v11:3.4868 v12:3.5843 v13:7.0434] eff_attn_scale:[v0:0.0662 v1:0.0718 v2:0.0726 v3:0.0776 v4:0.0853 v5:0.0718 v6:0.0726 v7:0.0776 v8:0.0853 v9:0.0718 v10:0.0726 v11:0.0776 v12:0.0853 v13:0.2365] +step:4400/20000 val_loss:2.2279 val_bpb:1.3195 train_time:252659ms step_avg:57.42ms +step:4600/20000 train_loss:2.0360 train_time:264095ms step_avg:57.41ms +step:4600 shared0_alpha:mean=0.504,std=0.058 shared1_alpha:mean=0.527,std=0.062 shared2_alpha:mean=0.625,std=0.053 shared3_alpha:mean=0.670,std=0.047 eff_mlp_scale:[v0:5.3476 v1:4.3643 v2:3.7734 v3:3.5598 v4:3.6669 v5:4.3643 v6:3.7734 v7:3.5598 v8:3.6669 v9:4.3643 v10:3.7734 v11:3.5598 v12:3.6669 v13:7.1847] eff_attn_scale:[v0:0.0653 v1:0.0721 v2:0.0734 v3:0.0796 v4:0.0876 v5:0.0721 v6:0.0734 v7:0.0796 v8:0.0876 v9:0.0721 v10:0.0734 v11:0.0796 v12:0.0876 v13:0.2406] +step:4600/20000 val_loss:2.2237 val_bpb:1.3170 train_time:264117ms step_avg:57.42ms +step:4800/20000 train_loss:2.3253 train_time:275547ms step_avg:57.41ms +step:4800 shared0_alpha:mean=0.503,std=0.060 shared1_alpha:mean=0.526,std=0.063 shared2_alpha:mean=0.626,std=0.054 shared3_alpha:mean=0.672,std=0.047 eff_mlp_scale:[v0:5.4785 v1:4.4603 v2:3.8470 v3:3.6289 v4:3.7423 v5:4.4603 v6:3.8470 v7:3.6289 v8:3.7423 v9:4.4603 v10:3.8470 v11:3.6289 v12:3.7423 v13:7.3152] eff_attn_scale:[v0:0.0640 v1:0.0726 v2:0.0748 v3:0.0810 v4:0.0890 v5:0.0726 v6:0.0748 v7:0.0810 v8:0.0890 v9:0.0726 v10:0.0748 v11:0.0810 v12:0.0890 v13:0.2432] +step:4800/20000 val_loss:2.2191 val_bpb:1.3143 train_time:275570ms step_avg:57.41ms +step:5000/20000 train_loss:2.0975 train_time:287017ms step_avg:57.40ms +step:5000 shared0_alpha:mean=0.504,std=0.060 shared1_alpha:mean=0.525,std=0.063 shared2_alpha:mean=0.628,std=0.054 shared3_alpha:mean=0.673,std=0.047 eff_mlp_scale:[v0:5.6246 v1:4.5548 v2:3.9182 v3:3.6932 v4:3.8205 v5:4.5548 v6:3.9182 v7:3.6932 v8:3.8205 v9:4.5548 v10:3.9182 v11:3.6932 v12:3.8205 v13:7.4428] eff_attn_scale:[v0:0.0645 v1:0.0745 v2:0.0762 v3:0.0828 v4:0.0901 v5:0.0745 v6:0.0762 v7:0.0828 v8:0.0901 v9:0.0745 v10:0.0762 v11:0.0828 v12:0.0901 v13:0.2478] +step:5000/20000 val_loss:2.2141 val_bpb:1.3113 train_time:287041ms step_avg:57.41ms +step:5200/20000 train_loss:2.2291 train_time:298470ms step_avg:57.40ms +step:5200 shared0_alpha:mean=0.503,std=0.060 shared1_alpha:mean=0.525,std=0.064 shared2_alpha:mean=0.630,std=0.054 shared3_alpha:mean=0.674,std=0.047 eff_mlp_scale:[v0:5.7628 v1:4.6504 v2:3.9846 v3:3.7586 v4:3.8975 v5:4.6504 v6:3.9846 v7:3.7586 v8:3.8975 v9:4.6504 v10:3.9846 v11:3.7586 v12:3.8975 v13:7.5695] eff_attn_scale:[v0:0.0642 v1:0.0762 v2:0.0781 v3:0.0857 v4:0.0929 v5:0.0762 v6:0.0781 v7:0.0857 v8:0.0929 v9:0.0762 v10:0.0781 v11:0.0857 v12:0.0929 v13:0.2523] +step:5200/20000 val_loss:2.2152 val_bpb:1.3120 train_time:298494ms step_avg:57.40ms +step:5400/20000 train_loss:2.2444 train_time:309937ms step_avg:57.40ms +step:5400 shared0_alpha:mean=0.503,std=0.061 shared1_alpha:mean=0.524,std=0.064 shared2_alpha:mean=0.631,std=0.054 shared3_alpha:mean=0.675,std=0.047 eff_mlp_scale:[v0:5.8953 v1:4.7314 v2:4.0484 v3:3.8255 v4:3.9725 v5:4.7314 v6:4.0484 v7:3.8255 v8:3.9725 v9:4.7314 v10:4.0484 v11:3.8255 v12:3.9725 v13:7.6953] eff_attn_scale:[v0:0.0654 v1:0.0763 v2:0.0791 v3:0.0869 v4:0.0942 v5:0.0763 v6:0.0791 v7:0.0869 v8:0.0942 v9:0.0763 v10:0.0791 v11:0.0869 v12:0.0942 v13:0.2553] +step:5400/20000 val_loss:2.2099 val_bpb:1.3088 train_time:309961ms step_avg:57.40ms +step:5600/20000 train_loss:2.2454 train_time:321491ms step_avg:57.41ms +step:5600 shared0_alpha:mean=0.503,std=0.061 shared1_alpha:mean=0.523,std=0.064 shared2_alpha:mean=0.633,std=0.055 shared3_alpha:mean=0.677,std=0.047 eff_mlp_scale:[v0:6.0291 v1:4.8143 v2:4.1116 v3:3.8768 v4:4.0404 v5:4.8143 v6:4.1116 v7:3.8768 v8:4.0404 v9:4.8143 v10:4.1116 v11:3.8768 v12:4.0404 v13:7.8172] eff_attn_scale:[v0:0.0658 v1:0.0785 v2:0.0819 v3:0.0901 v4:0.0972 v5:0.0785 v6:0.0819 v7:0.0901 v8:0.0972 v9:0.0785 v10:0.0819 v11:0.0901 v12:0.0972 v13:0.2599] +step:5600/20000 val_loss:2.2099 val_bpb:1.3089 train_time:321514ms step_avg:57.41ms +step:5800/20000 train_loss:2.2088 train_time:332949ms step_avg:57.41ms +step:5800 shared0_alpha:mean=0.503,std=0.061 shared1_alpha:mean=0.523,std=0.065 shared2_alpha:mean=0.635,std=0.054 shared3_alpha:mean=0.678,std=0.047 eff_mlp_scale:[v0:6.1602 v1:4.8998 v2:4.1701 v3:3.9354 v4:4.1028 v5:4.8998 v6:4.1701 v7:3.9354 v8:4.1028 v9:4.8998 v10:4.1701 v11:3.9354 v12:4.1028 v13:7.9418] eff_attn_scale:[v0:0.0647 v1:0.0811 v2:0.0848 v3:0.0930 v4:0.0998 v5:0.0811 v6:0.0848 v7:0.0930 v8:0.0998 v9:0.0811 v10:0.0848 v11:0.0930 v12:0.0998 v13:0.2658] +step:5800/20000 val_loss:2.2081 val_bpb:1.3078 train_time:332971ms step_avg:57.41ms +step:6000/20000 train_loss:2.2765 train_time:344409ms step_avg:57.40ms +step:6000 shared0_alpha:mean=0.502,std=0.062 shared1_alpha:mean=0.521,std=0.064 shared2_alpha:mean=0.636,std=0.054 shared3_alpha:mean=0.679,std=0.047 eff_mlp_scale:[v0:6.2971 v1:4.9726 v2:4.2283 v3:3.9858 v4:4.1690 v5:4.9726 v6:4.2283 v7:3.9858 v8:4.1690 v9:4.9726 v10:4.2283 v11:3.9858 v12:4.1690 v13:8.0613] eff_attn_scale:[v0:0.0646 v1:0.0823 v2:0.0871 v3:0.0955 v4:0.1027 v5:0.0823 v6:0.0871 v7:0.0955 v8:0.1027 v9:0.0823 v10:0.0871 v11:0.0955 v12:0.1027 v13:0.2706] +step:6000/20000 val_loss:2.2043 val_bpb:1.3055 train_time:344433ms step_avg:57.41ms +step:6200/20000 train_loss:2.1509 train_time:355882ms step_avg:57.40ms +step:6200 shared0_alpha:mean=0.501,std=0.062 shared1_alpha:mean=0.520,std=0.065 shared2_alpha:mean=0.637,std=0.054 shared3_alpha:mean=0.681,std=0.047 eff_mlp_scale:[v0:6.4114 v1:5.0481 v2:4.2785 v3:4.0379 v4:4.2320 v5:5.0481 v6:4.2785 v7:4.0379 v8:4.2320 v9:5.0481 v10:4.2785 v11:4.0379 v12:4.2320 v13:8.1688] eff_attn_scale:[v0:0.0649 v1:0.0856 v2:0.0888 v3:0.0980 v4:0.1043 v5:0.0856 v6:0.0888 v7:0.0980 v8:0.1043 v9:0.0856 v10:0.0888 v11:0.0980 v12:0.1043 v13:0.2722] +step:6200/20000 val_loss:2.2039 val_bpb:1.3053 train_time:355904ms step_avg:57.40ms +step:6400/20000 train_loss:2.2259 train_time:367341ms step_avg:57.40ms +step:6400 shared0_alpha:mean=0.500,std=0.062 shared1_alpha:mean=0.519,std=0.066 shared2_alpha:mean=0.639,std=0.055 shared3_alpha:mean=0.682,std=0.047 eff_mlp_scale:[v0:6.5351 v1:5.1226 v2:4.3308 v3:4.0836 v4:4.2879 v5:5.1226 v6:4.3308 v7:4.0836 v8:4.2879 v9:5.1226 v10:4.3308 v11:4.0836 v12:4.2879 v13:8.2928] eff_attn_scale:[v0:0.0635 v1:0.0877 v2:0.0915 v3:0.1007 v4:0.1079 v5:0.0877 v6:0.0915 v7:0.1007 v8:0.1079 v9:0.0877 v10:0.0915 v11:0.1007 v12:0.1079 v13:0.2803] +step:6400/20000 val_loss:2.2003 val_bpb:1.3032 train_time:367365ms step_avg:57.40ms +step:6600/20000 train_loss:2.1869 train_time:378798ms step_avg:57.39ms +step:6600 shared0_alpha:mean=0.499,std=0.062 shared1_alpha:mean=0.519,std=0.066 shared2_alpha:mean=0.640,std=0.055 shared3_alpha:mean=0.684,std=0.047 eff_mlp_scale:[v0:6.6441 v1:5.1860 v2:4.3720 v3:4.1323 v4:4.3483 v5:5.1860 v6:4.3720 v7:4.1323 v8:4.3483 v9:5.1860 v10:4.3720 v11:4.1323 v12:4.3483 v13:8.4074] eff_attn_scale:[v0:0.0646 v1:0.0909 v2:0.0951 v3:0.1033 v4:0.1097 v5:0.0909 v6:0.0951 v7:0.1033 v8:0.1097 v9:0.0909 v10:0.0951 v11:0.1033 v12:0.1097 v13:0.2823] +step:6600/20000 val_loss:2.1963 val_bpb:1.3008 train_time:378822ms step_avg:57.40ms +step:6800/20000 train_loss:2.2482 train_time:390261ms step_avg:57.39ms +step:6800 shared0_alpha:mean=0.497,std=0.062 shared1_alpha:mean=0.519,std=0.065 shared2_alpha:mean=0.642,std=0.055 shared3_alpha:mean=0.685,std=0.047 eff_mlp_scale:[v0:6.7425 v1:5.2416 v2:4.4147 v3:4.1788 v4:4.4099 v5:5.2416 v6:4.4147 v7:4.1788 v8:4.4099 v9:5.2416 v10:4.4147 v11:4.1788 v12:4.4099 v13:8.5245] eff_attn_scale:[v0:0.0646 v1:0.0937 v2:0.0981 v3:0.1059 v4:0.1123 v5:0.0937 v6:0.0981 v7:0.1059 v8:0.1123 v9:0.0937 v10:0.0981 v11:0.1059 v12:0.1123 v13:0.2862] +step:6800/20000 val_loss:2.1960 val_bpb:1.3006 train_time:390285ms step_avg:57.39ms +step:7000/20000 train_loss:2.2844 train_time:401714ms step_avg:57.39ms +step:7000 shared0_alpha:mean=0.495,std=0.063 shared1_alpha:mean=0.518,std=0.066 shared2_alpha:mean=0.643,std=0.055 shared3_alpha:mean=0.687,std=0.048 eff_mlp_scale:[v0:6.8519 v1:5.2991 v2:4.4581 v3:4.2266 v4:4.4656 v5:5.2991 v6:4.4581 v7:4.2266 v8:4.4656 v9:5.2991 v10:4.4581 v11:4.2266 v12:4.4656 v13:8.6452] eff_attn_scale:[v0:0.0630 v1:0.0991 v2:0.0998 v3:0.1087 v4:0.1151 v5:0.0991 v6:0.0998 v7:0.1087 v8:0.1151 v9:0.0991 v10:0.0998 v11:0.1087 v12:0.1151 v13:0.2903] +step:7000/20000 val_loss:2.1929 val_bpb:1.2988 train_time:401738ms step_avg:57.39ms +step:7200/20000 train_loss:2.2607 train_time:413178ms step_avg:57.39ms +step:7200 shared0_alpha:mean=0.493,std=0.063 shared1_alpha:mean=0.518,std=0.066 shared2_alpha:mean=0.645,std=0.056 shared3_alpha:mean=0.688,std=0.048 eff_mlp_scale:[v0:6.9500 v1:5.3454 v2:4.4947 v3:4.2648 v4:4.5212 v5:5.3454 v6:4.4947 v7:4.2648 v8:4.5212 v9:5.3454 v10:4.4947 v11:4.2648 v12:4.5212 v13:8.7626] eff_attn_scale:[v0:0.0631 v1:0.1011 v2:0.1012 v3:0.1111 v4:0.1166 v5:0.1011 v6:0.1012 v7:0.1111 v8:0.1166 v9:0.1011 v10:0.1012 v11:0.1111 v12:0.1166 v13:0.2945] +step:7200/20000 val_loss:2.1924 val_bpb:1.2984 train_time:413195ms step_avg:57.39ms +step:7400/20000 train_loss:2.1764 train_time:424624ms step_avg:57.38ms +step:7400 shared0_alpha:mean=0.491,std=0.063 shared1_alpha:mean=0.518,std=0.066 shared2_alpha:mean=0.647,std=0.055 shared3_alpha:mean=0.690,std=0.047 eff_mlp_scale:[v0:7.0455 v1:5.3944 v2:4.5337 v3:4.3121 v4:4.5744 v5:5.3944 v6:4.5337 v7:4.3121 v8:4.5744 v9:5.3944 v10:4.5337 v11:4.3121 v12:4.5744 v13:8.8736] eff_attn_scale:[v0:0.0645 v1:0.1028 v2:0.1050 v3:0.1124 v4:0.1170 v5:0.1028 v6:0.1050 v7:0.1124 v8:0.1170 v9:0.1028 v10:0.1050 v11:0.1124 v12:0.1170 v13:0.3000] +step:7400/20000 val_loss:2.1906 val_bpb:1.2974 train_time:424647ms step_avg:57.38ms +step:7600/20000 train_loss:2.0633 train_time:436076ms step_avg:57.38ms +step:7600 shared0_alpha:mean=0.489,std=0.063 shared1_alpha:mean=0.517,std=0.066 shared2_alpha:mean=0.648,std=0.055 shared3_alpha:mean=0.690,std=0.047 eff_mlp_scale:[v0:7.1421 v1:5.4362 v2:4.5743 v3:4.3527 v4:4.6272 v5:5.4362 v6:4.5743 v7:4.3527 v8:4.6272 v9:5.4362 v10:4.5743 v11:4.3527 v12:4.6272 v13:8.9904] eff_attn_scale:[v0:0.0626 v1:0.1057 v2:0.1074 v3:0.1144 v4:0.1194 v5:0.1057 v6:0.1074 v7:0.1144 v8:0.1194 v9:0.1057 v10:0.1074 v11:0.1144 v12:0.1194 v13:0.3038] +step:7600/20000 val_loss:2.1883 val_bpb:1.2960 train_time:436098ms step_avg:57.38ms +step:7800/20000 train_loss:2.2093 train_time:447541ms step_avg:57.38ms +step:7800 shared0_alpha:mean=0.487,std=0.063 shared1_alpha:mean=0.517,std=0.065 shared2_alpha:mean=0.649,std=0.055 shared3_alpha:mean=0.692,std=0.047 eff_mlp_scale:[v0:7.2288 v1:5.4790 v2:4.6127 v3:4.3997 v4:4.6860 v5:5.4790 v6:4.6127 v7:4.3997 v8:4.6860 v9:5.4790 v10:4.6127 v11:4.3997 v12:4.6860 v13:9.1032] eff_attn_scale:[v0:0.0637 v1:0.1072 v2:0.1084 v3:0.1156 v4:0.1204 v5:0.1072 v6:0.1084 v7:0.1156 v8:0.1204 v9:0.1072 v10:0.1084 v11:0.1156 v12:0.1204 v13:0.3068] +step:7800/20000 val_loss:2.1853 val_bpb:1.2943 train_time:447563ms step_avg:57.38ms +step:8000/20000 train_loss:2.1715 train_time:459000ms step_avg:57.37ms +step:8000 shared0_alpha:mean=0.485,std=0.063 shared1_alpha:mean=0.517,std=0.065 shared2_alpha:mean=0.651,std=0.055 shared3_alpha:mean=0.693,std=0.048 eff_mlp_scale:[v0:7.3141 v1:5.5184 v2:4.6498 v3:4.4485 v4:4.7372 v5:5.5184 v6:4.6498 v7:4.4485 v8:4.7372 v9:5.5184 v10:4.6498 v11:4.4485 v12:4.7372 v13:9.2187] eff_attn_scale:[v0:0.0634 v1:0.1099 v2:0.1099 v3:0.1173 v4:0.1211 v5:0.1099 v6:0.1099 v7:0.1173 v8:0.1211 v9:0.1099 v10:0.1099 v11:0.1173 v12:0.1211 v13:0.3094] +step:8000/20000 val_loss:2.1830 val_bpb:1.2929 train_time:459021ms step_avg:57.38ms +step:8200/20000 train_loss:2.2420 train_time:470458ms step_avg:57.37ms +step:8200 shared0_alpha:mean=0.483,std=0.063 shared1_alpha:mean=0.517,std=0.066 shared2_alpha:mean=0.653,std=0.056 shared3_alpha:mean=0.695,std=0.047 eff_mlp_scale:[v0:7.4092 v1:5.5668 v2:4.6873 v3:4.4928 v4:4.7978 v5:5.5668 v6:4.6873 v7:4.4928 v8:4.7978 v9:5.5668 v10:4.6873 v11:4.4928 v12:4.7978 v13:9.3390] eff_attn_scale:[v0:0.0640 v1:0.1117 v2:0.1111 v3:0.1192 v4:0.1227 v5:0.1117 v6:0.1111 v7:0.1192 v8:0.1227 v9:0.1117 v10:0.1111 v11:0.1192 v12:0.1227 v13:0.3133] +step:8200/20000 val_loss:2.1822 val_bpb:1.2924 train_time:470482ms step_avg:57.38ms +step:8400/20000 train_loss:2.1937 train_time:481977ms step_avg:57.38ms +step:8400 shared0_alpha:mean=0.481,std=0.062 shared1_alpha:mean=0.517,std=0.066 shared2_alpha:mean=0.655,std=0.056 shared3_alpha:mean=0.696,std=0.047 eff_mlp_scale:[v0:7.4919 v1:5.6076 v2:4.7245 v3:4.5381 v4:4.8513 v5:5.6076 v6:4.7245 v7:4.5381 v8:4.8513 v9:5.6076 v10:4.7245 v11:4.5381 v12:4.8513 v13:9.4506] eff_attn_scale:[v0:0.0641 v1:0.1146 v2:0.1138 v3:0.1211 v4:0.1255 v5:0.1146 v6:0.1138 v7:0.1211 v8:0.1255 v9:0.1146 v10:0.1138 v11:0.1211 v12:0.1255 v13:0.3161] +step:8400/20000 val_loss:2.1826 val_bpb:1.2927 train_time:481998ms step_avg:57.38ms +step:8600/20000 train_loss:2.1968 train_time:493428ms step_avg:57.38ms +step:8600 shared0_alpha:mean=0.480,std=0.063 shared1_alpha:mean=0.516,std=0.066 shared2_alpha:mean=0.656,std=0.056 shared3_alpha:mean=0.697,std=0.048 eff_mlp_scale:[v0:7.5842 v1:5.6452 v2:4.7610 v3:4.5847 v4:4.9082 v5:5.6452 v6:4.7610 v7:4.5847 v8:4.9082 v9:5.6452 v10:4.7610 v11:4.5847 v12:4.9082 v13:9.5691] eff_attn_scale:[v0:0.0647 v1:0.1152 v2:0.1141 v3:0.1225 v4:0.1251 v5:0.1152 v6:0.1141 v7:0.1225 v8:0.1251 v9:0.1152 v10:0.1141 v11:0.1225 v12:0.1251 v13:0.3181] +step:8600/20000 val_loss:2.1793 val_bpb:1.2907 train_time:493453ms step_avg:57.38ms +step:8800/20000 train_loss:2.1656 train_time:504889ms step_avg:57.37ms +step:8800 shared0_alpha:mean=0.478,std=0.063 shared1_alpha:mean=0.516,std=0.066 shared2_alpha:mean=0.657,std=0.056 shared3_alpha:mean=0.699,std=0.048 eff_mlp_scale:[v0:7.6625 v1:5.6841 v2:4.7996 v3:4.6281 v4:4.9624 v5:5.6841 v6:4.7996 v7:4.6281 v8:4.9624 v9:5.6841 v10:4.7996 v11:4.6281 v12:4.9624 v13:9.6825] eff_attn_scale:[v0:0.0647 v1:0.1175 v2:0.1154 v3:0.1258 v4:0.1274 v5:0.1175 v6:0.1154 v7:0.1258 v8:0.1274 v9:0.1175 v10:0.1154 v11:0.1258 v12:0.1274 v13:0.3224] +step:8800/20000 val_loss:2.1788 val_bpb:1.2904 train_time:504912ms step_avg:57.38ms +step:9000/20000 train_loss:2.0864 train_time:516349ms step_avg:57.37ms +step:9000 shared0_alpha:mean=0.476,std=0.063 shared1_alpha:mean=0.515,std=0.066 shared2_alpha:mean=0.659,std=0.056 shared3_alpha:mean=0.700,std=0.048 eff_mlp_scale:[v0:7.7516 v1:5.7250 v2:4.8396 v3:4.6751 v4:5.0149 v5:5.7250 v6:4.8396 v7:4.6751 v8:5.0149 v9:5.7250 v10:4.8396 v11:4.6751 v12:5.0149 v13:9.8016] eff_attn_scale:[v0:0.0650 v1:0.1190 v2:0.1164 v3:0.1256 v4:0.1296 v5:0.1190 v6:0.1164 v7:0.1256 v8:0.1296 v9:0.1190 v10:0.1164 v11:0.1256 v12:0.1296 v13:0.3243] +step:9000/20000 val_loss:2.1777 val_bpb:1.2897 train_time:516368ms step_avg:57.37ms +step:9200/20000 train_loss:2.1450 train_time:527808ms step_avg:57.37ms +step:9200 shared0_alpha:mean=0.475,std=0.062 shared1_alpha:mean=0.516,std=0.065 shared2_alpha:mean=0.660,std=0.056 shared3_alpha:mean=0.701,std=0.048 eff_mlp_scale:[v0:7.8317 v1:5.7715 v2:4.8791 v3:4.7231 v4:5.0707 v5:5.7715 v6:4.8791 v7:4.7231 v8:5.0707 v9:5.7715 v10:4.8791 v11:4.7231 v12:5.0707 v13:9.9149] eff_attn_scale:[v0:0.0639 v1:0.1202 v2:0.1183 v3:0.1271 v4:0.1294 v5:0.1202 v6:0.1183 v7:0.1271 v8:0.1294 v9:0.1202 v10:0.1183 v11:0.1271 v12:0.1294 v13:0.3259] +step:9200/20000 val_loss:2.1759 val_bpb:1.2887 train_time:527830ms step_avg:57.37ms +step:9400/20000 train_loss:2.1990 train_time:539264ms step_avg:57.37ms +step:9400 shared0_alpha:mean=0.473,std=0.063 shared1_alpha:mean=0.514,std=0.066 shared2_alpha:mean=0.661,std=0.056 shared3_alpha:mean=0.702,std=0.048 eff_mlp_scale:[v0:7.9080 v1:5.8017 v2:4.9106 v3:4.7674 v4:5.1213 v5:5.8017 v6:4.9106 v7:4.7674 v8:5.1213 v9:5.8017 v10:4.9106 v11:4.7674 v12:5.1213 v13:10.0210] eff_attn_scale:[v0:0.0630 v1:0.1210 v2:0.1193 v3:0.1293 v4:0.1292 v5:0.1210 v6:0.1193 v7:0.1293 v8:0.1292 v9:0.1210 v10:0.1193 v11:0.1293 v12:0.1292 v13:0.3287] +step:9400/20000 val_loss:2.1688 val_bpb:1.2845 train_time:539287ms step_avg:57.37ms +step:9600/20000 train_loss:2.1948 train_time:550723ms step_avg:57.37ms +step:9600 shared0_alpha:mean=0.472,std=0.062 shared1_alpha:mean=0.514,std=0.065 shared2_alpha:mean=0.662,std=0.056 shared3_alpha:mean=0.703,std=0.048 eff_mlp_scale:[v0:7.9682 v1:5.8351 v2:4.9413 v3:4.8047 v4:5.1669 v5:5.8351 v6:4.9413 v7:4.8047 v8:5.1669 v9:5.8351 v10:4.9413 v11:4.8047 v12:5.1669 v13:10.1222] eff_attn_scale:[v0:0.0631 v1:0.1242 v2:0.1214 v3:0.1324 v4:0.1343 v5:0.1242 v6:0.1214 v7:0.1324 v8:0.1343 v9:0.1242 v10:0.1214 v11:0.1324 v12:0.1343 v13:0.3338] +step:9600/20000 val_loss:2.1600 val_bpb:1.2793 train_time:550746ms step_avg:57.37ms +step:9800/20000 train_loss:2.1124 train_time:562175ms step_avg:57.36ms +step:9800 shared0_alpha:mean=0.471,std=0.062 shared1_alpha:mean=0.514,std=0.065 shared2_alpha:mean=0.662,std=0.056 shared3_alpha:mean=0.704,std=0.048 eff_mlp_scale:[v0:8.0155 v1:5.8620 v2:4.9632 v3:4.8406 v4:5.2018 v5:5.8620 v6:4.9632 v7:4.8406 v8:5.2018 v9:5.8620 v10:4.9632 v11:4.8406 v12:5.2018 v13:10.2185] eff_attn_scale:[v0:0.0627 v1:0.1273 v2:0.1252 v3:0.1363 v4:0.1381 v5:0.1273 v6:0.1252 v7:0.1363 v8:0.1381 v9:0.1273 v10:0.1252 v11:0.1363 v12:0.1381 v13:0.3390] +step:9800/20000 val_loss:2.1503 val_bpb:1.2735 train_time:562200ms step_avg:57.37ms +step:10000/20000 train_loss:2.1440 train_time:573631ms step_avg:57.36ms +step:10000 shared0_alpha:mean=0.471,std=0.062 shared1_alpha:mean=0.513,std=0.065 shared2_alpha:mean=0.662,std=0.057 shared3_alpha:mean=0.704,std=0.048 eff_mlp_scale:[v0:8.0456 v1:5.8834 v2:4.9828 v3:4.8627 v4:5.2293 v5:5.8834 v6:4.9828 v7:4.8627 v8:5.2293 v9:5.8834 v10:4.9828 v11:4.8627 v12:5.2293 v13:10.2956] eff_attn_scale:[v0:0.0613 v1:0.1291 v2:0.1290 v3:0.1394 v4:0.1420 v5:0.1291 v6:0.1290 v7:0.1394 v8:0.1420 v9:0.1291 v10:0.1290 v11:0.1394 v12:0.1420 v13:0.3431] +step:10000/20000 val_loss:2.1414 val_bpb:1.2683 train_time:573655ms step_avg:57.37ms +step:10200/20000 train_loss:2.0848 train_time:585091ms step_avg:57.36ms +step:10200 shared0_alpha:mean=0.470,std=0.062 shared1_alpha:mean=0.513,std=0.065 shared2_alpha:mean=0.662,std=0.057 shared3_alpha:mean=0.704,std=0.048 eff_mlp_scale:[v0:8.0616 v1:5.8984 v2:4.9941 v3:4.8799 v4:5.2463 v5:5.8984 v6:4.9941 v7:4.8799 v8:5.2463 v9:5.8984 v10:4.9941 v11:4.8799 v12:5.2463 v13:10.3493] eff_attn_scale:[v0:0.0606 v1:0.1303 v2:0.1296 v3:0.1413 v4:0.1426 v5:0.1303 v6:0.1296 v7:0.1413 v8:0.1426 v9:0.1303 v10:0.1296 v11:0.1413 v12:0.1426 v13:0.3458] +step:10200/20000 val_loss:2.1312 val_bpb:1.2622 train_time:585113ms step_avg:57.36ms +step:10400/20000 train_loss:2.1149 train_time:596544ms step_avg:57.36ms +step:10400 shared0_alpha:mean=0.470,std=0.062 shared1_alpha:mean=0.513,std=0.065 shared2_alpha:mean=0.662,std=0.057 shared3_alpha:mean=0.704,std=0.048 eff_mlp_scale:[v0:8.0672 v1:5.9037 v2:4.9990 v3:4.8884 v4:5.2554 v5:5.9037 v6:4.9990 v7:4.8884 v8:5.2554 v9:5.9037 v10:4.9990 v11:4.8884 v12:5.2554 v13:10.3716] eff_attn_scale:[v0:0.0590 v1:0.1291 v2:0.1300 v3:0.1412 v4:0.1431 v5:0.1291 v6:0.1300 v7:0.1412 v8:0.1431 v9:0.1291 v10:0.1300 v11:0.1412 v12:0.1431 v13:0.3473] +step:10400/20000 val_loss:2.1232 val_bpb:1.2575 train_time:596568ms step_avg:57.36ms +step:10461/20000 val_loss:2.1218 val_bpb:1.2567 train_time:600045ms step_avg:57.36ms +stopping_early: wallclock_cap train_time:600045ms step:10461/20000 +peak memory allocated: 13735 MiB reserved: 13808 MiB +Serialized model: 45149624 bytes +Code size: 57024 bytes +Total submission size: 45206648 bytes +Serialized model int8+zlib: 10698550 bytes (payload:11610304 raw_torch:11640701 payload_ratio:3.89x) +Total submission size int8+zlib: 10755574 bytes +final_int8_zlib_roundtrip val_loss:2.1393 val_bpb:1.2670 eval_time:1834ms +final_int8_zlib_roundtrip_exact val_loss:2.13925806 val_bpb:1.26698912 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_K.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_K.txt new file mode 100644 index 0000000000..d4830cec70 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s2_K.txt @@ -0,0 +1,1587 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + + def get(self, v: int) -> tuple[Tensor, Tensor]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + return ag, mg + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 21:36:33 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 35C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:2F:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:38:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 31C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 1046 C /usr/local/bin/python 1512MiB | +| 1 N/A N/A 1047 C /usr/local/bin/python 1512MiB | +| 2 N/A N/A 1048 C /usr/local/bin/python 1512MiB | +| 3 N/A N/A 1049 C /usr/local/bin/python 1512MiB | +| 4 N/A N/A 1050 C /usr/local/bin/python 1512MiB | +| 5 N/A N/A 1051 C /usr/local/bin/python 1512MiB | +| 6 N/A N/A 1052 C /usr/local/bin/python 1512MiB | +| 7 N/A N/A 1053 C /usr/local/bin/python 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11557936 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:3 coda:1 effective_layers:14 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:14336 +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9377 train_time:34ms step_avg:33.54ms +step:2/20000 train_loss:9.6217 train_time:88ms step_avg:44.00ms +step:3/20000 train_loss:7.3812 train_time:144ms step_avg:48.09ms +step:4/20000 train_loss:9.1484 train_time:201ms step_avg:50.14ms +step:5/20000 train_loss:8.6154 train_time:257ms step_avg:51.49ms +step:6/20000 train_loss:8.3173 train_time:314ms step_avg:52.32ms +step:7/20000 train_loss:6.9135 train_time:370ms step_avg:52.87ms +step:8/20000 train_loss:6.1894 train_time:427ms step_avg:53.36ms +step:9/20000 train_loss:5.6406 train_time:483ms step_avg:53.70ms +step:10/20000 train_loss:5.4038 train_time:540ms step_avg:54.00ms +step:200/20000 train_loss:2.7823 train_time:11733ms step_avg:58.67ms +step:200 shared0_alpha:mean=0.448,std=0.049 shared1_alpha:mean=0.476,std=0.042 shared2_alpha:mean=0.489,std=0.039 shared3_alpha:mean=0.516,std=0.040 eff_mlp_scale:[v0:27.5484 v1:26.7534 v2:27.3004 v3:28.6959 v4:29.6441 v5:31.4062 v6:28.1404 v7:28.1103 v8:29.0391 v9:30.2430 v10:28.0004 v11:30.8921 v12:32.8202 v13:58.7413] eff_attn_scale:[v0:12.9666 v1:13.2604 v2:10.8659 v3:10.9078 v4:9.5361 v5:11.1804 v6:9.5996 v7:9.4641 v8:8.8824 v9:11.2670 v10:10.2124 v11:9.9453 v12:9.4207 v13:16.9741] +step:200/20000 val_loss:2.7597 val_bpb:1.6344 train_time:11795ms step_avg:58.98ms +step:400/20000 train_loss:2.3575 train_time:23561ms step_avg:58.90ms +step:400 shared0_alpha:mean=0.454,std=0.050 shared1_alpha:mean=0.497,std=0.043 shared2_alpha:mean=0.515,std=0.039 shared3_alpha:mean=0.550,std=0.043 eff_mlp_scale:[v0:34.2854 v1:33.6046 v2:34.5982 v3:37.3141 v4:36.5859 v5:42.7999 v6:38.2653 v7:36.6623 v8:34.8208 v9:38.9546 v10:35.2360 v11:36.8252 v12:35.6232 v13:73.9149] eff_attn_scale:[v0:6.4260 v1:6.6149 v2:5.7415 v3:5.8219 v4:5.4229 v5:6.0857 v6:5.4070 v7:5.2637 v8:5.3181 v9:5.9387 v10:5.1841 v11:4.7586 v12:4.7941 v13:9.9713] +step:400/20000 val_loss:2.5650 val_bpb:1.5192 train_time:23593ms step_avg:58.98ms +step:600/20000 train_loss:2.5758 train_time:35380ms step_avg:58.97ms +step:600 shared0_alpha:mean=0.454,std=0.050 shared1_alpha:mean=0.509,std=0.043 shared2_alpha:mean=0.532,std=0.040 shared3_alpha:mean=0.574,std=0.045 eff_mlp_scale:[v0:40.3946 v1:37.9918 v2:38.8994 v3:42.4052 v4:40.5967 v5:49.2288 v6:44.6498 v7:42.9203 v8:39.4368 v9:44.2346 v10:39.0685 v11:39.8300 v12:37.1170 v13:88.3020] eff_attn_scale:[v0:3.0049 v1:3.5136 v2:3.1467 v3:3.4297 v4:3.4763 v5:3.3945 v6:3.2047 v7:3.1516 v8:3.5922 v9:3.2953 v10:2.8958 v11:2.6696 v12:2.9162 v13:6.4813] +step:600/20000 val_loss:2.4776 val_bpb:1.4674 train_time:35411ms step_avg:59.02ms +step:800/20000 train_loss:2.3367 train_time:47210ms step_avg:59.01ms +step:800 shared0_alpha:mean=0.453,std=0.050 shared1_alpha:mean=0.517,std=0.045 shared2_alpha:mean=0.544,std=0.042 shared3_alpha:mean=0.590,std=0.046 eff_mlp_scale:[v0:47.3770 v1:41.6239 v2:42.4979 v3:45.9607 v4:43.3238 v5:54.2598 v6:49.1711 v7:47.3749 v8:43.1539 v9:47.9419 v10:41.6198 v11:41.5414 v12:38.7366 v13:99.1780] eff_attn_scale:[v0:1.9183 v1:2.3710 v2:2.2032 v3:2.3913 v4:2.6813 v5:2.3556 v6:2.2797 v7:2.2455 v8:2.8579 v9:2.2633 v10:2.0349 v11:1.8226 v12:2.1354 v13:4.8487] +step:800/20000 val_loss:2.4226 val_bpb:1.4348 train_time:47241ms step_avg:59.05ms +step:1000/20000 train_loss:2.4170 train_time:59041ms step_avg:59.04ms +step:1000 shared0_alpha:mean=0.449,std=0.051 shared1_alpha:mean=0.523,std=0.047 shared2_alpha:mean=0.554,std=0.043 shared3_alpha:mean=0.603,std=0.048 eff_mlp_scale:[v0:54.0667 v1:45.0744 v2:45.5208 v3:49.3566 v4:46.1017 v5:57.9254 v6:52.7463 v7:50.8083 v8:46.4510 v9:51.4040 v10:43.8951 v11:43.5500 v12:40.8629 v13:108.3864] eff_attn_scale:[v0:1.4628 v1:1.8692 v2:1.7604 v3:1.9248 v4:2.2411 v5:1.8034 v6:1.7998 v7:1.8108 v8:2.3981 v9:1.7639 v10:1.5830 v11:1.4246 v12:1.7486 v13:3.9293] +step:1000/20000 val_loss:2.3828 val_bpb:1.4112 train_time:59072ms step_avg:59.07ms +step:1200/20000 train_loss:2.4374 train_time:70896ms step_avg:59.08ms +step:1200 shared0_alpha:mean=0.447,std=0.052 shared1_alpha:mean=0.527,std=0.047 shared2_alpha:mean=0.561,std=0.044 shared3_alpha:mean=0.612,std=0.049 eff_mlp_scale:[v0:59.9127 v1:48.1806 v2:48.0975 v3:52.2504 v4:48.5522 v5:61.3565 v6:56.2371 v7:54.1033 v8:49.6232 v9:53.8836 v10:46.0626 v11:45.0243 v12:42.8401 v13:117.2016] eff_attn_scale:[v0:1.2315 v1:1.5553 v2:1.4982 v3:1.6681 v4:1.9516 v5:1.5137 v6:1.5342 v7:1.5631 v8:2.0957 v9:1.4900 v10:1.3364 v11:1.2073 v12:1.4997 v13:3.4107] +step:1200/20000 val_loss:2.3552 val_bpb:1.3949 train_time:70929ms step_avg:59.11ms +step:1400/20000 train_loss:2.4865 train_time:82745ms step_avg:59.10ms +step:1400 shared0_alpha:mean=0.443,std=0.052 shared1_alpha:mean=0.531,std=0.048 shared2_alpha:mean=0.567,std=0.045 shared3_alpha:mean=0.620,std=0.050 eff_mlp_scale:[v0:65.7036 v1:51.0674 v2:51.0106 v3:55.1941 v4:50.6375 v5:64.3369 v6:58.9456 v7:57.0843 v8:52.4589 v9:56.6969 v10:47.7988 v11:46.6881 v12:44.6265 v13:124.5906] eff_attn_scale:[v0:1.0368 v1:1.3702 v2:1.3014 v3:1.4703 v4:1.7817 v5:1.3202 v6:1.3401 v7:1.3669 v8:1.8930 v9:1.2925 v10:1.1518 v11:1.0510 v12:1.3363 v13:3.0387] +step:1400/20000 val_loss:2.3352 val_bpb:1.3830 train_time:82776ms step_avg:59.13ms +step:1600/20000 train_loss:2.1583 train_time:94592ms step_avg:59.12ms +step:1600 shared0_alpha:mean=0.440,std=0.054 shared1_alpha:mean=0.534,std=0.049 shared2_alpha:mean=0.571,std=0.046 shared3_alpha:mean=0.626,std=0.050 eff_mlp_scale:[v0:71.6690 v1:53.9213 v2:53.3864 v3:57.5377 v4:52.9029 v5:66.5846 v6:61.4519 v7:59.4556 v8:54.7527 v9:58.4147 v10:49.5456 v11:47.9481 v12:46.0588 v13:131.3244] eff_attn_scale:[v0:0.8817 v1:1.2447 v2:1.1989 v3:1.3474 v4:1.6413 v5:1.1715 v6:1.2308 v7:1.2737 v8:1.7483 v9:1.1349 v10:1.0504 v11:0.9526 v12:1.2250 v13:2.7730] +step:1600/20000 val_loss:2.3234 val_bpb:1.3760 train_time:94624ms step_avg:59.14ms +step:1800/20000 train_loss:2.2560 train_time:106476ms step_avg:59.15ms +step:1800 shared0_alpha:mean=0.437,std=0.055 shared1_alpha:mean=0.536,std=0.049 shared2_alpha:mean=0.575,std=0.047 shared3_alpha:mean=0.631,std=0.051 eff_mlp_scale:[v0:77.0559 v1:56.5370 v2:55.8204 v3:59.9762 v4:55.0234 v5:69.4242 v6:64.0178 v7:61.9235 v8:57.2847 v9:60.6942 v10:51.1362 v11:49.4609 v12:48.0513 v13:137.3129] eff_attn_scale:[v0:0.8088 v1:1.1540 v2:1.0999 v3:1.2589 v4:1.5357 v5:1.0727 v6:1.1151 v7:1.1628 v8:1.6388 v9:1.0523 v10:0.9529 v11:0.8646 v12:1.1288 v13:2.5785] +step:1800/20000 val_loss:2.3070 val_bpb:1.3663 train_time:106508ms step_avg:59.17ms +step:2000/20000 train_loss:2.3081 train_time:118327ms step_avg:59.16ms +step:2000 shared0_alpha:mean=0.434,std=0.055 shared1_alpha:mean=0.538,std=0.050 shared2_alpha:mean=0.579,std=0.048 shared3_alpha:mean=0.636,std=0.052 eff_mlp_scale:[v0:82.7543 v1:59.0391 v2:58.7532 v3:62.3953 v4:57.0309 v5:71.6903 v6:66.2958 v7:63.9749 v8:59.7102 v9:62.8344 v10:52.7985 v11:50.5481 v12:49.3758 v13:142.4716] eff_attn_scale:[v0:0.7283 v1:1.0662 v2:1.0360 v3:1.1851 v4:1.4729 v5:0.9834 v6:1.0360 v7:1.1012 v8:1.5511 v9:0.9640 v10:0.8796 v11:0.8049 v12:1.0601 v13:2.4340] +step:2000/20000 val_loss:2.2943 val_bpb:1.3588 train_time:118359ms step_avg:59.18ms +step:2200/20000 train_loss:2.1339 train_time:130184ms step_avg:59.17ms +step:2200 shared0_alpha:mean=0.431,std=0.056 shared1_alpha:mean=0.540,std=0.051 shared2_alpha:mean=0.582,std=0.049 shared3_alpha:mean=0.640,std=0.053 eff_mlp_scale:[v0:87.8139 v1:61.0818 v2:60.7661 v3:64.6332 v4:59.1065 v5:73.8962 v6:68.4121 v7:66.2390 v8:61.8285 v9:64.4990 v10:54.3273 v11:52.1883 v12:51.3293 v13:148.2095] eff_attn_scale:[v0:0.6803 v1:1.0172 v2:0.9873 v3:1.1284 v4:1.4244 v5:0.9179 v6:0.9921 v7:1.0319 v8:1.5017 v9:0.8990 v10:0.8380 v11:0.7667 v12:1.0269 v13:2.3139] +step:2200/20000 val_loss:2.2862 val_bpb:1.3540 train_time:130215ms step_avg:59.19ms +step:2400/20000 train_loss:2.2593 train_time:142034ms step_avg:59.18ms +step:2400 shared0_alpha:mean=0.428,std=0.057 shared1_alpha:mean=0.541,std=0.052 shared2_alpha:mean=0.585,std=0.050 shared3_alpha:mean=0.644,std=0.054 eff_mlp_scale:[v0:93.0263 v1:63.7902 v2:62.8558 v3:66.7213 v4:61.5684 v5:76.3747 v6:70.6107 v7:68.3486 v8:64.3311 v9:66.8279 v10:55.9172 v11:53.7025 v12:52.8857 v13:154.3660] eff_attn_scale:[v0:0.6218 v1:0.9561 v2:0.9382 v3:1.0693 v4:1.3684 v5:0.8596 v6:0.9243 v7:0.9802 v8:1.4387 v9:0.8458 v10:0.7764 v11:0.7222 v12:0.9898 v13:2.1780] +step:2400/20000 val_loss:2.2751 val_bpb:1.3475 train_time:142066ms step_avg:59.19ms +step:2600/20000 train_loss:2.4727 train_time:153881ms step_avg:59.19ms +step:2600 shared0_alpha:mean=0.425,std=0.057 shared1_alpha:mean=0.543,std=0.052 shared2_alpha:mean=0.587,std=0.051 shared3_alpha:mean=0.647,std=0.054 eff_mlp_scale:[v0:97.4525 v1:65.8375 v2:64.8462 v3:69.2610 v4:63.6741 v5:78.1271 v6:72.6938 v7:70.4978 v8:66.4774 v9:68.4710 v10:57.4116 v11:54.8316 v12:54.4634 v13:158.4925] eff_attn_scale:[v0:0.5899 v1:0.9353 v2:0.9162 v3:1.0517 v4:1.3262 v5:0.8269 v6:0.8934 v7:0.9590 v8:1.3737 v9:0.8088 v10:0.7475 v11:0.6996 v12:0.9352 v13:2.1238] +step:2600/20000 val_loss:2.2874 val_bpb:1.3547 train_time:153913ms step_avg:59.20ms +step:2800/20000 train_loss:2.2975 train_time:165743ms step_avg:59.19ms +step:2800 shared0_alpha:mean=0.423,std=0.058 shared1_alpha:mean=0.545,std=0.053 shared2_alpha:mean=0.590,std=0.052 shared3_alpha:mean=0.649,std=0.055 eff_mlp_scale:[v0:102.8439 v1:68.1906 v2:67.2278 v3:71.7107 v4:65.6281 v5:79.7033 v6:74.3264 v7:72.5445 v8:68.4639 v9:69.9617 v10:58.8765 v11:56.2846 v12:56.3106 v13:163.4298] eff_attn_scale:[v0:0.5667 v1:0.9016 v2:0.8768 v3:1.0077 v4:1.2826 v5:0.7999 v6:0.8543 v7:0.9177 v8:1.3507 v9:0.7734 v10:0.7239 v11:0.6613 v12:0.9214 v13:2.0292] +step:2800/20000 val_loss:2.2621 val_bpb:1.3398 train_time:165774ms step_avg:59.21ms +step:3000/20000 train_loss:2.2850 train_time:177586ms step_avg:59.20ms +step:3000 shared0_alpha:mean=0.420,std=0.058 shared1_alpha:mean=0.545,std=0.054 shared2_alpha:mean=0.592,std=0.052 shared3_alpha:mean=0.652,std=0.056 eff_mlp_scale:[v0:107.1784 v1:70.3300 v2:69.3097 v3:73.8286 v4:67.6902 v5:81.9770 v6:76.0716 v7:74.6724 v8:70.1517 v9:71.6739 v10:60.0120 v11:57.7973 v12:57.8444 v13:168.6794] eff_attn_scale:[v0:0.5273 v1:0.8750 v2:0.8527 v3:0.9793 v4:1.2626 v5:0.7700 v6:0.8307 v7:0.8778 v8:1.3090 v9:0.7394 v10:0.6989 v11:0.6352 v12:0.8916 v13:1.9865] +step:3000/20000 val_loss:2.2550 val_bpb:1.3355 train_time:177618ms step_avg:59.21ms +step:3200/20000 train_loss:2.2509 train_time:189434ms step_avg:59.20ms +step:3200 shared0_alpha:mean=0.417,std=0.059 shared1_alpha:mean=0.546,std=0.054 shared2_alpha:mean=0.595,std=0.053 shared3_alpha:mean=0.654,std=0.056 eff_mlp_scale:[v0:111.8766 v1:72.4473 v2:71.3405 v3:75.8253 v4:69.6745 v5:84.2200 v6:78.1756 v7:76.6773 v8:71.7481 v9:72.9001 v10:61.5152 v11:58.7859 v12:59.3062 v13:172.0224] eff_attn_scale:[v0:0.5131 v1:0.8464 v2:0.8337 v3:0.9520 v4:1.2488 v5:0.7401 v6:0.8033 v7:0.8647 v8:1.2951 v9:0.7103 v10:0.6774 v11:0.6201 v12:0.8891 v13:1.9186] +step:3200/20000 val_loss:2.2506 val_bpb:1.3329 train_time:189467ms step_avg:59.21ms +step:3400/20000 train_loss:2.2156 train_time:201279ms step_avg:59.20ms +step:3400 shared0_alpha:mean=0.415,std=0.059 shared1_alpha:mean=0.547,std=0.054 shared2_alpha:mean=0.597,std=0.054 shared3_alpha:mean=0.657,std=0.057 eff_mlp_scale:[v0:116.7305 v1:74.6774 v2:73.4077 v3:77.6875 v4:71.9578 v5:86.1310 v6:79.8848 v7:78.5507 v8:74.0618 v9:74.6774 v10:63.0442 v11:60.4236 v12:61.0169 v13:176.9400] eff_attn_scale:[v0:0.4849 v1:0.8381 v2:0.8049 v3:0.9454 v4:1.2184 v5:0.7281 v6:0.7836 v7:0.8452 v8:1.2641 v9:0.6942 v10:0.6558 v11:0.6143 v12:0.8631 v13:1.8646] +step:3400/20000 val_loss:2.2473 val_bpb:1.3310 train_time:201311ms step_avg:59.21ms +step:3600/20000 train_loss:2.1850 train_time:213120ms step_avg:59.20ms +step:3600 shared0_alpha:mean=0.413,std=0.060 shared1_alpha:mean=0.548,std=0.055 shared2_alpha:mean=0.599,std=0.054 shared3_alpha:mean=0.659,std=0.057 eff_mlp_scale:[v0:121.4435 v1:77.4289 v2:75.4982 v3:80.2307 v4:73.4860 v5:88.0927 v6:82.0443 v7:80.2307 v8:76.0347 v9:76.5016 v10:64.1516 v11:61.4811 v12:62.4419 v13:180.8942] eff_attn_scale:[v0:0.4713 v1:0.8233 v2:0.7938 v3:0.9230 v4:1.1999 v5:0.7141 v6:0.7726 v7:0.8281 v8:1.2351 v9:0.6847 v10:0.6410 v11:0.5995 v12:0.8535 v13:1.8283] +step:3600/20000 val_loss:2.2399 val_bpb:1.3266 train_time:213152ms step_avg:59.21ms +step:3800/20000 train_loss:2.2832 train_time:224971ms step_avg:59.20ms +step:3800 shared0_alpha:mean=0.411,std=0.061 shared1_alpha:mean=0.549,std=0.055 shared2_alpha:mean=0.600,std=0.055 shared3_alpha:mean=0.661,std=0.057 eff_mlp_scale:[v0:126.1614 v1:79.1049 v2:77.7751 v3:82.4223 v4:75.6845 v5:89.8707 v6:83.9617 v7:81.9816 v8:77.8346 v9:77.7007 v10:65.8437 v11:63.0289 v12:64.0738 v13:184.9535] eff_attn_scale:[v0:0.4496 v1:0.8116 v2:0.7936 v3:0.8995 v4:1.2013 v5:0.7034 v6:0.7516 v7:0.8181 v8:1.2162 v9:0.6660 v10:0.6257 v11:0.5868 v12:0.8374 v13:1.7982] +step:3800/20000 val_loss:2.2354 val_bpb:1.3239 train_time:225002ms step_avg:59.21ms +step:4000/20000 train_loss:2.2207 train_time:236822ms step_avg:59.21ms +step:4000 shared0_alpha:mean=0.409,std=0.061 shared1_alpha:mean=0.550,std=0.055 shared2_alpha:mean=0.602,std=0.056 shared3_alpha:mean=0.663,std=0.057 eff_mlp_scale:[v0:130.4283 v1:81.3678 v2:79.3461 v3:84.2017 v4:77.4413 v5:91.7753 v6:85.5869 v7:83.7562 v8:79.6166 v9:79.4755 v10:66.8647 v11:64.1537 v12:65.6946 v13:189.9913] eff_attn_scale:[v0:0.4358 v1:0.7930 v2:0.7767 v3:0.8925 v4:1.1616 v5:0.6856 v6:0.7350 v7:0.7910 v8:1.2061 v9:0.6526 v10:0.6139 v11:0.5795 v12:0.8205 v13:1.7570] +step:4000/20000 val_loss:2.2301 val_bpb:1.3208 train_time:236854ms step_avg:59.21ms +step:4200/20000 train_loss:2.2356 train_time:248914ms step_avg:59.27ms +step:4200 shared0_alpha:mean=0.407,std=0.061 shared1_alpha:mean=0.551,std=0.056 shared2_alpha:mean=0.604,std=0.056 shared3_alpha:mean=0.665,std=0.058 eff_mlp_scale:[v0:134.3793 v1:83.5345 v2:81.4617 v3:86.8770 v4:79.6370 v5:93.5586 v6:87.3126 v7:85.5265 v8:81.3969 v9:80.6705 v10:67.9598 v11:65.7204 v12:67.3175 v13:193.9883] eff_attn_scale:[v0:0.4229 v1:0.7838 v2:0.7523 v3:0.8786 v4:1.1727 v5:0.6701 v6:0.7157 v7:0.7908 v8:1.1975 v9:0.6295 v10:0.5896 v11:0.5648 v12:0.8115 v13:1.7432] +step:4200/20000 val_loss:2.2262 val_bpb:1.3185 train_time:248945ms step_avg:59.27ms +step:4400/20000 train_loss:2.1765 train_time:260750ms step_avg:59.26ms +step:4400 shared0_alpha:mean=0.405,std=0.061 shared1_alpha:mean=0.552,std=0.056 shared2_alpha:mean=0.606,std=0.057 shared3_alpha:mean=0.667,std=0.058 eff_mlp_scale:[v0:140.1213 v1:85.2745 v2:83.1819 v3:88.7375 v4:81.4193 v5:95.3918 v6:89.0910 v7:87.3723 v8:83.1990 v9:82.3838 v10:69.5455 v11:66.8944 v12:68.9617 v13:197.1806] eff_attn_scale:[v0:0.4059 v1:0.7894 v2:0.7553 v3:0.8802 v4:1.1561 v5:0.6626 v6:0.7063 v7:0.7801 v8:1.1757 v9:0.6339 v10:0.5838 v11:0.5590 v12:0.8117 v13:1.7199] +step:4400/20000 val_loss:2.2270 val_bpb:1.3190 train_time:260781ms step_avg:59.27ms +step:4600/20000 train_loss:2.0368 train_time:272605ms step_avg:59.26ms +step:4600 shared0_alpha:mean=0.403,std=0.062 shared1_alpha:mean=0.552,std=0.056 shared2_alpha:mean=0.607,std=0.057 shared3_alpha:mean=0.668,std=0.059 eff_mlp_scale:[v0:145.3811 v1:87.5979 v2:85.2514 v3:90.8275 v4:83.6477 v5:96.8443 v6:90.7515 v7:88.9926 v8:85.4466 v9:83.7046 v10:70.5845 v11:67.8912 v12:70.6059 v13:202.0904] eff_attn_scale:[v0:0.3928 v1:0.7792 v2:0.7549 v3:0.8647 v4:1.1528 v5:0.6507 v6:0.7025 v7:0.7658 v8:1.1725 v9:0.6105 v10:0.5814 v11:0.5476 v12:0.8094 v13:1.6740] +step:4600/20000 val_loss:2.2220 val_bpb:1.3160 train_time:272637ms step_avg:59.27ms +step:4800/20000 train_loss:2.3243 train_time:284445ms step_avg:59.26ms +step:4800 shared0_alpha:mean=0.401,std=0.062 shared1_alpha:mean=0.553,std=0.056 shared2_alpha:mean=0.609,std=0.058 shared3_alpha:mean=0.670,std=0.059 eff_mlp_scale:[v0:149.6388 v1:89.8112 v2:86.9723 v3:93.2148 v4:85.5791 v5:98.6451 v6:92.5237 v7:90.8961 v8:86.9447 v9:85.3943 v10:72.1685 v11:69.5633 v12:72.3780 v13:205.0534] eff_attn_scale:[v0:0.3859 v1:0.7676 v2:0.7423 v3:0.8653 v4:1.1359 v5:0.6436 v6:0.6898 v7:0.7623 v8:1.1603 v9:0.6076 v10:0.5728 v11:0.5398 v12:0.7946 v13:1.6561] +step:4800/20000 val_loss:2.2187 val_bpb:1.3140 train_time:284477ms step_avg:59.27ms +step:5000/20000 train_loss:2.0935 train_time:296302ms step_avg:59.26ms +step:5000 shared0_alpha:mean=0.399,std=0.062 shared1_alpha:mean=0.553,std=0.057 shared2_alpha:mean=0.610,std=0.058 shared3_alpha:mean=0.672,std=0.059 eff_mlp_scale:[v0:153.8149 v1:91.5847 v2:89.0435 v3:95.4824 v4:87.7745 v5:100.4957 v6:93.7055 v7:92.2061 v8:88.6936 v9:86.6342 v10:73.1929 v11:70.6757 v12:73.9879 v13:208.6951] eff_attn_scale:[v0:0.3806 v1:0.7597 v2:0.7427 v3:0.8650 v4:1.1437 v5:0.6259 v6:0.6865 v7:0.7615 v8:1.1633 v9:0.5904 v10:0.5621 v11:0.5422 v12:0.8001 v13:1.6242] +step:5000/20000 val_loss:2.2133 val_bpb:1.3108 train_time:296334ms step_avg:59.27ms +step:5200/20000 train_loss:2.2323 train_time:308141ms step_avg:59.26ms +step:5200 shared0_alpha:mean=0.397,std=0.062 shared1_alpha:mean=0.553,std=0.058 shared2_alpha:mean=0.612,std=0.059 shared3_alpha:mean=0.674,std=0.060 eff_mlp_scale:[v0:159.4933 v1:93.7957 v2:90.9036 v3:97.4001 v4:89.7327 v5:101.7783 v6:95.6136 v7:94.5632 v8:91.1275 v9:87.8087 v10:74.4185 v11:71.8680 v12:75.7846 v13:212.6315] eff_attn_scale:[v0:0.3687 v1:0.7652 v2:0.7401 v3:0.8644 v4:1.1541 v5:0.6304 v6:0.6878 v7:0.7615 v8:1.1688 v9:0.5947 v10:0.5631 v11:0.5392 v12:0.8054 v13:1.6234] +step:5200/20000 val_loss:2.2144 val_bpb:1.3115 train_time:308173ms step_avg:59.26ms +step:5400/20000 train_loss:2.2443 train_time:320001ms step_avg:59.26ms +step:5400 shared0_alpha:mean=0.395,std=0.062 shared1_alpha:mean=0.554,std=0.058 shared2_alpha:mean=0.613,std=0.059 shared3_alpha:mean=0.675,std=0.060 eff_mlp_scale:[v0:163.9131 v1:95.6766 v2:93.0526 v3:99.6529 v4:91.3985 v5:103.7336 v6:97.3254 v7:95.8384 v8:92.8046 v9:89.6339 v10:75.4865 v11:73.4285 v12:76.8684 v13:215.8871] eff_attn_scale:[v0:0.3644 v1:0.7672 v2:0.7476 v3:0.8665 v4:1.1519 v5:0.6249 v6:0.6761 v7:0.7551 v8:1.1715 v9:0.5893 v10:0.5528 v11:0.5364 v12:0.7990 v13:1.6067] +step:5400/20000 val_loss:2.2088 val_bpb:1.3082 train_time:320032ms step_avg:59.27ms +step:5600/20000 train_loss:2.2457 train_time:331844ms step_avg:59.26ms +step:5600 shared0_alpha:mean=0.393,std=0.063 shared1_alpha:mean=0.555,std=0.058 shared2_alpha:mean=0.615,std=0.059 shared3_alpha:mean=0.677,std=0.061 eff_mlp_scale:[v0:168.9937 v1:97.9664 v2:94.8258 v3:101.4387 v4:93.6504 v5:105.5803 v6:99.1361 v7:98.0734 v8:94.5964 v9:90.8600 v10:76.6269 v11:74.5165 v12:78.5150 v13:219.7498] eff_attn_scale:[v0:0.3544 v1:0.7693 v2:0.7419 v3:0.8706 v4:1.1515 v5:0.6273 v6:0.6704 v7:0.7515 v8:1.1515 v9:0.5839 v10:0.5475 v11:0.5338 v12:0.7904 v13:1.6007] +step:5600/20000 val_loss:2.2100 val_bpb:1.3089 train_time:331876ms step_avg:59.26ms +step:5800/20000 train_loss:2.2100 train_time:343701ms step_avg:59.26ms +step:5800 shared0_alpha:mean=0.392,std=0.063 shared1_alpha:mean=0.556,std=0.058 shared2_alpha:mean=0.617,std=0.060 shared3_alpha:mean=0.678,std=0.061 eff_mlp_scale:[v0:173.9413 v1:100.1436 v2:97.1412 v3:103.6861 v4:95.6212 v5:106.7858 v6:101.0075 v7:99.3255 v8:96.5774 v9:91.9686 v10:78.2929 v11:75.5843 v12:79.8437 v13:222.7330] eff_attn_scale:[v0:0.3424 v1:0.7614 v2:0.7388 v3:0.8595 v4:1.1536 v5:0.6169 v6:0.6673 v7:0.7339 v8:1.1439 v9:0.5779 v10:0.5481 v11:0.5271 v12:0.7901 v13:1.5714] +step:5800/20000 val_loss:2.2085 val_bpb:1.3080 train_time:343733ms step_avg:59.26ms +step:6000/20000 train_loss:2.2787 train_time:355541ms step_avg:59.26ms +step:6000 shared0_alpha:mean=0.390,std=0.063 shared1_alpha:mean=0.556,std=0.057 shared2_alpha:mean=0.618,std=0.060 shared3_alpha:mean=0.679,std=0.061 eff_mlp_scale:[v0:179.7115 v1:102.5137 v2:98.8222 v3:106.1559 v4:97.2881 v5:109.2105 v6:102.7167 v7:101.2639 v8:98.2513 v9:93.2411 v10:79.3499 v11:76.8040 v12:81.3945 v13:226.4459] eff_attn_scale:[v0:0.3415 v1:0.7450 v2:0.7313 v3:0.8596 v4:1.1518 v5:0.6067 v6:0.6684 v7:0.7385 v8:1.1567 v9:0.5683 v10:0.5465 v11:0.5246 v12:0.7922 v13:1.5915] +step:6000/20000 val_loss:2.2042 val_bpb:1.3054 train_time:355572ms step_avg:59.26ms +step:6200/20000 train_loss:2.1520 train_time:367387ms step_avg:59.26ms +step:6200 shared0_alpha:mean=0.388,std=0.063 shared1_alpha:mean=0.557,std=0.058 shared2_alpha:mean=0.620,std=0.060 shared3_alpha:mean=0.681,std=0.062 eff_mlp_scale:[v0:184.1409 v1:104.4329 v2:101.0519 v3:108.1300 v4:99.9027 v5:110.6677 v6:103.9952 v7:102.6988 v8:100.3901 v9:95.0807 v10:80.4491 v11:78.0116 v12:83.3335 v13:228.4138] eff_attn_scale:[v0:0.3336 v1:0.7624 v2:0.7430 v3:0.8663 v4:1.1573 v5:0.6139 v6:0.6565 v7:0.7402 v8:1.1622 v9:0.5630 v10:0.5347 v11:0.5287 v12:0.7878 v13:1.5432] +step:6200/20000 val_loss:2.2034 val_bpb:1.3050 train_time:367419ms step_avg:59.26ms +step:6400/20000 train_loss:2.2230 train_time:379237ms step_avg:59.26ms +step:6400 shared0_alpha:mean=0.386,std=0.064 shared1_alpha:mean=0.558,std=0.058 shared2_alpha:mean=0.621,std=0.061 shared3_alpha:mean=0.682,std=0.062 eff_mlp_scale:[v0:189.8302 v1:106.8324 v2:102.8722 v3:110.1204 v4:101.9474 v5:112.5930 v6:105.8396 v7:104.6393 v8:102.4399 v9:96.3587 v10:81.6053 v11:79.2269 v12:85.2024 v13:232.3741] eff_attn_scale:[v0:0.3267 v1:0.7589 v2:0.7307 v3:0.8651 v4:1.1648 v5:0.5993 v6:0.6447 v7:0.7392 v8:1.1502 v9:0.5604 v10:0.5353 v11:0.5199 v12:0.7944 v13:1.5767] +step:6400/20000 val_loss:2.2000 val_bpb:1.3029 train_time:379269ms step_avg:59.26ms +step:6600/20000 train_loss:2.1894 train_time:391086ms step_avg:59.26ms +step:6600 shared0_alpha:mean=0.384,std=0.064 shared1_alpha:mean=0.558,std=0.059 shared2_alpha:mean=0.622,std=0.061 shared3_alpha:mean=0.684,std=0.063 eff_mlp_scale:[v0:195.4746 v1:108.5459 v2:104.6949 v3:112.5847 v4:103.8120 v5:114.3420 v6:107.1876 v7:106.5534 v8:104.3087 v9:97.4805 v10:82.7588 v11:80.9202 v12:86.9239 v13:236.5316] eff_attn_scale:[v0:0.3259 v1:0.7674 v2:0.7435 v3:0.8732 v4:1.1760 v5:0.6077 v6:0.6530 v7:0.7420 v8:1.1562 v9:0.5648 v10:0.5311 v11:0.5247 v12:0.8054 v13:1.5359] +step:6600/20000 val_loss:2.1963 val_bpb:1.3007 train_time:391118ms step_avg:59.26ms +step:6800/20000 train_loss:2.2551 train_time:402928ms step_avg:59.25ms +step:6800 shared0_alpha:mean=0.382,std=0.063 shared1_alpha:mean=0.558,std=0.059 shared2_alpha:mean=0.624,std=0.062 shared3_alpha:mean=0.685,std=0.063 eff_mlp_scale:[v0:199.3570 v1:110.9249 v2:106.4690 v3:114.4221 v4:106.2555 v5:115.7016 v6:108.4778 v7:108.3466 v8:105.7543 v9:98.7179 v10:83.8694 v11:82.0194 v12:88.2121 v13:238.4881] eff_attn_scale:[v0:0.3224 v1:0.7681 v2:0.7286 v3:0.8746 v4:1.1817 v5:0.6052 v6:0.6468 v7:0.7323 v8:1.1669 v9:0.5470 v10:0.5182 v11:0.5187 v12:0.7977 v13:1.5251] +step:6800/20000 val_loss:2.1956 val_bpb:1.3004 train_time:402961ms step_avg:59.26ms +step:7000/20000 train_loss:2.2832 train_time:414774ms step_avg:59.25ms +step:7000 shared0_alpha:mean=0.380,std=0.064 shared1_alpha:mean=0.559,std=0.059 shared2_alpha:mean=0.625,std=0.062 shared3_alpha:mean=0.686,std=0.063 eff_mlp_scale:[v0:205.1638 v1:112.9238 v2:108.7964 v3:116.8961 v4:108.2008 v5:117.2053 v6:110.3145 v7:109.7496 v8:108.2008 v9:100.0794 v10:85.0130 v11:83.2055 v12:89.9988 v13:242.5826] eff_attn_scale:[v0:0.3126 v1:0.7662 v2:0.7415 v3:0.8720 v4:1.1966 v5:0.6037 v6:0.6586 v7:0.7369 v8:1.1671 v9:0.5495 v10:0.5245 v11:0.5240 v12:0.8076 v13:1.5226] +step:7000/20000 val_loss:2.1936 val_bpb:1.2992 train_time:414805ms step_avg:59.26ms +step:7200/20000 train_loss:2.2635 train_time:426617ms step_avg:59.25ms +step:7200 shared0_alpha:mean=0.379,std=0.063 shared1_alpha:mean=0.560,std=0.059 shared2_alpha:mean=0.627,std=0.062 shared3_alpha:mean=0.687,std=0.063 eff_mlp_scale:[v0:209.4189 v1:114.7228 v2:110.1191 v3:118.8122 v4:110.1611 v5:119.0317 v6:112.1584 v7:111.6115 v8:109.6511 v9:101.2577 v10:86.1580 v11:84.3516 v12:91.8009 v13:246.3847] eff_attn_scale:[v0:0.3114 v1:0.7758 v2:0.7512 v3:0.8862 v4:1.1983 v5:0.6082 v6:0.6597 v7:0.7536 v8:1.1687 v9:0.5536 v10:0.5326 v11:0.5342 v12:0.8087 v13:1.5231] +step:7200/20000 val_loss:2.1947 val_bpb:1.2998 train_time:426648ms step_avg:59.26ms +step:7400/20000 train_loss:2.1789 train_time:438467ms step_avg:59.25ms +step:7400 shared0_alpha:mean=0.376,std=0.063 shared1_alpha:mean=0.560,std=0.060 shared2_alpha:mean=0.628,std=0.063 shared3_alpha:mean=0.688,std=0.064 eff_mlp_scale:[v0:214.8811 v1:116.5179 v2:112.4040 v3:120.8861 v4:112.1007 v5:120.8534 v6:113.4305 v7:113.6226 v8:111.5865 v9:102.4273 v10:87.2542 v11:85.6061 v12:93.0745 v13:250.0835] eff_attn_scale:[v0:0.3001 v1:0.7651 v2:0.7407 v3:0.8777 v4:1.2095 v5:0.6036 v6:0.6472 v7:0.7436 v8:1.1749 v9:0.5459 v10:0.5146 v11:0.5140 v12:0.7997 v13:1.5224] +step:7400/20000 val_loss:2.1898 val_bpb:1.2969 train_time:438499ms step_avg:59.26ms +step:7600/20000 train_loss:2.0665 train_time:450304ms step_avg:59.25ms +step:7600 shared0_alpha:mean=0.375,std=0.064 shared1_alpha:mean=0.561,std=0.060 shared2_alpha:mean=0.630,std=0.063 shared3_alpha:mean=0.690,std=0.064 eff_mlp_scale:[v0:219.5673 v1:119.0424 v2:114.3795 v3:123.2943 v4:114.5205 v5:122.3188 v6:114.8971 v7:114.9354 v8:113.4841 v9:103.7525 v10:88.5018 v11:86.7240 v12:94.8292 v13:252.0920] eff_attn_scale:[v0:0.3055 v1:0.7820 v2:0.7397 v3:0.8893 v4:1.2225 v5:0.5991 v6:0.6463 v7:0.7418 v8:1.1827 v9:0.5446 v10:0.5178 v11:0.5184 v12:0.8100 v13:1.5209] +step:7600/20000 val_loss:2.1886 val_bpb:1.2962 train_time:450336ms step_avg:59.25ms +step:7800/20000 train_loss:2.2108 train_time:462150ms step_avg:59.25ms +step:7800 shared0_alpha:mean=0.373,std=0.063 shared1_alpha:mean=0.562,std=0.060 shared2_alpha:mean=0.631,std=0.063 shared3_alpha:mean=0.691,std=0.065 eff_mlp_scale:[v0:225.3324 v1:120.9175 v2:116.2944 v3:125.4547 v4:116.5135 v5:124.2153 v6:116.8159 v7:116.4937 v8:115.4685 v9:104.9784 v10:89.1764 v11:88.0292 v12:96.1367 v13:256.1435] eff_attn_scale:[v0:0.3010 v1:0.7857 v2:0.7362 v3:0.9016 v4:1.2284 v5:0.6009 v6:0.6388 v7:0.7418 v8:1.1887 v9:0.5392 v10:0.5181 v11:0.5287 v12:0.8107 v13:1.5223] +step:7800/20000 val_loss:2.1854 val_bpb:1.2943 train_time:462181ms step_avg:59.25ms +step:8000/20000 train_loss:2.1762 train_time:473998ms step_avg:59.25ms +step:8000 shared0_alpha:mean=0.371,std=0.063 shared1_alpha:mean=0.562,std=0.060 shared2_alpha:mean=0.632,std=0.063 shared3_alpha:mean=0.692,std=0.065 eff_mlp_scale:[v0:230.1253 v1:122.9614 v2:118.8503 v3:126.9168 v4:119.1279 v5:126.2847 v6:118.3244 v7:118.4203 v8:117.0194 v9:106.3450 v10:90.9783 v11:89.2135 v12:98.0433 v13:258.0185] eff_attn_scale:[v0:0.2941 v1:0.7905 v2:0.7514 v3:0.9012 v4:1.2311 v5:0.6075 v6:0.6457 v7:0.7407 v8:1.1814 v9:0.5490 v10:0.5126 v11:0.5246 v12:0.8191 v13:1.5115] +step:8000/20000 val_loss:2.1836 val_bpb:1.2933 train_time:474029ms step_avg:59.25ms +step:8200/20000 train_loss:2.2414 train_time:485841ms step_avg:59.25ms +step:8200 shared0_alpha:mean=0.370,std=0.063 shared1_alpha:mean=0.563,std=0.060 shared2_alpha:mean=0.634,std=0.064 shared3_alpha:mean=0.693,std=0.066 eff_mlp_scale:[v0:236.0981 v1:125.4997 v2:120.2003 v3:130.1774 v4:121.1870 v5:127.1730 v6:119.6708 v7:119.9989 v8:119.0609 v9:107.6508 v10:91.6064 v11:90.5349 v12:99.9261 v13:262.0005] eff_attn_scale:[v0:0.2834 v1:0.8007 v2:0.7520 v3:0.9107 v4:1.2411 v5:0.6064 v6:0.6423 v7:0.7493 v8:1.2061 v9:0.5442 v10:0.5209 v11:0.5340 v12:0.8257 v13:1.5124] +step:8200/20000 val_loss:2.1824 val_bpb:1.2925 train_time:485873ms step_avg:59.25ms +step:8400/20000 train_loss:2.1960 train_time:497938ms step_avg:59.28ms +step:8400 shared0_alpha:mean=0.368,std=0.064 shared1_alpha:mean=0.564,std=0.060 shared2_alpha:mean=0.635,std=0.064 shared3_alpha:mean=0.695,std=0.066 eff_mlp_scale:[v0:240.2487 v1:127.4610 v2:122.7811 v3:132.1472 v4:123.2464 v5:129.1455 v6:121.7134 v7:121.8991 v8:121.1030 v9:108.9314 v10:92.8866 v11:91.6940 v12:101.2764 v13:263.9616] eff_attn_scale:[v0:0.2900 v1:0.8182 v2:0.7521 v3:0.9300 v4:1.2661 v5:0.6078 v6:0.6391 v7:0.7523 v8:1.2058 v9:0.5455 v10:0.5105 v11:0.5332 v12:0.8290 v13:1.5126] +step:8400/20000 val_loss:2.1815 val_bpb:1.2920 train_time:497969ms step_avg:59.28ms +step:8600/20000 train_loss:2.1955 train_time:509769ms step_avg:59.28ms +step:8600 shared0_alpha:mean=0.366,std=0.064 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.636,std=0.064 shared3_alpha:mean=0.696,std=0.066 eff_mlp_scale:[v0:246.5324 v1:129.4005 v2:124.7777 v3:134.7354 v4:125.2409 v5:130.5306 v6:123.1642 v7:123.8696 v8:123.0816 v9:110.1882 v10:94.1211 v11:92.9022 v12:103.1078 v13:267.7187] eff_attn_scale:[v0:0.2869 v1:0.8100 v2:0.7461 v3:0.9212 v4:1.2635 v5:0.6153 v6:0.6445 v7:0.7526 v8:1.2086 v9:0.5413 v10:0.5156 v11:0.5346 v12:0.8290 v13:1.5093] +step:8600/20000 val_loss:2.1791 val_bpb:1.2906 train_time:509804ms step_avg:59.28ms +step:8800/20000 train_loss:2.1664 train_time:521614ms step_avg:59.27ms +step:8800 shared0_alpha:mean=0.364,std=0.064 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.638,std=0.064 shared3_alpha:mean=0.697,std=0.066 eff_mlp_scale:[v0:251.0507 v1:131.9416 v2:126.6586 v3:136.1598 v4:127.8389 v5:132.5104 v6:124.4935 v7:125.2233 v8:124.5749 v9:111.4679 v10:95.2646 v11:94.0542 v12:104.4471 v13:271.2946] eff_attn_scale:[v0:0.2766 v1:0.8323 v2:0.7588 v3:0.9297 v4:1.2828 v5:0.6154 v6:0.6414 v7:0.7637 v8:1.2121 v9:0.5522 v10:0.5085 v11:0.5354 v12:0.8384 v13:1.5304] +step:8800/20000 val_loss:2.1778 val_bpb:1.2898 train_time:521645ms step_avg:59.28ms +step:9000/20000 train_loss:2.0805 train_time:533470ms step_avg:59.27ms +step:9000 shared0_alpha:mean=0.363,std=0.064 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.698,std=0.066 eff_mlp_scale:[v0:257.4152 v1:133.7777 v2:128.6963 v3:138.8611 v4:130.0868 v5:134.3494 v6:126.5150 v7:127.2894 v8:126.7935 v9:112.6248 v10:96.5222 v11:95.3293 v12:106.4845 v13:273.4311] eff_attn_scale:[v0:0.2766 v1:0.8403 v2:0.7611 v3:0.9508 v4:1.3217 v5:0.6273 v6:0.6512 v7:0.7699 v8:1.2397 v9:0.5484 v10:0.5139 v11:0.5427 v12:0.8555 v13:1.5236] +step:9000/20000 val_loss:2.1757 val_bpb:1.2886 train_time:533502ms step_avg:59.28ms +step:9200/20000 train_loss:2.1355 train_time:545326ms step_avg:59.27ms +step:9200 shared0_alpha:mean=0.361,std=0.063 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.698,std=0.066 eff_mlp_scale:[v0:262.0444 v1:135.2486 v2:130.5472 v3:140.4741 v4:132.1480 v5:135.8241 v6:127.8047 v7:128.8142 v8:128.2776 v9:113.9541 v10:97.0877 v11:96.6107 v12:107.8195 v13:277.5444] eff_attn_scale:[v0:0.2734 v1:0.8502 v2:0.7689 v3:0.9572 v4:1.3303 v5:0.6336 v6:0.6512 v7:0.7759 v8:1.2633 v9:0.5614 v10:0.5139 v11:0.5482 v12:0.8663 v13:1.5439] +step:9200/20000 val_loss:2.1655 val_bpb:1.2825 train_time:545357ms step_avg:59.28ms +step:9400/20000 train_loss:2.1757 train_time:557172ms step_avg:59.27ms +step:9400 shared0_alpha:mean=0.360,std=0.063 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.640,std=0.065 shared3_alpha:mean=0.699,std=0.066 eff_mlp_scale:[v0:263.2965 v1:136.5867 v2:131.9275 v3:142.5522 v4:134.5448 v5:137.1655 v6:129.1675 v7:130.2536 v8:130.0786 v9:114.5939 v10:97.7036 v11:97.2709 v12:109.4223 v13:279.7477] eff_attn_scale:[v0:0.2731 v1:0.8550 v2:0.7636 v3:0.9639 v4:1.3521 v5:0.6342 v6:0.6539 v7:0.7821 v8:1.2793 v9:0.5580 v10:0.5090 v11:0.5538 v12:0.8841 v13:1.5618] +step:9400/20000 val_loss:2.1559 val_bpb:1.2768 train_time:557203ms step_avg:59.28ms +step:9600/20000 train_loss:2.1789 train_time:569015ms step_avg:59.27ms +step:9600 shared0_alpha:mean=0.359,std=0.063 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.640,std=0.065 shared3_alpha:mean=0.699,std=0.067 eff_mlp_scale:[v0:264.5626 v1:137.2842 v2:132.6420 v3:143.3359 v4:135.3772 v5:137.8659 v6:129.8671 v7:130.9697 v8:130.8833 v9:115.1791 v10:98.7878 v11:97.8057 v12:110.0993 v13:281.6810] eff_attn_scale:[v0:0.2760 v1:0.8545 v2:0.7619 v3:0.9659 v4:1.3683 v5:0.6338 v6:0.6598 v7:0.7880 v8:1.2795 v9:0.5536 v10:0.5106 v11:0.5507 v12:0.8930 v13:1.5710] +step:9600/20000 val_loss:2.1464 val_bpb:1.2712 train_time:569047ms step_avg:59.28ms +step:9800/20000 train_loss:2.0988 train_time:580854ms step_avg:59.27ms +step:9800 shared0_alpha:mean=0.359,std=0.062 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.699,std=0.067 eff_mlp_scale:[v0:265.2037 v1:138.5365 v2:133.1247 v3:143.8686 v4:136.1328 v5:138.5365 v6:130.8967 v7:131.4564 v8:131.6138 v9:115.7394 v10:99.1473 v11:98.1692 v12:110.7138 v13:285.5064] eff_attn_scale:[v0:0.2792 v1:0.8503 v2:0.7623 v3:0.9771 v4:1.3634 v5:0.6438 v6:0.6556 v7:0.7894 v8:1.2953 v9:0.5588 v10:0.5095 v11:0.5547 v12:0.8915 v13:1.5897] +step:9800/20000 val_loss:2.1375 val_bpb:1.2660 train_time:580886ms step_avg:59.27ms +step:10000/20000 train_loss:2.1327 train_time:592699ms step_avg:59.27ms +step:10000 shared0_alpha:mean=0.359,std=0.062 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.698,std=0.067 eff_mlp_scale:[v0:265.1537 v1:138.7435 v2:133.3693 v3:144.2558 v4:136.6628 v5:138.7435 v6:131.1372 v7:131.8102 v8:132.1263 v9:115.9123 v10:99.3294 v11:98.4333 v12:111.1449 v13:286.6695] eff_attn_scale:[v0:0.2733 v1:0.8454 v2:0.7637 v3:0.9686 v4:1.3718 v5:0.6391 v6:0.6563 v7:0.7894 v8:1.3032 v9:0.5582 v10:0.5131 v11:0.5504 v12:0.8970 v13:1.5789] +step:10000/20000 val_loss:2.1286 val_bpb:1.2606 train_time:592731ms step_avg:59.27ms +step:10122/20000 val_loss:2.1246 val_bpb:1.2583 train_time:600039ms step_avg:59.28ms +stopping_early: wallclock_cap train_time:600039ms step:10122/20000 +peak memory allocated: 13736 MiB reserved: 14068 MiB +Serialized model: 45178938 bytes +Code size: 57024 bytes +Total submission size: 45235962 bytes +Serialized model int8+zlib: 10717286 bytes (payload:11638976 raw_torch:11670071 payload_ratio:3.88x) +Total submission size int8+zlib: 10774310 bytes +final_int8_zlib_roundtrip val_loss:2.1374 val_bpb:1.2659 eval_time:1873ms +final_int8_zlib_roundtrip_exact val_loss:2.13735865 val_bpb:1.26586418 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale.sh b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale.sh new file mode 100755 index 0000000000..3f6e2a3587 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale.sh @@ -0,0 +1,64 @@ +#!/bin/bash +set -uo pipefail + +SCRIPT="train_gpt.py" +NGPU=${NGPU:-8} +COMMON="SEED=1337 MAX_WALLCLOCK_SECONDS=600 VAL_LOSS_EVERY=200 TRAIN_LOG_EVERY=200" +DATA="DATA_PATH=${DATA_PATH:-./data/datasets/fineweb10B_sp1024} TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model VOCAB_SIZE=1024" + +FAILS=0 +SUMMARY="" + +run_experiment() { + local name="$1"; shift + echo "" + echo "=== $name ===" + if "$@"; then + SUMMARY="${SUMMARY} PASS $name"$'\n' + else + SUMMARY="${SUMMARY} FAIL $name (exit $?)"$'\n' + FAILS=$((FAILS + 1)) + fi +} + +# --- G: 1+3×2+1, peri+birk, NO timestep — isolate gamma risk --- + +run_experiment "Run G: 1+3x2+1 peri+birk (no timestep)" \ + env $COMMON $DATA RUN_ID=s2_G NUM_LAYERS=8 NUM_PRELUDE=1 NUM_SHARED=3 NUM_LOOPS=2 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=0 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +# --- H: 1+4×2+1, peri+birk, NO timestep — depth increase (10 eff layers) --- + +run_experiment "Run H: 1+4x2+1 peri+birk (no timestep, 10 eff layers)" \ + env $COMMON $DATA RUN_ID=s2_H NUM_LAYERS=10 NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=2 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=0 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +# --- I: 1+4×2+1, peri+birk+ts capped — does capped timestep help at depth? --- + +run_experiment "Run I: 1+4x2+1 peri+birk+ts capped (GAMMA_MAX=4.0)" \ + env $COMMON $DATA RUN_ID=s2_I NUM_LAYERS=10 NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=2 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +# --- J: 1+4×3+1, peri+birk, NO timestep — ambitious 14 eff layers, 3 loops --- + +run_experiment "Run J: 1+4x3+1 peri+birk (no timestep, 14 eff layers, 3 loops)" \ + env $COMMON $DATA RUN_ID=s2_J NUM_LAYERS=14 NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=3 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=0 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +# --- K: 1+4×3+1, peri+birk+ts capped — 3 loops + timestep scaling (best combo) --- + +run_experiment "Run K: 1+4x3+1 peri+birk+ts capped (14 eff layers, 3 loops)" \ + env $COMMON $DATA RUN_ID=s2_K NUM_LAYERS=14 NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=3 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +echo "" +echo "===============================" +echo " FULL-SCALE SUMMARY" +echo "===============================" +echo "$SUMMARY" +echo "$FAILS run(s) failed." From ab1db1bbc22a084877c76ac032506e3060d331a0 Mon Sep 17 00:00:00 2001 From: Alexandr Azizyan Date: Thu, 26 Mar 2026 17:58:37 +0400 Subject: [PATCH 04/10] docs: add research notes with theory and citations --- .../research_notes.md | 110 ++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md new file mode 100644 index 0000000000..f3cbc09101 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md @@ -0,0 +1,110 @@ +# Research Notes: Theoretical Foundations for Depth Recurrence Stabilization + +This document provides citations and theoretical grounding for the three techniques developed to stabilize depth recurrence in parameter-shared transformers within the 16MB parameter-golf competition. + +## 1. The Three Failure Modes of Depth Recurrence + +Depth recurrence (weight-shared looping) has been attempted 15+ times in this competition with no SOTA result. Three failure modes explain why. + +### 1a. Quantization Error Amplification + +When the same weight matrix $W$ is applied $k$ times in a forward pass, post-training quantization error $\epsilon$ from int8 rounding compounds across iterations. PR #363 measured ~900× amplification over 3 cycles. PR #579 confirmed empirically: "2 loops survive, 3+ catastrophic (+4.3 BPB)." PR #623 showed AWQ (activation-aware quantization) closes 63% of the gap but cannot eliminate compounding. + +GPTQ quantizes layer-by-layer, compensating downstream weights for each layer's rounding error (Frantar et al., 2023, §3). With shared weights, this compensation is impossible. The same quantized matrix must serve all iterations. Errors from early iterations propagate uncompensated through later ones. + +> Frantar, E., Ashkboos, S., Hoefler, T. & Alistarh, D. (2023). "GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers." ICLR 2023. [arXiv:2210.17323](https://arxiv.org/abs/2210.17323) + +### 1b. Per-Iteration Identity Collapse + +Shared weights produce identical transformations for identical inputs. Without per-iteration conditioning, all loop iterations compute the same function. This collapses depth recurrence to a single effective pass. PR #319 showed learned gating parameters collapse to identity. PR #484 found attention layers need per-iteration identity more than MLP. + +Dehghani et al. (2019, §2.1) addressed this with sinusoidal timestep embeddings added at each recurrence step. Xu & Sato (2025) formalized the limitation: without timestep encoding, looped transformers have strict approximation rate bounds (Lemma 4.1). + +> Dehghani, M., Gouws, S., Vinyals, O., Uszkoreit, J. & Kaiser, L. (2019). "Universal Transformers." ICLR 2019. [arXiv:1807.03819](https://arxiv.org/abs/1807.03819) + +### 1c. Residual Magnitude Erasure + +Standard pre-norm applies RMSNorm before every sub-layer, projecting inputs to unit magnitude. With shared weights, all loop iterations receive normalized (magnitude-erased) inputs. The network cannot distinguish iteration 1 from iteration 3. This makes identity collapse worse — pre-norm removes the magnitude signal that shared weights would need to behave differently per iteration. + +Xiong et al. (2020, Theorem 1) showed pre-norm yields well-behaved gradients $O(d\sqrt{\ln d / L})$ but did not analyze the cost of erasing magnitude for weight-shared architectures. Run C' confirmed: with Birkhoff mixing but standard pre-norm, all mixing alphas collapsed to ~0.48 (uniform) and MLP scale dropped to 0.2–0.3. + +> Xiong, R. et al. (2020). "On Layer Normalization in the Transformer Architecture." ICML 2020. [arXiv:2002.04745](https://arxiv.org/abs/2002.04745) + +## 2. Technique 1: Birkhoff-Constrained Residual Mixing + +**Problem.** Unconstrained residual mixing (learned 2-vector weighting of residual and skip streams) has unbounded spectral norm. Xie et al. (2025, §3.1) demonstrated this at scale: unconstrained Hyper-Connections (Zhu et al., 2025, §2) produce signal gain exceeding 3000× in a 27B-parameter model. With looping, this gain compounds exponentially across iterations. + +**Solution.** Constrain mixing to the Birkhoff polytope $B_n$ — the set of $n \times n$ doubly stochastic matrices. By the Birkhoff-von Neumann theorem (Birkhoff, 1946), every such matrix is a convex combination of permutation matrices. For $n=2$ streams (residual $x$ and skip $x_0$), $B_2$ has exactly 2 vertices ($I$ and swap). Any doubly stochastic mixing is $\alpha \cdot x + (1-\alpha) \cdot x_0$ with $\alpha \in [0, 1]$. The implementation parameterizes $\alpha = \sigma(\text{logit})$ — an exact parameterization of $B_2$ with spectral norm $\leq 1$ by construction. + +This follows the HC → mHC → mHC-lite simplification chain. mHC (Xie et al., 2025, §4.1) uses iterative Sinkhorn-Knopp projection. mHC-lite (Yang & Gao, 2026, Theorem 3.1, §3.2) bypasses Sinkhorn via explicit convex combinations of permutations. The $n=2$ case is the minimal instance where a single sigmoid suffices. + +**Result.** Q-gap reduced from +0.0024 (unconstrained, Run B') to +0.0019 (Birkhoff + peri-norm, Run C) at screening. Exponential 3-loop blowup eliminated (Run F: Q-gap +0.0019 at 3 loops). + +> Zhu, D. et al. (2025). "Hyper-Connections." ICLR 2025. [arXiv:2409.19606](https://arxiv.org/abs/2409.19606) +> Xie, Z. et al. (2025). "mHC: Manifold-Constrained Hyper-Connections." [arXiv:2512.24880](https://arxiv.org/abs/2512.24880) +> Yang, Y. & Gao, J. (2026). "mHC-lite: You Don't Need 20 Sinkhorn-Knopp Iterations." [arXiv:2601.05732](https://arxiv.org/abs/2601.05732) +> Birkhoff, G. (1946). "Three observations on linear algebra." Univ. Nac. Tucumán Rev. A, 5:147–151. + +## 3. Technique 2: Output-LN (Peri-LN Variant for Recurrence) + +**Problem.** Pre-norm erases input magnitude before every sub-layer, making all loop iterations indistinguishable to shared weights (§1c). + +**Failed attempt.** Removing MLP input norm entirely — as in MoEUT (Csordas et al., 2024, §2.4), which places normalization only before sigmoid/softmax gates — caused NaN at step 2 with leaky_relu². MoEUT uses ReLU on the main data path, which has bounded gradient. The quadratic activation $\text{leaky\_relu}(x)^2$ has no implicit magnitude limiter. + +**Fix.** Move normalization from MLP input to MLP output: $x + \text{Norm}(\text{MLP}(x))$. The MLP receives raw (unnormalized) activations, so weight matrices see different magnitude distributions across loop iterations. This produces different outputs per iteration. RMSNorm on the output bounds the contribution to the residual stream. + +**Relation to Peri-LN.** Kim et al. (2025) define full Peri-LN as dual normalization: $x + \text{Norm}(\text{Module}(\text{Norm}(x)))$. The implementation omits the input norm, which the paper terms "Output-LN" (§3.2). Output-LN was chosen over full Peri-LN because the input norm would erase magnitude — exactly the signal shared weights need to differentiate loop iterations. Proposition 3.1 in Kim et al. analyzes the full Peri-LN scheme (dual norm). The gradient bounds it establishes do not directly cover Output-LN alone, though the output norm's damping factor $\|a\|$ is present in both variants. + +**Result.** Run C (peri+birkhoff) vs Run C' (birkhoff only): −0.007 BPB. Alpha learned meaningful gradient (0.37→0.70 across layers) vs collapsed (0.45→0.50 uniform). Peri-norm is the main factor. + +> Kim, J. et al. (2025). "Peri-LN: Revisiting Normalization Layer in the Transformer Architecture." ICML 2025. [arXiv:2502.02732](https://arxiv.org/abs/2502.02732) +> Csordas, R., Irie, K., Schmidhuber, J., Potts, C. & Manning, C. (2024). "MoEUT: Mixture-of-Experts Universal Transformers." NeurIPS 2024. [arXiv:2405.16039](https://arxiv.org/abs/2405.16039) + +## 4. Technique 3: Capped Timestep Scaling + +**Problem.** Without per-iteration conditioning, looped transformers have strict approximation limitations. Xu & Sato (2025, Lemma 4.1) prove that weight-tied feed-forward networks cannot drive approximation error to zero for varying target sequences. With timestep encoding, the expressivity gap closes completely (Theorem 4.2). + +**Solution.** Learned per-iteration scale vectors $\gamma_{\text{attn}}^{(t)}, \gamma_{\text{mlp}}^{(t)}$ for attention and MLP residual contributions, clamped to $[-M, +M]$. This is a simplified FiLM conditioning (Perez et al., 2018) — scale-only, no shift — applied per loop iteration. Parameter cost: $2 \times \text{eff\_layers} \times 512 \approx 8\text{KB}$. + +**Choice of cap $M = 4.0$.** This is an empirical choice, not theoretically derived. Three reasons: (1) Gammas are multiplicative modifiers on residual contributions. Uncapped, they could reintroduce the spectral norm amplification that Birkhoff mixing prevents. (2) $M = 4$ is large enough for meaningful per-iteration differentiation — one iteration's MLP contribution can be 4× another's — while small enough to keep downstream activations at similar scales across iterations. (3) Values in $[-4, 4]$ have ~0.001 precision in float16, preserving fine-grained specialization through the quantization passthrough. No ablation over cap values (e.g., 2.0 vs 4.0 vs 8.0 vs uncapped) was performed. This is an open question for future work. Screening Run D used uncapped timestep scaling at 2000 steps without issues, but full-scale runs only tested $M = 4$. + +**Surprising finding.** Timestep scaling has near-zero effect on pre-quantization BPB (Run H vs I: 1.2578 vs 1.2580) but reduces quantization gap by 26–30% (H vs I: +0.0126 → +0.0088; J vs K: +0.0103 → +0.0076). The mechanism: capped gammas are stored as float16 passthrough parameters that bypass int8 quantization entirely. They provide per-iteration specialization that survives the quantization pipeline. In short, timestep scaling helps quantization, not training. + +**Result.** Run K (best): post-quant 1.2659 BPB, Q-gap +0.0076 — vs prior results showing catastrophic failure at 3+ loops (PR #579, +4.3 BPB). + +> Xu, K. & Sato, I. (2025). "On Expressive Power of Looped Transformers: Theoretical Analysis and Enhancement via Timestep Encoding." ICML 2025. [arXiv:2410.01405](https://arxiv.org/abs/2410.01405) +> Perez, E., Strub, F., de Vries, H., Dumoulin, V. & Courville, A. (2018). "FiLM: Visual Reasoning with a General Conditioning Layer." AAAI 2018. [arXiv:1709.07871](https://arxiv.org/abs/1709.07871) + +## 5. Supporting Technique: Prelude-Recurrent-Coda Architecture + +First and last transformer layers perform fundamentally different functions — input encoding and output prediction — compared to middle layers that do iterative refinement. Forcing boundary layers into shared weights compromises both functions. Geiping et al. (2025) demonstrated this at scale with Huginn 3.5B: 2 prelude + 4 shared (×32 loops) + 2 coda layers, achieving 132 effective depth from 3.5B parameters. In this competition, PR #575 independently explored prefix + 2 tied (×3) + suffix. + +**Result.** Run E (1+3×2+1, all fixes) vs Run D (4×2, all fixes): −0.016 BPB — the largest single architectural improvement in the ablation. Boundary layers need unique parameters. + +> Geiping, J. et al. (2025). "Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach." [arXiv:2502.05171](https://arxiv.org/abs/2502.05171) + +## 6. Combined Effect and Key Results + +### Screening (2000 steps, 1×H100) + +| Run | Config | Post-Q BPB | Q-Gap | Δ vs B' (bare) | +|-----|--------|-----------|-------|-----------------| +| B' | 4×2 bare recurrence | 1.3637 | +0.0024 | — | +| C' | 4×2 + birkhoff only | 1.3660 | +0.0024 | +0.002 | +| C | 4×2 + peri + birkhoff | 1.3587 | +0.0020 | −0.005 | +| D | 4×2 + peri + birk + timestep | 1.3584 | +0.0019 | −0.005 | +| E | 1+3×2+1 all fixes | 1.3428 | +0.0019 | −0.021 | +| F | 1+2×3+1 all (3 loops) | 1.3622 | +0.0019 | −0.002 | + +Birkhoff alone hurts (C' > B'). Peri-norm is the main factor (C − C' = −0.007). Prelude-coda is the largest single win (E − D = −0.016). Three loops are viable for the first time (F: Q-gap +0.0019, not exponential). + +### Full-Scale (600s, 8×H100) + +| Run | Config | Eff. Layers | Pre-Q BPB | Post-Q BPB | Q-Gap | +|-----|--------|-------------|-----------|------------|-------| +| H | 1+4×2+1 peri+birk | 10 | 1.2578 | 1.2704 | +0.0126 | +| I | 1+4×2+1 peri+birk+ts(cap4) | 10 | 1.2580 | 1.2668 | +0.0088 | +| J | 1+4×3+1 peri+birk (3 loops) | 14 | 1.2567 | 1.2670 | +0.0103 | +| **K** | **1+4×3+1 peri+birk+ts(cap4)** | **14** | **1.2583** | **1.2659** | **+0.0076** | + +**Headline result.** Run K achieves 14 effective layers from 6 unique blocks with Q-gap +0.0076. This is the first viable 3-loop depth recurrence in competition history, vs prior results showing catastrophic failure at 3+ loops. Timestep scaling reduces Q-gap by 26–30% on both 2-loop and 3-loop configurations. It helps quantization, not training. From 52442c9168ec8d66914ecb22ba40275f5a6fe2a4 Mon Sep 17 00:00:00 2001 From: Alexandr Azizyan Date: Thu, 26 Mar 2026 17:58:53 +0400 Subject: [PATCH 05/10] chore: add submission metadata and primary run log --- .../submission.json | 18 + .../train_log.txt | 1587 +++++++++++++++++ 2 files changed, 1605 insertions(+) create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/submission.json create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_log.txt diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/submission.json b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/submission.json new file mode 100644 index 0000000000..0003769848 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/submission.json @@ -0,0 +1,18 @@ +{ + "author": "Alexandr Azizyan", + "github_id": "aazizyan", + "name": "First Viable 3-Loop Recurrence: Birkhoff + Output-LN + Timestep Scaling (1+4x3+1, 14 eff layers)", + "blurb": "Three novel techniques enabling the first viable 3-loop depth recurrence in competition history. Birkhoff-constrained mixing (spectral norm ≤ 1) eliminates exponential quantization blowup. Output-LN preserves magnitude signal across loop iterations. Capped timestep scaling reduces Q-gap by 26-30%. Config: 1 prelude + 4 shared x 3 loops + 1 coda = 14 effective layers from 6 unique blocks.", + "date": "2026-03-26T00:00:00Z", + "track": "non-record-16mb", + "val_loss": 2.13735865, + "val_bpb": 1.26586418, + "pre_quant_val_loss": 2.1246, + "pre_quant_val_bpb": 1.2583, + "step_stop": 10122, + "wallclock_seconds": 600, + "bytes_total": 10774310, + "bytes_model_int8_zlib": 10717286, + "bytes_code": 57024, + "gpu": "8xH100" +} \ No newline at end of file diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_log.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_log.txt new file mode 100644 index 0000000000..d4830cec70 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_log.txt @@ -0,0 +1,1587 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + + def get(self, v: int) -> tuple[Tensor, Tensor]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + return ag, mg + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + for block in self.coda_blocks: + ag, mg = self._get_ts(v) + x = block(x, x0, ag, mg) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + all_blocks = ( + list(gpt.prelude_blocks) + + list(gpt.shared_blocks) * gpt.num_loops + + list(gpt.coda_blocks) + ) + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + for block in all_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Wed Mar 25 21:36:33 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 35C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:2F:00.0 Off | 0 | +| N/A 29C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:38:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | +| N/A 31C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | +| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 1046 C /usr/local/bin/python 1512MiB | +| 1 N/A N/A 1047 C /usr/local/bin/python 1512MiB | +| 2 N/A N/A 1048 C /usr/local/bin/python 1512MiB | +| 3 N/A N/A 1049 C /usr/local/bin/python 1512MiB | +| 4 N/A N/A 1050 C /usr/local/bin/python 1512MiB | +| 5 N/A N/A 1051 C /usr/local/bin/python 1512MiB | +| 6 N/A N/A 1052 C /usr/local/bin/python 1512MiB | +| 7 N/A N/A 1053 C /usr/local/bin/python 1512MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11557936 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:3 coda:1 effective_layers:14 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:14336 +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9377 train_time:34ms step_avg:33.54ms +step:2/20000 train_loss:9.6217 train_time:88ms step_avg:44.00ms +step:3/20000 train_loss:7.3812 train_time:144ms step_avg:48.09ms +step:4/20000 train_loss:9.1484 train_time:201ms step_avg:50.14ms +step:5/20000 train_loss:8.6154 train_time:257ms step_avg:51.49ms +step:6/20000 train_loss:8.3173 train_time:314ms step_avg:52.32ms +step:7/20000 train_loss:6.9135 train_time:370ms step_avg:52.87ms +step:8/20000 train_loss:6.1894 train_time:427ms step_avg:53.36ms +step:9/20000 train_loss:5.6406 train_time:483ms step_avg:53.70ms +step:10/20000 train_loss:5.4038 train_time:540ms step_avg:54.00ms +step:200/20000 train_loss:2.7823 train_time:11733ms step_avg:58.67ms +step:200 shared0_alpha:mean=0.448,std=0.049 shared1_alpha:mean=0.476,std=0.042 shared2_alpha:mean=0.489,std=0.039 shared3_alpha:mean=0.516,std=0.040 eff_mlp_scale:[v0:27.5484 v1:26.7534 v2:27.3004 v3:28.6959 v4:29.6441 v5:31.4062 v6:28.1404 v7:28.1103 v8:29.0391 v9:30.2430 v10:28.0004 v11:30.8921 v12:32.8202 v13:58.7413] eff_attn_scale:[v0:12.9666 v1:13.2604 v2:10.8659 v3:10.9078 v4:9.5361 v5:11.1804 v6:9.5996 v7:9.4641 v8:8.8824 v9:11.2670 v10:10.2124 v11:9.9453 v12:9.4207 v13:16.9741] +step:200/20000 val_loss:2.7597 val_bpb:1.6344 train_time:11795ms step_avg:58.98ms +step:400/20000 train_loss:2.3575 train_time:23561ms step_avg:58.90ms +step:400 shared0_alpha:mean=0.454,std=0.050 shared1_alpha:mean=0.497,std=0.043 shared2_alpha:mean=0.515,std=0.039 shared3_alpha:mean=0.550,std=0.043 eff_mlp_scale:[v0:34.2854 v1:33.6046 v2:34.5982 v3:37.3141 v4:36.5859 v5:42.7999 v6:38.2653 v7:36.6623 v8:34.8208 v9:38.9546 v10:35.2360 v11:36.8252 v12:35.6232 v13:73.9149] eff_attn_scale:[v0:6.4260 v1:6.6149 v2:5.7415 v3:5.8219 v4:5.4229 v5:6.0857 v6:5.4070 v7:5.2637 v8:5.3181 v9:5.9387 v10:5.1841 v11:4.7586 v12:4.7941 v13:9.9713] +step:400/20000 val_loss:2.5650 val_bpb:1.5192 train_time:23593ms step_avg:58.98ms +step:600/20000 train_loss:2.5758 train_time:35380ms step_avg:58.97ms +step:600 shared0_alpha:mean=0.454,std=0.050 shared1_alpha:mean=0.509,std=0.043 shared2_alpha:mean=0.532,std=0.040 shared3_alpha:mean=0.574,std=0.045 eff_mlp_scale:[v0:40.3946 v1:37.9918 v2:38.8994 v3:42.4052 v4:40.5967 v5:49.2288 v6:44.6498 v7:42.9203 v8:39.4368 v9:44.2346 v10:39.0685 v11:39.8300 v12:37.1170 v13:88.3020] eff_attn_scale:[v0:3.0049 v1:3.5136 v2:3.1467 v3:3.4297 v4:3.4763 v5:3.3945 v6:3.2047 v7:3.1516 v8:3.5922 v9:3.2953 v10:2.8958 v11:2.6696 v12:2.9162 v13:6.4813] +step:600/20000 val_loss:2.4776 val_bpb:1.4674 train_time:35411ms step_avg:59.02ms +step:800/20000 train_loss:2.3367 train_time:47210ms step_avg:59.01ms +step:800 shared0_alpha:mean=0.453,std=0.050 shared1_alpha:mean=0.517,std=0.045 shared2_alpha:mean=0.544,std=0.042 shared3_alpha:mean=0.590,std=0.046 eff_mlp_scale:[v0:47.3770 v1:41.6239 v2:42.4979 v3:45.9607 v4:43.3238 v5:54.2598 v6:49.1711 v7:47.3749 v8:43.1539 v9:47.9419 v10:41.6198 v11:41.5414 v12:38.7366 v13:99.1780] eff_attn_scale:[v0:1.9183 v1:2.3710 v2:2.2032 v3:2.3913 v4:2.6813 v5:2.3556 v6:2.2797 v7:2.2455 v8:2.8579 v9:2.2633 v10:2.0349 v11:1.8226 v12:2.1354 v13:4.8487] +step:800/20000 val_loss:2.4226 val_bpb:1.4348 train_time:47241ms step_avg:59.05ms +step:1000/20000 train_loss:2.4170 train_time:59041ms step_avg:59.04ms +step:1000 shared0_alpha:mean=0.449,std=0.051 shared1_alpha:mean=0.523,std=0.047 shared2_alpha:mean=0.554,std=0.043 shared3_alpha:mean=0.603,std=0.048 eff_mlp_scale:[v0:54.0667 v1:45.0744 v2:45.5208 v3:49.3566 v4:46.1017 v5:57.9254 v6:52.7463 v7:50.8083 v8:46.4510 v9:51.4040 v10:43.8951 v11:43.5500 v12:40.8629 v13:108.3864] eff_attn_scale:[v0:1.4628 v1:1.8692 v2:1.7604 v3:1.9248 v4:2.2411 v5:1.8034 v6:1.7998 v7:1.8108 v8:2.3981 v9:1.7639 v10:1.5830 v11:1.4246 v12:1.7486 v13:3.9293] +step:1000/20000 val_loss:2.3828 val_bpb:1.4112 train_time:59072ms step_avg:59.07ms +step:1200/20000 train_loss:2.4374 train_time:70896ms step_avg:59.08ms +step:1200 shared0_alpha:mean=0.447,std=0.052 shared1_alpha:mean=0.527,std=0.047 shared2_alpha:mean=0.561,std=0.044 shared3_alpha:mean=0.612,std=0.049 eff_mlp_scale:[v0:59.9127 v1:48.1806 v2:48.0975 v3:52.2504 v4:48.5522 v5:61.3565 v6:56.2371 v7:54.1033 v8:49.6232 v9:53.8836 v10:46.0626 v11:45.0243 v12:42.8401 v13:117.2016] eff_attn_scale:[v0:1.2315 v1:1.5553 v2:1.4982 v3:1.6681 v4:1.9516 v5:1.5137 v6:1.5342 v7:1.5631 v8:2.0957 v9:1.4900 v10:1.3364 v11:1.2073 v12:1.4997 v13:3.4107] +step:1200/20000 val_loss:2.3552 val_bpb:1.3949 train_time:70929ms step_avg:59.11ms +step:1400/20000 train_loss:2.4865 train_time:82745ms step_avg:59.10ms +step:1400 shared0_alpha:mean=0.443,std=0.052 shared1_alpha:mean=0.531,std=0.048 shared2_alpha:mean=0.567,std=0.045 shared3_alpha:mean=0.620,std=0.050 eff_mlp_scale:[v0:65.7036 v1:51.0674 v2:51.0106 v3:55.1941 v4:50.6375 v5:64.3369 v6:58.9456 v7:57.0843 v8:52.4589 v9:56.6969 v10:47.7988 v11:46.6881 v12:44.6265 v13:124.5906] eff_attn_scale:[v0:1.0368 v1:1.3702 v2:1.3014 v3:1.4703 v4:1.7817 v5:1.3202 v6:1.3401 v7:1.3669 v8:1.8930 v9:1.2925 v10:1.1518 v11:1.0510 v12:1.3363 v13:3.0387] +step:1400/20000 val_loss:2.3352 val_bpb:1.3830 train_time:82776ms step_avg:59.13ms +step:1600/20000 train_loss:2.1583 train_time:94592ms step_avg:59.12ms +step:1600 shared0_alpha:mean=0.440,std=0.054 shared1_alpha:mean=0.534,std=0.049 shared2_alpha:mean=0.571,std=0.046 shared3_alpha:mean=0.626,std=0.050 eff_mlp_scale:[v0:71.6690 v1:53.9213 v2:53.3864 v3:57.5377 v4:52.9029 v5:66.5846 v6:61.4519 v7:59.4556 v8:54.7527 v9:58.4147 v10:49.5456 v11:47.9481 v12:46.0588 v13:131.3244] eff_attn_scale:[v0:0.8817 v1:1.2447 v2:1.1989 v3:1.3474 v4:1.6413 v5:1.1715 v6:1.2308 v7:1.2737 v8:1.7483 v9:1.1349 v10:1.0504 v11:0.9526 v12:1.2250 v13:2.7730] +step:1600/20000 val_loss:2.3234 val_bpb:1.3760 train_time:94624ms step_avg:59.14ms +step:1800/20000 train_loss:2.2560 train_time:106476ms step_avg:59.15ms +step:1800 shared0_alpha:mean=0.437,std=0.055 shared1_alpha:mean=0.536,std=0.049 shared2_alpha:mean=0.575,std=0.047 shared3_alpha:mean=0.631,std=0.051 eff_mlp_scale:[v0:77.0559 v1:56.5370 v2:55.8204 v3:59.9762 v4:55.0234 v5:69.4242 v6:64.0178 v7:61.9235 v8:57.2847 v9:60.6942 v10:51.1362 v11:49.4609 v12:48.0513 v13:137.3129] eff_attn_scale:[v0:0.8088 v1:1.1540 v2:1.0999 v3:1.2589 v4:1.5357 v5:1.0727 v6:1.1151 v7:1.1628 v8:1.6388 v9:1.0523 v10:0.9529 v11:0.8646 v12:1.1288 v13:2.5785] +step:1800/20000 val_loss:2.3070 val_bpb:1.3663 train_time:106508ms step_avg:59.17ms +step:2000/20000 train_loss:2.3081 train_time:118327ms step_avg:59.16ms +step:2000 shared0_alpha:mean=0.434,std=0.055 shared1_alpha:mean=0.538,std=0.050 shared2_alpha:mean=0.579,std=0.048 shared3_alpha:mean=0.636,std=0.052 eff_mlp_scale:[v0:82.7543 v1:59.0391 v2:58.7532 v3:62.3953 v4:57.0309 v5:71.6903 v6:66.2958 v7:63.9749 v8:59.7102 v9:62.8344 v10:52.7985 v11:50.5481 v12:49.3758 v13:142.4716] eff_attn_scale:[v0:0.7283 v1:1.0662 v2:1.0360 v3:1.1851 v4:1.4729 v5:0.9834 v6:1.0360 v7:1.1012 v8:1.5511 v9:0.9640 v10:0.8796 v11:0.8049 v12:1.0601 v13:2.4340] +step:2000/20000 val_loss:2.2943 val_bpb:1.3588 train_time:118359ms step_avg:59.18ms +step:2200/20000 train_loss:2.1339 train_time:130184ms step_avg:59.17ms +step:2200 shared0_alpha:mean=0.431,std=0.056 shared1_alpha:mean=0.540,std=0.051 shared2_alpha:mean=0.582,std=0.049 shared3_alpha:mean=0.640,std=0.053 eff_mlp_scale:[v0:87.8139 v1:61.0818 v2:60.7661 v3:64.6332 v4:59.1065 v5:73.8962 v6:68.4121 v7:66.2390 v8:61.8285 v9:64.4990 v10:54.3273 v11:52.1883 v12:51.3293 v13:148.2095] eff_attn_scale:[v0:0.6803 v1:1.0172 v2:0.9873 v3:1.1284 v4:1.4244 v5:0.9179 v6:0.9921 v7:1.0319 v8:1.5017 v9:0.8990 v10:0.8380 v11:0.7667 v12:1.0269 v13:2.3139] +step:2200/20000 val_loss:2.2862 val_bpb:1.3540 train_time:130215ms step_avg:59.19ms +step:2400/20000 train_loss:2.2593 train_time:142034ms step_avg:59.18ms +step:2400 shared0_alpha:mean=0.428,std=0.057 shared1_alpha:mean=0.541,std=0.052 shared2_alpha:mean=0.585,std=0.050 shared3_alpha:mean=0.644,std=0.054 eff_mlp_scale:[v0:93.0263 v1:63.7902 v2:62.8558 v3:66.7213 v4:61.5684 v5:76.3747 v6:70.6107 v7:68.3486 v8:64.3311 v9:66.8279 v10:55.9172 v11:53.7025 v12:52.8857 v13:154.3660] eff_attn_scale:[v0:0.6218 v1:0.9561 v2:0.9382 v3:1.0693 v4:1.3684 v5:0.8596 v6:0.9243 v7:0.9802 v8:1.4387 v9:0.8458 v10:0.7764 v11:0.7222 v12:0.9898 v13:2.1780] +step:2400/20000 val_loss:2.2751 val_bpb:1.3475 train_time:142066ms step_avg:59.19ms +step:2600/20000 train_loss:2.4727 train_time:153881ms step_avg:59.19ms +step:2600 shared0_alpha:mean=0.425,std=0.057 shared1_alpha:mean=0.543,std=0.052 shared2_alpha:mean=0.587,std=0.051 shared3_alpha:mean=0.647,std=0.054 eff_mlp_scale:[v0:97.4525 v1:65.8375 v2:64.8462 v3:69.2610 v4:63.6741 v5:78.1271 v6:72.6938 v7:70.4978 v8:66.4774 v9:68.4710 v10:57.4116 v11:54.8316 v12:54.4634 v13:158.4925] eff_attn_scale:[v0:0.5899 v1:0.9353 v2:0.9162 v3:1.0517 v4:1.3262 v5:0.8269 v6:0.8934 v7:0.9590 v8:1.3737 v9:0.8088 v10:0.7475 v11:0.6996 v12:0.9352 v13:2.1238] +step:2600/20000 val_loss:2.2874 val_bpb:1.3547 train_time:153913ms step_avg:59.20ms +step:2800/20000 train_loss:2.2975 train_time:165743ms step_avg:59.19ms +step:2800 shared0_alpha:mean=0.423,std=0.058 shared1_alpha:mean=0.545,std=0.053 shared2_alpha:mean=0.590,std=0.052 shared3_alpha:mean=0.649,std=0.055 eff_mlp_scale:[v0:102.8439 v1:68.1906 v2:67.2278 v3:71.7107 v4:65.6281 v5:79.7033 v6:74.3264 v7:72.5445 v8:68.4639 v9:69.9617 v10:58.8765 v11:56.2846 v12:56.3106 v13:163.4298] eff_attn_scale:[v0:0.5667 v1:0.9016 v2:0.8768 v3:1.0077 v4:1.2826 v5:0.7999 v6:0.8543 v7:0.9177 v8:1.3507 v9:0.7734 v10:0.7239 v11:0.6613 v12:0.9214 v13:2.0292] +step:2800/20000 val_loss:2.2621 val_bpb:1.3398 train_time:165774ms step_avg:59.21ms +step:3000/20000 train_loss:2.2850 train_time:177586ms step_avg:59.20ms +step:3000 shared0_alpha:mean=0.420,std=0.058 shared1_alpha:mean=0.545,std=0.054 shared2_alpha:mean=0.592,std=0.052 shared3_alpha:mean=0.652,std=0.056 eff_mlp_scale:[v0:107.1784 v1:70.3300 v2:69.3097 v3:73.8286 v4:67.6902 v5:81.9770 v6:76.0716 v7:74.6724 v8:70.1517 v9:71.6739 v10:60.0120 v11:57.7973 v12:57.8444 v13:168.6794] eff_attn_scale:[v0:0.5273 v1:0.8750 v2:0.8527 v3:0.9793 v4:1.2626 v5:0.7700 v6:0.8307 v7:0.8778 v8:1.3090 v9:0.7394 v10:0.6989 v11:0.6352 v12:0.8916 v13:1.9865] +step:3000/20000 val_loss:2.2550 val_bpb:1.3355 train_time:177618ms step_avg:59.21ms +step:3200/20000 train_loss:2.2509 train_time:189434ms step_avg:59.20ms +step:3200 shared0_alpha:mean=0.417,std=0.059 shared1_alpha:mean=0.546,std=0.054 shared2_alpha:mean=0.595,std=0.053 shared3_alpha:mean=0.654,std=0.056 eff_mlp_scale:[v0:111.8766 v1:72.4473 v2:71.3405 v3:75.8253 v4:69.6745 v5:84.2200 v6:78.1756 v7:76.6773 v8:71.7481 v9:72.9001 v10:61.5152 v11:58.7859 v12:59.3062 v13:172.0224] eff_attn_scale:[v0:0.5131 v1:0.8464 v2:0.8337 v3:0.9520 v4:1.2488 v5:0.7401 v6:0.8033 v7:0.8647 v8:1.2951 v9:0.7103 v10:0.6774 v11:0.6201 v12:0.8891 v13:1.9186] +step:3200/20000 val_loss:2.2506 val_bpb:1.3329 train_time:189467ms step_avg:59.21ms +step:3400/20000 train_loss:2.2156 train_time:201279ms step_avg:59.20ms +step:3400 shared0_alpha:mean=0.415,std=0.059 shared1_alpha:mean=0.547,std=0.054 shared2_alpha:mean=0.597,std=0.054 shared3_alpha:mean=0.657,std=0.057 eff_mlp_scale:[v0:116.7305 v1:74.6774 v2:73.4077 v3:77.6875 v4:71.9578 v5:86.1310 v6:79.8848 v7:78.5507 v8:74.0618 v9:74.6774 v10:63.0442 v11:60.4236 v12:61.0169 v13:176.9400] eff_attn_scale:[v0:0.4849 v1:0.8381 v2:0.8049 v3:0.9454 v4:1.2184 v5:0.7281 v6:0.7836 v7:0.8452 v8:1.2641 v9:0.6942 v10:0.6558 v11:0.6143 v12:0.8631 v13:1.8646] +step:3400/20000 val_loss:2.2473 val_bpb:1.3310 train_time:201311ms step_avg:59.21ms +step:3600/20000 train_loss:2.1850 train_time:213120ms step_avg:59.20ms +step:3600 shared0_alpha:mean=0.413,std=0.060 shared1_alpha:mean=0.548,std=0.055 shared2_alpha:mean=0.599,std=0.054 shared3_alpha:mean=0.659,std=0.057 eff_mlp_scale:[v0:121.4435 v1:77.4289 v2:75.4982 v3:80.2307 v4:73.4860 v5:88.0927 v6:82.0443 v7:80.2307 v8:76.0347 v9:76.5016 v10:64.1516 v11:61.4811 v12:62.4419 v13:180.8942] eff_attn_scale:[v0:0.4713 v1:0.8233 v2:0.7938 v3:0.9230 v4:1.1999 v5:0.7141 v6:0.7726 v7:0.8281 v8:1.2351 v9:0.6847 v10:0.6410 v11:0.5995 v12:0.8535 v13:1.8283] +step:3600/20000 val_loss:2.2399 val_bpb:1.3266 train_time:213152ms step_avg:59.21ms +step:3800/20000 train_loss:2.2832 train_time:224971ms step_avg:59.20ms +step:3800 shared0_alpha:mean=0.411,std=0.061 shared1_alpha:mean=0.549,std=0.055 shared2_alpha:mean=0.600,std=0.055 shared3_alpha:mean=0.661,std=0.057 eff_mlp_scale:[v0:126.1614 v1:79.1049 v2:77.7751 v3:82.4223 v4:75.6845 v5:89.8707 v6:83.9617 v7:81.9816 v8:77.8346 v9:77.7007 v10:65.8437 v11:63.0289 v12:64.0738 v13:184.9535] eff_attn_scale:[v0:0.4496 v1:0.8116 v2:0.7936 v3:0.8995 v4:1.2013 v5:0.7034 v6:0.7516 v7:0.8181 v8:1.2162 v9:0.6660 v10:0.6257 v11:0.5868 v12:0.8374 v13:1.7982] +step:3800/20000 val_loss:2.2354 val_bpb:1.3239 train_time:225002ms step_avg:59.21ms +step:4000/20000 train_loss:2.2207 train_time:236822ms step_avg:59.21ms +step:4000 shared0_alpha:mean=0.409,std=0.061 shared1_alpha:mean=0.550,std=0.055 shared2_alpha:mean=0.602,std=0.056 shared3_alpha:mean=0.663,std=0.057 eff_mlp_scale:[v0:130.4283 v1:81.3678 v2:79.3461 v3:84.2017 v4:77.4413 v5:91.7753 v6:85.5869 v7:83.7562 v8:79.6166 v9:79.4755 v10:66.8647 v11:64.1537 v12:65.6946 v13:189.9913] eff_attn_scale:[v0:0.4358 v1:0.7930 v2:0.7767 v3:0.8925 v4:1.1616 v5:0.6856 v6:0.7350 v7:0.7910 v8:1.2061 v9:0.6526 v10:0.6139 v11:0.5795 v12:0.8205 v13:1.7570] +step:4000/20000 val_loss:2.2301 val_bpb:1.3208 train_time:236854ms step_avg:59.21ms +step:4200/20000 train_loss:2.2356 train_time:248914ms step_avg:59.27ms +step:4200 shared0_alpha:mean=0.407,std=0.061 shared1_alpha:mean=0.551,std=0.056 shared2_alpha:mean=0.604,std=0.056 shared3_alpha:mean=0.665,std=0.058 eff_mlp_scale:[v0:134.3793 v1:83.5345 v2:81.4617 v3:86.8770 v4:79.6370 v5:93.5586 v6:87.3126 v7:85.5265 v8:81.3969 v9:80.6705 v10:67.9598 v11:65.7204 v12:67.3175 v13:193.9883] eff_attn_scale:[v0:0.4229 v1:0.7838 v2:0.7523 v3:0.8786 v4:1.1727 v5:0.6701 v6:0.7157 v7:0.7908 v8:1.1975 v9:0.6295 v10:0.5896 v11:0.5648 v12:0.8115 v13:1.7432] +step:4200/20000 val_loss:2.2262 val_bpb:1.3185 train_time:248945ms step_avg:59.27ms +step:4400/20000 train_loss:2.1765 train_time:260750ms step_avg:59.26ms +step:4400 shared0_alpha:mean=0.405,std=0.061 shared1_alpha:mean=0.552,std=0.056 shared2_alpha:mean=0.606,std=0.057 shared3_alpha:mean=0.667,std=0.058 eff_mlp_scale:[v0:140.1213 v1:85.2745 v2:83.1819 v3:88.7375 v4:81.4193 v5:95.3918 v6:89.0910 v7:87.3723 v8:83.1990 v9:82.3838 v10:69.5455 v11:66.8944 v12:68.9617 v13:197.1806] eff_attn_scale:[v0:0.4059 v1:0.7894 v2:0.7553 v3:0.8802 v4:1.1561 v5:0.6626 v6:0.7063 v7:0.7801 v8:1.1757 v9:0.6339 v10:0.5838 v11:0.5590 v12:0.8117 v13:1.7199] +step:4400/20000 val_loss:2.2270 val_bpb:1.3190 train_time:260781ms step_avg:59.27ms +step:4600/20000 train_loss:2.0368 train_time:272605ms step_avg:59.26ms +step:4600 shared0_alpha:mean=0.403,std=0.062 shared1_alpha:mean=0.552,std=0.056 shared2_alpha:mean=0.607,std=0.057 shared3_alpha:mean=0.668,std=0.059 eff_mlp_scale:[v0:145.3811 v1:87.5979 v2:85.2514 v3:90.8275 v4:83.6477 v5:96.8443 v6:90.7515 v7:88.9926 v8:85.4466 v9:83.7046 v10:70.5845 v11:67.8912 v12:70.6059 v13:202.0904] eff_attn_scale:[v0:0.3928 v1:0.7792 v2:0.7549 v3:0.8647 v4:1.1528 v5:0.6507 v6:0.7025 v7:0.7658 v8:1.1725 v9:0.6105 v10:0.5814 v11:0.5476 v12:0.8094 v13:1.6740] +step:4600/20000 val_loss:2.2220 val_bpb:1.3160 train_time:272637ms step_avg:59.27ms +step:4800/20000 train_loss:2.3243 train_time:284445ms step_avg:59.26ms +step:4800 shared0_alpha:mean=0.401,std=0.062 shared1_alpha:mean=0.553,std=0.056 shared2_alpha:mean=0.609,std=0.058 shared3_alpha:mean=0.670,std=0.059 eff_mlp_scale:[v0:149.6388 v1:89.8112 v2:86.9723 v3:93.2148 v4:85.5791 v5:98.6451 v6:92.5237 v7:90.8961 v8:86.9447 v9:85.3943 v10:72.1685 v11:69.5633 v12:72.3780 v13:205.0534] eff_attn_scale:[v0:0.3859 v1:0.7676 v2:0.7423 v3:0.8653 v4:1.1359 v5:0.6436 v6:0.6898 v7:0.7623 v8:1.1603 v9:0.6076 v10:0.5728 v11:0.5398 v12:0.7946 v13:1.6561] +step:4800/20000 val_loss:2.2187 val_bpb:1.3140 train_time:284477ms step_avg:59.27ms +step:5000/20000 train_loss:2.0935 train_time:296302ms step_avg:59.26ms +step:5000 shared0_alpha:mean=0.399,std=0.062 shared1_alpha:mean=0.553,std=0.057 shared2_alpha:mean=0.610,std=0.058 shared3_alpha:mean=0.672,std=0.059 eff_mlp_scale:[v0:153.8149 v1:91.5847 v2:89.0435 v3:95.4824 v4:87.7745 v5:100.4957 v6:93.7055 v7:92.2061 v8:88.6936 v9:86.6342 v10:73.1929 v11:70.6757 v12:73.9879 v13:208.6951] eff_attn_scale:[v0:0.3806 v1:0.7597 v2:0.7427 v3:0.8650 v4:1.1437 v5:0.6259 v6:0.6865 v7:0.7615 v8:1.1633 v9:0.5904 v10:0.5621 v11:0.5422 v12:0.8001 v13:1.6242] +step:5000/20000 val_loss:2.2133 val_bpb:1.3108 train_time:296334ms step_avg:59.27ms +step:5200/20000 train_loss:2.2323 train_time:308141ms step_avg:59.26ms +step:5200 shared0_alpha:mean=0.397,std=0.062 shared1_alpha:mean=0.553,std=0.058 shared2_alpha:mean=0.612,std=0.059 shared3_alpha:mean=0.674,std=0.060 eff_mlp_scale:[v0:159.4933 v1:93.7957 v2:90.9036 v3:97.4001 v4:89.7327 v5:101.7783 v6:95.6136 v7:94.5632 v8:91.1275 v9:87.8087 v10:74.4185 v11:71.8680 v12:75.7846 v13:212.6315] eff_attn_scale:[v0:0.3687 v1:0.7652 v2:0.7401 v3:0.8644 v4:1.1541 v5:0.6304 v6:0.6878 v7:0.7615 v8:1.1688 v9:0.5947 v10:0.5631 v11:0.5392 v12:0.8054 v13:1.6234] +step:5200/20000 val_loss:2.2144 val_bpb:1.3115 train_time:308173ms step_avg:59.26ms +step:5400/20000 train_loss:2.2443 train_time:320001ms step_avg:59.26ms +step:5400 shared0_alpha:mean=0.395,std=0.062 shared1_alpha:mean=0.554,std=0.058 shared2_alpha:mean=0.613,std=0.059 shared3_alpha:mean=0.675,std=0.060 eff_mlp_scale:[v0:163.9131 v1:95.6766 v2:93.0526 v3:99.6529 v4:91.3985 v5:103.7336 v6:97.3254 v7:95.8384 v8:92.8046 v9:89.6339 v10:75.4865 v11:73.4285 v12:76.8684 v13:215.8871] eff_attn_scale:[v0:0.3644 v1:0.7672 v2:0.7476 v3:0.8665 v4:1.1519 v5:0.6249 v6:0.6761 v7:0.7551 v8:1.1715 v9:0.5893 v10:0.5528 v11:0.5364 v12:0.7990 v13:1.6067] +step:5400/20000 val_loss:2.2088 val_bpb:1.3082 train_time:320032ms step_avg:59.27ms +step:5600/20000 train_loss:2.2457 train_time:331844ms step_avg:59.26ms +step:5600 shared0_alpha:mean=0.393,std=0.063 shared1_alpha:mean=0.555,std=0.058 shared2_alpha:mean=0.615,std=0.059 shared3_alpha:mean=0.677,std=0.061 eff_mlp_scale:[v0:168.9937 v1:97.9664 v2:94.8258 v3:101.4387 v4:93.6504 v5:105.5803 v6:99.1361 v7:98.0734 v8:94.5964 v9:90.8600 v10:76.6269 v11:74.5165 v12:78.5150 v13:219.7498] eff_attn_scale:[v0:0.3544 v1:0.7693 v2:0.7419 v3:0.8706 v4:1.1515 v5:0.6273 v6:0.6704 v7:0.7515 v8:1.1515 v9:0.5839 v10:0.5475 v11:0.5338 v12:0.7904 v13:1.6007] +step:5600/20000 val_loss:2.2100 val_bpb:1.3089 train_time:331876ms step_avg:59.26ms +step:5800/20000 train_loss:2.2100 train_time:343701ms step_avg:59.26ms +step:5800 shared0_alpha:mean=0.392,std=0.063 shared1_alpha:mean=0.556,std=0.058 shared2_alpha:mean=0.617,std=0.060 shared3_alpha:mean=0.678,std=0.061 eff_mlp_scale:[v0:173.9413 v1:100.1436 v2:97.1412 v3:103.6861 v4:95.6212 v5:106.7858 v6:101.0075 v7:99.3255 v8:96.5774 v9:91.9686 v10:78.2929 v11:75.5843 v12:79.8437 v13:222.7330] eff_attn_scale:[v0:0.3424 v1:0.7614 v2:0.7388 v3:0.8595 v4:1.1536 v5:0.6169 v6:0.6673 v7:0.7339 v8:1.1439 v9:0.5779 v10:0.5481 v11:0.5271 v12:0.7901 v13:1.5714] +step:5800/20000 val_loss:2.2085 val_bpb:1.3080 train_time:343733ms step_avg:59.26ms +step:6000/20000 train_loss:2.2787 train_time:355541ms step_avg:59.26ms +step:6000 shared0_alpha:mean=0.390,std=0.063 shared1_alpha:mean=0.556,std=0.057 shared2_alpha:mean=0.618,std=0.060 shared3_alpha:mean=0.679,std=0.061 eff_mlp_scale:[v0:179.7115 v1:102.5137 v2:98.8222 v3:106.1559 v4:97.2881 v5:109.2105 v6:102.7167 v7:101.2639 v8:98.2513 v9:93.2411 v10:79.3499 v11:76.8040 v12:81.3945 v13:226.4459] eff_attn_scale:[v0:0.3415 v1:0.7450 v2:0.7313 v3:0.8596 v4:1.1518 v5:0.6067 v6:0.6684 v7:0.7385 v8:1.1567 v9:0.5683 v10:0.5465 v11:0.5246 v12:0.7922 v13:1.5915] +step:6000/20000 val_loss:2.2042 val_bpb:1.3054 train_time:355572ms step_avg:59.26ms +step:6200/20000 train_loss:2.1520 train_time:367387ms step_avg:59.26ms +step:6200 shared0_alpha:mean=0.388,std=0.063 shared1_alpha:mean=0.557,std=0.058 shared2_alpha:mean=0.620,std=0.060 shared3_alpha:mean=0.681,std=0.062 eff_mlp_scale:[v0:184.1409 v1:104.4329 v2:101.0519 v3:108.1300 v4:99.9027 v5:110.6677 v6:103.9952 v7:102.6988 v8:100.3901 v9:95.0807 v10:80.4491 v11:78.0116 v12:83.3335 v13:228.4138] eff_attn_scale:[v0:0.3336 v1:0.7624 v2:0.7430 v3:0.8663 v4:1.1573 v5:0.6139 v6:0.6565 v7:0.7402 v8:1.1622 v9:0.5630 v10:0.5347 v11:0.5287 v12:0.7878 v13:1.5432] +step:6200/20000 val_loss:2.2034 val_bpb:1.3050 train_time:367419ms step_avg:59.26ms +step:6400/20000 train_loss:2.2230 train_time:379237ms step_avg:59.26ms +step:6400 shared0_alpha:mean=0.386,std=0.064 shared1_alpha:mean=0.558,std=0.058 shared2_alpha:mean=0.621,std=0.061 shared3_alpha:mean=0.682,std=0.062 eff_mlp_scale:[v0:189.8302 v1:106.8324 v2:102.8722 v3:110.1204 v4:101.9474 v5:112.5930 v6:105.8396 v7:104.6393 v8:102.4399 v9:96.3587 v10:81.6053 v11:79.2269 v12:85.2024 v13:232.3741] eff_attn_scale:[v0:0.3267 v1:0.7589 v2:0.7307 v3:0.8651 v4:1.1648 v5:0.5993 v6:0.6447 v7:0.7392 v8:1.1502 v9:0.5604 v10:0.5353 v11:0.5199 v12:0.7944 v13:1.5767] +step:6400/20000 val_loss:2.2000 val_bpb:1.3029 train_time:379269ms step_avg:59.26ms +step:6600/20000 train_loss:2.1894 train_time:391086ms step_avg:59.26ms +step:6600 shared0_alpha:mean=0.384,std=0.064 shared1_alpha:mean=0.558,std=0.059 shared2_alpha:mean=0.622,std=0.061 shared3_alpha:mean=0.684,std=0.063 eff_mlp_scale:[v0:195.4746 v1:108.5459 v2:104.6949 v3:112.5847 v4:103.8120 v5:114.3420 v6:107.1876 v7:106.5534 v8:104.3087 v9:97.4805 v10:82.7588 v11:80.9202 v12:86.9239 v13:236.5316] eff_attn_scale:[v0:0.3259 v1:0.7674 v2:0.7435 v3:0.8732 v4:1.1760 v5:0.6077 v6:0.6530 v7:0.7420 v8:1.1562 v9:0.5648 v10:0.5311 v11:0.5247 v12:0.8054 v13:1.5359] +step:6600/20000 val_loss:2.1963 val_bpb:1.3007 train_time:391118ms step_avg:59.26ms +step:6800/20000 train_loss:2.2551 train_time:402928ms step_avg:59.25ms +step:6800 shared0_alpha:mean=0.382,std=0.063 shared1_alpha:mean=0.558,std=0.059 shared2_alpha:mean=0.624,std=0.062 shared3_alpha:mean=0.685,std=0.063 eff_mlp_scale:[v0:199.3570 v1:110.9249 v2:106.4690 v3:114.4221 v4:106.2555 v5:115.7016 v6:108.4778 v7:108.3466 v8:105.7543 v9:98.7179 v10:83.8694 v11:82.0194 v12:88.2121 v13:238.4881] eff_attn_scale:[v0:0.3224 v1:0.7681 v2:0.7286 v3:0.8746 v4:1.1817 v5:0.6052 v6:0.6468 v7:0.7323 v8:1.1669 v9:0.5470 v10:0.5182 v11:0.5187 v12:0.7977 v13:1.5251] +step:6800/20000 val_loss:2.1956 val_bpb:1.3004 train_time:402961ms step_avg:59.26ms +step:7000/20000 train_loss:2.2832 train_time:414774ms step_avg:59.25ms +step:7000 shared0_alpha:mean=0.380,std=0.064 shared1_alpha:mean=0.559,std=0.059 shared2_alpha:mean=0.625,std=0.062 shared3_alpha:mean=0.686,std=0.063 eff_mlp_scale:[v0:205.1638 v1:112.9238 v2:108.7964 v3:116.8961 v4:108.2008 v5:117.2053 v6:110.3145 v7:109.7496 v8:108.2008 v9:100.0794 v10:85.0130 v11:83.2055 v12:89.9988 v13:242.5826] eff_attn_scale:[v0:0.3126 v1:0.7662 v2:0.7415 v3:0.8720 v4:1.1966 v5:0.6037 v6:0.6586 v7:0.7369 v8:1.1671 v9:0.5495 v10:0.5245 v11:0.5240 v12:0.8076 v13:1.5226] +step:7000/20000 val_loss:2.1936 val_bpb:1.2992 train_time:414805ms step_avg:59.26ms +step:7200/20000 train_loss:2.2635 train_time:426617ms step_avg:59.25ms +step:7200 shared0_alpha:mean=0.379,std=0.063 shared1_alpha:mean=0.560,std=0.059 shared2_alpha:mean=0.627,std=0.062 shared3_alpha:mean=0.687,std=0.063 eff_mlp_scale:[v0:209.4189 v1:114.7228 v2:110.1191 v3:118.8122 v4:110.1611 v5:119.0317 v6:112.1584 v7:111.6115 v8:109.6511 v9:101.2577 v10:86.1580 v11:84.3516 v12:91.8009 v13:246.3847] eff_attn_scale:[v0:0.3114 v1:0.7758 v2:0.7512 v3:0.8862 v4:1.1983 v5:0.6082 v6:0.6597 v7:0.7536 v8:1.1687 v9:0.5536 v10:0.5326 v11:0.5342 v12:0.8087 v13:1.5231] +step:7200/20000 val_loss:2.1947 val_bpb:1.2998 train_time:426648ms step_avg:59.26ms +step:7400/20000 train_loss:2.1789 train_time:438467ms step_avg:59.25ms +step:7400 shared0_alpha:mean=0.376,std=0.063 shared1_alpha:mean=0.560,std=0.060 shared2_alpha:mean=0.628,std=0.063 shared3_alpha:mean=0.688,std=0.064 eff_mlp_scale:[v0:214.8811 v1:116.5179 v2:112.4040 v3:120.8861 v4:112.1007 v5:120.8534 v6:113.4305 v7:113.6226 v8:111.5865 v9:102.4273 v10:87.2542 v11:85.6061 v12:93.0745 v13:250.0835] eff_attn_scale:[v0:0.3001 v1:0.7651 v2:0.7407 v3:0.8777 v4:1.2095 v5:0.6036 v6:0.6472 v7:0.7436 v8:1.1749 v9:0.5459 v10:0.5146 v11:0.5140 v12:0.7997 v13:1.5224] +step:7400/20000 val_loss:2.1898 val_bpb:1.2969 train_time:438499ms step_avg:59.26ms +step:7600/20000 train_loss:2.0665 train_time:450304ms step_avg:59.25ms +step:7600 shared0_alpha:mean=0.375,std=0.064 shared1_alpha:mean=0.561,std=0.060 shared2_alpha:mean=0.630,std=0.063 shared3_alpha:mean=0.690,std=0.064 eff_mlp_scale:[v0:219.5673 v1:119.0424 v2:114.3795 v3:123.2943 v4:114.5205 v5:122.3188 v6:114.8971 v7:114.9354 v8:113.4841 v9:103.7525 v10:88.5018 v11:86.7240 v12:94.8292 v13:252.0920] eff_attn_scale:[v0:0.3055 v1:0.7820 v2:0.7397 v3:0.8893 v4:1.2225 v5:0.5991 v6:0.6463 v7:0.7418 v8:1.1827 v9:0.5446 v10:0.5178 v11:0.5184 v12:0.8100 v13:1.5209] +step:7600/20000 val_loss:2.1886 val_bpb:1.2962 train_time:450336ms step_avg:59.25ms +step:7800/20000 train_loss:2.2108 train_time:462150ms step_avg:59.25ms +step:7800 shared0_alpha:mean=0.373,std=0.063 shared1_alpha:mean=0.562,std=0.060 shared2_alpha:mean=0.631,std=0.063 shared3_alpha:mean=0.691,std=0.065 eff_mlp_scale:[v0:225.3324 v1:120.9175 v2:116.2944 v3:125.4547 v4:116.5135 v5:124.2153 v6:116.8159 v7:116.4937 v8:115.4685 v9:104.9784 v10:89.1764 v11:88.0292 v12:96.1367 v13:256.1435] eff_attn_scale:[v0:0.3010 v1:0.7857 v2:0.7362 v3:0.9016 v4:1.2284 v5:0.6009 v6:0.6388 v7:0.7418 v8:1.1887 v9:0.5392 v10:0.5181 v11:0.5287 v12:0.8107 v13:1.5223] +step:7800/20000 val_loss:2.1854 val_bpb:1.2943 train_time:462181ms step_avg:59.25ms +step:8000/20000 train_loss:2.1762 train_time:473998ms step_avg:59.25ms +step:8000 shared0_alpha:mean=0.371,std=0.063 shared1_alpha:mean=0.562,std=0.060 shared2_alpha:mean=0.632,std=0.063 shared3_alpha:mean=0.692,std=0.065 eff_mlp_scale:[v0:230.1253 v1:122.9614 v2:118.8503 v3:126.9168 v4:119.1279 v5:126.2847 v6:118.3244 v7:118.4203 v8:117.0194 v9:106.3450 v10:90.9783 v11:89.2135 v12:98.0433 v13:258.0185] eff_attn_scale:[v0:0.2941 v1:0.7905 v2:0.7514 v3:0.9012 v4:1.2311 v5:0.6075 v6:0.6457 v7:0.7407 v8:1.1814 v9:0.5490 v10:0.5126 v11:0.5246 v12:0.8191 v13:1.5115] +step:8000/20000 val_loss:2.1836 val_bpb:1.2933 train_time:474029ms step_avg:59.25ms +step:8200/20000 train_loss:2.2414 train_time:485841ms step_avg:59.25ms +step:8200 shared0_alpha:mean=0.370,std=0.063 shared1_alpha:mean=0.563,std=0.060 shared2_alpha:mean=0.634,std=0.064 shared3_alpha:mean=0.693,std=0.066 eff_mlp_scale:[v0:236.0981 v1:125.4997 v2:120.2003 v3:130.1774 v4:121.1870 v5:127.1730 v6:119.6708 v7:119.9989 v8:119.0609 v9:107.6508 v10:91.6064 v11:90.5349 v12:99.9261 v13:262.0005] eff_attn_scale:[v0:0.2834 v1:0.8007 v2:0.7520 v3:0.9107 v4:1.2411 v5:0.6064 v6:0.6423 v7:0.7493 v8:1.2061 v9:0.5442 v10:0.5209 v11:0.5340 v12:0.8257 v13:1.5124] +step:8200/20000 val_loss:2.1824 val_bpb:1.2925 train_time:485873ms step_avg:59.25ms +step:8400/20000 train_loss:2.1960 train_time:497938ms step_avg:59.28ms +step:8400 shared0_alpha:mean=0.368,std=0.064 shared1_alpha:mean=0.564,std=0.060 shared2_alpha:mean=0.635,std=0.064 shared3_alpha:mean=0.695,std=0.066 eff_mlp_scale:[v0:240.2487 v1:127.4610 v2:122.7811 v3:132.1472 v4:123.2464 v5:129.1455 v6:121.7134 v7:121.8991 v8:121.1030 v9:108.9314 v10:92.8866 v11:91.6940 v12:101.2764 v13:263.9616] eff_attn_scale:[v0:0.2900 v1:0.8182 v2:0.7521 v3:0.9300 v4:1.2661 v5:0.6078 v6:0.6391 v7:0.7523 v8:1.2058 v9:0.5455 v10:0.5105 v11:0.5332 v12:0.8290 v13:1.5126] +step:8400/20000 val_loss:2.1815 val_bpb:1.2920 train_time:497969ms step_avg:59.28ms +step:8600/20000 train_loss:2.1955 train_time:509769ms step_avg:59.28ms +step:8600 shared0_alpha:mean=0.366,std=0.064 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.636,std=0.064 shared3_alpha:mean=0.696,std=0.066 eff_mlp_scale:[v0:246.5324 v1:129.4005 v2:124.7777 v3:134.7354 v4:125.2409 v5:130.5306 v6:123.1642 v7:123.8696 v8:123.0816 v9:110.1882 v10:94.1211 v11:92.9022 v12:103.1078 v13:267.7187] eff_attn_scale:[v0:0.2869 v1:0.8100 v2:0.7461 v3:0.9212 v4:1.2635 v5:0.6153 v6:0.6445 v7:0.7526 v8:1.2086 v9:0.5413 v10:0.5156 v11:0.5346 v12:0.8290 v13:1.5093] +step:8600/20000 val_loss:2.1791 val_bpb:1.2906 train_time:509804ms step_avg:59.28ms +step:8800/20000 train_loss:2.1664 train_time:521614ms step_avg:59.27ms +step:8800 shared0_alpha:mean=0.364,std=0.064 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.638,std=0.064 shared3_alpha:mean=0.697,std=0.066 eff_mlp_scale:[v0:251.0507 v1:131.9416 v2:126.6586 v3:136.1598 v4:127.8389 v5:132.5104 v6:124.4935 v7:125.2233 v8:124.5749 v9:111.4679 v10:95.2646 v11:94.0542 v12:104.4471 v13:271.2946] eff_attn_scale:[v0:0.2766 v1:0.8323 v2:0.7588 v3:0.9297 v4:1.2828 v5:0.6154 v6:0.6414 v7:0.7637 v8:1.2121 v9:0.5522 v10:0.5085 v11:0.5354 v12:0.8384 v13:1.5304] +step:8800/20000 val_loss:2.1778 val_bpb:1.2898 train_time:521645ms step_avg:59.28ms +step:9000/20000 train_loss:2.0805 train_time:533470ms step_avg:59.27ms +step:9000 shared0_alpha:mean=0.363,std=0.064 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.698,std=0.066 eff_mlp_scale:[v0:257.4152 v1:133.7777 v2:128.6963 v3:138.8611 v4:130.0868 v5:134.3494 v6:126.5150 v7:127.2894 v8:126.7935 v9:112.6248 v10:96.5222 v11:95.3293 v12:106.4845 v13:273.4311] eff_attn_scale:[v0:0.2766 v1:0.8403 v2:0.7611 v3:0.9508 v4:1.3217 v5:0.6273 v6:0.6512 v7:0.7699 v8:1.2397 v9:0.5484 v10:0.5139 v11:0.5427 v12:0.8555 v13:1.5236] +step:9000/20000 val_loss:2.1757 val_bpb:1.2886 train_time:533502ms step_avg:59.28ms +step:9200/20000 train_loss:2.1355 train_time:545326ms step_avg:59.27ms +step:9200 shared0_alpha:mean=0.361,std=0.063 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.698,std=0.066 eff_mlp_scale:[v0:262.0444 v1:135.2486 v2:130.5472 v3:140.4741 v4:132.1480 v5:135.8241 v6:127.8047 v7:128.8142 v8:128.2776 v9:113.9541 v10:97.0877 v11:96.6107 v12:107.8195 v13:277.5444] eff_attn_scale:[v0:0.2734 v1:0.8502 v2:0.7689 v3:0.9572 v4:1.3303 v5:0.6336 v6:0.6512 v7:0.7759 v8:1.2633 v9:0.5614 v10:0.5139 v11:0.5482 v12:0.8663 v13:1.5439] +step:9200/20000 val_loss:2.1655 val_bpb:1.2825 train_time:545357ms step_avg:59.28ms +step:9400/20000 train_loss:2.1757 train_time:557172ms step_avg:59.27ms +step:9400 shared0_alpha:mean=0.360,std=0.063 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.640,std=0.065 shared3_alpha:mean=0.699,std=0.066 eff_mlp_scale:[v0:263.2965 v1:136.5867 v2:131.9275 v3:142.5522 v4:134.5448 v5:137.1655 v6:129.1675 v7:130.2536 v8:130.0786 v9:114.5939 v10:97.7036 v11:97.2709 v12:109.4223 v13:279.7477] eff_attn_scale:[v0:0.2731 v1:0.8550 v2:0.7636 v3:0.9639 v4:1.3521 v5:0.6342 v6:0.6539 v7:0.7821 v8:1.2793 v9:0.5580 v10:0.5090 v11:0.5538 v12:0.8841 v13:1.5618] +step:9400/20000 val_loss:2.1559 val_bpb:1.2768 train_time:557203ms step_avg:59.28ms +step:9600/20000 train_loss:2.1789 train_time:569015ms step_avg:59.27ms +step:9600 shared0_alpha:mean=0.359,std=0.063 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.640,std=0.065 shared3_alpha:mean=0.699,std=0.067 eff_mlp_scale:[v0:264.5626 v1:137.2842 v2:132.6420 v3:143.3359 v4:135.3772 v5:137.8659 v6:129.8671 v7:130.9697 v8:130.8833 v9:115.1791 v10:98.7878 v11:97.8057 v12:110.0993 v13:281.6810] eff_attn_scale:[v0:0.2760 v1:0.8545 v2:0.7619 v3:0.9659 v4:1.3683 v5:0.6338 v6:0.6598 v7:0.7880 v8:1.2795 v9:0.5536 v10:0.5106 v11:0.5507 v12:0.8930 v13:1.5710] +step:9600/20000 val_loss:2.1464 val_bpb:1.2712 train_time:569047ms step_avg:59.28ms +step:9800/20000 train_loss:2.0988 train_time:580854ms step_avg:59.27ms +step:9800 shared0_alpha:mean=0.359,std=0.062 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.699,std=0.067 eff_mlp_scale:[v0:265.2037 v1:138.5365 v2:133.1247 v3:143.8686 v4:136.1328 v5:138.5365 v6:130.8967 v7:131.4564 v8:131.6138 v9:115.7394 v10:99.1473 v11:98.1692 v12:110.7138 v13:285.5064] eff_attn_scale:[v0:0.2792 v1:0.8503 v2:0.7623 v3:0.9771 v4:1.3634 v5:0.6438 v6:0.6556 v7:0.7894 v8:1.2953 v9:0.5588 v10:0.5095 v11:0.5547 v12:0.8915 v13:1.5897] +step:9800/20000 val_loss:2.1375 val_bpb:1.2660 train_time:580886ms step_avg:59.27ms +step:10000/20000 train_loss:2.1327 train_time:592699ms step_avg:59.27ms +step:10000 shared0_alpha:mean=0.359,std=0.062 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.698,std=0.067 eff_mlp_scale:[v0:265.1537 v1:138.7435 v2:133.3693 v3:144.2558 v4:136.6628 v5:138.7435 v6:131.1372 v7:131.8102 v8:132.1263 v9:115.9123 v10:99.3294 v11:98.4333 v12:111.1449 v13:286.6695] eff_attn_scale:[v0:0.2733 v1:0.8454 v2:0.7637 v3:0.9686 v4:1.3718 v5:0.6391 v6:0.6563 v7:0.7894 v8:1.3032 v9:0.5582 v10:0.5131 v11:0.5504 v12:0.8970 v13:1.5789] +step:10000/20000 val_loss:2.1286 val_bpb:1.2606 train_time:592731ms step_avg:59.27ms +step:10122/20000 val_loss:2.1246 val_bpb:1.2583 train_time:600039ms step_avg:59.28ms +stopping_early: wallclock_cap train_time:600039ms step:10122/20000 +peak memory allocated: 13736 MiB reserved: 14068 MiB +Serialized model: 45178938 bytes +Code size: 57024 bytes +Total submission size: 45235962 bytes +Serialized model int8+zlib: 10717286 bytes (payload:11638976 raw_torch:11670071 payload_ratio:3.88x) +Total submission size int8+zlib: 10774310 bytes +final_int8_zlib_roundtrip val_loss:2.1374 val_bpb:1.2659 eval_time:1873ms +final_int8_zlib_roundtrip_exact val_loss:2.13735865 val_bpb:1.26586418 From 37e64221609c9a0dc56bc54ceea410a8c75702ae Mon Sep 17 00:00:00 2001 From: Alexandr Azizyan Date: Thu, 26 Mar 2026 17:59:16 +0400 Subject: [PATCH 06/10] docs: add README for PR submission --- .../README.md | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md new file mode 100644 index 0000000000..8475411d08 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md @@ -0,0 +1,100 @@ +# Non-Record: First Viable 3-Loop Recurrence — Birkhoff + Output-LN + Timestep Scaling + +## Result + +**1.2659 BPB** post-quantization (10.7 MB model + 57KB code). Config: 1 prelude + 4 shared × 3 loops + 1 coda = **14 effective layers from 6 unique blocks**. Q-gap +0.0076. Depth recurrence at 3 loops has never worked in this competition — prior attempts produced +4.3 BPB blowup (PR #579). This submission brings Q-gap down to +0.0076 using three new techniques. + +## Technique Summary + +| Technique | What it does | Delta | Verdict | +|-----------|-------------|-------|---------| +| **Output-LN** | Moves RMSNorm from MLP input to output, letting shared weights see different magnitudes per iteration | −0.007 BPB (screening) | Critical — nothing works without it | +| **Birkhoff mixing** | Constrains residual mixing to convex combination (spectral norm ≤ 1), preventing signal blowup across loops | Enables 3-loop stability (Q-gap +0.0076 vs prior +4.3 BPB blowup) | Required for 3-loop stability, but hurts alone | +| **Timestep scaling** | Per-iteration learned scale vectors (capped ±4.0), stored as float16 passthrough | Q-gap −26–30% | Helps quantization, not training | +| **Prelude-coda** | Unique first/last layers, shared middle blocks (Huginn-style) | −0.016 BPB (screening) | Biggest single BPB win | +| **LeakyReLU(0.5)²** | Preserves negative signal through quadratic activation | Adopted from SOTA | Necessary with Output-LN to avoid dead neurons | + +## Key Findings + +- **Timestep scaling helps quantization, not training.** Pre-quant BPB is unchanged (1.2578 vs 1.2580), but Q-gap drops 26–30%. The gammas are float16 passthrough params that bypass int8 quantization entirely. +- **Birkhoff alone hurts.** Run C' (Birkhoff only) is +0.002 BPB worse than bare recurrence. It only helps when paired with Output-LN. +- **Q-gap scales with training duration.** Screening (2000 steps) shows Q-gap +0.0019. Full-scale (10k steps) shows +0.0076–0.0126. Screening underestimates the quantization problem by 4–7×. +- **Output-LN is the critical technique.** Without it, mixing alphas collapse to ~0.48 (uniform) and MLP scale drops to 0.2–0.3. With it, alphas learn a meaningful gradient (0.37→0.70 across layers). +- **Prelude-coda gives the biggest single improvement** (−0.016 BPB at screening). Boundary layers need unique parameters. +- **3 loops are viable for the first time.** Q-gap +0.0076 at 3 loops, vs prior results showing catastrophic failure (+4.3 BPB, PR #579). + +## Techniques Applicable to Non-Recurrent Submissions + +Output-LN could benefit any submission using quadratic activations (relu², leaky_relu²) — it lets the MLP see unnormalized inputs while bounding its output, which may improve gradient flow in deeper networks. Birkhoff mixing is a drop-in replacement for learned residual mixing with fewer parameters and bounded spectral norm. Per-layer scaling vectors (the non-recurrent version of timestep scaling) add ~4KB of float16 params that survive quantization and could reduce Q-gap on any deep submission. + +## Screening Results (2000 steps, 1×H100) + +| Run | Config | Post-Q BPB | Q-Gap | Δ vs B' (bare) | +|-----|--------|-----------|-------|-----------------| +| B' | 4×2 bare recurrence | 1.3637 | +0.0024 | — | +| C' | 4×2 + birkhoff only | 1.3660 | +0.0024 | +0.002 | +| C | 4×2 + peri + birkhoff | 1.3587 | +0.0020 | −0.005 | +| D | 4×2 + peri + birk + timestep | 1.3584 | +0.0019 | −0.005 | +| E | 1+3×2+1 all fixes | 1.3428 | +0.0019 | −0.021 | +| F | 1+2×3+1 all (3 loops) | 1.3622 | +0.0019 | −0.002 | + +## Full-Scale Results (600s, 8×H100) + +| Run | Config | Eff. Layers | Pre-Q BPB | Post-Q BPB | Q-Gap | +|-----|--------|-------------|-----------|------------|-------| +| H | 1+4×2+1 peri+birk | 10 | 1.2578 | 1.2704 | +0.0126 | +| I | 1+4×2+1 peri+birk+ts(cap4) | 10 | 1.2580 | 1.2668 | +0.0088 | +| J | 1+4×3+1 peri+birk (3 loops) | 14 | 1.2567 | 1.2670 | +0.0103 | +| **K** | **1+4×3+1 peri+birk+ts(cap4)** | **14** | **1.2583** | **1.2659** | **+0.0076** | + +## How to Reproduce + +The submitted run (Run K) from repo root: + +```bash +# 8×H100 (600s wallclock cap) +SEED=1337 MAX_WALLCLOCK_SECONDS=600 VAL_LOSS_EVERY=200 TRAIN_LOG_EVERY=200 \ + DATA_PATH=./data/datasets/fineweb10B_sp1024 \ + TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model VOCAB_SIZE=1024 \ + NUM_LAYERS=14 NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=3 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +Full ablation via bundled scripts: + +```bash +# Screening (7 runs, on 1xH100) +bash scripts/run_screening.sh + +# Full-scale (5 runs, on 8×H100) +bash scripts/run_fullscale.sh +``` + +### Environment Variables + +| Variable | Run K Value | Description | +|----------|-------------|-------------| +| `NUM_PRELUDE` | 1 | Unique prefix layers | +| `NUM_CODA` | 1 | Unique suffix layers | +| `NUM_SHARED` | 4 | Shared blocks in the recurrent loop | +| `NUM_LOOPS` | 3 | Loop iterations over shared blocks | +| `USE_PERI_NORM` | 1 | Output-LN (norm on MLP output, not input) | +| `USE_BIRKHOFF_MIX` | 1 | Sigmoid-constrained residual mixing | +| `USE_TIMESTEP_SCALE` | 1 | Per-iteration learned scale vectors | +| `TIMESTEP_GAMMA_MAX` | 4.0 | Cap for timestep gammas (0 = uncapped) | +| `LEAKY_RELU_SLOPE` | 0.5 | Negative slope for leaky relu² (0.0 = relu²) | + +## Files +``` +├── train_gpt.py # Modified training script +├── train_log.txt # Run K log (primary submission) +├── submission.json # Competition metadata +├── research_notes.md # Theory + citations +├── logs/ # All 12 run logs (s1_Ap–F screening, s2_G–K full-scale) +└── scripts/ # run_screening.sh, run_fullscale.sh +``` + +## Links + +See [research_notes.md](research_notes.md) for theory, citations, and detailed technique descriptions. From a3a861317bac7e9b6c3a7226c80cac8495a32774 Mon Sep 17 00:00:00 2001 From: Alexandr Azizyan Date: Thu, 26 Mar 2026 18:29:41 +0400 Subject: [PATCH 07/10] docs: polish README and research notes for PR submission --- .../README.md | 6 +++--- .../research_notes.md | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md index 8475411d08..ced702ffc1 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md @@ -2,14 +2,14 @@ ## Result -**1.2659 BPB** post-quantization (10.7 MB model + 57KB code). Config: 1 prelude + 4 shared × 3 loops + 1 coda = **14 effective layers from 6 unique blocks**. Q-gap +0.0076. Depth recurrence at 3 loops has never worked in this competition — prior attempts produced +4.3 BPB blowup (PR #579). This submission brings Q-gap down to +0.0076 using three new techniques. +**1.2659 BPB** post-quantization (10.7 MB model + 57KB code). Config: 1 prelude + 4 shared × 3 loops + 1 coda = **14 effective layers from 6 unique blocks**. Q-gap +0.0076. Depth recurrence at 3 loops has never worked in this competition — prior 3-loop attempts failed catastrophically (PR #363 measured ~900× quantization amplification over 3 cycles). This submission brings Q-gap down to +0.0076 using three new techniques. ## Technique Summary | Technique | What it does | Delta | Verdict | |-----------|-------------|-------|---------| | **Output-LN** | Moves RMSNorm from MLP input to output, letting shared weights see different magnitudes per iteration | −0.007 BPB (screening) | Critical — nothing works without it | -| **Birkhoff mixing** | Constrains residual mixing to convex combination (spectral norm ≤ 1), preventing signal blowup across loops | Enables 3-loop stability (Q-gap +0.0076 vs prior +4.3 BPB blowup) | Required for 3-loop stability, but hurts alone | +| **Birkhoff mixing** | Constrains residual mixing to convex combination (spectral norm ≤ 1), preventing signal blowup across loops | Enables 3-loop stability (Q-gap +0.0076 vs prior catastrophic 3-loop failure) | Required for 3-loop stability, but hurts alone | | **Timestep scaling** | Per-iteration learned scale vectors (capped ±4.0), stored as float16 passthrough | Q-gap −26–30% | Helps quantization, not training | | **Prelude-coda** | Unique first/last layers, shared middle blocks (Huginn-style) | −0.016 BPB (screening) | Biggest single BPB win | | **LeakyReLU(0.5)²** | Preserves negative signal through quadratic activation | Adopted from SOTA | Necessary with Output-LN to avoid dead neurons | @@ -21,7 +21,7 @@ - **Q-gap scales with training duration.** Screening (2000 steps) shows Q-gap +0.0019. Full-scale (10k steps) shows +0.0076–0.0126. Screening underestimates the quantization problem by 4–7×. - **Output-LN is the critical technique.** Without it, mixing alphas collapse to ~0.48 (uniform) and MLP scale drops to 0.2–0.3. With it, alphas learn a meaningful gradient (0.37→0.70 across layers). - **Prelude-coda gives the biggest single improvement** (−0.016 BPB at screening). Boundary layers need unique parameters. -- **3 loops are viable for the first time.** Q-gap +0.0076 at 3 loops, vs prior results showing catastrophic failure (+4.3 BPB, PR #579). +- **3 loops are viable for the first time.** Q-gap +0.0076 at 3 loops, vs prior 3-loop attempts that failed catastrophically (PR #363 measured ~900× quantization amplification over 3 cycles). ## Techniques Applicable to Non-Recurrent Submissions diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md index f3cbc09101..8ea3b565f4 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md @@ -8,7 +8,7 @@ Depth recurrence (weight-shared looping) has been attempted 15+ times in this co ### 1a. Quantization Error Amplification -When the same weight matrix $W$ is applied $k$ times in a forward pass, post-training quantization error $\epsilon$ from int8 rounding compounds across iterations. PR #363 measured ~900× amplification over 3 cycles. PR #579 confirmed empirically: "2 loops survive, 3+ catastrophic (+4.3 BPB)." PR #623 showed AWQ (activation-aware quantization) closes 63% of the gap but cannot eliminate compounding. +When the same weight matrix $W$ is applied $k$ times in a forward pass, post-training quantization error $\epsilon$ from int8 rounding compounds across iterations. PR #363 measured ~900× amplification over 3 cycles. GPTQ quantizes layer-by-layer, compensating downstream weights for each layer's rounding error (Frantar et al., 2023, §3). With shared weights, this compensation is impossible. The same quantized matrix must serve all iterations. Errors from early iterations propagate uncompensated through later ones. @@ -16,7 +16,7 @@ GPTQ quantizes layer-by-layer, compensating downstream weights for each layer's ### 1b. Per-Iteration Identity Collapse -Shared weights produce identical transformations for identical inputs. Without per-iteration conditioning, all loop iterations compute the same function. This collapses depth recurrence to a single effective pass. PR #319 showed learned gating parameters collapse to identity. PR #484 found attention layers need per-iteration identity more than MLP. +Shared weights produce identical transformations for identical inputs. Without per-iteration conditioning, all loop iterations compute the same function. This collapses depth recurrence to a single effective pass. Dehghani et al. (2019, §2.1) addressed this with sinusoidal timestep embeddings added at each recurrence step. Xu & Sato (2025) formalized the limitation: without timestep encoding, looped transformers have strict approximation rate bounds (Lemma 4.1). @@ -70,14 +70,14 @@ This follows the HC → mHC → mHC-lite simplification chain. mHC (Xie et al., **Surprising finding.** Timestep scaling has near-zero effect on pre-quantization BPB (Run H vs I: 1.2578 vs 1.2580) but reduces quantization gap by 26–30% (H vs I: +0.0126 → +0.0088; J vs K: +0.0103 → +0.0076). The mechanism: capped gammas are stored as float16 passthrough parameters that bypass int8 quantization entirely. They provide per-iteration specialization that survives the quantization pipeline. In short, timestep scaling helps quantization, not training. -**Result.** Run K (best): post-quant 1.2659 BPB, Q-gap +0.0076 — vs prior results showing catastrophic failure at 3+ loops (PR #579, +4.3 BPB). +**Result.** Run K (best): post-quant 1.2659 BPB, Q-gap +0.0076 — vs prior 3-loop attempts that failed catastrophically (PR #363 measured ~900× quantization amplification over 3 cycles). > Xu, K. & Sato, I. (2025). "On Expressive Power of Looped Transformers: Theoretical Analysis and Enhancement via Timestep Encoding." ICML 2025. [arXiv:2410.01405](https://arxiv.org/abs/2410.01405) > Perez, E., Strub, F., de Vries, H., Dumoulin, V. & Courville, A. (2018). "FiLM: Visual Reasoning with a General Conditioning Layer." AAAI 2018. [arXiv:1709.07871](https://arxiv.org/abs/1709.07871) ## 5. Supporting Technique: Prelude-Recurrent-Coda Architecture -First and last transformer layers perform fundamentally different functions — input encoding and output prediction — compared to middle layers that do iterative refinement. Forcing boundary layers into shared weights compromises both functions. Geiping et al. (2025) demonstrated this at scale with Huginn 3.5B: 2 prelude + 4 shared (×32 loops) + 2 coda layers, achieving 132 effective depth from 3.5B parameters. In this competition, PR #575 independently explored prefix + 2 tied (×3) + suffix. +First and last transformer layers perform fundamentally different functions — input encoding and output prediction — compared to middle layers that do iterative refinement. Forcing boundary layers into shared weights compromises both functions. Geiping et al. (2025) demonstrated this at scale with Huginn 3.5B: 2 prelude + 4 shared (×32 loops) + 2 coda layers, achieving 132 effective depth from 3.5B parameters. **Result.** Run E (1+3×2+1, all fixes) vs Run D (4×2, all fixes): −0.016 BPB — the largest single architectural improvement in the ablation. Boundary layers need unique parameters. From 1cffe4d3daa3ce32a8a46477984abc5375280db8 Mon Sep 17 00:00:00 2001 From: Alexandr Azizyan Date: Tue, 31 Mar 2026 19:01:37 +0400 Subject: [PATCH 08/10] feat: add FiLM bias and attention-only sharing ablations (Series 3) --- .../README.md | 35 +- .../logs/s3_L.txt | 1828 +++++++++++++++++ .../logs/s3_M.txt | 1578 ++++++++++++++ .../logs/s3_N.txt | 1828 +++++++++++++++++ .../logs/s3_O.txt | 1771 ++++++++++++++++ .../research_notes.md | 77 + .../scripts/run_fullscale2.sh | 61 + .../train_gpt.py | 252 ++- 8 files changed, 7396 insertions(+), 34 deletions(-) create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_L.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_M.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_N.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_O.txt create mode 100755 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale2.sh diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md index ced702ffc1..1e426cd848 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md @@ -13,6 +13,7 @@ | **Timestep scaling** | Per-iteration learned scale vectors (capped ±4.0), stored as float16 passthrough | Q-gap −26–30% | Helps quantization, not training | | **Prelude-coda** | Unique first/last layers, shared middle blocks (Huginn-style) | −0.016 BPB (screening) | Biggest single BPB win | | **LeakyReLU(0.5)²** | Preserves negative signal through quadratic activation | Adopted from SOTA | Necessary with Output-LN to avoid dead neurons | +| **FiLM bias** | Per-iteration shift vectors (β), completing scale+shift FiLM conditioning | −0.003 BPB | Clean free win, additive with gammas | ## Key Findings @@ -22,6 +23,9 @@ - **Output-LN is the critical technique.** Without it, mixing alphas collapse to ~0.48 (uniform) and MLP scale drops to 0.2–0.3. With it, alphas learn a meaningful gradient (0.37→0.70 across layers). - **Prelude-coda gives the biggest single improvement** (−0.016 BPB at screening). Boundary layers need unique parameters. - **3 loops are viable for the first time.** Q-gap +0.0076 at 3 loops, vs prior 3-loop attempts that failed catastrophically (PR #363 measured ~900× quantization amplification over 3 cycles). +- **Attention-only sharing (shared attention, unique MLPs per iteration) gives −0.026 post-Q BPB** — the largest improvement found — but costs +3.9MB artifact (14.65MB total), leaving insufficient room for SOTA features. +- **FiLM bias adds −0.003 BPB at both 2 and 3 loops** with zero artifact/throughput cost. Additive with timestep scaling gammas. +- **ALBERT (Lan et al., 2020) found attention sharing is nearly free while FFN sharing causes most degradation.** s3_L confirms: the model needs per-iteration MLP differentiation, not per-iteration attention differentiation. ## Techniques Applicable to Non-Recurrent Submissions @@ -47,6 +51,22 @@ Output-LN could benefit any submission using quadratic activations (relu², leak | J | 1+4×3+1 peri+birk (3 loops) | 14 | 1.2567 | 1.2670 | +0.0103 | | **K** | **1+4×3+1 peri+birk+ts(cap4)** | **14** | **1.2583** | **1.2659** | **+0.0076** | +## Series 3: Follow-up Ablations (600s, 8×H100) + +| Run | Config | Eff. Layers | Pre-Q BPB | Post-Q BPB | Q-Gap | +|-----|--------|-------------|-----------|------------|-------| +| L | 1+4×2+1 shared-attn + unique MLPs | 10 | 1.2333 | 1.2406 | +0.0073 | +| N | 1+4×2+1 full share + FiLM bias | 10 | 1.2555 | 1.2641 | +0.0086 | +| O | 1+4×3+1 full share + FiLM bias | 14 | 1.2547 | 1.2625 | +0.0078 | + +Run M (1+4×3+1 attn-only, 3 loops) crashed during torch.compile with 12 UniqueMLP modules. Works without compile (verified via smoke test). + +## Next Direction + +Run s3_L validated that per-iteration MLP differentiation is critical (−0.026 BPB), but unique MLPs per loop iteration are too expensive: 12 unique MLP modules add ~4MB to the artifact (14.65MB total), leaving only ~1.35MB headroom — insufficient for integrating SOTA features. The 3-loop variant (s3_M) also crashes torch.compile(fullgraph=True). + +The planned approach achieves per-iteration differentiation at ~110KB instead of ~12MB: per-iteration **unique input norms** (24KB) control what the shared MLP sees at each iteration, **learned depth embeddings** (14KB) provide positional identity, and **FiLM gammas + betas** (28KB) modulate residual contributions — all stored as FP16 passthrough. This leaves ~4.8MB headroom for SOTA feature integration while preserving the per-iteration specialization that s3_L showed is essential. + ## How to Reproduce The submitted run (Run K) from repo root: @@ -69,6 +89,9 @@ bash scripts/run_screening.sh # Full-scale (5 runs, on 8×H100) bash scripts/run_fullscale.sh + +# Follow-up ablations (4 runs, on 8×H100) +bash scripts/run_fullscale2.sh ``` ### Environment Variables @@ -86,13 +109,21 @@ bash scripts/run_fullscale.sh | `LEAKY_RELU_SLOPE` | 0.5 | Negative slope for leaky relu² (0.0 = relu²) | ## Files + ``` ├── train_gpt.py # Modified training script ├── train_log.txt # Run K log (primary submission) ├── submission.json # Competition metadata ├── research_notes.md # Theory + citations -├── logs/ # All 12 run logs (s1_Ap–F screening, s2_G–K full-scale) -└── scripts/ # run_screening.sh, run_fullscale.sh +├── ablation2_report.md # Full comparison report (Series 1–3) +├── logs/ # All run logs +│ ├── s1_Ap–F.txt # Screening (7 runs) +│ ├── s2_G–K.txt # Full-scale (5 runs) +│ └── s3_L–O.txt # Follow-up ablations (4 runs) +└── scripts/ + ├── run_screening.sh # Series 1 screening + ├── run_fullscale.sh # Series 2 full-scale + └── run_fullscale2.sh # Series 3 follow-up ablations ``` ## Links diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_L.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_L.txt new file mode 100644 index 0000000000..d86fa23743 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_L.txt @@ -0,0 +1,1828 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) + share_attn_only = bool(int(os.environ.get("SHARE_ATTN_ONLY", "0"))) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0, use_bias: bool = False): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + if use_bias: + self.attn_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + self.mlp_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + else: + self.attn_beta = None + self.mlp_beta = None + + def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + ab = self.attn_beta[v] if self.attn_beta is not None else None + mb = self.mlp_beta[v] if self.mlp_beta is not None else None + return ag, mg, ab, mb + + +class SharedAttnLayer(nn.Module): + """Shared attention layer (mixing + attention only, no MLP) for attn-only sharing mode.""" + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_birkhoff_mix: bool = False, + ): + super().__init__() + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + return x + + +class UniqueMLP(nn.Module): + """Unique MLP per virtual shared position for attn-only sharing mode.""" + def __init__( + self, + dim: int, + mlp_mult: int, + use_peri_norm: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + if use_peri_norm: + self.mlp_out_norm = RMSNorm() + else: + self.mlp_norm = RMSNorm() + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: Tensor, + ts_mlp_gamma: Tensor | None = None, + ts_mlp_beta: Tensor | None = None) -> Tensor: + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + out = mlp_s * mlp_out + if ts_mlp_beta is not None: + out = out + ts_mlp_beta[None, None, :] + return out + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None, + ts_mlp_beta: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + if ts_mlp_beta is not None: + x = x + ts_mlp_beta[None, None, :] + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + use_timestep_bias: bool = False, + share_attn_only: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + self.share_attn_only = share_attn_only if self.use_recurrence else False + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + if self.share_attn_only: + shared_attn_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + rope_base=rope_base, qk_gain_init=qk_gain_init, + use_birkhoff_mix=use_birkhoff_mix, + ) + unique_mlp_kwargs = dict( + dim=model_dim, mlp_mult=mlp_mult, + use_peri_norm=use_peri_norm, + leaky_relu_slope=leaky_relu_slope, + ) + self.shared_attn_layers = nn.ModuleList([SharedAttnLayer(**shared_attn_kwargs) for _ in range(num_shared)]) + self.unique_mlps = nn.ModuleList([UniqueMLP(**unique_mlp_kwargs) for _ in range(num_shared * self.num_loops)]) + self.shared_blocks = nn.ModuleList() # empty — keeps diagnostics safe + else: + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.shared_attn_layers = nn.ModuleList() + self.unique_mlps = nn.ModuleList() + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None, None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + if self.share_attn_only: + vid = 0 + for _loop in range(self.num_loops): + for attn_layer in self.shared_attn_layers: + ag, mg, ab, mb = self._get_ts(v) + x = attn_layer(x, x0, ag, ab) + x = x + self.unique_mlps[vid](x, mg, mb) + vid += 1 + v += 1 + else: + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + for block in self.coda_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block/layer + if gpt.share_attn_only: + for i, layer in enumerate(gpt.shared_attn_layers): + if hasattr(layer, "resid_mix_logit"): + a = torch.sigmoid(layer.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + else: + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + effective_count = gpt.num_prelude + len(gpt.shared_blocks if not gpt.share_attn_only else gpt.shared_attn_layers) * gpt.num_loops + gpt.num_coda + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + + # Prelude blocks + for block in gpt.prelude_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Shared positions + if gpt.share_attn_only: + vid = 0 + for _loop in range(gpt.num_loops): + for layer in gpt.shared_attn_layers: + asc = layer.attn_scale.norm().item() + ms = gpt.unique_mlps[vid].mlp_scale.norm().item() + d = layer.attn_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + vid += 1 + v += 1 + else: + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Coda blocks + for block in gpt.coda_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + if gpt.timestep_scale is not None and gpt.timestep_scale.attn_beta is not None: + attn_bias_norms: list[str] = [] + mlp_bias_norms: list[str] = [] + for vi in range(effective_count): + ab_rms = gpt.timestep_scale.attn_beta[vi].norm().item() / gpt.timestep_scale.attn_beta[vi].numel() ** 0.5 + mb_rms = gpt.timestep_scale.mlp_beta[vi].norm().item() / gpt.timestep_scale.mlp_beta[vi].numel() ** 0.5 + attn_bias_norms.append(f"v{vi}:{ab_rms:.4f}") + mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") + parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") + parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + use_timestep_bias=args.use_timestep_bias, + share_attn_only=args.share_attn_only, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + if base_model.share_attn_only: + block_named_params.extend(base_model.shared_attn_layers.named_parameters()) + block_named_params.extend(base_model.unique_mlps.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + num_shared = len(base_model.shared_attn_layers) if base_model.share_attn_only else len(base_model.shared_blocks) + eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda + shared_label = f"shared_attn:{num_shared}" if base_model.share_attn_only else f"shared:{num_shared}" + log0(f"recurrence:enabled prelude:{base_model.num_prelude} {shared_label} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Mon Mar 30 15:28:45 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 31C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 33C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 32C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:15750192 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared_attn:4 loops:2 coda:1 effective_layers:10 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:10240 +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9370 val_bpb:4.1085 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9360 train_time:32ms step_avg:32.34ms +step:2/20000 train_loss:9.5197 train_time:77ms step_avg:38.44ms +step:3/20000 train_loss:7.3187 train_time:120ms step_avg:39.90ms +step:4/20000 train_loss:8.9283 train_time:161ms step_avg:40.33ms +step:5/20000 train_loss:8.8027 train_time:203ms step_avg:40.59ms +step:6/20000 train_loss:8.1629 train_time:244ms step_avg:40.72ms +step:7/20000 train_loss:6.6962 train_time:286ms step_avg:40.86ms +step:8/20000 train_loss:6.0630 train_time:328ms step_avg:41.04ms +step:9/20000 train_loss:5.5920 train_time:370ms step_avg:41.10ms +step:10/20000 train_loss:5.2862 train_time:410ms step_avg:41.00ms +step:200/20000 train_loss:2.7705 train_time:8457ms step_avg:42.28ms +step:200 shared0_alpha:mean=0.484,std=0.041 shared1_alpha:mean=0.486,std=0.037 shared2_alpha:mean=0.498,std=0.039 shared3_alpha:mean=0.527,std=0.041 eff_mlp_scale:[v0:28.3880 v1:24.0799 v2:29.2892 v3:27.9932 v4:26.7853 v5:25.5424 v6:31.0116 v7:35.1224 v8:36.9525 v9:55.7187] eff_attn_scale:[v0:13.5735 v1:8.6653 v2:11.3481 v3:10.1426 v4:9.1944 v5:10.3333 v6:10.7553 v7:9.9011 v8:9.6639 v9:15.2824] +step:200/20000 val_loss:2.7662 val_bpb:1.6383 train_time:8503ms step_avg:42.51ms +step:400/20000 train_loss:2.3561 train_time:16988ms step_avg:42.47ms +step:400 shared0_alpha:mean=0.510,std=0.042 shared1_alpha:mean=0.524,std=0.040 shared2_alpha:mean=0.530,std=0.043 shared3_alpha:mean=0.559,std=0.045 eff_mlp_scale:[v0:31.8850 v1:27.8652 v2:36.7028 v3:36.8398 v4:36.4799 v5:34.9039 v6:38.9597 v7:41.9912 v8:42.0571 v9:70.8682] eff_attn_scale:[v0:6.1725 v1:4.5493 v2:5.7020 v3:5.3801 v4:4.7731 v5:5.6867 v6:5.3368 v7:4.8935 v8:4.9050 v9:8.9557] +step:400/20000 val_loss:2.5595 val_bpb:1.5159 train_time:17000ms step_avg:42.50ms +step:600/20000 train_loss:2.5661 train_time:25492ms step_avg:42.49ms +step:600 shared0_alpha:mean=0.524,std=0.046 shared1_alpha:mean=0.545,std=0.044 shared2_alpha:mean=0.551,std=0.046 shared3_alpha:mean=0.579,std=0.048 eff_mlp_scale:[v0:34.2921 v1:30.8115 v2:41.2858 v3:41.1313 v4:42.7592 v5:40.7883 v6:44.1080 v7:45.8696 v8:45.3634 v9:84.9876] eff_attn_scale:[v0:2.9338 v1:2.5483 v2:3.2053 v3:3.3108 v4:2.9506 v5:3.2849 v6:3.1069 v7:2.9364 v8:3.0883 v9:5.9348] +step:600/20000 val_loss:2.4710 val_bpb:1.4635 train_time:25504ms step_avg:42.51ms +step:800/20000 train_loss:2.3250 train_time:34015ms step_avg:42.52ms +step:800 shared0_alpha:mean=0.533,std=0.051 shared1_alpha:mean=0.558,std=0.048 shared2_alpha:mean=0.562,std=0.049 shared3_alpha:mean=0.589,std=0.051 eff_mlp_scale:[v0:37.1737 v1:34.2931 v2:45.0908 v3:44.2544 v4:47.6881 v5:45.3110 v6:47.7324 v7:48.9970 v8:48.8171 v9:97.2529] eff_attn_scale:[v0:1.8170 v1:1.7692 v2:2.2562 v3:2.4720 v4:2.1590 v5:2.3749 v6:2.2404 v7:2.1630 v8:2.2899 v9:4.5124] +step:800/20000 val_loss:2.4096 val_bpb:1.4271 train_time:34027ms step_avg:42.53ms +step:1000/20000 train_loss:2.4029 train_time:42561ms step_avg:42.56ms +step:1000 shared0_alpha:mean=0.538,std=0.055 shared1_alpha:mean=0.567,std=0.050 shared2_alpha:mean=0.570,std=0.051 shared3_alpha:mean=0.595,std=0.053 eff_mlp_scale:[v0:40.1535 v1:37.3806 v2:49.0057 v3:47.2219 v4:52.1159 v5:49.0925 v6:51.2211 v7:52.0917 v8:52.0288 v9:106.7244] eff_attn_scale:[v0:1.3829 v1:1.4232 v2:1.8084 v3:2.0602 v4:1.7670 v5:1.9300 v6:1.8362 v7:1.7773 v8:1.8985 v9:3.7188] +step:1000/20000 val_loss:2.3672 val_bpb:1.4020 train_time:42574ms step_avg:42.57ms +step:1200/20000 train_loss:2.4194 train_time:51106ms step_avg:42.59ms +step:1200 shared0_alpha:mean=0.539,std=0.058 shared1_alpha:mean=0.574,std=0.052 shared2_alpha:mean=0.576,std=0.054 shared3_alpha:mean=0.598,std=0.055 eff_mlp_scale:[v0:42.5497 v1:40.1006 v2:52.1604 v3:50.0465 v4:55.9540 v5:52.6658 v6:54.4203 v7:55.2098 v8:55.1591 v9:116.8508] eff_attn_scale:[v0:1.1190 v1:1.1772 v2:1.5333 v3:1.7682 v4:1.5126 v5:1.6406 v6:1.5782 v7:1.5405 v8:1.6476 v9:3.2626] +step:1200/20000 val_loss:2.3369 val_bpb:1.3840 train_time:51118ms step_avg:42.60ms +step:1400/20000 train_loss:2.4671 train_time:59651ms step_avg:42.61ms +step:1400 shared0_alpha:mean=0.542,std=0.063 shared1_alpha:mean=0.579,std=0.054 shared2_alpha:mean=0.580,std=0.056 shared3_alpha:mean=0.600,std=0.057 eff_mlp_scale:[v0:45.8068 v1:42.8228 v2:55.7806 v3:53.0418 v4:59.4204 v5:55.7295 v6:57.4700 v7:57.8427 v8:58.4009 v9:125.4551] eff_attn_scale:[v0:0.9616 v1:1.0284 v2:1.3442 v3:1.5946 v4:1.3630 v5:1.4633 v6:1.4229 v7:1.3850 v8:1.5051 v9:2.9855] +step:1400/20000 val_loss:2.3156 val_bpb:1.3714 train_time:59663ms step_avg:42.62ms +step:1600/20000 train_loss:2.1379 train_time:68203ms step_avg:42.63ms +step:1600 shared0_alpha:mean=0.542,std=0.065 shared1_alpha:mean=0.583,std=0.055 shared2_alpha:mean=0.583,std=0.057 shared3_alpha:mean=0.601,std=0.058 eff_mlp_scale:[v0:48.1216 v1:45.2145 v2:59.0290 v3:55.5124 v4:62.6978 v5:58.7741 v6:59.9760 v7:60.2620 v8:60.9913 v9:132.8601] eff_attn_scale:[v0:0.8449 v1:0.9510 v2:1.2181 v3:1.4752 v4:1.2428 v5:1.3505 v6:1.2993 v7:1.2741 v8:1.3789 v9:2.7726] +step:1600/20000 val_loss:2.3014 val_bpb:1.3630 train_time:68215ms step_avg:42.63ms +step:1800/20000 train_loss:2.2380 train_time:76760ms step_avg:42.64ms +step:1800 shared0_alpha:mean=0.543,std=0.068 shared1_alpha:mean=0.587,std=0.055 shared2_alpha:mean=0.586,std=0.059 shared3_alpha:mean=0.600,std=0.059 eff_mlp_scale:[v0:50.3525 v1:47.5188 v2:62.2251 v3:58.4682 v4:65.8700 v5:61.3062 v6:62.6660 v7:63.3270 v8:64.1538 v9:140.4091] eff_attn_scale:[v0:0.7518 v1:0.8570 v2:1.1300 v3:1.3748 v4:1.1462 v5:1.2349 v6:1.2031 v7:1.1919 v8:1.2895 v9:2.5929] +step:1800/20000 val_loss:2.2853 val_bpb:1.3535 train_time:76771ms step_avg:42.65ms +step:2000/20000 train_loss:2.2872 train_time:85309ms step_avg:42.65ms +step:2000 shared0_alpha:mean=0.543,std=0.069 shared1_alpha:mean=0.590,std=0.056 shared2_alpha:mean=0.587,std=0.061 shared3_alpha:mean=0.599,std=0.060 eff_mlp_scale:[v0:53.0668 v1:50.1805 v2:64.9040 v3:61.0745 v4:68.7187 v5:63.8040 v6:65.1654 v7:65.9291 v8:66.6957 v9:146.1840] eff_attn_scale:[v0:0.6805 v1:0.7953 v2:1.0555 v3:1.3069 v4:1.0770 v5:1.1568 v6:1.1262 v7:1.1284 v8:1.2217 v9:2.4250] +step:2000/20000 val_loss:2.2709 val_bpb:1.3450 train_time:85322ms step_avg:42.66ms +step:2200/20000 train_loss:2.1115 train_time:93865ms step_avg:42.67ms +step:2200 shared0_alpha:mean=0.543,std=0.072 shared1_alpha:mean=0.594,std=0.057 shared2_alpha:mean=0.590,std=0.062 shared3_alpha:mean=0.599,std=0.060 eff_mlp_scale:[v0:54.9497 v1:51.8206 v2:68.0756 v3:63.5284 v4:71.4948 v5:66.5054 v6:68.2298 v7:68.6605 v8:69.4061 v9:153.7187] eff_attn_scale:[v0:0.6316 v1:0.7569 v2:0.9871 v3:1.2402 v4:1.0359 v5:1.0927 v6:1.0778 v7:1.0986 v8:1.1838 v9:2.3459] +step:2200/20000 val_loss:2.2623 val_bpb:1.3398 train_time:93878ms step_avg:42.67ms +step:2400/20000 train_loss:2.2323 train_time:102402ms step_avg:42.67ms +step:2400 shared0_alpha:mean=0.542,std=0.073 shared1_alpha:mean=0.597,std=0.058 shared2_alpha:mean=0.591,std=0.063 shared3_alpha:mean=0.597,std=0.061 eff_mlp_scale:[v0:57.7260 v1:54.1711 v2:70.8756 v3:66.0999 v4:74.2740 v5:68.3935 v6:70.3284 v7:71.3192 v8:72.1222 v9:160.2638] eff_attn_scale:[v0:0.5943 v1:0.7100 v2:0.9295 v3:1.1891 v4:0.9896 v5:1.0407 v6:1.0334 v7:1.0453 v8:1.1286 v9:2.2392] +step:2400/20000 val_loss:2.2511 val_bpb:1.3333 train_time:102414ms step_avg:42.67ms +step:2600/20000 train_loss:2.4541 train_time:110942ms step_avg:42.67ms +step:2600 shared0_alpha:mean=0.542,std=0.074 shared1_alpha:mean=0.600,std=0.058 shared2_alpha:mean=0.593,std=0.064 shared3_alpha:mean=0.595,std=0.062 eff_mlp_scale:[v0:59.6185 v1:56.3688 v2:73.7807 v3:68.6433 v4:76.9368 v5:71.1393 v6:72.8048 v7:73.4723 v8:74.4310 v9:166.4584] eff_attn_scale:[v0:0.5791 v1:0.6938 v2:0.9218 v3:1.1346 v4:0.9490 v5:1.0070 v6:0.9991 v7:0.9941 v8:1.1081 v9:2.1999] +step:2600/20000 val_loss:2.2690 val_bpb:1.3438 train_time:110955ms step_avg:42.67ms +step:2800/20000 train_loss:2.2707 train_time:119476ms step_avg:42.67ms +step:2800 shared0_alpha:mean=0.541,std=0.075 shared1_alpha:mean=0.603,std=0.059 shared2_alpha:mean=0.595,std=0.065 shared3_alpha:mean=0.594,std=0.062 eff_mlp_scale:[v0:62.0038 v1:58.1266 v2:76.5693 v3:71.1723 v4:79.8341 v5:73.5304 v6:75.7056 v7:76.0881 v8:76.9053 v9:171.8438] eff_attn_scale:[v0:0.5219 v1:0.6557 v2:0.8758 v3:1.1346 v4:0.9195 v5:0.9812 v6:0.9659 v7:0.9887 v8:1.0755 v9:2.0916] +step:2800/20000 val_loss:2.2379 val_bpb:1.3254 train_time:119488ms step_avg:42.67ms +step:3000/20000 train_loss:2.2609 train_time:128013ms step_avg:42.67ms +step:3000 shared0_alpha:mean=0.541,std=0.077 shared1_alpha:mean=0.605,std=0.060 shared2_alpha:mean=0.596,std=0.066 shared3_alpha:mean=0.593,std=0.062 eff_mlp_scale:[v0:64.2917 v1:60.4402 v2:79.2767 v3:74.0053 v4:82.2812 v5:75.5075 v6:77.7761 v7:78.2755 v8:79.0867 v9:178.0808] eff_attn_scale:[v0:0.4896 v1:0.6337 v2:0.8487 v3:1.0928 v4:0.9089 v5:0.9413 v6:0.9223 v7:0.9556 v8:1.0471 v9:2.0182] +step:3000/20000 val_loss:2.2292 val_bpb:1.3203 train_time:128025ms step_avg:42.68ms +step:3200/20000 train_loss:2.2182 train_time:136546ms step_avg:42.67ms +step:3200 shared0_alpha:mean=0.540,std=0.078 shared1_alpha:mean=0.609,std=0.060 shared2_alpha:mean=0.598,std=0.067 shared3_alpha:mean=0.591,std=0.063 eff_mlp_scale:[v0:66.3024 v1:62.1723 v2:82.0718 v3:76.0492 v4:85.2207 v5:77.5691 v6:80.5491 v7:81.0574 v8:81.8298 v9:182.3725] eff_attn_scale:[v0:0.4617 v1:0.6190 v2:0.8144 v3:1.0615 v4:0.8840 v5:0.9082 v6:0.8969 v7:0.9407 v8:1.0261 v9:1.9801] +step:3200/20000 val_loss:2.2248 val_bpb:1.3177 train_time:136559ms step_avg:42.67ms +step:3400/20000 train_loss:2.1909 train_time:145079ms step_avg:42.67ms +step:3400 shared0_alpha:mean=0.539,std=0.079 shared1_alpha:mean=0.612,std=0.061 shared2_alpha:mean=0.599,std=0.068 shared3_alpha:mean=0.590,std=0.063 eff_mlp_scale:[v0:68.2101 v1:63.9502 v2:84.4184 v3:78.1660 v4:87.4333 v5:79.6008 v6:82.7450 v7:83.4339 v8:84.1822 v9:187.6665] eff_attn_scale:[v0:0.4486 v1:0.5971 v2:0.7965 v3:1.0424 v4:0.8726 v5:0.8889 v6:0.8828 v7:0.9179 v8:1.0137 v9:1.9442] +step:3400/20000 val_loss:2.2205 val_bpb:1.3151 train_time:145091ms step_avg:42.67ms +step:3600/20000 train_loss:2.1534 train_time:153606ms step_avg:42.67ms +step:3600 shared0_alpha:mean=0.539,std=0.079 shared1_alpha:mean=0.615,std=0.062 shared2_alpha:mean=0.600,std=0.069 shared3_alpha:mean=0.588,std=0.063 eff_mlp_scale:[v0:70.0995 v1:65.6907 v2:86.6731 v3:80.9776 v4:90.4612 v5:81.5523 v6:84.9834 v7:85.6536 v8:86.3169 v9:192.9353] eff_attn_scale:[v0:0.4310 v1:0.5933 v2:0.7943 v3:1.0361 v4:0.8578 v5:0.8833 v6:0.8762 v7:0.9163 v8:1.0086 v9:1.8904] +step:3600/20000 val_loss:2.2142 val_bpb:1.3114 train_time:153618ms step_avg:42.67ms +step:3800/20000 train_loss:2.2556 train_time:162129ms step_avg:42.67ms +step:3800 shared0_alpha:mean=0.538,std=0.079 shared1_alpha:mean=0.617,std=0.063 shared2_alpha:mean=0.602,std=0.070 shared3_alpha:mean=0.587,std=0.063 eff_mlp_scale:[v0:72.5787 v1:67.3238 v2:89.6872 v3:83.3735 v4:92.7995 v5:83.6552 v6:87.4015 v7:88.1265 v8:88.7095 v9:197.4801] eff_attn_scale:[v0:0.4194 v1:0.5755 v2:0.7810 v3:1.0139 v4:0.8445 v5:0.8676 v6:0.8614 v7:0.9058 v8:1.0000 v9:1.8876] +step:3800/20000 val_loss:2.2098 val_bpb:1.3088 train_time:162141ms step_avg:42.67ms +step:4000/20000 train_loss:2.1973 train_time:170654ms step_avg:42.66ms +step:4000 shared0_alpha:mean=0.537,std=0.081 shared1_alpha:mean=0.621,std=0.063 shared2_alpha:mean=0.603,std=0.071 shared3_alpha:mean=0.585,std=0.064 eff_mlp_scale:[v0:74.5182 v1:69.0459 v2:92.1823 v3:85.5387 v4:95.2490 v5:85.7896 v6:90.2188 v7:91.0132 v8:91.0947 v9:202.8772] eff_attn_scale:[v0:0.4051 v1:0.5732 v2:0.7514 v3:1.0043 v4:0.8398 v5:0.8533 v6:0.8500 v7:0.8916 v8:0.9841 v9:1.8405] +step:4000/20000 val_loss:2.2045 val_bpb:1.3056 train_time:170667ms step_avg:42.67ms +step:4200/20000 train_loss:2.2077 train_time:179233ms step_avg:42.67ms +step:4200 shared0_alpha:mean=0.536,std=0.082 shared1_alpha:mean=0.623,std=0.063 shared2_alpha:mean=0.605,std=0.071 shared3_alpha:mean=0.584,std=0.064 eff_mlp_scale:[v0:76.3813 v1:70.6879 v2:95.2320 v3:87.6971 v4:97.7672 v5:87.7398 v6:92.4711 v7:93.5434 v8:93.5187 v9:207.3148] eff_attn_scale:[v0:0.3893 v1:0.5588 v2:0.7485 v3:0.9875 v4:0.8316 v5:0.8513 v6:0.8374 v7:0.8806 v8:0.9694 v9:1.8191] +step:4200/20000 val_loss:2.2000 val_bpb:1.3030 train_time:179246ms step_avg:42.68ms +step:4400/20000 train_loss:2.1472 train_time:187750ms step_avg:42.67ms +step:4400 shared0_alpha:mean=0.536,std=0.083 shared1_alpha:mean=0.626,std=0.064 shared2_alpha:mean=0.606,std=0.072 shared3_alpha:mean=0.582,std=0.064 eff_mlp_scale:[v0:78.5042 v1:72.5004 v2:97.7045 v3:89.8764 v4:100.1539 v5:89.8615 v6:95.0158 v7:95.9755 v8:96.4816 v9:212.4293] eff_attn_scale:[v0:0.3738 v1:0.5432 v2:0.7366 v3:0.9837 v4:0.8272 v5:0.8330 v6:0.8252 v7:0.8818 v8:0.9702 v9:1.8051] +step:4400/20000 val_loss:2.1999 val_bpb:1.3029 train_time:187763ms step_avg:42.67ms +step:4600/20000 train_loss:2.0112 train_time:196273ms step_avg:42.67ms +step:4600 shared0_alpha:mean=0.535,std=0.084 shared1_alpha:mean=0.629,std=0.065 shared2_alpha:mean=0.608,std=0.072 shared3_alpha:mean=0.580,std=0.064 eff_mlp_scale:[v0:80.2572 v1:74.3306 v2:100.3648 v3:92.9644 v4:102.5015 v5:91.2381 v6:97.3078 v7:98.3552 v8:98.8009 v9:217.0316] eff_attn_scale:[v0:0.3696 v1:0.5448 v2:0.7332 v3:0.9826 v4:0.8280 v5:0.8236 v6:0.8306 v7:0.8813 v8:0.9814 v9:1.7749] +step:4600/20000 val_loss:2.1964 val_bpb:1.3009 train_time:196286ms step_avg:42.67ms +step:4800/20000 train_loss:2.2957 train_time:204793ms step_avg:42.67ms +step:4800 shared0_alpha:mean=0.535,std=0.084 shared1_alpha:mean=0.632,std=0.066 shared2_alpha:mean=0.609,std=0.073 shared3_alpha:mean=0.578,std=0.065 eff_mlp_scale:[v0:82.1790 v1:76.0158 v2:102.7894 v3:94.7438 v4:105.6761 v5:93.3021 v6:99.7023 v7:101.7526 v8:101.2999 v9:221.4375] eff_attn_scale:[v0:0.3580 v1:0.5393 v2:0.7271 v3:0.9672 v4:0.8114 v5:0.8164 v6:0.8198 v7:0.8670 v8:0.9687 v9:1.7523] +step:4800/20000 val_loss:2.1912 val_bpb:1.2977 train_time:204806ms step_avg:42.67ms +step:5000/20000 train_loss:2.0626 train_time:213309ms step_avg:42.66ms +step:5000 shared0_alpha:mean=0.534,std=0.085 shared1_alpha:mean=0.635,std=0.066 shared2_alpha:mean=0.610,std=0.073 shared3_alpha:mean=0.575,std=0.064 eff_mlp_scale:[v0:84.2924 v1:77.9286 v2:105.2526 v3:97.1999 v4:108.1718 v5:95.4974 v6:102.1231 v7:103.5784 v8:103.5997 v9:225.2929] eff_attn_scale:[v0:0.3527 v1:0.5316 v2:0.7225 v3:0.9758 v4:0.8140 v5:0.8113 v6:0.8053 v7:0.8607 v8:0.9656 v9:1.7034] +step:5000/20000 val_loss:2.1867 val_bpb:1.2951 train_time:213321ms step_avg:42.66ms +step:5200/20000 train_loss:2.2053 train_time:221831ms step_avg:42.66ms +step:5200 shared0_alpha:mean=0.534,std=0.085 shared1_alpha:mean=0.637,std=0.067 shared2_alpha:mean=0.611,std=0.074 shared3_alpha:mean=0.574,std=0.065 eff_mlp_scale:[v0:86.3702 v1:79.1191 v2:107.9870 v3:99.5870 v4:110.6327 v5:97.0781 v6:104.7342 v7:106.2682 v8:106.2758 v9:229.5533] eff_attn_scale:[v0:0.3491 v1:0.5318 v2:0.7129 v3:0.9752 v4:0.8102 v5:0.8051 v6:0.8187 v7:0.8741 v8:0.9611 v9:1.7191] +step:5200/20000 val_loss:2.1868 val_bpb:1.2952 train_time:221844ms step_avg:42.66ms +step:5400/20000 train_loss:2.2219 train_time:230344ms step_avg:42.66ms +step:5400 shared0_alpha:mean=0.533,std=0.086 shared1_alpha:mean=0.641,std=0.067 shared2_alpha:mean=0.612,std=0.074 shared3_alpha:mean=0.572,std=0.065 eff_mlp_scale:[v0:88.4106 v1:80.4525 v2:110.6037 v3:102.0018 v4:113.1708 v5:99.4256 v6:107.1525 v7:109.5096 v8:108.5379 v9:233.7632] eff_attn_scale:[v0:0.3380 v1:0.5248 v2:0.7216 v3:0.9459 v4:0.8050 v5:0.7882 v6:0.8142 v7:0.8608 v8:0.9559 v9:1.6910] +step:5400/20000 val_loss:2.1831 val_bpb:1.2929 train_time:230356ms step_avg:42.66ms +step:5600/20000 train_loss:2.2165 train_time:238864ms step_avg:42.65ms +step:5600 shared0_alpha:mean=0.533,std=0.087 shared1_alpha:mean=0.643,std=0.068 shared2_alpha:mean=0.614,std=0.075 shared3_alpha:mean=0.570,std=0.065 eff_mlp_scale:[v0:90.4609 v1:82.3534 v2:113.1843 v3:104.6682 v4:115.8157 v5:101.4593 v6:109.7839 v7:111.8817 v8:111.2600 v9:238.2242] eff_attn_scale:[v0:0.3330 v1:0.5183 v2:0.7065 v3:0.9554 v4:0.8147 v5:0.7932 v6:0.8074 v7:0.8604 v8:0.9615 v9:1.6870] +step:5600/20000 val_loss:2.1821 val_bpb:1.2924 train_time:238876ms step_avg:42.66ms +step:5800/20000 train_loss:2.1811 train_time:247383ms step_avg:42.65ms +step:5800 shared0_alpha:mean=0.532,std=0.088 shared1_alpha:mean=0.647,std=0.068 shared2_alpha:mean=0.615,std=0.076 shared3_alpha:mean=0.567,std=0.065 eff_mlp_scale:[v0:92.0874 v1:84.2741 v2:115.8135 v3:107.6628 v4:118.3128 v5:102.8635 v6:111.9944 v7:114.3712 v8:113.7482 v9:242.2198] eff_attn_scale:[v0:0.3232 v1:0.5207 v2:0.7000 v3:0.9613 v4:0.8164 v5:0.7946 v6:0.8000 v7:0.8612 v8:0.9745 v9:1.7030] +step:5800/20000 val_loss:2.1806 val_bpb:1.2915 train_time:247396ms step_avg:42.65ms +step:6000/20000 train_loss:2.2463 train_time:255894ms step_avg:42.65ms +step:6000 shared0_alpha:mean=0.532,std=0.088 shared1_alpha:mean=0.650,std=0.069 shared2_alpha:mean=0.616,std=0.076 shared3_alpha:mean=0.565,std=0.066 eff_mlp_scale:[v0:94.0400 v1:85.5876 v2:119.1415 v3:110.2268 v4:120.8064 v5:104.6090 v6:114.5083 v7:117.0273 v8:115.7046 v9:246.0522] eff_attn_scale:[v0:0.3158 v1:0.5158 v2:0.7080 v3:0.9592 v4:0.8229 v5:0.7778 v6:0.8039 v7:0.8642 v8:0.9643 v9:1.6915] +step:6000/20000 val_loss:2.1755 val_bpb:1.2885 train_time:255907ms step_avg:42.65ms +step:6200/20000 train_loss:2.1207 train_time:264410ms step_avg:42.65ms +step:6200 shared0_alpha:mean=0.531,std=0.088 shared1_alpha:mean=0.653,std=0.069 shared2_alpha:mean=0.617,std=0.077 shared3_alpha:mean=0.562,std=0.066 eff_mlp_scale:[v0:95.9563 v1:86.7445 v2:121.7121 v3:111.9978 v4:123.5617 v5:106.9548 v6:117.2904 v7:119.6312 v8:118.3502 v9:250.3509] eff_attn_scale:[v0:0.3149 v1:0.5174 v2:0.7096 v3:0.9677 v4:0.8208 v5:0.7803 v6:0.8011 v7:0.8624 v8:0.9568 v9:1.6351] +step:6200/20000 val_loss:2.1753 val_bpb:1.2883 train_time:264423ms step_avg:42.65ms +step:6400/20000 train_loss:2.1966 train_time:272926ms step_avg:42.64ms +step:6400 shared0_alpha:mean=0.530,std=0.089 shared1_alpha:mean=0.657,std=0.070 shared2_alpha:mean=0.618,std=0.077 shared3_alpha:mean=0.560,std=0.066 eff_mlp_scale:[v0:97.7203 v1:88.6236 v2:124.2881 v3:115.3028 v4:125.9675 v5:108.9628 v6:120.0107 v7:123.3458 v8:120.8623 v9:252.5516] eff_attn_scale:[v0:0.3073 v1:0.5205 v2:0.7006 v3:0.9486 v4:0.8206 v5:0.7828 v6:0.7955 v7:0.8683 v8:0.9675 v9:1.6556] +step:6400/20000 val_loss:2.1723 val_bpb:1.2866 train_time:272938ms step_avg:42.65ms +step:6600/20000 train_loss:2.1551 train_time:281444ms step_avg:42.64ms +step:6600 shared0_alpha:mean=0.530,std=0.089 shared1_alpha:mean=0.660,std=0.070 shared2_alpha:mean=0.620,std=0.077 shared3_alpha:mean=0.558,std=0.065 eff_mlp_scale:[v0:100.0658 v1:90.0812 v2:126.3587 v3:117.8195 v4:127.9309 v5:110.6250 v6:122.5566 v7:126.1000 v8:123.3664 v9:256.8854] eff_attn_scale:[v0:0.3046 v1:0.5137 v2:0.6977 v3:0.9612 v4:0.8222 v5:0.7766 v6:0.8019 v7:0.8806 v8:0.9684 v9:1.6362] +step:6600/20000 val_loss:2.1688 val_bpb:1.2845 train_time:281456ms step_avg:42.64ms +step:6800/20000 train_loss:2.2248 train_time:289951ms step_avg:42.64ms +step:6800 shared0_alpha:mean=0.530,std=0.090 shared1_alpha:mean=0.662,std=0.071 shared2_alpha:mean=0.621,std=0.078 shared3_alpha:mean=0.556,std=0.066 eff_mlp_scale:[v0:101.9769 v1:91.3168 v2:129.2413 v3:119.9645 v4:130.6474 v5:112.2710 v6:125.3418 v7:128.8518 v8:126.0640 v9:261.0441] eff_attn_scale:[v0:0.2840 v1:0.5062 v2:0.6924 v3:0.9508 v4:0.8232 v5:0.7664 v6:0.7869 v7:0.8657 v8:0.9646 v9:1.6177] +step:6800/20000 val_loss:2.1673 val_bpb:1.2836 train_time:289964ms step_avg:42.64ms +step:7000/20000 train_loss:2.2561 train_time:298467ms step_avg:42.64ms +step:7000 shared0_alpha:mean=0.530,std=0.090 shared1_alpha:mean=0.666,std=0.072 shared2_alpha:mean=0.622,std=0.079 shared3_alpha:mean=0.554,std=0.066 eff_mlp_scale:[v0:103.5721 v1:93.4563 v2:131.7646 v3:122.5624 v4:133.3015 v5:114.6403 v6:128.0356 v7:131.6172 v8:128.7146 v9:265.4260] eff_attn_scale:[v0:0.2906 v1:0.5046 v2:0.6980 v3:0.9660 v4:0.8338 v5:0.7702 v6:0.7977 v7:0.8855 v8:0.9761 v9:1.6183] +step:7000/20000 val_loss:2.1650 val_bpb:1.2822 train_time:298479ms step_avg:42.64ms +step:7200/20000 train_loss:2.2310 train_time:306975ms step_avg:42.64ms +step:7200 shared0_alpha:mean=0.530,std=0.091 shared1_alpha:mean=0.669,std=0.072 shared2_alpha:mean=0.623,std=0.079 shared3_alpha:mean=0.551,std=0.066 eff_mlp_scale:[v0:105.7434 v1:94.5495 v2:134.6851 v3:125.7721 v4:135.4604 v5:116.1258 v6:130.6383 v7:135.2298 v8:131.2471 v9:269.7646] eff_attn_scale:[v0:0.2861 v1:0.5102 v2:0.6994 v3:0.9815 v4:0.8313 v5:0.7766 v6:0.8045 v7:0.9048 v8:0.9894 v9:1.6052] +step:7200/20000 val_loss:2.1665 val_bpb:1.2831 train_time:306987ms step_avg:42.64ms +step:7400/20000 train_loss:2.1537 train_time:315487ms step_avg:42.63ms +step:7400 shared0_alpha:mean=0.530,std=0.091 shared1_alpha:mean=0.672,std=0.074 shared2_alpha:mean=0.625,std=0.080 shared3_alpha:mean=0.549,std=0.066 eff_mlp_scale:[v0:107.4044 v1:95.9543 v2:137.6167 v3:128.3425 v4:138.1341 v5:118.3082 v6:133.7125 v7:138.1902 v8:133.7159 v9:271.8210] eff_attn_scale:[v0:0.2806 v1:0.5033 v2:0.6912 v3:0.9738 v4:0.8297 v5:0.7724 v6:0.7997 v7:0.8876 v8:0.9773 v9:1.6051] +step:7400/20000 val_loss:2.1621 val_bpb:1.2805 train_time:315499ms step_avg:42.64ms +step:7600/20000 train_loss:2.0340 train_time:324000ms step_avg:42.63ms +step:7600 shared0_alpha:mean=0.530,std=0.091 shared1_alpha:mean=0.675,std=0.073 shared2_alpha:mean=0.626,std=0.080 shared3_alpha:mean=0.547,std=0.066 eff_mlp_scale:[v0:109.5411 v1:98.1257 v2:140.5602 v3:131.1342 v4:141.0223 v5:119.7408 v6:136.2682 v7:141.1938 v8:135.7513 v9:276.3030] eff_attn_scale:[v0:0.2794 v1:0.5048 v2:0.7003 v3:0.9756 v4:0.8538 v5:0.7644 v6:0.8101 v7:0.8892 v8:1.0030 v9:1.6024] +step:7600/20000 val_loss:2.1600 val_bpb:1.2793 train_time:324010ms step_avg:42.63ms +step:7800/20000 train_loss:2.1786 train_time:332512ms step_avg:42.63ms +step:7800 shared0_alpha:mean=0.529,std=0.092 shared1_alpha:mean=0.679,std=0.074 shared2_alpha:mean=0.627,std=0.081 shared3_alpha:mean=0.544,std=0.066 eff_mlp_scale:[v0:111.2017 v1:99.5799 v2:143.4963 v3:133.8642 v4:143.6751 v5:122.0978 v6:139.1756 v7:144.3930 v8:138.9148 v9:280.7726] eff_attn_scale:[v0:0.2719 v1:0.5039 v2:0.7026 v3:0.9814 v4:0.8534 v5:0.7702 v6:0.8075 v7:0.9047 v8:1.0025 v9:1.6027] +step:7800/20000 val_loss:2.1567 val_bpb:1.2773 train_time:332523ms step_avg:42.63ms +step:8000/20000 train_loss:2.1474 train_time:341024ms step_avg:42.63ms +step:8000 shared0_alpha:mean=0.529,std=0.092 shared1_alpha:mean=0.682,std=0.074 shared2_alpha:mean=0.628,std=0.081 shared3_alpha:mean=0.541,std=0.066 eff_mlp_scale:[v0:113.5248 v1:100.9212 v2:146.5628 v3:136.8268 v4:146.4019 v5:124.5063 v6:142.0919 v7:147.4374 v8:140.8351 v9:284.9535] eff_attn_scale:[v0:0.2696 v1:0.5030 v2:0.7054 v3:0.9866 v4:0.8512 v5:0.7627 v6:0.8107 v7:0.9095 v8:0.9999 v9:1.5866] +step:8000/20000 val_loss:2.1549 val_bpb:1.2763 train_time:341036ms step_avg:42.63ms +step:8200/20000 train_loss:2.2189 train_time:349532ms step_avg:42.63ms +step:8200 shared0_alpha:mean=0.529,std=0.092 shared1_alpha:mean=0.686,std=0.075 shared2_alpha:mean=0.629,std=0.081 shared3_alpha:mean=0.538,std=0.066 eff_mlp_scale:[v0:115.6790 v1:102.9386 v2:149.2333 v3:139.5241 v4:148.4178 v5:126.4754 v6:144.8927 v7:150.4119 v8:143.6807 v9:287.1323] eff_attn_scale:[v0:0.2696 v1:0.5106 v2:0.7020 v3:0.9907 v4:0.8589 v5:0.7731 v6:0.8114 v7:0.9085 v8:1.0132 v9:1.6141] +step:8200/20000 val_loss:2.1537 val_bpb:1.2756 train_time:349544ms step_avg:42.63ms +step:8400/20000 train_loss:2.1698 train_time:358110ms step_avg:42.63ms +step:8400 shared0_alpha:mean=0.529,std=0.093 shared1_alpha:mean=0.689,std=0.075 shared2_alpha:mean=0.630,std=0.082 shared3_alpha:mean=0.536,std=0.066 eff_mlp_scale:[v0:117.1696 v1:104.5319 v2:152.2676 v3:142.3639 v4:151.1524 v5:127.9611 v6:147.8988 v7:154.2814 v8:147.1329 v9:291.3692] eff_attn_scale:[v0:0.2668 v1:0.4978 v2:0.7079 v3:0.9866 v4:0.8710 v5:0.7743 v6:0.8175 v7:0.9151 v8:1.0214 v9:1.6191] +step:8400/20000 val_loss:2.1524 val_bpb:1.2748 train_time:358123ms step_avg:42.63ms +step:8600/20000 train_loss:2.1671 train_time:366614ms step_avg:42.63ms +step:8600 shared0_alpha:mean=0.530,std=0.093 shared1_alpha:mean=0.693,std=0.076 shared2_alpha:mean=0.632,std=0.082 shared3_alpha:mean=0.533,std=0.066 eff_mlp_scale:[v0:119.6620 v1:105.9419 v2:155.1287 v3:145.4451 v4:153.8885 v5:130.4333 v6:150.9129 v7:157.3671 v8:149.8370 v9:295.3305] eff_attn_scale:[v0:0.2627 v1:0.5019 v2:0.7162 v3:1.0106 v4:0.8763 v5:0.7703 v6:0.8264 v7:0.9229 v8:1.0155 v9:1.6057] +step:8600/20000 val_loss:2.1503 val_bpb:1.2736 train_time:366626ms step_avg:42.63ms +step:8800/20000 train_loss:2.1362 train_time:375132ms step_avg:42.63ms +step:8800 shared0_alpha:mean=0.529,std=0.094 shared1_alpha:mean=0.696,std=0.076 shared2_alpha:mean=0.633,std=0.082 shared3_alpha:mean=0.530,std=0.066 eff_mlp_scale:[v0:121.2455 v1:107.3397 v2:158.1906 v3:148.2097 v4:155.8999 v5:132.0103 v6:153.9370 v7:160.6911 v8:151.7160 v9:297.2757] eff_attn_scale:[v0:0.2593 v1:0.5064 v2:0.7042 v3:1.0127 v4:0.8857 v5:0.7709 v6:0.8133 v7:0.9404 v8:1.0367 v9:1.6071] +step:8800/20000 val_loss:2.1494 val_bpb:1.2730 train_time:375145ms step_avg:42.63ms +step:9000/20000 train_loss:2.0551 train_time:383640ms step_avg:42.63ms +step:9000 shared0_alpha:mean=0.530,std=0.094 shared1_alpha:mean=0.700,std=0.076 shared2_alpha:mean=0.635,std=0.083 shared3_alpha:mean=0.528,std=0.066 eff_mlp_scale:[v0:122.9368 v1:109.4864 v2:161.2221 v3:151.8921 v4:158.8706 v5:134.3784 v6:157.0095 v7:164.0049 v8:154.6675 v9:301.7382] eff_attn_scale:[v0:0.2583 v1:0.5056 v2:0.7132 v3:1.0269 v4:0.9121 v5:0.7656 v6:0.8336 v7:0.9388 v8:1.0544 v9:1.6356] +step:9000/20000 val_loss:2.1490 val_bpb:1.2727 train_time:383652ms step_avg:42.63ms +step:9200/20000 train_loss:2.1168 train_time:392146ms step_avg:42.62ms +step:9200 shared0_alpha:mean=0.530,std=0.095 shared1_alpha:mean=0.703,std=0.077 shared2_alpha:mean=0.636,std=0.083 shared3_alpha:mean=0.525,std=0.066 eff_mlp_scale:[v0:125.2359 v1:111.1469 v2:164.0712 v3:154.8471 v4:161.5470 v5:136.0026 v6:159.7852 v7:168.2285 v8:157.4959 v9:303.8700] eff_attn_scale:[v0:0.2561 v1:0.5113 v2:0.7228 v3:1.0310 v4:0.9194 v5:0.7825 v6:0.8333 v7:0.9525 v8:1.0621 v9:1.6123] +step:9200/20000 val_loss:2.1471 val_bpb:1.2716 train_time:392159ms step_avg:42.63ms +step:9400/20000 train_loss:2.1659 train_time:400667ms step_avg:42.62ms +step:9400 shared0_alpha:mean=0.530,std=0.095 shared1_alpha:mean=0.707,std=0.077 shared2_alpha:mean=0.638,std=0.084 shared3_alpha:mean=0.522,std=0.065 eff_mlp_scale:[v0:127.1187 v1:112.7082 v2:166.9327 v3:157.7378 v4:164.4733 v5:138.4737 v6:162.8718 v7:171.3794 v8:160.2144 v9:308.3118] eff_attn_scale:[v0:0.2521 v1:0.5074 v2:0.7123 v3:1.0271 v4:0.9348 v5:0.7796 v6:0.8409 v7:0.9642 v8:1.0737 v9:1.6239] +step:9400/20000 val_loss:2.1453 val_bpb:1.2706 train_time:400679ms step_avg:42.63ms +step:9600/20000 train_loss:2.1761 train_time:409174ms step_avg:42.62ms +step:9600 shared0_alpha:mean=0.530,std=0.095 shared1_alpha:mean=0.710,std=0.078 shared2_alpha:mean=0.639,std=0.083 shared3_alpha:mean=0.519,std=0.065 eff_mlp_scale:[v0:128.7033 v1:114.0261 v2:170.1002 v3:160.8656 v4:166.5059 v5:140.2586 v6:166.7156 v7:174.8982 v8:162.9963 v9:312.3873] eff_attn_scale:[v0:0.2491 v1:0.5019 v2:0.7115 v3:1.0427 v4:0.9261 v5:0.7640 v6:0.8354 v7:0.9693 v8:1.0698 v9:1.6113] +step:9600/20000 val_loss:2.1457 val_bpb:1.2708 train_time:409186ms step_avg:42.62ms +step:9800/20000 train_loss:2.1050 train_time:417683ms step_avg:42.62ms +step:9800 shared0_alpha:mean=0.530,std=0.097 shared1_alpha:mean=0.714,std=0.079 shared2_alpha:mean=0.640,std=0.085 shared3_alpha:mean=0.516,std=0.066 eff_mlp_scale:[v0:130.9903 v1:116.3724 v2:173.9150 v3:164.0603 v4:168.3621 v5:142.0792 v6:169.6515 v7:179.2270 v8:165.9562 v9:314.4527] eff_attn_scale:[v0:0.2460 v1:0.5120 v2:0.7273 v3:1.0403 v4:0.9431 v5:0.7909 v6:0.8485 v7:0.9720 v8:1.0940 v9:1.6141] +step:9800/20000 val_loss:2.1447 val_bpb:1.2702 train_time:417695ms step_avg:42.62ms +step:10000/20000 train_loss:2.1473 train_time:426188ms step_avg:42.62ms +step:10000 shared0_alpha:mean=0.531,std=0.097 shared1_alpha:mean=0.718,std=0.079 shared2_alpha:mean=0.641,std=0.085 shared3_alpha:mean=0.513,std=0.065 eff_mlp_scale:[v0:133.4993 v1:117.8749 v2:176.8462 v3:167.9897 v4:172.2010 v5:144.5386 v6:173.6187 v7:182.8824 v8:168.9445 v9:319.0671] eff_attn_scale:[v0:0.2439 v1:0.5171 v2:0.7357 v3:1.0651 v4:0.9612 v5:0.7933 v6:0.8622 v7:0.9909 v8:1.1187 v9:1.6126] +step:10000/20000 val_loss:2.1439 val_bpb:1.2698 train_time:426199ms step_avg:42.62ms +step:10200/20000 train_loss:2.0959 train_time:434702ms step_avg:42.62ms +step:10200 shared0_alpha:mean=0.532,std=0.097 shared1_alpha:mean=0.721,std=0.080 shared2_alpha:mean=0.642,std=0.085 shared3_alpha:mean=0.509,std=0.065 eff_mlp_scale:[v0:135.2895 v1:119.8998 v2:180.8499 v3:171.0312 v4:174.5362 v5:147.1871 v6:176.9959 v7:187.0349 v8:171.6876 v9:320.9614] eff_attn_scale:[v0:0.2457 v1:0.5134 v2:0.7404 v3:1.0757 v4:0.9609 v5:0.7950 v6:0.8662 v7:0.9959 v8:1.1066 v9:1.6390] +step:10200/20000 val_loss:2.1402 val_bpb:1.2676 train_time:434714ms step_avg:42.62ms +step:10400/20000 train_loss:2.1306 train_time:443206ms step_avg:42.62ms +step:10400 shared0_alpha:mean=0.532,std=0.098 shared1_alpha:mean=0.726,std=0.080 shared2_alpha:mean=0.644,std=0.085 shared3_alpha:mean=0.507,std=0.065 eff_mlp_scale:[v0:136.9241 v1:121.2368 v2:183.2520 v3:174.2560 v4:176.8566 v5:149.0685 v6:180.3395 v7:190.8543 v8:174.4648 v9:325.2477] eff_attn_scale:[v0:0.2393 v1:0.5192 v2:0.7533 v3:1.0728 v4:0.9598 v5:0.7923 v6:0.8710 v7:1.0034 v8:1.1225 v9:1.6231] +step:10400/20000 val_loss:2.1400 val_bpb:1.2674 train_time:443217ms step_avg:42.62ms +step:10600/20000 train_loss:2.0117 train_time:451713ms step_avg:42.61ms +step:10600 shared0_alpha:mean=0.533,std=0.098 shared1_alpha:mean=0.729,std=0.080 shared2_alpha:mean=0.645,std=0.086 shared3_alpha:mean=0.503,std=0.065 eff_mlp_scale:[v0:139.3970 v1:122.8613 v2:186.3927 v3:177.7040 v4:179.9691 v5:151.2466 v6:184.7009 v7:195.0799 v8:177.4702 v9:329.6817] eff_attn_scale:[v0:0.2370 v1:0.5219 v2:0.7503 v3:1.0824 v4:0.9930 v5:0.8071 v6:0.8881 v7:1.0232 v8:1.1475 v9:1.6096] +step:10600/20000 val_loss:2.1407 val_bpb:1.2678 train_time:451725ms step_avg:42.62ms +step:10800/20000 train_loss:2.2213 train_time:460224ms step_avg:42.61ms +step:10800 shared0_alpha:mean=0.533,std=0.098 shared1_alpha:mean=0.732,std=0.082 shared2_alpha:mean=0.646,std=0.086 shared3_alpha:mean=0.501,std=0.064 eff_mlp_scale:[v0:140.8454 v1:125.0080 v2:189.6674 v3:182.0608 v4:181.9680 v5:153.9291 v6:188.0721 v7:199.7270 v8:180.2983 v9:331.6027] eff_attn_scale:[v0:0.2339 v1:0.5242 v2:0.7481 v3:1.0797 v4:1.0045 v5:0.7990 v6:0.8910 v7:1.0103 v8:1.1480 v9:1.6369] +step:10800/20000 val_loss:2.1383 val_bpb:1.2664 train_time:460235ms step_avg:42.61ms +step:11000/20000 train_loss:2.1461 train_time:468732ms step_avg:42.61ms +step:11000 shared0_alpha:mean=0.534,std=0.099 shared1_alpha:mean=0.736,std=0.082 shared2_alpha:mean=0.647,std=0.087 shared3_alpha:mean=0.497,std=0.064 eff_mlp_scale:[v0:143.8485 v1:126.6024 v2:193.1592 v3:185.2109 v4:185.0573 v5:155.7508 v6:191.2416 v7:204.0561 v8:182.5521 v9:336.0238] eff_attn_scale:[v0:0.2355 v1:0.5254 v2:0.7531 v3:1.1074 v4:1.0188 v5:0.8018 v6:0.8858 v7:1.0368 v8:1.1635 v9:1.6224] +step:11000/20000 val_loss:2.1366 val_bpb:1.2654 train_time:468743ms step_avg:42.61ms +step:11200/20000 train_loss:2.1011 train_time:477233ms step_avg:42.61ms +step:11200 shared0_alpha:mean=0.535,std=0.100 shared1_alpha:mean=0.740,std=0.082 shared2_alpha:mean=0.649,std=0.088 shared3_alpha:mean=0.494,std=0.065 eff_mlp_scale:[v0:145.6090 v1:129.0358 v2:197.5520 v3:188.8528 v4:188.0731 v5:158.4482 v6:195.1418 v7:207.8885 v8:186.4106 v9:337.9784] eff_attn_scale:[v0:0.2336 v1:0.5225 v2:0.7689 v3:1.1180 v4:1.0305 v5:0.8026 v6:0.9026 v7:1.0471 v8:1.1762 v9:1.6517] +step:11200/20000 val_loss:2.1363 val_bpb:1.2652 train_time:477246ms step_avg:42.61ms +step:11400/20000 train_loss:2.0855 train_time:485736ms step_avg:42.61ms +step:11400 shared0_alpha:mean=0.536,std=0.100 shared1_alpha:mean=0.744,std=0.083 shared2_alpha:mean=0.650,std=0.088 shared3_alpha:mean=0.491,std=0.064 eff_mlp_scale:[v0:148.0383 v1:130.3857 v2:200.9540 v3:193.1590 v4:191.1929 v5:161.2432 v6:199.5130 v7:212.0066 v8:189.4103 v9:342.1046] eff_attn_scale:[v0:0.2308 v1:0.5386 v2:0.7737 v3:1.1352 v4:1.0334 v5:0.8217 v6:0.9026 v7:1.0691 v8:1.1850 v9:1.6541] +step:11400/20000 val_loss:2.1364 val_bpb:1.2653 train_time:485748ms step_avg:42.61ms +step:11600/20000 train_loss:2.0953 train_time:494246ms step_avg:42.61ms +step:11600 shared0_alpha:mean=0.536,std=0.101 shared1_alpha:mean=0.747,std=0.082 shared2_alpha:mean=0.652,std=0.088 shared3_alpha:mean=0.488,std=0.064 eff_mlp_scale:[v0:149.7553 v1:132.4760 v2:204.5622 v3:196.6609 v4:193.0643 v5:162.8436 v6:203.0969 v7:215.7832 v8:192.3099 v9:344.1545] eff_attn_scale:[v0:0.2369 v1:0.5359 v2:0.7749 v3:1.1370 v4:1.0486 v5:0.8186 v6:0.9137 v7:1.0821 v8:1.2130 v9:1.6221] +step:11600/20000 val_loss:2.1356 val_bpb:1.2648 train_time:494257ms step_avg:42.61ms +step:11800/20000 train_loss:2.1265 train_time:502783ms step_avg:42.61ms +step:11800 shared0_alpha:mean=0.537,std=0.102 shared1_alpha:mean=0.751,std=0.083 shared2_alpha:mean=0.653,std=0.088 shared3_alpha:mean=0.485,std=0.064 eff_mlp_scale:[v0:151.6786 v1:134.0959 v2:207.9402 v3:201.2269 v4:196.3660 v5:165.5384 v6:206.7773 v7:221.6175 v8:195.4164 v9:348.3466] eff_attn_scale:[v0:0.2298 v1:0.5452 v2:0.7877 v3:1.1440 v4:1.0491 v5:0.8220 v6:0.9336 v7:1.0832 v8:1.2127 v9:1.6391] +step:11800/20000 val_loss:2.1329 val_bpb:1.2632 train_time:502796ms step_avg:42.61ms +step:12000/20000 train_loss:2.1004 train_time:511299ms step_avg:42.61ms +step:12000 shared0_alpha:mean=0.538,std=0.102 shared1_alpha:mean=0.755,std=0.084 shared2_alpha:mean=0.654,std=0.089 shared3_alpha:mean=0.482,std=0.064 eff_mlp_scale:[v0:154.1958 v1:136.2558 v2:211.6706 v3:204.5753 v4:199.2398 v5:167.7950 v6:210.5409 v7:225.4635 v8:197.4685 v9:352.7194] eff_attn_scale:[v0:0.2264 v1:0.5388 v2:0.7884 v3:1.1720 v4:1.0612 v5:0.8167 v6:0.9286 v7:1.0938 v8:1.2200 v9:1.6185] +step:12000/20000 val_loss:2.1313 val_bpb:1.2623 train_time:511310ms step_avg:42.61ms +step:12200/20000 train_loss:2.2481 train_time:519812ms step_avg:42.61ms +step:12200 shared0_alpha:mean=0.538,std=0.102 shared1_alpha:mean=0.758,std=0.084 shared2_alpha:mean=0.656,std=0.089 shared3_alpha:mean=0.478,std=0.064 eff_mlp_scale:[v0:156.0455 v1:137.5728 v2:215.3837 v3:208.3083 v4:201.6977 v5:170.0221 v6:214.5259 v7:231.3655 v8:201.3298 v9:354.4663] eff_attn_scale:[v0:0.2206 v1:0.5374 v2:0.7902 v3:1.1706 v4:1.0677 v5:0.8219 v6:0.9356 v7:1.0981 v8:1.2333 v9:1.6526] +step:12200/20000 val_loss:2.1310 val_bpb:1.2621 train_time:519823ms step_avg:42.61ms +step:12400/20000 train_loss:1.8937 train_time:528378ms step_avg:42.61ms +step:12400 shared0_alpha:mean=0.540,std=0.102 shared1_alpha:mean=0.763,std=0.084 shared2_alpha:mean=0.658,std=0.090 shared3_alpha:mean=0.476,std=0.063 eff_mlp_scale:[v0:158.5418 v1:139.9026 v2:218.8049 v3:212.6121 v4:204.6887 v5:172.7634 v6:218.4762 v7:235.3858 v8:202.8567 v9:358.5690] eff_attn_scale:[v0:0.2232 v1:0.5510 v2:0.7948 v3:1.1875 v4:1.0922 v5:0.8265 v6:0.9469 v7:1.1200 v8:1.2589 v9:1.6405] +step:12400/20000 val_loss:2.1311 val_bpb:1.2621 train_time:528390ms step_avg:42.61ms +step:12600/20000 train_loss:2.1253 train_time:536891ms step_avg:42.61ms +step:12600 shared0_alpha:mean=0.541,std=0.104 shared1_alpha:mean=0.767,std=0.085 shared2_alpha:mean=0.660,std=0.090 shared3_alpha:mean=0.473,std=0.063 eff_mlp_scale:[v0:161.0704 v1:142.2011 v2:222.4779 v3:216.2944 v4:208.1218 v5:174.9844 v6:224.0638 v7:241.8932 v8:206.0310 v9:360.5451] eff_attn_scale:[v0:0.2259 v1:0.5585 v2:0.8068 v3:1.1926 v4:1.1059 v5:0.8528 v6:0.9593 v7:1.1474 v8:1.2796 v9:1.6610] +step:12600/20000 val_loss:2.1315 val_bpb:1.2624 train_time:536904ms step_avg:42.61ms +step:12800/20000 train_loss:2.1448 train_time:545397ms step_avg:42.61ms +step:12800 shared0_alpha:mean=0.542,std=0.104 shared1_alpha:mean=0.770,std=0.085 shared2_alpha:mean=0.660,std=0.090 shared3_alpha:mean=0.469,std=0.063 eff_mlp_scale:[v0:162.7280 v1:143.6815 v2:226.2649 v3:220.0968 v4:209.9233 v5:177.8617 v6:227.8007 v7:246.1311 v8:210.9755 v9:364.9839] eff_attn_scale:[v0:0.2175 v1:0.5459 v2:0.8094 v3:1.2027 v4:1.0981 v5:0.8435 v6:0.9723 v7:1.1460 v8:1.2889 v9:1.6455] +step:12800/20000 val_loss:2.1302 val_bpb:1.2616 train_time:545409ms step_avg:42.61ms +step:13000/20000 train_loss:2.2203 train_time:553910ms step_avg:42.61ms +step:13000 shared0_alpha:mean=0.543,std=0.105 shared1_alpha:mean=0.773,std=0.085 shared2_alpha:mean=0.662,std=0.091 shared3_alpha:mean=0.465,std=0.063 eff_mlp_scale:[v0:164.8956 v1:146.2567 v2:230.3892 v3:223.7054 v4:212.7496 v5:180.5472 v6:231.9168 v7:250.3686 v8:214.1154 v9:367.1068] eff_attn_scale:[v0:0.2149 v1:0.5481 v2:0.8159 v3:1.2123 v4:1.1263 v5:0.8436 v6:0.9585 v7:1.1494 v8:1.3141 v9:1.6651] +step:13000/20000 val_loss:2.1278 val_bpb:1.2602 train_time:553921ms step_avg:42.61ms +step:13200/20000 train_loss:2.2260 train_time:562411ms step_avg:42.61ms +step:13200 shared0_alpha:mean=0.543,std=0.105 shared1_alpha:mean=0.776,std=0.085 shared2_alpha:mean=0.662,std=0.091 shared3_alpha:mean=0.462,std=0.062 eff_mlp_scale:[v0:166.5168 v1:147.2642 v2:233.6956 v3:229.0826 v4:215.5730 v5:182.4341 v6:235.5651 v7:254.3570 v8:217.0851 v9:371.3537] eff_attn_scale:[v0:0.2199 v1:0.5701 v2:0.8196 v3:1.2387 v4:1.1397 v5:0.8728 v6:0.9835 v7:1.1750 v8:1.3287 v9:1.6681] +step:13200/20000 val_loss:2.1160 val_bpb:1.2532 train_time:562422ms step_avg:42.61ms +step:13400/20000 train_loss:2.0886 train_time:570909ms step_avg:42.61ms +step:13400 shared0_alpha:mean=0.542,std=0.105 shared1_alpha:mean=0.776,std=0.086 shared2_alpha:mean=0.662,std=0.091 shared3_alpha:mean=0.459,std=0.062 eff_mlp_scale:[v0:167.3332 v1:147.7490 v2:236.7683 v3:230.6978 v4:216.4044 v5:184.3049 v6:237.6637 v7:258.1160 v8:218.2668 v9:373.5216] eff_attn_scale:[v0:0.2180 v1:0.5551 v2:0.8276 v3:1.2448 v4:1.1472 v5:0.8654 v6:0.9821 v7:1.1867 v8:1.3552 v9:1.7137] +step:13400/20000 val_loss:2.1083 val_bpb:1.2487 train_time:570922ms step_avg:42.61ms +step:13600/20000 train_loss:1.9538 train_time:579413ms step_avg:42.60ms +step:13600 shared0_alpha:mean=0.542,std=0.105 shared1_alpha:mean=0.777,std=0.086 shared2_alpha:mean=0.661,std=0.091 shared3_alpha:mean=0.457,std=0.062 eff_mlp_scale:[v0:167.6753 v1:148.2091 v2:237.5709 v3:233.3946 v4:217.1757 v5:184.2325 v6:238.6256 v7:259.4693 v8:219.1671 v9:375.4492] eff_attn_scale:[v0:0.2182 v1:0.5618 v2:0.8246 v3:1.2547 v4:1.1635 v5:0.8647 v6:0.9895 v7:1.2019 v8:1.3675 v9:1.7123] +step:13600/20000 val_loss:2.1018 val_bpb:1.2448 train_time:579425ms step_avg:42.60ms +step:13800/20000 train_loss:2.0379 train_time:587925ms step_avg:42.60ms +step:13800 shared0_alpha:mean=0.541,std=0.105 shared1_alpha:mean=0.776,std=0.086 shared2_alpha:mean=0.660,std=0.091 shared3_alpha:mean=0.455,std=0.062 eff_mlp_scale:[v0:167.6242 v1:148.1596 v2:238.2047 v3:232.2239 v4:217.5854 v5:184.3320 v6:239.3322 v7:260.1901 v8:219.8273 v9:376.7473] eff_attn_scale:[v0:0.2178 v1:0.5522 v2:0.8171 v3:1.2409 v4:1.1518 v5:0.8512 v6:0.9816 v7:1.1885 v8:1.3606 v9:1.7156] +step:13800/20000 val_loss:2.0914 val_bpb:1.2387 train_time:587936ms step_avg:42.60ms +step:14000/20000 train_loss:2.0888 train_time:596429ms step_avg:42.60ms +step:14000 shared0_alpha:mean=0.540,std=0.105 shared1_alpha:mean=0.776,std=0.086 shared2_alpha:mean=0.660,std=0.091 shared3_alpha:mean=0.454,std=0.062 eff_mlp_scale:[v0:167.5780 v1:148.1671 v2:238.3079 v3:232.4329 v4:217.8412 v5:184.4296 v6:239.5521 v7:260.4415 v8:220.0159 v9:377.4643] eff_attn_scale:[v0:0.2184 v1:0.5498 v2:0.8140 v3:1.2422 v4:1.1505 v5:0.8587 v6:0.9838 v7:1.1897 v8:1.3662 v9:1.7104] +step:14000/20000 val_loss:2.0843 val_bpb:1.2345 train_time:596441ms step_avg:42.60ms +step:14085/20000 val_loss:2.0823 val_bpb:1.2333 train_time:600029ms step_avg:42.60ms +stopping_early: wallclock_cap train_time:600029ms step:14085/20000 +peak memory allocated: 9959 MiB reserved: 10282 MiB +Serialized model: 61961414 bytes +Code size: 66161 bytes +Total submission size: 62027575 bytes +Serialized model int8+zlib: 14584241 bytes (payload:15845568 raw_torch:15883087 payload_ratio:3.91x) +Total submission size int8+zlib: 14650402 bytes +final_int8_zlib_roundtrip val_loss:2.0947 val_bpb:1.2406 eval_time:1336ms +final_int8_zlib_roundtrip_exact val_loss:2.09472386 val_bpb:1.24061346 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_M.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_M.txt new file mode 100644 index 0000000000..469bb36358 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_M.txt @@ -0,0 +1,1578 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) + share_attn_only = bool(int(os.environ.get("SHARE_ATTN_ONLY", "0"))) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0, use_bias: bool = False): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + if use_bias: + self.attn_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + self.mlp_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + else: + self.attn_beta = None + self.mlp_beta = None + + def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + ab = self.attn_beta[v] if self.attn_beta is not None else None + mb = self.mlp_beta[v] if self.mlp_beta is not None else None + return ag, mg, ab, mb + + +class SharedAttnLayer(nn.Module): + """Shared attention layer (mixing + attention only, no MLP) for attn-only sharing mode.""" + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_birkhoff_mix: bool = False, + ): + super().__init__() + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + return x + + +class UniqueMLP(nn.Module): + """Unique MLP per virtual shared position for attn-only sharing mode.""" + def __init__( + self, + dim: int, + mlp_mult: int, + use_peri_norm: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + if use_peri_norm: + self.mlp_out_norm = RMSNorm() + else: + self.mlp_norm = RMSNorm() + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: Tensor, + ts_mlp_gamma: Tensor | None = None, + ts_mlp_beta: Tensor | None = None) -> Tensor: + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + out = mlp_s * mlp_out + if ts_mlp_beta is not None: + out = out + ts_mlp_beta[None, None, :] + return out + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None, + ts_mlp_beta: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + if ts_mlp_beta is not None: + x = x + ts_mlp_beta[None, None, :] + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + use_timestep_bias: bool = False, + share_attn_only: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + self.share_attn_only = share_attn_only if self.use_recurrence else False + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + if self.share_attn_only: + shared_attn_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + rope_base=rope_base, qk_gain_init=qk_gain_init, + use_birkhoff_mix=use_birkhoff_mix, + ) + unique_mlp_kwargs = dict( + dim=model_dim, mlp_mult=mlp_mult, + use_peri_norm=use_peri_norm, + leaky_relu_slope=leaky_relu_slope, + ) + self.shared_attn_layers = nn.ModuleList([SharedAttnLayer(**shared_attn_kwargs) for _ in range(num_shared)]) + self.unique_mlps = nn.ModuleList([UniqueMLP(**unique_mlp_kwargs) for _ in range(num_shared * self.num_loops)]) + self.shared_blocks = nn.ModuleList() # empty — keeps diagnostics safe + else: + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.shared_attn_layers = nn.ModuleList() + self.unique_mlps = nn.ModuleList() + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None, None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + if self.share_attn_only: + vid = 0 + for _loop in range(self.num_loops): + for attn_layer in self.shared_attn_layers: + ag, mg, ab, mb = self._get_ts(v) + x = attn_layer(x, x0, ag, ab) + x = x + self.unique_mlps[vid](x, mg, mb) + vid += 1 + v += 1 + else: + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + for block in self.coda_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block/layer + if gpt.share_attn_only: + for i, layer in enumerate(gpt.shared_attn_layers): + if hasattr(layer, "resid_mix_logit"): + a = torch.sigmoid(layer.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + else: + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + effective_count = gpt.num_prelude + len(gpt.shared_blocks if not gpt.share_attn_only else gpt.shared_attn_layers) * gpt.num_loops + gpt.num_coda + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + + # Prelude blocks + for block in gpt.prelude_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Shared positions + if gpt.share_attn_only: + vid = 0 + for _loop in range(gpt.num_loops): + for layer in gpt.shared_attn_layers: + asc = layer.attn_scale.norm().item() + ms = gpt.unique_mlps[vid].mlp_scale.norm().item() + d = layer.attn_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + vid += 1 + v += 1 + else: + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Coda blocks + for block in gpt.coda_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + if gpt.timestep_scale is not None and gpt.timestep_scale.attn_beta is not None: + attn_bias_norms: list[str] = [] + mlp_bias_norms: list[str] = [] + for vi in range(effective_count): + ab_rms = gpt.timestep_scale.attn_beta[vi].norm().item() / gpt.timestep_scale.attn_beta[vi].numel() ** 0.5 + mb_rms = gpt.timestep_scale.mlp_beta[vi].norm().item() / gpt.timestep_scale.mlp_beta[vi].numel() ** 0.5 + attn_bias_norms.append(f"v{vi}:{ab_rms:.4f}") + mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") + parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") + parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + use_timestep_bias=args.use_timestep_bias, + share_attn_only=args.share_attn_only, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + if base_model.share_attn_only: + block_named_params.extend(base_model.shared_attn_layers.named_parameters()) + block_named_params.extend(base_model.unique_mlps.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + num_shared = len(base_model.shared_attn_layers) if base_model.share_attn_only else len(base_model.shared_blocks) + eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda + shared_label = f"shared_attn:{num_shared}" if base_model.share_attn_only else f"shared:{num_shared}" + log0(f"recurrence:enabled prelude:{base_model.num_prelude} {shared_label} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Mon Mar 30 15:42:21 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 44C P0 126W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 36C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 35C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 43C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 44C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 42C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:19950640 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared_attn:4 loops:3 coda:1 effective_layers:14 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:14336 +compile_mode:fullgraph=True +warmup_step:1/20 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_N.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_N.txt new file mode 100644 index 0000000000..d2222dec84 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_N.txt @@ -0,0 +1,1828 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) + share_attn_only = bool(int(os.environ.get("SHARE_ATTN_ONLY", "0"))) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0, use_bias: bool = False): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + if use_bias: + self.attn_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + self.mlp_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + else: + self.attn_beta = None + self.mlp_beta = None + + def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + ab = self.attn_beta[v] if self.attn_beta is not None else None + mb = self.mlp_beta[v] if self.mlp_beta is not None else None + return ag, mg, ab, mb + + +class SharedAttnLayer(nn.Module): + """Shared attention layer (mixing + attention only, no MLP) for attn-only sharing mode.""" + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_birkhoff_mix: bool = False, + ): + super().__init__() + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + return x + + +class UniqueMLP(nn.Module): + """Unique MLP per virtual shared position for attn-only sharing mode.""" + def __init__( + self, + dim: int, + mlp_mult: int, + use_peri_norm: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + if use_peri_norm: + self.mlp_out_norm = RMSNorm() + else: + self.mlp_norm = RMSNorm() + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: Tensor, + ts_mlp_gamma: Tensor | None = None, + ts_mlp_beta: Tensor | None = None) -> Tensor: + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + out = mlp_s * mlp_out + if ts_mlp_beta is not None: + out = out + ts_mlp_beta[None, None, :] + return out + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None, + ts_mlp_beta: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + if ts_mlp_beta is not None: + x = x + ts_mlp_beta[None, None, :] + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + use_timestep_bias: bool = False, + share_attn_only: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + self.share_attn_only = share_attn_only if self.use_recurrence else False + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + if self.share_attn_only: + shared_attn_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + rope_base=rope_base, qk_gain_init=qk_gain_init, + use_birkhoff_mix=use_birkhoff_mix, + ) + unique_mlp_kwargs = dict( + dim=model_dim, mlp_mult=mlp_mult, + use_peri_norm=use_peri_norm, + leaky_relu_slope=leaky_relu_slope, + ) + self.shared_attn_layers = nn.ModuleList([SharedAttnLayer(**shared_attn_kwargs) for _ in range(num_shared)]) + self.unique_mlps = nn.ModuleList([UniqueMLP(**unique_mlp_kwargs) for _ in range(num_shared * self.num_loops)]) + self.shared_blocks = nn.ModuleList() # empty — keeps diagnostics safe + else: + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.shared_attn_layers = nn.ModuleList() + self.unique_mlps = nn.ModuleList() + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None, None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + if self.share_attn_only: + vid = 0 + for _loop in range(self.num_loops): + for attn_layer in self.shared_attn_layers: + ag, mg, ab, mb = self._get_ts(v) + x = attn_layer(x, x0, ag, ab) + x = x + self.unique_mlps[vid](x, mg, mb) + vid += 1 + v += 1 + else: + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + for block in self.coda_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block/layer + if gpt.share_attn_only: + for i, layer in enumerate(gpt.shared_attn_layers): + if hasattr(layer, "resid_mix_logit"): + a = torch.sigmoid(layer.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + else: + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + effective_count = gpt.num_prelude + len(gpt.shared_blocks if not gpt.share_attn_only else gpt.shared_attn_layers) * gpt.num_loops + gpt.num_coda + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + + # Prelude blocks + for block in gpt.prelude_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Shared positions + if gpt.share_attn_only: + vid = 0 + for _loop in range(gpt.num_loops): + for layer in gpt.shared_attn_layers: + asc = layer.attn_scale.norm().item() + ms = gpt.unique_mlps[vid].mlp_scale.norm().item() + d = layer.attn_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + vid += 1 + v += 1 + else: + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Coda blocks + for block in gpt.coda_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + if gpt.timestep_scale is not None and gpt.timestep_scale.attn_beta is not None: + attn_bias_norms: list[str] = [] + mlp_bias_norms: list[str] = [] + for vi in range(effective_count): + ab_rms = gpt.timestep_scale.attn_beta[vi].norm().item() / gpt.timestep_scale.attn_beta[vi].numel() ** 0.5 + mb_rms = gpt.timestep_scale.mlp_beta[vi].norm().item() / gpt.timestep_scale.mlp_beta[vi].numel() ** 0.5 + attn_bias_norms.append(f"v{vi}:{ab_rms:.4f}") + mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") + parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") + parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + use_timestep_bias=args.use_timestep_bias, + share_attn_only=args.share_attn_only, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + if base_model.share_attn_only: + block_named_params.extend(base_model.shared_attn_layers.named_parameters()) + block_named_params.extend(base_model.unique_mlps.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + num_shared = len(base_model.shared_attn_layers) if base_model.share_attn_only else len(base_model.shared_blocks) + eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda + shared_label = f"shared_attn:{num_shared}" if base_model.share_attn_only else f"shared:{num_shared}" + log0(f"recurrence:enabled prelude:{base_model.num_prelude} {shared_label} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Mon Mar 30 15:43:42 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 34C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 33C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 38C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 39C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 34C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 38C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11564080 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:2 coda:1 effective_layers:10 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:20480 +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9377 train_time:24ms step_avg:23.55ms +step:2/20000 train_loss:9.3764 train_time:61ms step_avg:30.50ms +step:3/20000 train_loss:8.1030 train_time:102ms step_avg:34.04ms +step:4/20000 train_loss:9.5206 train_time:143ms step_avg:35.84ms +step:5/20000 train_loss:8.7272 train_time:185ms step_avg:36.98ms +step:6/20000 train_loss:8.2118 train_time:226ms step_avg:37.65ms +step:7/20000 train_loss:7.0584 train_time:268ms step_avg:38.22ms +step:8/20000 train_loss:6.8365 train_time:309ms step_avg:38.63ms +step:9/20000 train_loss:5.9441 train_time:350ms step_avg:38.90ms +step:10/20000 train_loss:5.5813 train_time:392ms step_avg:39.23ms +step:200/20000 train_loss:2.7764 train_time:8413ms step_avg:42.06ms +step:200 shared0_alpha:mean=0.462,std=0.046 shared1_alpha:mean=0.477,std=0.038 shared2_alpha:mean=0.491,std=0.036 shared3_alpha:mean=0.517,std=0.039 eff_mlp_scale:[v0:32.9357 v1:27.5678 v2:27.9938 v3:30.1772 v4:31.9283 v5:31.6737 v6:29.8600 v7:32.8932 v8:35.3885 v9:59.2032] eff_attn_scale:[v0:14.7348 v1:10.3625 v2:10.0586 v3:9.1697 v4:9.6930 v5:10.3625 v6:9.5796 v7:9.2474 v8:10.3854 v9:14.9619] eff_attn_bias:[v0:0.1409 v1:0.1167 v2:0.1195 v3:0.1229 v4:0.1353 v5:0.1333 v6:0.1271 v7:0.1195 v8:0.1119 v9:0.1126] eff_mlp_bias:[v0:0.1181 v1:0.1091 v2:0.1132 v3:0.1181 v4:0.1271 v5:0.1139 v6:0.1091 v7:0.1022 v8:0.1153 v9:0.1975] +step:200/20000 val_loss:2.7728 val_bpb:1.6422 train_time:8472ms step_avg:42.36ms +step:400/20000 train_loss:2.3642 train_time:16926ms step_avg:42.31ms +step:400 shared0_alpha:mean=0.475,std=0.049 shared1_alpha:mean=0.495,std=0.042 shared2_alpha:mean=0.518,std=0.038 shared3_alpha:mean=0.549,std=0.042 eff_mlp_scale:[v0:41.8934 v1:35.1468 v2:36.8554 v3:39.7362 v4:38.8016 v5:42.4478 v6:38.3362 v7:40.2457 v8:39.1375 v9:74.8137] eff_attn_scale:[v0:6.5544 v1:5.6344 v2:5.5612 v3:5.1170 v4:5.4696 v5:5.9572 v6:5.2789 v7:4.7152 v8:5.0848 v9:9.0718] eff_attn_bias:[v0:0.1878 v1:0.1588 v2:0.1637 v3:0.1719 v4:0.1975 v5:0.1782 v6:0.1713 v7:0.1512 v8:0.1388 v9:0.1326] eff_mlp_bias:[v0:0.2113 v1:0.1457 v2:0.1422 v3:0.1505 v4:0.1671 v5:0.1498 v6:0.1429 v7:0.1422 v8:0.1457 v9:0.2762] +step:400/20000 val_loss:2.5662 val_bpb:1.5199 train_time:16951ms step_avg:42.38ms +step:600/20000 train_loss:2.5771 train_time:25416ms step_avg:42.36ms +step:600 shared0_alpha:mean=0.485,std=0.050 shared1_alpha:mean=0.504,std=0.044 shared2_alpha:mean=0.535,std=0.038 shared3_alpha:mean=0.571,std=0.043 eff_mlp_scale:[v0:48.1545 v1:39.7386 v2:42.0919 v3:45.8033 v4:42.8111 v5:48.6298 v6:42.9651 v7:44.3663 v8:41.4188 v9:89.7017] eff_attn_scale:[v0:3.1433 v1:3.2360 v2:3.3679 v3:3.2036 v4:3.5105 v5:3.4642 v6:3.2026 v7:2.7738 v8:3.0641 v9:6.1145] eff_attn_bias:[v0:0.2583 v1:0.2141 v2:0.2141 v3:0.2251 v4:0.2610 v5:0.2306 v6:0.2196 v7:0.1851 v8:0.1643 v9:0.1643] eff_mlp_bias:[v0:0.3315 v1:0.1975 v2:0.1864 v3:0.1947 v4:0.2224 v5:0.2003 v6:0.1892 v7:0.1892 v8:0.1809 v9:0.3011] +step:600/20000 val_loss:2.4810 val_bpb:1.4694 train_time:25442ms step_avg:42.40ms +step:800/20000 train_loss:2.3450 train_time:33923ms step_avg:42.40ms +step:800 shared0_alpha:mean=0.490,std=0.051 shared1_alpha:mean=0.512,std=0.045 shared2_alpha:mean=0.546,std=0.040 shared3_alpha:mean=0.583,std=0.044 eff_mlp_scale:[v0:53.4033 v1:43.8357 v2:46.2040 v3:49.4329 v4:45.7032 v5:53.5137 v6:46.0228 v7:47.3887 v8:44.0902 v9:102.5373] eff_attn_scale:[v0:1.9779 v1:2.3017 v2:2.4438 v3:2.3335 v4:2.6428 v5:2.4865 v6:2.2921 v7:1.9604 v8:2.2220 v9:4.7257] eff_attn_bias:[v0:0.3163 v1:0.2652 v2:0.2583 v3:0.2735 v4:0.3135 v5:0.2804 v6:0.2638 v7:0.2196 v8:0.1906 v9:0.2016] eff_mlp_bias:[v0:0.4364 v1:0.2472 v2:0.2348 v3:0.2306 v4:0.2748 v5:0.2527 v6:0.2334 v7:0.2320 v8:0.2141 v9:0.3190] +step:800/20000 val_loss:2.4251 val_bpb:1.4363 train_time:33946ms step_avg:42.43ms +step:1000/20000 train_loss:2.4229 train_time:42445ms step_avg:42.44ms +step:1000 shared0_alpha:mean=0.494,std=0.052 shared1_alpha:mean=0.515,std=0.047 shared2_alpha:mean=0.554,std=0.040 shared3_alpha:mean=0.591,std=0.045 eff_mlp_scale:[v0:58.7795 v1:47.1495 v2:49.8036 v3:53.1304 v4:48.2799 v5:57.3652 v6:49.0546 v7:50.0725 v8:46.8057 v9:113.1224] eff_attn_scale:[v0:1.4807 v1:1.8640 v2:1.9793 v3:1.9541 v4:2.1812 v5:2.0199 v6:1.8677 v7:1.6047 v8:1.8202 v9:3.8923] eff_attn_bias:[v0:0.3522 v1:0.3066 v2:0.2983 v3:0.3135 v4:0.3563 v5:0.3218 v6:0.2997 v7:0.2472 v8:0.2182 v9:0.2403] eff_mlp_bias:[v0:0.5248 v1:0.2928 v2:0.2776 v3:0.2638 v4:0.3190 v5:0.2969 v6:0.2707 v7:0.2721 v8:0.2431 v9:0.3287] +step:1000/20000 val_loss:2.3854 val_bpb:1.4128 train_time:42460ms step_avg:42.46ms +step:1200/20000 train_loss:2.4418 train_time:50967ms step_avg:42.47ms +step:1200 shared0_alpha:mean=0.495,std=0.054 shared1_alpha:mean=0.518,std=0.048 shared2_alpha:mean=0.559,std=0.041 shared3_alpha:mean=0.597,std=0.046 eff_mlp_scale:[v0:63.1393 v1:50.5603 v2:52.9076 v3:56.0283 v4:50.7355 v5:61.0769 v6:51.3741 v7:52.8938 v8:49.5996 v9:122.5207] eff_attn_scale:[v0:1.2060 v1:1.5849 v2:1.7103 v3:1.6985 v4:1.8290 v5:1.6928 v6:1.5944 v7:1.3759 v8:1.5287 v9:3.3487] eff_attn_bias:[v0:0.3839 v1:0.3439 v2:0.3315 v3:0.3480 v4:0.3950 v5:0.3591 v6:0.3356 v7:0.2748 v8:0.2431 v9:0.2762] eff_mlp_bias:[v0:0.6049 v1:0.3342 v2:0.3163 v3:0.2928 v4:0.3591 v5:0.3397 v6:0.3038 v7:0.3038 v8:0.2679 v9:0.3370] +step:1200/20000 val_loss:2.3580 val_bpb:1.3966 train_time:50981ms step_avg:42.48ms +step:1400/20000 train_loss:2.4856 train_time:59491ms step_avg:42.49ms +step:1400 shared0_alpha:mean=0.496,std=0.055 shared1_alpha:mean=0.520,std=0.048 shared2_alpha:mean=0.562,std=0.043 shared3_alpha:mean=0.601,std=0.047 eff_mlp_scale:[v0:67.2519 v1:53.3913 v2:55.3175 v3:58.4239 v4:53.0812 v5:64.1523 v6:54.1405 v7:55.2226 v8:52.3062 v9:130.5288] eff_attn_scale:[v0:1.0276 v1:1.4109 v2:1.5221 v3:1.5174 v4:1.6408 v5:1.4995 v6:1.4130 v7:1.2302 v8:1.3869 v9:3.0568] eff_attn_bias:[v0:0.4088 v1:0.3757 v2:0.3618 v3:0.3784 v4:0.4281 v5:0.3922 v6:0.3674 v7:0.2997 v8:0.2652 v9:0.3107] eff_mlp_bias:[v0:0.6767 v1:0.3729 v2:0.3522 v3:0.3204 v4:0.3895 v5:0.3784 v6:0.3328 v7:0.3342 v8:0.2914 v9:0.3466] +step:1400/20000 val_loss:2.3377 val_bpb:1.3845 train_time:59505ms step_avg:42.50ms +step:1600/20000 train_loss:2.1589 train_time:68011ms step_avg:42.51ms +step:1600 shared0_alpha:mean=0.496,std=0.056 shared1_alpha:mean=0.522,std=0.050 shared2_alpha:mean=0.565,std=0.044 shared3_alpha:mean=0.603,std=0.047 eff_mlp_scale:[v0:70.8960 v1:56.1940 v2:58.4105 v3:61.1004 v4:55.3384 v5:66.7568 v6:56.0101 v7:57.4344 v8:54.5479 v9:137.5262] eff_attn_scale:[v0:0.9057 v1:1.2853 v2:1.4083 v3:1.3897 v4:1.4900 v5:1.3642 v6:1.3086 v7:1.1153 v8:1.2541 v9:2.8152] eff_attn_bias:[v0:0.4254 v1:0.4088 v2:0.3867 v3:0.4060 v4:0.4585 v5:0.4226 v6:0.3922 v7:0.3232 v8:0.2859 v9:0.3453] eff_mlp_bias:[v0:0.7403 v1:0.4060 v2:0.3895 v3:0.3453 v4:0.4226 v5:0.4143 v6:0.3591 v7:0.3646 v8:0.3135 v9:0.3563] +step:1600/20000 val_loss:2.3239 val_bpb:1.3763 train_time:68026ms step_avg:42.52ms +step:1800/20000 train_loss:2.2601 train_time:76533ms step_avg:42.52ms +step:1800 shared0_alpha:mean=0.495,std=0.058 shared1_alpha:mean=0.522,std=0.051 shared2_alpha:mean=0.566,std=0.044 shared3_alpha:mean=0.606,std=0.049 eff_mlp_scale:[v0:74.4192 v1:58.6089 v2:60.7007 v3:62.9121 v4:57.3198 v5:69.8135 v6:58.2564 v7:59.6009 v8:57.3198 v9:144.9208] eff_attn_scale:[v0:0.7988 v1:1.1850 v2:1.3080 v3:1.3022 v4:1.3571 v5:1.2613 v6:1.2005 v7:1.0372 v8:1.1607 v9:2.6190] eff_attn_bias:[v0:0.4392 v1:0.4337 v2:0.4116 v3:0.4309 v4:0.4861 v5:0.4502 v6:0.4198 v7:0.3453 v8:0.3094 v9:0.3812] eff_mlp_bias:[v0:0.7955 v1:0.4392 v2:0.4171 v3:0.3729 v4:0.4530 v5:0.4502 v6:0.3867 v7:0.3922 v8:0.3328 v9:0.3646] +step:1800/20000 val_loss:2.3106 val_bpb:1.3685 train_time:76548ms step_avg:42.53ms +step:2000/20000 train_loss:2.3106 train_time:85055ms step_avg:42.53ms +step:2000 shared0_alpha:mean=0.496,std=0.059 shared1_alpha:mean=0.523,std=0.051 shared2_alpha:mean=0.568,std=0.046 shared3_alpha:mean=0.607,std=0.050 eff_mlp_scale:[v0:78.0648 v1:61.4887 v2:62.9866 v3:65.7977 v4:59.5880 v5:72.4688 v6:60.5003 v7:62.0017 v8:59.5880 v9:151.1933] eff_attn_scale:[v0:0.7183 v1:1.0985 v2:1.2271 v3:1.2163 v4:1.2587 v5:1.1886 v6:1.1287 v7:0.9664 v8:1.0813 v9:2.4434] eff_attn_bias:[v0:0.4530 v1:0.4558 v2:0.4364 v3:0.4558 v4:0.5165 v5:0.4751 v6:0.4419 v7:0.3701 v8:0.3273 v9:0.4143] eff_mlp_bias:[v0:0.8563 v1:0.4723 v2:0.4502 v3:0.4005 v4:0.4778 v5:0.4806 v6:0.4116 v7:0.4198 v8:0.3508 v9:0.3729] +step:2000/20000 val_loss:2.2959 val_bpb:1.3597 train_time:85070ms step_avg:42.54ms +step:2200/20000 train_loss:2.1399 train_time:93572ms step_avg:42.53ms +step:2200 shared0_alpha:mean=0.497,std=0.059 shared1_alpha:mean=0.525,std=0.053 shared2_alpha:mean=0.570,std=0.047 shared3_alpha:mean=0.609,std=0.050 eff_mlp_scale:[v0:81.0634 v1:63.8234 v2:65.8390 v3:68.3027 v4:61.5537 v5:74.9814 v6:62.8847 v7:64.4365 v8:61.9725 v9:157.5611] eff_attn_scale:[v0:0.6593 v1:1.0463 v2:1.1605 v3:1.1621 v4:1.1996 v5:1.1068 v6:1.0699 v7:0.9418 v8:1.0363 v9:2.3458] eff_attn_bias:[v0:0.4613 v1:0.4806 v2:0.4530 v3:0.4751 v4:0.5414 v5:0.4999 v6:0.4640 v7:0.3867 v8:0.3425 v9:0.4475] eff_mlp_bias:[v0:0.9060 v1:0.5055 v2:0.4751 v3:0.4254 v4:0.5027 v5:0.5110 v6:0.4364 v7:0.4447 v8:0.3646 v9:0.3839] +step:2200/20000 val_loss:2.2888 val_bpb:1.3556 train_time:93587ms step_avg:42.54ms +step:2400/20000 train_loss:2.2592 train_time:102087ms step_avg:42.54ms +step:2400 shared0_alpha:mean=0.496,std=0.060 shared1_alpha:mean=0.525,std=0.054 shared2_alpha:mean=0.571,std=0.048 shared3_alpha:mean=0.609,std=0.051 eff_mlp_scale:[v0:84.0586 v1:65.7476 v2:67.5126 v3:70.6702 v4:63.4980 v5:77.5369 v6:64.5216 v7:66.3078 v8:64.3504 v9:163.9369] eff_attn_scale:[v0:0.6312 v1:0.9820 v2:1.1067 v3:1.1131 v4:1.1313 v5:1.0407 v6:1.0233 v7:0.8926 v8:0.9728 v9:2.2404] eff_attn_bias:[v0:0.4696 v1:0.4999 v2:0.4723 v3:0.4944 v4:0.5635 v5:0.5193 v6:0.4806 v7:0.4088 v8:0.3618 v9:0.4778] eff_mlp_bias:[v0:0.9502 v1:0.5303 v2:0.5027 v3:0.4475 v4:0.5276 v5:0.5386 v6:0.4613 v7:0.4696 v8:0.3839 v9:0.3977] +step:2400/20000 val_loss:2.2775 val_bpb:1.3489 train_time:102102ms step_avg:42.54ms +step:2600/20000 train_loss:2.4832 train_time:110592ms step_avg:42.54ms +step:2600 shared0_alpha:mean=0.497,std=0.060 shared1_alpha:mean=0.526,std=0.054 shared2_alpha:mean=0.572,std=0.049 shared3_alpha:mean=0.609,std=0.053 eff_mlp_scale:[v0:87.1887 v1:68.6699 v2:70.3462 v3:73.0543 v4:65.5069 v5:80.1917 v6:66.8723 v7:68.6268 v8:66.8084 v9:169.2790] eff_attn_scale:[v0:0.6224 v1:0.9722 v2:1.1046 v3:1.0902 v4:1.0999 v5:1.0202 v6:0.9991 v7:0.8775 v8:0.9482 v9:2.1496] eff_attn_bias:[v0:0.4778 v1:0.5220 v2:0.4889 v3:0.5193 v4:0.5856 v5:0.5386 v6:0.4999 v7:0.4226 v8:0.3757 v9:0.5027] eff_mlp_bias:[v0:0.9999 v1:0.5552 v2:0.5276 v3:0.4751 v4:0.5497 v5:0.5662 v6:0.4778 v7:0.4917 v8:0.3950 v9:0.4088] +step:2600/20000 val_loss:2.3034 val_bpb:1.3642 train_time:110607ms step_avg:42.54ms +step:2800/20000 train_loss:2.3015 train_time:119103ms step_avg:42.54ms +step:2800 shared0_alpha:mean=0.495,std=0.061 shared1_alpha:mean=0.526,std=0.056 shared2_alpha:mean=0.572,std=0.048 shared3_alpha:mean=0.609,std=0.052 eff_mlp_scale:[v0:90.2368 v1:70.9658 v2:72.1151 v3:74.9927 v4:67.1409 v5:82.6378 v6:68.5972 v7:70.9512 v8:68.8962 v9:174.4148] eff_attn_scale:[v0:0.5599 v1:0.9016 v2:1.0467 v3:1.0421 v4:1.0543 v5:0.9682 v6:0.9555 v7:0.8399 v8:0.9067 v9:2.1050] eff_attn_bias:[v0:0.4889 v1:0.5414 v2:0.5055 v3:0.5359 v4:0.6049 v5:0.5552 v6:0.5138 v7:0.4419 v8:0.3895 v9:0.5303] eff_mlp_bias:[v0:1.0386 v1:0.5800 v2:0.5469 v3:0.4944 v4:0.5718 v5:0.5911 v6:0.4999 v7:0.5165 v8:0.4116 v9:0.4198] +step:2800/20000 val_loss:2.2648 val_bpb:1.3414 train_time:119118ms step_avg:42.54ms +step:3000/20000 train_loss:2.2862 train_time:127605ms step_avg:42.53ms +step:3000 shared0_alpha:mean=0.495,std=0.061 shared1_alpha:mean=0.527,std=0.056 shared2_alpha:mean=0.574,std=0.050 shared3_alpha:mean=0.610,std=0.053 eff_mlp_scale:[v0:92.6723 v1:72.7632 v2:74.4261 v3:77.3959 v4:68.9692 v5:84.5754 v6:70.8607 v7:72.8432 v8:71.1940 v9:179.7163] eff_attn_scale:[v0:0.5310 v1:0.8852 v2:1.0094 v3:1.0132 v4:1.0210 v5:0.9358 v6:0.9143 v7:0.8045 v8:0.8856 v9:1.9969] eff_attn_bias:[v0:0.4944 v1:0.5607 v2:0.5220 v3:0.5497 v4:0.6242 v5:0.5690 v6:0.5276 v7:0.4585 v8:0.4060 v9:0.5552] eff_mlp_bias:[v0:1.0828 v1:0.6104 v2:0.5718 v3:0.5165 v4:0.5966 v5:0.6160 v6:0.5193 v7:0.5386 v8:0.4254 v9:0.4337] +step:3000/20000 val_loss:2.2569 val_bpb:1.3367 train_time:127619ms step_avg:42.54ms +step:3200/20000 train_loss:2.2511 train_time:136108ms step_avg:42.53ms +step:3200 shared0_alpha:mean=0.495,std=0.062 shared1_alpha:mean=0.528,std=0.056 shared2_alpha:mean=0.575,std=0.049 shared3_alpha:mean=0.610,std=0.054 eff_mlp_scale:[v0:95.7158 v1:74.6197 v2:76.7844 v3:79.4133 v4:70.8797 v5:87.0563 v6:72.7194 v7:74.7963 v8:73.5885 v9:184.4669] eff_attn_scale:[v0:0.5073 v1:0.8447 v2:0.9977 v3:0.9878 v4:0.9953 v5:0.8993 v6:0.9027 v7:0.7903 v8:0.8709 v9:1.9539] eff_attn_bias:[v0:0.5027 v1:0.5828 v2:0.5359 v3:0.5690 v4:0.6463 v5:0.5883 v6:0.5441 v7:0.4723 v8:0.4171 v9:0.5800] eff_mlp_bias:[v0:1.1214 v1:0.6381 v2:0.5966 v3:0.5386 v4:0.6160 v5:0.6408 v6:0.5386 v7:0.5552 v8:0.4364 v9:0.4447] +step:3200/20000 val_loss:2.2522 val_bpb:1.3339 train_time:136122ms step_avg:42.54ms +step:3400/20000 train_loss:2.2188 train_time:144605ms step_avg:42.53ms +step:3400 shared0_alpha:mean=0.495,std=0.063 shared1_alpha:mean=0.529,std=0.057 shared2_alpha:mean=0.576,std=0.050 shared3_alpha:mean=0.610,std=0.055 eff_mlp_scale:[v0:98.2270 v1:77.0657 v2:78.4994 v3:81.9594 v4:72.3891 v5:89.1829 v6:74.8483 v7:77.2760 v8:75.5962 v9:188.7226] eff_attn_scale:[v0:0.4811 v1:0.8204 v2:0.9651 v3:0.9684 v4:0.9710 v5:0.8892 v6:0.8821 v7:0.7687 v8:0.8439 v9:1.9072] eff_attn_bias:[v0:0.5138 v1:0.6021 v2:0.5552 v3:0.5883 v4:0.6657 v5:0.6021 v6:0.5607 v7:0.4861 v8:0.4309 v9:0.6077] eff_mlp_bias:[v0:1.1656 v1:0.6657 v2:0.6187 v3:0.5607 v4:0.6408 v5:0.6629 v6:0.5580 v7:0.5773 v8:0.4502 v9:0.4585] +step:3400/20000 val_loss:2.2489 val_bpb:1.3319 train_time:144620ms step_avg:42.54ms +step:3600/20000 train_loss:2.1891 train_time:153121ms step_avg:42.53ms +step:3600 shared0_alpha:mean=0.494,std=0.064 shared1_alpha:mean=0.530,std=0.058 shared2_alpha:mean=0.577,std=0.051 shared3_alpha:mean=0.611,std=0.055 eff_mlp_scale:[v0:101.1877 v1:78.3690 v2:80.3507 v3:83.8093 v4:74.2892 v5:91.1040 v6:76.1947 v7:79.0743 v8:78.0036 v9:193.7955] eff_attn_scale:[v0:0.4631 v1:0.8037 v2:0.9510 v3:0.9570 v4:0.9671 v5:0.8570 v6:0.8585 v7:0.7576 v8:0.8347 v9:1.8671] eff_attn_bias:[v0:0.5193 v1:0.6187 v2:0.5662 v3:0.6021 v4:0.6850 v5:0.6160 v6:0.5745 v7:0.4999 v8:0.4447 v9:0.6270] eff_mlp_bias:[v0:1.2043 v1:0.6878 v2:0.6408 v3:0.5828 v4:0.6602 v5:0.6878 v6:0.5773 v7:0.5966 v8:0.4668 v9:0.4723] +step:3600/20000 val_loss:2.2430 val_bpb:1.3284 train_time:153136ms step_avg:42.54ms +step:3800/20000 train_loss:2.2852 train_time:161618ms step_avg:42.53ms +step:3800 shared0_alpha:mean=0.494,std=0.065 shared1_alpha:mean=0.531,std=0.058 shared2_alpha:mean=0.577,std=0.051 shared3_alpha:mean=0.611,std=0.056 eff_mlp_scale:[v0:103.6674 v1:80.7269 v2:82.5741 v3:85.7050 v4:75.7054 v5:93.1083 v6:77.9089 v7:80.9170 v8:80.4076 v9:198.8463] eff_attn_scale:[v0:0.4489 v1:0.7831 v2:0.9318 v3:0.9297 v4:0.9378 v5:0.8407 v6:0.8503 v7:0.7477 v8:0.8124 v9:1.8339] eff_attn_bias:[v0:0.5248 v1:0.6325 v2:0.5828 v3:0.6187 v4:0.7043 v5:0.6298 v6:0.5911 v7:0.5165 v8:0.4558 v9:0.6463] eff_mlp_bias:[v0:1.2374 v1:0.7182 v2:0.6657 v3:0.6021 v4:0.6822 v5:0.7126 v6:0.5939 v7:0.6160 v8:0.4806 v9:0.4861] +step:3800/20000 val_loss:2.2381 val_bpb:1.3255 train_time:161633ms step_avg:42.53ms +step:4000/20000 train_loss:2.2229 train_time:170115ms step_avg:42.53ms +step:4000 shared0_alpha:mean=0.493,std=0.065 shared1_alpha:mean=0.533,std=0.059 shared2_alpha:mean=0.579,std=0.051 shared3_alpha:mean=0.612,std=0.057 eff_mlp_scale:[v0:106.1655 v1:82.2494 v2:84.4490 v3:87.6807 v4:77.6246 v5:95.7904 v6:79.7312 v7:82.8365 v8:82.3869 v9:202.9972] eff_attn_scale:[v0:0.4311 v1:0.7685 v2:0.9209 v3:0.9158 v4:0.9324 v5:0.8207 v6:0.8348 v7:0.7258 v8:0.8028 v9:1.7960] eff_attn_bias:[v0:0.5331 v1:0.6463 v2:0.5994 v3:0.6298 v4:0.7237 v5:0.6436 v6:0.6049 v7:0.5303 v8:0.4668 v9:0.6684] eff_mlp_bias:[v0:1.2816 v1:0.7347 v2:0.6850 v3:0.6215 v4:0.7043 v5:0.7347 v6:0.6132 v7:0.6353 v8:0.4944 v9:0.4999] +step:4000/20000 val_loss:2.2334 val_bpb:1.3228 train_time:170128ms step_avg:42.53ms +step:4200/20000 train_loss:2.2376 train_time:178670ms step_avg:42.54ms +step:4200 shared0_alpha:mean=0.493,std=0.065 shared1_alpha:mean=0.534,std=0.058 shared2_alpha:mean=0.579,std=0.052 shared3_alpha:mean=0.612,std=0.057 eff_mlp_scale:[v0:108.5371 v1:84.0727 v2:86.4137 v3:89.2180 v4:79.4430 v5:97.2408 v6:81.6394 v7:85.2963 v8:84.2578 v9:207.1766] eff_attn_scale:[v0:0.4185 v1:0.7615 v2:0.8969 v3:0.9067 v4:0.9110 v5:0.8136 v6:0.8167 v7:0.7273 v8:0.7837 v9:1.7728] eff_attn_bias:[v0:0.5386 v1:0.6629 v2:0.6104 v3:0.6463 v4:0.7403 v5:0.6574 v6:0.6187 v7:0.5441 v8:0.4778 v9:0.6878] eff_mlp_bias:[v0:1.3148 v1:0.7568 v2:0.7043 v3:0.6408 v4:0.7237 v5:0.7568 v6:0.6298 v7:0.6546 v8:0.5110 v9:0.5165] +step:4200/20000 val_loss:2.2292 val_bpb:1.3202 train_time:178683ms step_avg:42.54ms +step:4400/20000 train_loss:2.1816 train_time:187160ms step_avg:42.54ms +step:4400 shared0_alpha:mean=0.492,std=0.066 shared1_alpha:mean=0.535,std=0.059 shared2_alpha:mean=0.580,std=0.052 shared3_alpha:mean=0.612,std=0.058 eff_mlp_scale:[v0:110.9901 v1:86.0447 v2:88.3308 v3:91.7975 v4:81.4642 v5:99.8733 v6:83.5040 v7:87.3316 v8:86.8301 v9:211.3171] eff_attn_scale:[v0:0.4103 v1:0.7485 v2:0.8904 v3:0.8978 v4:0.9022 v5:0.8097 v6:0.8059 v7:0.7183 v8:0.7845 v9:1.7136] eff_attn_bias:[v0:0.5469 v1:0.6767 v2:0.6242 v3:0.6602 v4:0.7568 v5:0.6684 v6:0.6325 v7:0.5580 v8:0.4917 v9:0.7071] eff_mlp_bias:[v0:1.3534 v1:0.7789 v2:0.7292 v3:0.6602 v4:0.7458 v5:0.7789 v6:0.6463 v7:0.6712 v8:0.5220 v9:0.5303] +step:4400/20000 val_loss:2.2302 val_bpb:1.3209 train_time:187175ms step_avg:42.54ms +step:4600/20000 train_loss:2.0392 train_time:195657ms step_avg:42.53ms +step:4600 shared0_alpha:mean=0.492,std=0.067 shared1_alpha:mean=0.536,std=0.059 shared2_alpha:mean=0.581,std=0.053 shared3_alpha:mean=0.612,std=0.058 eff_mlp_scale:[v0:113.5122 v1:88.3819 v2:90.3280 v3:93.8184 v4:82.9336 v5:101.8200 v6:85.4454 v7:89.3031 v8:88.8574 v9:216.3856] eff_attn_scale:[v0:0.3940 v1:0.7361 v2:0.8797 v3:0.8939 v4:0.8978 v5:0.7827 v6:0.7957 v7:0.7161 v8:0.7850 v9:1.6936] eff_attn_bias:[v0:0.5552 v1:0.6933 v2:0.6353 v3:0.6740 v4:0.7789 v5:0.6822 v6:0.6436 v7:0.5718 v8:0.5027 v9:0.7237] eff_mlp_bias:[v0:1.3921 v1:0.8065 v2:0.7458 v3:0.6822 v4:0.7623 v5:0.8010 v6:0.6657 v7:0.6878 v8:0.5359 v9:0.5441] +step:4600/20000 val_loss:2.2258 val_bpb:1.3182 train_time:195674ms step_avg:42.54ms +step:4800/20000 train_loss:2.3292 train_time:204157ms step_avg:42.53ms +step:4800 shared0_alpha:mean=0.492,std=0.067 shared1_alpha:mean=0.537,std=0.059 shared2_alpha:mean=0.581,std=0.053 shared3_alpha:mean=0.612,std=0.059 eff_mlp_scale:[v0:116.0095 v1:89.7244 v2:92.2093 v3:95.9295 v4:84.9303 v5:103.2873 v6:86.7853 v7:91.3614 v8:90.9254 v9:218.7925] eff_attn_scale:[v0:0.3892 v1:0.7210 v2:0.8763 v3:0.8903 v4:0.8944 v5:0.7811 v6:0.7877 v7:0.7123 v8:0.7722 v9:1.6600] eff_attn_bias:[v0:0.5607 v1:0.7043 v2:0.6491 v3:0.6878 v4:0.7900 v5:0.6961 v6:0.6574 v7:0.5856 v8:0.5165 v9:0.7403] eff_mlp_bias:[v0:1.4253 v1:0.8231 v2:0.7679 v3:0.7016 v4:0.7734 v5:0.8176 v6:0.6822 v7:0.7071 v8:0.5497 v9:0.5552] +step:4800/20000 val_loss:2.2222 val_bpb:1.3161 train_time:204170ms step_avg:42.54ms +step:5000/20000 train_loss:2.0966 train_time:212651ms step_avg:42.53ms +step:5000 shared0_alpha:mean=0.491,std=0.068 shared1_alpha:mean=0.537,std=0.060 shared2_alpha:mean=0.582,std=0.054 shared3_alpha:mean=0.611,std=0.059 eff_mlp_scale:[v0:118.3944 v1:91.3857 v2:94.1895 v3:97.4031 v4:86.1391 v5:105.5662 v6:88.7076 v7:92.7893 v8:92.6877 v9:222.4789] eff_attn_scale:[v0:0.3782 v1:0.7167 v2:0.8698 v3:0.8733 v4:0.8898 v5:0.7719 v6:0.7912 v7:0.7015 v8:0.7725 v9:1.6387] eff_attn_bias:[v0:0.5662 v1:0.7182 v2:0.6629 v3:0.7016 v4:0.8065 v5:0.7071 v6:0.6684 v7:0.5994 v8:0.5276 v9:0.7568] eff_mlp_bias:[v0:1.4584 v1:0.8452 v2:0.7900 v3:0.7237 v4:0.8010 v5:0.8397 v6:0.6961 v7:0.7237 v8:0.5635 v9:0.5690] +step:5000/20000 val_loss:2.2167 val_bpb:1.3128 train_time:212665ms step_avg:42.53ms +step:5200/20000 train_loss:2.2379 train_time:221147ms step_avg:42.53ms +step:5200 shared0_alpha:mean=0.491,std=0.069 shared1_alpha:mean=0.539,std=0.060 shared2_alpha:mean=0.582,std=0.054 shared3_alpha:mean=0.612,std=0.060 eff_mlp_scale:[v0:120.9297 v1:93.7707 v2:96.7463 v3:100.1323 v4:88.3324 v5:107.5449 v6:90.6997 v7:94.9441 v8:95.4806 v9:226.7539] eff_attn_scale:[v0:0.3728 v1:0.7086 v2:0.8648 v3:0.8766 v4:0.8920 v5:0.7634 v6:0.7911 v7:0.7003 v8:0.7701 v9:1.6072] eff_attn_bias:[v0:0.5745 v1:0.7292 v2:0.6712 v3:0.7182 v4:0.8231 v5:0.7237 v6:0.6822 v7:0.6104 v8:0.5359 v9:0.7734] eff_mlp_bias:[v0:1.4916 v1:0.8618 v2:0.8065 v3:0.7403 v4:0.8176 v5:0.8563 v6:0.7126 v7:0.7403 v8:0.5745 v9:0.5800] +step:5200/20000 val_loss:2.2187 val_bpb:1.3140 train_time:221162ms step_avg:42.53ms +step:5400/20000 train_loss:2.2484 train_time:229636ms step_avg:42.53ms +step:5400 shared0_alpha:mean=0.491,std=0.069 shared1_alpha:mean=0.540,std=0.060 shared2_alpha:mean=0.583,std=0.054 shared3_alpha:mean=0.612,std=0.060 eff_mlp_scale:[v0:123.4227 v1:95.2113 v2:98.3279 v3:101.8444 v4:89.6270 v5:109.6534 v6:92.7237 v7:97.1197 v8:97.3535 v9:230.8102] eff_attn_scale:[v0:0.3684 v1:0.7092 v2:0.8543 v3:0.8585 v4:0.8708 v5:0.7595 v6:0.7718 v7:0.6934 v8:0.7554 v9:1.6088] eff_attn_bias:[v0:0.5828 v1:0.7403 v2:0.6822 v3:0.7292 v4:0.8397 v5:0.7347 v6:0.6905 v7:0.6242 v8:0.5469 v9:0.7900] eff_mlp_bias:[v0:1.5247 v1:0.8839 v2:0.8231 v3:0.7623 v4:0.8342 v5:0.8728 v6:0.7237 v7:0.7568 v8:0.5856 v9:0.5939] +step:5400/20000 val_loss:2.2140 val_bpb:1.3113 train_time:229651ms step_avg:42.53ms +step:5600/20000 train_loss:2.2519 train_time:238123ms step_avg:42.52ms +step:5600 shared0_alpha:mean=0.491,std=0.068 shared1_alpha:mean=0.541,std=0.062 shared2_alpha:mean=0.584,std=0.055 shared3_alpha:mean=0.612,std=0.061 eff_mlp_scale:[v0:125.9326 v1:97.1565 v2:100.1363 v3:103.9969 v4:90.9645 v5:111.7299 v6:94.4876 v7:99.2216 v8:99.8011 v9:234.8268] eff_attn_scale:[v0:0.3537 v1:0.7023 v2:0.8580 v3:0.8810 v4:0.8918 v5:0.7616 v6:0.7893 v7:0.6991 v8:0.7651 v9:1.5759] eff_attn_bias:[v0:0.5911 v1:0.7513 v2:0.6961 v3:0.7403 v4:0.8507 v5:0.7458 v6:0.7043 v7:0.6353 v8:0.5580 v9:0.8065] eff_mlp_bias:[v0:1.5468 v1:0.9005 v2:0.8397 v3:0.7789 v4:0.8563 v5:0.8949 v6:0.7403 v7:0.7734 v8:0.5966 v9:0.6077] +step:5600/20000 val_loss:2.2149 val_bpb:1.3118 train_time:238139ms step_avg:42.52ms +step:5800/20000 train_loss:2.2098 train_time:246627ms step_avg:42.52ms +step:5800 shared0_alpha:mean=0.490,std=0.070 shared1_alpha:mean=0.543,std=0.062 shared2_alpha:mean=0.585,std=0.055 shared3_alpha:mean=0.612,std=0.061 eff_mlp_scale:[v0:128.4237 v1:98.5463 v2:102.0411 v3:105.6529 v4:92.9493 v5:113.2466 v6:95.8254 v7:101.3625 v8:101.8767 v9:238.7473] eff_attn_scale:[v0:0.3503 v1:0.6995 v2:0.8492 v3:0.8687 v4:0.8740 v5:0.7540 v6:0.7667 v7:0.6940 v8:0.7533 v9:1.5771] eff_attn_bias:[v0:0.5994 v1:0.7679 v2:0.7126 v3:0.7568 v4:0.8673 v5:0.7568 v6:0.7126 v7:0.6491 v8:0.5662 v9:0.8231] eff_mlp_bias:[v0:1.5910 v1:0.9170 v2:0.8618 v3:0.7955 v4:0.8728 v5:0.9060 v6:0.7568 v7:0.7900 v8:0.6104 v9:0.6187] +step:5800/20000 val_loss:2.2115 val_bpb:1.3097 train_time:246641ms step_avg:42.52ms +step:6000/20000 train_loss:2.2821 train_time:255117ms step_avg:42.52ms +step:6000 shared0_alpha:mean=0.489,std=0.070 shared1_alpha:mean=0.544,std=0.062 shared2_alpha:mean=0.585,std=0.055 shared3_alpha:mean=0.611,std=0.060 eff_mlp_scale:[v0:130.8991 v1:100.5298 v2:104.0305 v3:107.8229 v4:94.8494 v5:115.3621 v6:97.7573 v7:103.4884 v8:103.8574 v9:240.9193] eff_attn_scale:[v0:0.3420 v1:0.6882 v2:0.8500 v3:0.8707 v4:0.8543 v5:0.7376 v6:0.7626 v7:0.6909 v8:0.7445 v9:1.5878] eff_attn_bias:[v0:0.6049 v1:0.7844 v2:0.7237 v3:0.7734 v4:0.8839 v5:0.7679 v6:0.7292 v7:0.6602 v8:0.5773 v9:0.8397] eff_mlp_bias:[v0:1.6131 v1:0.9391 v2:0.8839 v3:0.8176 v4:0.8894 v5:0.9226 v6:0.7679 v7:0.8065 v8:0.6215 v9:0.6270] +step:6000/20000 val_loss:2.2073 val_bpb:1.3073 train_time:255129ms step_avg:42.52ms +step:6200/20000 train_loss:2.1528 train_time:263700ms step_avg:42.53ms +step:6200 shared0_alpha:mean=0.488,std=0.070 shared1_alpha:mean=0.545,std=0.062 shared2_alpha:mean=0.586,std=0.056 shared3_alpha:mean=0.611,std=0.061 eff_mlp_scale:[v0:133.2908 v1:102.5332 v2:106.1389 v3:110.1315 v4:96.5427 v5:117.4975 v6:99.8023 v7:105.7481 v8:106.1970 v9:245.0684] eff_attn_scale:[v0:0.3305 v1:0.6880 v2:0.8396 v3:0.8581 v4:0.8719 v5:0.7465 v6:0.7575 v7:0.6846 v8:0.7521 v9:1.5290] eff_attn_bias:[v0:0.6132 v1:0.8010 v2:0.7403 v3:0.7844 v4:0.9005 v5:0.7844 v6:0.7403 v7:0.6712 v8:0.5883 v9:0.8507] eff_mlp_bias:[v0:1.6462 v1:0.9557 v2:0.9005 v3:0.8342 v4:0.9060 v5:0.9391 v6:0.7844 v7:0.8176 v8:0.6353 v9:0.6408] +step:6200/20000 val_loss:2.2080 val_bpb:1.3077 train_time:263716ms step_avg:42.53ms +step:6400/20000 train_loss:2.2315 train_time:272189ms step_avg:42.53ms +step:6400 shared0_alpha:mean=0.488,std=0.072 shared1_alpha:mean=0.547,std=0.063 shared2_alpha:mean=0.587,std=0.056 shared3_alpha:mean=0.611,std=0.062 eff_mlp_scale:[v0:135.7257 v1:104.3874 v2:108.1399 v3:112.2449 v4:98.4547 v5:119.4594 v6:101.2147 v7:107.8215 v8:108.1919 v9:247.0472] eff_attn_scale:[v0:0.3308 v1:0.6883 v2:0.8349 v3:0.8527 v4:0.8622 v5:0.7423 v6:0.7533 v7:0.6887 v8:0.7479 v9:1.5556] eff_attn_bias:[v0:0.6187 v1:0.8121 v2:0.7458 v3:0.8010 v4:0.9115 v5:0.7955 v6:0.7513 v7:0.6850 v8:0.5994 v9:0.8673] eff_mlp_bias:[v0:1.6794 v1:0.9778 v2:0.9170 v3:0.8507 v4:0.9226 v5:0.9557 v6:0.7955 v7:0.8397 v8:0.6436 v9:0.6491] +step:6400/20000 val_loss:2.2049 val_bpb:1.3059 train_time:272204ms step_avg:42.53ms +step:6600/20000 train_loss:2.1921 train_time:280677ms step_avg:42.53ms +step:6600 shared0_alpha:mean=0.487,std=0.071 shared1_alpha:mean=0.548,std=0.063 shared2_alpha:mean=0.587,std=0.056 shared3_alpha:mean=0.611,std=0.062 eff_mlp_scale:[v0:138.2906 v1:105.7762 v2:110.3418 v3:113.7470 v4:99.9497 v5:121.5301 v6:103.3445 v7:109.8439 v8:110.3270 v9:251.1176] eff_attn_scale:[v0:0.3257 v1:0.6828 v2:0.8431 v3:0.8565 v4:0.8674 v5:0.7412 v6:0.7607 v7:0.6909 v8:0.7523 v9:1.5309] eff_attn_bias:[v0:0.6242 v1:0.8231 v2:0.7568 v3:0.8065 v4:0.9281 v5:0.8065 v6:0.7623 v7:0.6933 v8:0.6104 v9:0.8839] eff_mlp_bias:[v0:1.7015 v1:0.9944 v2:0.9281 v3:0.8673 v4:0.9391 v5:0.9667 v6:0.8121 v7:0.8507 v8:0.6546 v9:0.6602] +step:6600/20000 val_loss:2.2006 val_bpb:1.3033 train_time:280692ms step_avg:42.53ms +step:6800/20000 train_loss:2.2571 train_time:289163ms step_avg:42.52ms +step:6800 shared0_alpha:mean=0.488,std=0.072 shared1_alpha:mean=0.549,std=0.064 shared2_alpha:mean=0.588,std=0.057 shared3_alpha:mean=0.611,std=0.062 eff_mlp_scale:[v0:140.5878 v1:106.9655 v2:111.9076 v3:115.9790 v4:101.4399 v5:122.8122 v6:104.8455 v7:112.0380 v8:112.4660 v9:255.2744] eff_attn_scale:[v0:0.3227 v1:0.6599 v2:0.8301 v3:0.8449 v4:0.8559 v5:0.7171 v6:0.7485 v7:0.6806 v8:0.7459 v9:1.4914] eff_attn_bias:[v0:0.6353 v1:0.8397 v2:0.7734 v3:0.8231 v4:0.9447 v5:0.8231 v6:0.7734 v7:0.7071 v8:0.6215 v9:0.9005] eff_mlp_bias:[v0:1.7346 v1:1.0054 v2:0.9447 v3:0.8839 v4:0.9557 v5:0.9833 v6:0.8231 v7:0.8673 v8:0.6657 v9:0.6712] +step:6800/20000 val_loss:2.2018 val_bpb:1.3040 train_time:289176ms step_avg:42.53ms +step:7000/20000 train_loss:2.2906 train_time:297652ms step_avg:42.52ms +step:7000 shared0_alpha:mean=0.487,std=0.072 shared1_alpha:mean=0.550,std=0.064 shared2_alpha:mean=0.588,std=0.057 shared3_alpha:mean=0.610,std=0.062 eff_mlp_scale:[v0:142.5612 v1:108.8100 v2:113.5587 v3:118.2004 v4:102.8796 v5:124.7612 v6:106.9756 v7:114.2225 v8:114.5578 v9:259.4164] eff_attn_scale:[v0:0.3150 v1:0.6891 v2:0.8424 v3:0.8635 v4:0.8824 v5:0.7293 v6:0.7601 v7:0.6832 v8:0.7474 v9:1.4954] eff_attn_bias:[v0:0.6436 v1:0.8507 v2:0.7789 v3:0.8342 v4:0.9557 v5:0.8286 v6:0.7844 v7:0.7182 v8:0.6298 v9:0.9115] eff_mlp_bias:[v0:1.7678 v1:1.0220 v2:0.9612 v3:0.8949 v4:0.9667 v5:0.9999 v6:0.8342 v7:0.8784 v8:0.6795 v9:0.6795] +step:7000/20000 val_loss:2.1979 val_bpb:1.3017 train_time:297667ms step_avg:42.52ms +step:7200/20000 train_loss:2.2674 train_time:306138ms step_avg:42.52ms +step:7200 shared0_alpha:mean=0.486,std=0.073 shared1_alpha:mean=0.552,std=0.065 shared2_alpha:mean=0.589,std=0.058 shared3_alpha:mean=0.610,std=0.063 eff_mlp_scale:[v0:144.8833 v1:110.1690 v2:115.5278 v3:119.6396 v4:105.0249 v5:126.8091 v6:108.8946 v7:115.6325 v8:116.8192 v9:261.5519] eff_attn_scale:[v0:0.3159 v1:0.6691 v2:0.8499 v3:0.8750 v4:0.8796 v5:0.7316 v6:0.7771 v7:0.7029 v8:0.7629 v9:1.4777] eff_attn_bias:[v0:0.6491 v1:0.8618 v2:0.7900 v3:0.8507 v4:0.9723 v5:0.8452 v6:0.7955 v7:0.7292 v8:0.6408 v9:0.9226] eff_mlp_bias:[v0:1.7899 v1:1.0386 v2:0.9723 v3:0.9170 v4:0.9833 v5:1.0109 v6:0.8507 v7:0.8894 v8:0.6905 v9:0.6878] +step:7200/20000 val_loss:2.2006 val_bpb:1.3033 train_time:306151ms step_avg:42.52ms +step:7400/20000 train_loss:2.1845 train_time:314623ms step_avg:42.52ms +step:7400 shared0_alpha:mean=0.486,std=0.074 shared1_alpha:mean=0.553,std=0.065 shared2_alpha:mean=0.589,std=0.058 shared3_alpha:mean=0.610,std=0.063 eff_mlp_scale:[v0:147.2649 v1:112.2504 v2:117.0819 v3:122.0544 v4:106.4398 v5:129.0301 v6:110.9491 v7:118.0052 v8:118.8955 v9:265.5865] eff_attn_scale:[v0:0.3044 v1:0.6628 v2:0.8286 v3:0.8592 v4:0.8641 v5:0.7202 v6:0.7520 v7:0.6817 v8:0.7441 v9:1.4811] eff_attn_bias:[v0:0.6574 v1:0.8784 v2:0.8010 v3:0.8618 v4:0.9833 v5:0.8563 v6:0.8065 v7:0.7403 v8:0.6491 v9:0.9336] eff_mlp_bias:[v0:1.8230 v1:1.0496 v2:0.9944 v3:0.9336 v4:0.9999 v5:1.0275 v6:0.8618 v7:0.9005 v8:0.7016 v9:0.6988] +step:7400/20000 val_loss:2.1954 val_bpb:1.3002 train_time:314637ms step_avg:42.52ms +step:7600/20000 train_loss:2.0705 train_time:323105ms step_avg:42.51ms +step:7600 shared0_alpha:mean=0.485,std=0.074 shared1_alpha:mean=0.554,std=0.065 shared2_alpha:mean=0.589,std=0.058 shared3_alpha:mean=0.609,std=0.063 eff_mlp_scale:[v0:149.5663 v1:113.5419 v2:119.1398 v3:123.6507 v4:107.9163 v5:130.4277 v6:112.3961 v7:120.1512 v8:121.0489 v9:267.8277] eff_attn_scale:[v0:0.3006 v1:0.6627 v2:0.8306 v3:0.8563 v4:0.8709 v5:0.7201 v6:0.7537 v7:0.6832 v8:0.7548 v9:1.4787] eff_attn_bias:[v0:0.6684 v1:0.8839 v2:0.8176 v3:0.8728 v4:0.9944 v5:0.8618 v6:0.8176 v7:0.7513 v8:0.6602 v9:0.9502] eff_mlp_bias:[v0:1.8451 v1:1.0662 v2:1.0054 v3:0.9502 v4:1.0165 v5:1.0386 v6:0.8728 v7:0.9170 v8:0.7126 v9:0.7071] +step:7600/20000 val_loss:2.1938 val_bpb:1.2993 train_time:323121ms step_avg:42.52ms +step:7800/20000 train_loss:2.2106 train_time:331593ms step_avg:42.51ms +step:7800 shared0_alpha:mean=0.485,std=0.074 shared1_alpha:mean=0.555,std=0.065 shared2_alpha:mean=0.590,std=0.059 shared3_alpha:mean=0.608,std=0.064 eff_mlp_scale:[v0:151.9361 v1:115.5007 v2:121.4911 v3:126.0109 v4:110.0907 v5:132.5034 v6:114.6785 v7:122.4779 v8:123.3477 v9:271.9074] eff_attn_scale:[v0:0.3004 v1:0.6730 v2:0.8295 v3:0.8635 v4:0.8729 v5:0.7173 v6:0.7528 v7:0.6936 v8:0.7571 v9:1.4751] eff_attn_bias:[v0:0.6795 v1:0.9005 v2:0.8231 v3:0.8839 v4:1.0109 v5:0.8728 v6:0.8231 v7:0.7623 v8:0.6684 v9:0.9612] eff_mlp_bias:[v0:1.8672 v1:1.0828 v2:1.0165 v3:0.9612 v4:1.0275 v5:1.0551 v6:0.8839 v7:0.9281 v8:0.7237 v9:0.7182] +step:7800/20000 val_loss:2.1912 val_bpb:1.2978 train_time:331606ms step_avg:42.51ms +step:8000/20000 train_loss:2.1795 train_time:340078ms step_avg:42.51ms +step:8000 shared0_alpha:mean=0.485,std=0.075 shared1_alpha:mean=0.558,std=0.065 shared2_alpha:mean=0.591,std=0.059 shared3_alpha:mean=0.609,std=0.064 eff_mlp_scale:[v0:154.6129 v1:117.5615 v2:123.8094 v3:127.8608 v4:111.6704 v5:134.1028 v6:116.9311 v7:124.8873 v8:125.6292 v9:273.8535] eff_attn_scale:[v0:0.3005 v1:0.6743 v2:0.8341 v3:0.8611 v4:0.8542 v5:0.7190 v6:0.7526 v7:0.6955 v8:0.7492 v9:1.4625] eff_attn_bias:[v0:0.6822 v1:0.9115 v2:0.8342 v3:0.8949 v4:1.0275 v5:0.8839 v6:0.8397 v7:0.7734 v8:0.6795 v9:0.9723] eff_mlp_bias:[v0:1.9003 v1:1.0993 v2:1.0330 v3:0.9833 v4:1.0441 v5:1.0662 v6:0.8949 v7:0.9391 v8:0.7347 v9:0.7237] +step:8000/20000 val_loss:2.1900 val_bpb:1.2970 train_time:340092ms step_avg:42.51ms +step:8200/20000 train_loss:2.2476 train_time:348570ms step_avg:42.51ms +step:8200 shared0_alpha:mean=0.484,std=0.075 shared1_alpha:mean=0.558,std=0.066 shared2_alpha:mean=0.591,std=0.060 shared3_alpha:mean=0.608,std=0.064 eff_mlp_scale:[v0:156.1599 v1:118.9358 v2:125.1992 v3:129.5206 v4:113.8751 v5:136.1815 v6:118.2758 v7:126.5224 v8:127.9628 v9:278.3010] eff_attn_scale:[v0:0.2907 v1:0.6674 v2:0.8398 v3:0.8756 v4:0.8656 v5:0.7160 v6:0.7621 v7:0.6948 v8:0.7598 v9:1.4642] eff_attn_bias:[v0:0.6878 v1:0.9170 v2:0.8452 v3:0.9060 v4:1.0386 v5:0.8949 v6:0.8452 v7:0.7789 v8:0.6878 v9:0.9833] eff_mlp_bias:[v0:1.9224 v1:1.1104 v2:1.0441 v3:0.9944 v4:1.0551 v5:1.0828 v6:0.9115 v7:0.9502 v8:0.7458 v9:0.7292] +step:8200/20000 val_loss:2.1870 val_bpb:1.2953 train_time:348585ms step_avg:42.51ms +step:8400/20000 train_loss:2.1999 train_time:357121ms step_avg:42.51ms +step:8400 shared0_alpha:mean=0.484,std=0.076 shared1_alpha:mean=0.560,std=0.066 shared2_alpha:mean=0.592,std=0.059 shared3_alpha:mean=0.608,std=0.063 eff_mlp_scale:[v0:158.9732 v1:120.2939 v2:127.4437 v3:131.8435 v4:114.6447 v5:138.2483 v6:119.8785 v7:128.8195 v8:129.4185 v9:282.3048] eff_attn_scale:[v0:0.2874 v1:0.6655 v2:0.8374 v3:0.8714 v4:0.8682 v5:0.7139 v6:0.7600 v7:0.6961 v8:0.7579 v9:1.4612] eff_attn_bias:[v0:0.6961 v1:0.9281 v2:0.8563 v3:0.9226 v4:1.0496 v5:0.9060 v6:0.8563 v7:0.7955 v8:0.6961 v9:0.9944] eff_mlp_bias:[v0:1.9445 v1:1.1270 v2:1.0551 v3:1.0054 v4:1.0717 v5:1.0938 v6:0.9226 v7:0.9667 v8:0.7568 v9:0.7403] +step:8400/20000 val_loss:2.1881 val_bpb:1.2959 train_time:357136ms step_avg:42.52ms +step:8600/20000 train_loss:2.2028 train_time:365598ms step_avg:42.51ms +step:8600 shared0_alpha:mean=0.484,std=0.077 shared1_alpha:mean=0.561,std=0.067 shared2_alpha:mean=0.592,std=0.060 shared3_alpha:mean=0.608,std=0.064 eff_mlp_scale:[v0:160.9742 v1:121.6602 v2:129.0998 v3:134.0649 v4:116.7374 v5:139.7285 v6:122.0579 v7:131.0180 v8:131.6273 v9:284.1735] eff_attn_scale:[v0:0.2908 v1:0.6628 v2:0.8569 v3:0.8717 v4:0.8777 v5:0.7155 v6:0.7693 v7:0.7002 v8:0.7668 v9:1.4600] eff_attn_bias:[v0:0.7071 v1:0.9391 v2:0.8673 v3:0.9336 v4:1.0607 v5:0.9170 v6:0.8728 v7:0.8065 v8:0.7071 v9:1.0054] eff_mlp_bias:[v0:1.9666 v1:1.1380 v2:1.0662 v3:1.0220 v4:1.0883 v5:1.1049 v6:0.9336 v7:0.9778 v8:0.7679 v9:0.7513] +step:8600/20000 val_loss:2.1851 val_bpb:1.2941 train_time:365613ms step_avg:42.51ms +step:8800/20000 train_loss:2.1713 train_time:374088ms step_avg:42.51ms +step:8800 shared0_alpha:mean=0.484,std=0.077 shared1_alpha:mean=0.563,std=0.067 shared2_alpha:mean=0.592,std=0.060 shared3_alpha:mean=0.607,std=0.064 eff_mlp_scale:[v0:163.5798 v1:122.8591 v2:131.2678 v3:135.7807 v4:118.3099 v5:141.6208 v6:124.1722 v7:133.3231 v8:133.9244 v9:288.1069] eff_attn_scale:[v0:0.2836 v1:0.6669 v2:0.8477 v3:0.8815 v4:0.8747 v5:0.7243 v6:0.7605 v7:0.7052 v8:0.7546 v9:1.4490] eff_attn_bias:[v0:0.7126 v1:0.9557 v2:0.8784 v3:0.9447 v4:1.0772 v5:0.9281 v6:0.8784 v7:0.8176 v8:0.7182 v9:1.0220] eff_mlp_bias:[v0:1.9887 v1:1.1546 v2:1.0828 v3:1.0386 v4:1.1049 v5:1.1214 v6:0.9447 v7:0.9888 v8:0.7789 v9:0.7568] +step:8800/20000 val_loss:2.1837 val_bpb:1.2933 train_time:374102ms step_avg:42.51ms +step:9000/20000 train_loss:2.0897 train_time:382567ms step_avg:42.51ms +step:9000 shared0_alpha:mean=0.483,std=0.078 shared1_alpha:mean=0.564,std=0.067 shared2_alpha:mean=0.593,std=0.060 shared3_alpha:mean=0.607,std=0.064 eff_mlp_scale:[v0:166.3799 v1:124.9419 v2:132.9020 v3:138.2411 v4:119.9367 v5:143.8356 v6:125.7503 v7:135.7614 v8:136.2917 v9:292.3562] eff_attn_scale:[v0:0.2862 v1:0.6707 v2:0.8447 v3:0.8961 v4:0.8873 v5:0.7373 v6:0.7621 v7:0.7082 v8:0.7751 v9:1.4701] eff_attn_bias:[v0:0.7237 v1:0.9667 v2:0.8894 v3:0.9557 v4:1.0883 v5:0.9391 v6:0.8949 v7:0.8286 v8:0.7292 v9:1.0330] eff_mlp_bias:[v0:2.0108 v1:1.1656 v2:1.0993 v3:1.0496 v4:1.1159 v5:1.1325 v6:0.9557 v7:0.9999 v8:0.7844 v9:0.7679] +step:9000/20000 val_loss:2.1839 val_bpb:1.2934 train_time:382582ms step_avg:42.51ms +step:9200/20000 train_loss:2.1529 train_time:391051ms step_avg:42.51ms +step:9200 shared0_alpha:mean=0.483,std=0.078 shared1_alpha:mean=0.566,std=0.067 shared2_alpha:mean=0.594,std=0.061 shared3_alpha:mean=0.607,std=0.064 eff_mlp_scale:[v0:168.1808 v1:126.9594 v2:135.1729 v3:139.9743 v4:122.0191 v5:145.3593 v6:127.9637 v7:138.0996 v8:138.4917 v9:296.4727] eff_attn_scale:[v0:0.2805 v1:0.6722 v2:0.8589 v3:0.8891 v4:0.8906 v5:0.7164 v6:0.7711 v7:0.7170 v8:0.7738 v9:1.4566] eff_attn_bias:[v0:0.7292 v1:0.9778 v2:0.9005 v3:0.9612 v4:1.0993 v5:0.9502 v6:0.9060 v7:0.8342 v8:0.7347 v9:1.0441] eff_mlp_bias:[v0:2.0329 v1:1.1822 v2:1.1159 v3:1.0662 v4:1.1325 v5:1.1435 v6:0.9667 v7:1.0109 v8:0.7955 v9:0.7734] +step:9200/20000 val_loss:2.1820 val_bpb:1.2923 train_time:391067ms step_avg:42.51ms +step:9400/20000 train_loss:2.2031 train_time:399551ms step_avg:42.51ms +step:9400 shared0_alpha:mean=0.483,std=0.078 shared1_alpha:mean=0.567,std=0.067 shared2_alpha:mean=0.594,std=0.061 shared3_alpha:mean=0.606,std=0.065 eff_mlp_scale:[v0:170.1099 v1:127.7997 v2:136.6171 v3:142.2142 v4:123.5960 v5:147.5562 v6:129.3631 v7:140.3264 v8:140.8133 v9:298.2728] eff_attn_scale:[v0:0.2773 v1:0.6657 v2:0.8483 v3:0.8968 v4:0.8984 v5:0.7186 v6:0.7800 v7:0.7223 v8:0.7855 v9:1.4635] eff_attn_bias:[v0:0.7347 v1:0.9888 v2:0.9060 v3:0.9778 v4:1.1104 v5:0.9612 v6:0.9170 v7:0.8452 v8:0.7458 v9:1.0551] eff_mlp_bias:[v0:2.0550 v1:1.1877 v2:1.1214 v3:1.0772 v4:1.1490 v5:1.1546 v6:0.9778 v7:1.0220 v8:0.8065 v9:0.7844] +step:9400/20000 val_loss:2.1811 val_bpb:1.2918 train_time:399564ms step_avg:42.51ms +step:9600/20000 train_loss:2.2155 train_time:408036ms step_avg:42.50ms +step:9600 shared0_alpha:mean=0.482,std=0.079 shared1_alpha:mean=0.568,std=0.068 shared2_alpha:mean=0.595,std=0.061 shared3_alpha:mean=0.605,std=0.065 eff_mlp_scale:[v0:171.7990 v1:129.1932 v2:138.2863 v3:143.8646 v4:124.5939 v5:149.6903 v6:131.5852 v7:142.5970 v8:143.1900 v9:302.4401] eff_attn_scale:[v0:0.2770 v1:0.6554 v2:0.8582 v3:0.8843 v4:0.9023 v5:0.7170 v6:0.7709 v7:0.7170 v8:0.7804 v9:1.4496] eff_attn_bias:[v0:0.7458 v1:1.0054 v2:0.9226 v3:0.9888 v4:1.1214 v5:0.9778 v6:0.9281 v7:0.8563 v8:0.7513 v9:1.0607] eff_mlp_bias:[v0:2.0661 v1:1.2043 v2:1.1325 v3:1.0883 v4:1.1601 v5:1.1656 v6:0.9888 v7:1.0330 v8:0.8176 v9:0.7900] +step:9600/20000 val_loss:2.1812 val_bpb:1.2918 train_time:408049ms step_avg:42.51ms +step:9800/20000 train_loss:2.1458 train_time:416519ms step_avg:42.50ms +step:9800 shared0_alpha:mean=0.482,std=0.080 shared1_alpha:mean=0.570,std=0.069 shared2_alpha:mean=0.595,std=0.061 shared3_alpha:mean=0.606,std=0.065 eff_mlp_scale:[v0:174.4554 v1:131.2239 v2:140.3382 v3:145.6686 v4:126.9126 v5:151.2199 v6:133.5970 v7:145.0298 v8:145.6682 v9:304.0365] eff_attn_scale:[v0:0.2633 v1:0.6795 v2:0.8484 v3:0.8909 v4:0.9139 v5:0.7420 v6:0.7757 v7:0.7224 v8:0.7960 v9:1.4157] eff_attn_bias:[v0:0.7513 v1:1.0165 v2:0.9336 v3:0.9944 v4:1.1380 v5:0.9888 v6:0.9391 v7:0.8673 v8:0.7679 v9:1.0717] eff_mlp_bias:[v0:2.0882 v1:1.2209 v2:1.1490 v3:1.0993 v4:1.1711 v5:1.1767 v6:0.9999 v7:1.0441 v8:0.8286 v9:0.8010] +step:9800/20000 val_loss:2.1818 val_bpb:1.2922 train_time:416534ms step_avg:42.50ms +step:10000/20000 train_loss:2.1822 train_time:425002ms step_avg:42.50ms +step:10000 shared0_alpha:mean=0.483,std=0.080 shared1_alpha:mean=0.572,std=0.069 shared2_alpha:mean=0.596,std=0.062 shared3_alpha:mean=0.605,std=0.065 eff_mlp_scale:[v0:177.2707 v1:132.6599 v2:142.8461 v3:147.9492 v4:128.4827 v5:153.4077 v6:135.4255 v7:146.6626 v8:148.0071 v9:306.1425] eff_attn_scale:[v0:0.2687 v1:0.6843 v2:0.8721 v3:0.9221 v4:0.9368 v5:0.7473 v6:0.7883 v7:0.7416 v8:0.8122 v9:1.4451] eff_attn_bias:[v0:0.7568 v1:1.0275 v2:0.9447 v3:1.0109 v4:1.1546 v5:0.9999 v6:0.9502 v7:0.8784 v8:0.7734 v9:1.0883] eff_mlp_bias:[v0:2.1103 v1:1.2374 v2:1.1656 v3:1.1159 v4:1.1877 v5:1.1877 v6:1.0109 v7:1.0496 v8:0.8342 v9:0.8121] +step:10000/20000 val_loss:2.1807 val_bpb:1.2916 train_time:425017ms step_avg:42.50ms +step:10200/20000 train_loss:2.1343 train_time:433485ms step_avg:42.50ms +step:10200 shared0_alpha:mean=0.482,std=0.081 shared1_alpha:mean=0.573,std=0.069 shared2_alpha:mean=0.596,std=0.062 shared3_alpha:mean=0.604,std=0.065 eff_mlp_scale:[v0:178.9596 v1:134.0866 v2:145.2686 v3:150.6213 v4:130.1080 v5:154.9586 v6:137.7870 v7:149.3228 v8:150.4175 v9:310.0467] eff_attn_scale:[v0:0.2685 v1:0.6722 v2:0.8736 v3:0.9179 v4:0.9252 v5:0.7345 v6:0.7950 v7:0.7285 v8:0.8015 v9:1.4454] eff_attn_bias:[v0:0.7623 v1:1.0441 v2:0.9612 v3:1.0220 v4:1.1656 v5:1.0109 v6:0.9612 v7:0.8894 v8:0.7844 v9:1.0938] eff_mlp_bias:[v0:2.1324 v1:1.2485 v2:1.1767 v3:1.1270 v4:1.2043 v5:1.1988 v6:1.0220 v7:1.0662 v8:0.8452 v9:0.8176] +step:10200/20000 val_loss:2.1761 val_bpb:1.2888 train_time:433499ms step_avg:42.50ms +step:10400/20000 train_loss:2.1704 train_time:441971ms step_avg:42.50ms +step:10400 shared0_alpha:mean=0.482,std=0.081 shared1_alpha:mean=0.575,std=0.070 shared2_alpha:mean=0.597,std=0.062 shared3_alpha:mean=0.604,std=0.066 eff_mlp_scale:[v0:181.8062 v1:135.6558 v2:146.8773 v3:152.4303 v4:131.7468 v5:157.3098 v6:139.3451 v7:151.7761 v8:152.8519 v9:314.1966] eff_attn_scale:[v0:0.2651 v1:0.6849 v2:0.8717 v3:0.9236 v4:0.9328 v5:0.7297 v6:0.7979 v7:0.7428 v8:0.8131 v9:1.4456] eff_attn_bias:[v0:0.7734 v1:1.0551 v2:0.9667 v3:1.0386 v4:1.1767 v5:1.0220 v6:0.9778 v7:0.9005 v8:0.7955 v9:1.1049] eff_mlp_bias:[v0:2.1545 v1:1.2595 v2:1.1932 v3:1.1435 v4:1.2153 v5:1.2098 v6:1.0330 v7:1.0717 v8:0.8563 v9:0.8231] +step:10400/20000 val_loss:2.1756 val_bpb:1.2885 train_time:441985ms step_avg:42.50ms +step:10600/20000 train_loss:2.0433 train_time:450458ms step_avg:42.50ms +step:10600 shared0_alpha:mean=0.481,std=0.082 shared1_alpha:mean=0.576,std=0.070 shared2_alpha:mean=0.597,std=0.062 shared3_alpha:mean=0.604,std=0.065 eff_mlp_scale:[v0:183.3548 v1:136.8969 v2:147.9396 v3:154.2241 v4:133.9851 v5:158.6468 v6:141.6174 v7:154.2241 v8:155.2424 v9:316.0419] eff_attn_scale:[v0:0.2655 v1:0.6733 v2:0.8825 v3:0.9192 v4:0.9386 v5:0.7358 v6:0.7932 v7:0.7422 v8:0.8181 v9:1.4558] eff_attn_bias:[v0:0.7789 v1:1.0662 v2:0.9778 v3:1.0441 v4:1.1877 v5:1.0330 v6:0.9833 v7:0.9115 v8:0.8010 v9:1.1159] eff_mlp_bias:[v0:2.1655 v1:1.2706 v2:1.2043 v3:1.1546 v4:1.2264 v5:1.2209 v6:1.0441 v7:1.0828 v8:0.8673 v9:0.8286] +step:10600/20000 val_loss:2.1758 val_bpb:1.2886 train_time:450471ms step_avg:42.50ms +step:10800/20000 train_loss:2.2536 train_time:458946ms step_avg:42.50ms +step:10800 shared0_alpha:mean=0.481,std=0.081 shared1_alpha:mean=0.577,std=0.071 shared2_alpha:mean=0.597,std=0.063 shared3_alpha:mean=0.603,std=0.065 eff_mlp_scale:[v0:185.7749 v1:138.3009 v2:150.7885 v3:156.0175 v4:135.5514 v5:160.8150 v6:143.1536 v7:156.0175 v8:156.9543 v9:317.8026] eff_attn_scale:[v0:0.2591 v1:0.6667 v2:0.8858 v3:0.9139 v4:0.9381 v5:0.7289 v6:0.8017 v7:0.7340 v8:0.8091 v9:1.4577] eff_attn_bias:[v0:0.7789 v1:1.0772 v2:0.9888 v3:1.0551 v4:1.1988 v5:1.0441 v6:0.9944 v7:0.9226 v8:0.8121 v9:1.1214] eff_mlp_bias:[v0:2.1876 v1:1.2816 v2:1.2153 v3:1.1656 v4:1.2430 v5:1.2319 v6:1.0496 v7:1.0938 v8:0.8728 v9:0.8397] +step:10800/20000 val_loss:2.1749 val_bpb:1.2881 train_time:458961ms step_avg:42.50ms +step:11000/20000 train_loss:2.1828 train_time:467423ms step_avg:42.49ms +step:11000 shared0_alpha:mean=0.481,std=0.082 shared1_alpha:mean=0.579,std=0.071 shared2_alpha:mean=0.598,std=0.063 shared3_alpha:mean=0.603,std=0.065 eff_mlp_scale:[v0:187.9642 v1:140.4087 v2:152.4396 v3:158.1533 v4:136.9488 v5:162.4082 v6:145.3941 v7:158.8235 v8:159.1215 v9:321.8124] eff_attn_scale:[v0:0.2627 v1:0.6835 v2:0.8971 v3:0.9595 v4:0.9544 v5:0.7415 v6:0.8169 v7:0.7656 v8:0.8288 v9:1.4696] eff_attn_bias:[v0:0.7900 v1:1.0883 v2:0.9999 v3:1.0717 v4:1.2043 v5:1.0551 v6:1.0054 v7:0.9336 v8:0.8176 v9:1.1325] eff_mlp_bias:[v0:2.1987 v1:1.2982 v2:1.2264 v3:1.1767 v4:1.2540 v5:1.2430 v6:1.0607 v7:1.0993 v8:0.8839 v9:0.8452] +step:11000/20000 val_loss:2.1732 val_bpb:1.2871 train_time:467436ms step_avg:42.49ms +step:11200/20000 train_loss:2.1371 train_time:475904ms step_avg:42.49ms +step:11200 shared0_alpha:mean=0.481,std=0.082 shared1_alpha:mean=0.580,std=0.071 shared2_alpha:mean=0.598,std=0.063 shared3_alpha:mean=0.602,std=0.066 eff_mlp_scale:[v0:189.7691 v1:141.8777 v2:154.2266 v3:160.6670 v4:139.3189 v5:164.0054 v6:147.1283 v7:161.3421 v8:161.6625 v9:323.6354] eff_attn_scale:[v0:0.2565 v1:0.6806 v2:0.9102 v3:0.9416 v4:0.9501 v5:0.7388 v6:0.8101 v7:0.7642 v8:0.8301 v9:1.4622] eff_attn_bias:[v0:0.7955 v1:1.0993 v2:1.0165 v3:1.0828 v4:1.2209 v5:1.0662 v6:1.0165 v7:0.9447 v8:0.8286 v9:1.1435] eff_mlp_bias:[v0:2.2208 v1:1.3093 v2:1.2374 v3:1.1877 v4:1.2651 v5:1.2540 v6:1.0662 v7:1.1104 v8:0.8949 v9:0.8563] +step:11200/20000 val_loss:2.1721 val_bpb:1.2864 train_time:475917ms step_avg:42.49ms +step:11400/20000 train_loss:2.1227 train_time:484377ms step_avg:42.49ms +step:11400 shared0_alpha:mean=0.481,std=0.083 shared1_alpha:mean=0.581,std=0.071 shared2_alpha:mean=0.599,std=0.063 shared3_alpha:mean=0.602,std=0.065 eff_mlp_scale:[v0:192.3726 v1:143.2182 v2:156.7461 v3:163.0905 v4:140.2192 v5:166.1070 v6:149.5917 v7:163.7701 v8:164.6914 v9:325.2281] eff_attn_scale:[v0:0.2554 v1:0.6928 v2:0.9116 v3:0.9574 v4:0.9693 v5:0.7426 v6:0.8209 v7:0.7639 v8:0.8381 v9:1.4437] eff_attn_bias:[v0:0.8010 v1:1.1104 v2:1.0275 v3:1.0883 v4:1.2319 v5:1.0828 v6:1.0275 v7:0.9612 v8:0.8397 v9:1.1546] eff_mlp_bias:[v0:2.2318 v1:1.3203 v2:1.2485 v3:1.1988 v4:1.2816 v5:1.2595 v6:1.0772 v7:1.1214 v8:0.9005 v9:0.8618] +step:11400/20000 val_loss:2.1726 val_bpb:1.2867 train_time:484392ms step_avg:42.49ms +step:11600/20000 train_loss:2.1330 train_time:492865ms step_avg:42.49ms +step:11600 shared0_alpha:mean=0.481,std=0.083 shared1_alpha:mean=0.583,std=0.072 shared2_alpha:mean=0.599,std=0.064 shared3_alpha:mean=0.602,std=0.066 eff_mlp_scale:[v0:194.9825 v1:144.6659 v2:158.2664 v3:164.3005 v4:141.8622 v5:168.3386 v6:151.0725 v7:166.3543 v8:166.5049 v9:329.2878] eff_attn_scale:[v0:0.2562 v1:0.6943 v2:0.9088 v3:0.9578 v4:0.9790 v5:0.7536 v6:0.8239 v7:0.7642 v8:0.8471 v9:1.4596] eff_attn_bias:[v0:0.8065 v1:1.1270 v2:1.0330 v3:1.0993 v4:1.2430 v5:1.0938 v6:1.0386 v7:0.9723 v8:0.8507 v9:1.1601] eff_mlp_bias:[v0:2.2539 v1:1.3313 v2:1.2595 v3:1.2153 v4:1.2927 v5:1.2706 v6:1.0883 v7:1.1325 v8:0.9115 v9:0.8673] +step:11600/20000 val_loss:2.1716 val_bpb:1.2861 train_time:492878ms step_avg:42.49ms +step:11800/20000 train_loss:2.1629 train_time:501341ms step_avg:42.49ms +step:11800 shared0_alpha:mean=0.481,std=0.084 shared1_alpha:mean=0.584,std=0.072 shared2_alpha:mean=0.600,std=0.064 shared3_alpha:mean=0.601,std=0.066 eff_mlp_scale:[v0:196.3139 v1:146.0694 v2:160.2879 v3:167.0083 v4:144.4191 v5:170.5244 v6:153.6917 v7:168.3886 v8:169.2726 v9:333.1068] eff_attn_scale:[v0:0.2556 v1:0.6956 v2:0.9245 v3:0.9638 v4:0.9823 v5:0.7638 v6:0.8331 v7:0.7740 v8:0.8595 v9:1.4539] eff_attn_bias:[v0:0.8176 v1:1.1380 v2:1.0441 v3:1.1104 v4:1.2595 v5:1.1049 v6:1.0441 v7:0.9833 v8:0.8563 v9:1.1711] eff_mlp_bias:[v0:2.2760 v1:1.3424 v2:1.2761 v3:1.2264 v4:1.3093 v5:1.2816 v6:1.0993 v7:1.1435 v8:0.9226 v9:0.8784] +step:11800/20000 val_loss:2.1690 val_bpb:1.2846 train_time:501355ms step_avg:42.49ms +step:12000/20000 train_loss:2.1383 train_time:509822ms step_avg:42.49ms +step:12000 shared0_alpha:mean=0.480,std=0.085 shared1_alpha:mean=0.586,std=0.072 shared2_alpha:mean=0.599,std=0.064 shared3_alpha:mean=0.600,std=0.066 eff_mlp_scale:[v0:198.7707 v1:147.5491 v2:161.8137 v3:168.2198 v4:145.2074 v5:171.4760 v6:155.1819 v7:171.0003 v8:170.8720 v9:334.6900] eff_attn_scale:[v0:0.2494 v1:0.6875 v2:0.9460 v3:0.9727 v4:0.9919 v5:0.7504 v6:0.8392 v7:0.7822 v8:0.8596 v9:1.4289] eff_attn_bias:[v0:0.8176 v1:1.1490 v2:1.0551 v3:1.1214 v4:1.2706 v5:1.1159 v6:1.0551 v7:0.9888 v8:0.8673 v9:1.1822] eff_mlp_bias:[v0:2.2870 v1:1.3534 v2:1.2872 v3:1.2374 v4:1.3203 v5:1.2927 v6:1.1049 v7:1.1490 v8:0.9281 v9:0.8839] +step:12000/20000 val_loss:2.1682 val_bpb:1.2841 train_time:509836ms step_avg:42.49ms +step:12200/20000 train_loss:2.2803 train_time:518305ms step_avg:42.48ms +step:12200 shared0_alpha:mean=0.480,std=0.085 shared1_alpha:mean=0.588,std=0.072 shared2_alpha:mean=0.600,std=0.064 shared3_alpha:mean=0.600,std=0.066 eff_mlp_scale:[v0:201.4770 v1:148.9287 v2:163.5850 v3:170.7183 v4:146.9012 v5:173.6388 v6:157.5757 v7:173.5170 v8:173.4251 v9:336.4637] eff_attn_scale:[v0:0.2493 v1:0.6969 v2:0.9369 v3:0.9748 v4:0.9996 v5:0.7463 v6:0.8407 v7:0.7889 v8:0.8663 v9:1.4636] eff_attn_bias:[v0:0.8231 v1:1.1601 v2:1.0662 v3:1.1325 v4:1.2816 v5:1.1270 v6:1.0662 v7:0.9999 v8:0.8784 v9:1.1932] eff_mlp_bias:[v0:2.2981 v1:1.3645 v2:1.2982 v3:1.2485 v4:1.3313 v5:1.3037 v6:1.1104 v7:1.1601 v8:0.9336 v9:0.8894] +step:12200/20000 val_loss:2.1679 val_bpb:1.2840 train_time:518320ms step_avg:42.49ms +step:12400/20000 train_loss:1.9281 train_time:526837ms step_avg:42.49ms +step:12400 shared0_alpha:mean=0.481,std=0.085 shared1_alpha:mean=0.589,std=0.073 shared2_alpha:mean=0.601,std=0.065 shared3_alpha:mean=0.599,std=0.066 eff_mlp_scale:[v0:202.6757 v1:150.3434 v2:165.9824 v3:173.3130 v4:149.3008 v5:175.8480 v6:159.9344 v7:176.1311 v8:176.6955 v9:340.5024] eff_attn_scale:[v0:0.2523 v1:0.7055 v2:0.9441 v3:0.9888 v4:1.0126 v5:0.7646 v6:0.8569 v7:0.7921 v8:0.8731 v9:1.4610] eff_attn_bias:[v0:0.8286 v1:1.1711 v2:1.0772 v3:1.1490 v4:1.2927 v5:1.1325 v6:1.0772 v7:1.0165 v8:0.8839 v9:1.1988] eff_mlp_bias:[v0:2.3202 v1:1.3755 v2:1.3093 v3:1.2595 v4:1.3479 v5:1.3148 v6:1.1214 v7:1.1711 v8:0.9447 v9:0.9005] +step:12400/20000 val_loss:2.1679 val_bpb:1.2840 train_time:526851ms step_avg:42.49ms +step:12600/20000 train_loss:2.1625 train_time:535327ms step_avg:42.49ms +step:12600 shared0_alpha:mean=0.480,std=0.085 shared1_alpha:mean=0.591,std=0.073 shared2_alpha:mean=0.602,std=0.065 shared3_alpha:mean=0.599,std=0.067 eff_mlp_scale:[v0:205.5572 v1:151.9222 v2:168.0296 v3:175.3704 v4:150.3135 v5:178.2554 v6:161.9318 v7:178.9204 v8:179.2730 v9:342.1450] eff_attn_scale:[v0:0.2487 v1:0.7194 v2:0.9572 v3:1.0003 v4:1.0228 v5:0.7744 v6:0.8595 v7:0.8063 v8:0.8819 v9:1.4682] eff_attn_bias:[v0:0.8342 v1:1.1822 v2:1.0828 v3:1.1601 v4:1.3093 v5:1.1435 v6:1.0883 v7:1.0275 v8:0.8894 v9:1.2098] eff_mlp_bias:[v0:2.3312 v1:1.3866 v2:1.3203 v3:1.2706 v4:1.3590 v5:1.3258 v6:1.1380 v7:1.1822 v8:0.9502 v9:0.9060] +step:12600/20000 val_loss:2.1689 val_bpb:1.2845 train_time:535342ms step_avg:42.49ms +step:12800/20000 train_loss:2.1800 train_time:543796ms step_avg:42.48ms +step:12800 shared0_alpha:mean=0.481,std=0.086 shared1_alpha:mean=0.592,std=0.074 shared2_alpha:mean=0.602,std=0.065 shared3_alpha:mean=0.598,std=0.066 eff_mlp_scale:[v0:206.7098 v1:153.3298 v2:170.5374 v3:177.4912 v4:152.1298 v5:179.1110 v6:164.3981 v7:181.0696 v8:182.0001 v9:346.1584] eff_attn_scale:[v0:0.2457 v1:0.7043 v2:0.9714 v3:1.0048 v4:1.0228 v5:0.7684 v6:0.8675 v7:0.8110 v8:0.8872 v9:1.4530] eff_attn_bias:[v0:0.8397 v1:1.1932 v2:1.0883 v3:1.1711 v4:1.3203 v5:1.1601 v6:1.1049 v7:1.0441 v8:0.9005 v9:1.2153] eff_mlp_bias:[v0:2.3423 v1:1.3976 v2:1.3313 v3:1.2816 v4:1.3755 v5:1.3369 v6:1.1490 v7:1.1932 v8:0.9667 v9:0.9115] +step:12800/20000 val_loss:2.1671 val_bpb:1.2835 train_time:543812ms step_avg:42.49ms +step:13000/20000 train_loss:2.2595 train_time:552284ms step_avg:42.48ms +step:13000 shared0_alpha:mean=0.480,std=0.086 shared1_alpha:mean=0.593,std=0.074 shared2_alpha:mean=0.602,std=0.065 shared3_alpha:mean=0.597,std=0.067 eff_mlp_scale:[v0:209.5090 v1:154.8514 v2:172.3224 v3:179.4544 v4:154.6274 v5:181.4558 v6:166.8301 v7:183.7786 v8:183.3140 v9:347.7237] eff_attn_scale:[v0:0.2436 v1:0.7115 v2:0.9804 v3:1.0115 v4:1.0333 v5:0.7708 v6:0.8813 v7:0.8215 v8:0.8917 v9:1.4966] eff_attn_bias:[v0:0.8452 v1:1.2043 v2:1.1049 v3:1.1822 v4:1.3258 v5:1.1711 v6:1.1104 v7:1.0496 v8:0.9060 v9:1.2209] eff_mlp_bias:[v0:2.3644 v1:1.4087 v2:1.3424 v3:1.2982 v4:1.3811 v5:1.3424 v6:1.1601 v7:1.2043 v8:0.9723 v9:0.9170] +step:13000/20000 val_loss:2.1662 val_bpb:1.2829 train_time:552297ms step_avg:42.48ms +step:13200/20000 train_loss:2.2642 train_time:560758ms step_avg:42.48ms +step:13200 shared0_alpha:mean=0.479,std=0.086 shared1_alpha:mean=0.594,std=0.074 shared2_alpha:mean=0.602,std=0.065 shared3_alpha:mean=0.596,std=0.066 eff_mlp_scale:[v0:210.2067 v1:156.1116 v2:173.6324 v3:181.1249 v4:156.1897 v5:183.4995 v6:168.1202 v7:185.4719 v8:185.7391 v9:349.7395] eff_attn_scale:[v0:0.2425 v1:0.7174 v2:0.9833 v3:1.0273 v4:1.0480 v5:0.7822 v6:0.8844 v7:0.8312 v8:0.9058 v9:1.4839] eff_attn_bias:[v0:0.8507 v1:1.2098 v2:1.1159 v3:1.1877 v4:1.3369 v5:1.1822 v6:1.1159 v7:1.0607 v8:0.9170 v9:1.2319] eff_mlp_bias:[v0:2.3754 v1:1.4253 v2:1.3534 v3:1.3037 v4:1.3921 v5:1.3534 v6:1.1656 v7:1.2098 v8:0.9778 v9:0.9281] +step:13200/20000 val_loss:2.1547 val_bpb:1.2761 train_time:560773ms step_avg:42.48ms +step:13400/20000 train_loss:2.1306 train_time:569249ms step_avg:42.48ms +step:13400 shared0_alpha:mean=0.479,std=0.086 shared1_alpha:mean=0.594,std=0.075 shared2_alpha:mean=0.601,std=0.065 shared3_alpha:mean=0.596,std=0.066 eff_mlp_scale:[v0:211.1086 v1:156.7696 v2:174.3620 v3:182.3893 v4:157.0688 v5:184.2730 v6:168.8267 v7:188.2258 v8:188.1996 v9:354.3450] eff_attn_scale:[v0:0.2455 v1:0.7267 v2:0.9914 v3:1.0432 v4:1.0631 v5:0.7919 v6:0.8918 v7:0.8398 v8:0.9196 v9:1.4994] eff_attn_bias:[v0:0.8563 v1:1.2153 v2:1.1214 v3:1.1988 v4:1.3424 v5:1.1877 v6:1.1270 v7:1.0662 v8:0.9170 v9:1.2319] eff_mlp_bias:[v0:2.3754 v1:1.4253 v2:1.3590 v3:1.3093 v4:1.3976 v5:1.3534 v6:1.1711 v7:1.2153 v8:0.9833 v9:0.9336] +step:13400/20000 val_loss:2.1475 val_bpb:1.2718 train_time:569262ms step_avg:42.48ms +step:13600/20000 train_loss:1.9936 train_time:577733ms step_avg:42.48ms +step:13600 shared0_alpha:mean=0.478,std=0.086 shared1_alpha:mean=0.593,std=0.074 shared2_alpha:mean=0.600,std=0.065 shared3_alpha:mean=0.595,std=0.066 eff_mlp_scale:[v0:213.1614 v1:156.9691 v2:175.8217 v3:183.0316 v4:157.6895 v5:184.5075 v6:169.5672 v7:188.8886 v8:188.9433 v9:356.3730] eff_attn_scale:[v0:0.2452 v1:0.7320 v2:1.0023 v3:1.0541 v4:1.0794 v5:0.7879 v6:0.9015 v7:0.8444 v8:0.9244 v9:1.5127] eff_attn_bias:[v0:0.8563 v1:1.2153 v2:1.1270 v3:1.2043 v4:1.3479 v5:1.1932 v6:1.1325 v7:1.0662 v8:0.9226 v9:1.2374] eff_mlp_bias:[v0:2.3754 v1:1.4253 v2:1.3590 v3:1.3148 v4:1.4032 v5:1.3590 v6:1.1767 v7:1.2153 v8:0.9888 v9:0.9391] +step:13600/20000 val_loss:2.1417 val_bpb:1.2684 train_time:577746ms step_avg:42.48ms +step:13800/20000 train_loss:2.0734 train_time:586215ms step_avg:42.48ms +step:13800 shared0_alpha:mean=0.477,std=0.086 shared1_alpha:mean=0.593,std=0.074 shared2_alpha:mean=0.600,std=0.065 shared3_alpha:mean=0.594,std=0.066 eff_mlp_scale:[v0:213.3222 v1:157.2034 v2:176.0597 v3:183.3946 v4:158.1785 v5:184.7830 v6:169.7967 v7:189.2633 v8:189.5292 v9:357.9752] eff_attn_scale:[v0:0.2460 v1:0.7284 v2:0.9978 v3:1.0403 v4:1.0704 v5:0.7891 v6:0.8970 v7:0.8417 v8:0.9213 v9:1.5100] eff_attn_bias:[v0:0.8563 v1:1.2153 v2:1.1270 v3:1.2043 v4:1.3479 v5:1.1932 v6:1.1325 v7:1.0662 v8:0.9226 v9:1.2374] eff_mlp_bias:[v0:2.3754 v1:1.4253 v2:1.3590 v3:1.3148 v4:1.4032 v5:1.3590 v6:1.1767 v7:1.2153 v8:0.9888 v9:0.9447] +step:13800/20000 val_loss:2.1306 val_bpb:1.2618 train_time:586229ms step_avg:42.48ms +step:14000/20000 train_loss:2.1271 train_time:594686ms step_avg:42.48ms +step:14000 shared0_alpha:mean=0.477,std=0.085 shared1_alpha:mean=0.592,std=0.074 shared2_alpha:mean=0.599,std=0.065 shared3_alpha:mean=0.594,std=0.066 eff_mlp_scale:[v0:213.2602 v1:157.2287 v2:176.0651 v3:183.5446 v4:158.3914 v5:184.8127 v6:169.8019 v7:189.4180 v8:189.7843 v9:358.9992] eff_attn_scale:[v0:0.2431 v1:0.7220 v2:0.9900 v3:1.0454 v4:1.0810 v5:0.7825 v6:0.8894 v7:0.8458 v8:0.9258 v9:1.5127] eff_attn_bias:[v0:0.8563 v1:1.2153 v2:1.1325 v3:1.2043 v4:1.3479 v5:1.1932 v6:1.1325 v7:1.0662 v8:0.9226 v9:1.2374] eff_mlp_bias:[v0:2.3754 v1:1.4253 v2:1.3590 v3:1.3148 v4:1.4032 v5:1.3645 v6:1.1767 v7:1.2153 v8:0.9888 v9:0.9447] +step:14000/20000 val_loss:2.1234 val_bpb:1.2576 train_time:594701ms step_avg:42.48ms +step:14125/20000 val_loss:2.1198 val_bpb:1.2555 train_time:600019ms step_avg:42.48ms +stopping_early: wallclock_cap train_time:600019ms step:14125/20000 +peak memory allocated: 9924 MiB reserved: 10358 MiB +Serialized model: 45191996 bytes +Code size: 66161 bytes +Total submission size: 45258157 bytes +Serialized model int8+zlib: 10733408 bytes (payload:11651264 raw_torch:11683059 payload_ratio:3.88x) +Total submission size int8+zlib: 10799569 bytes +final_int8_zlib_roundtrip val_loss:2.1344 val_bpb:1.2641 eval_time:1343ms +final_int8_zlib_roundtrip_exact val_loss:2.13444568 val_bpb:1.26413896 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_O.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_O.txt new file mode 100644 index 0000000000..5a434099e4 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s3_O.txt @@ -0,0 +1,1771 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) + share_attn_only = bool(int(os.environ.get("SHARE_ATTN_ONLY", "0"))) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0, use_bias: bool = False): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + if use_bias: + self.attn_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + self.mlp_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + else: + self.attn_beta = None + self.mlp_beta = None + + def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + ab = self.attn_beta[v] if self.attn_beta is not None else None + mb = self.mlp_beta[v] if self.mlp_beta is not None else None + return ag, mg, ab, mb + + +class SharedAttnLayer(nn.Module): + """Shared attention layer (mixing + attention only, no MLP) for attn-only sharing mode.""" + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_birkhoff_mix: bool = False, + ): + super().__init__() + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + return x + + +class UniqueMLP(nn.Module): + """Unique MLP per virtual shared position for attn-only sharing mode.""" + def __init__( + self, + dim: int, + mlp_mult: int, + use_peri_norm: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + if use_peri_norm: + self.mlp_out_norm = RMSNorm() + else: + self.mlp_norm = RMSNorm() + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: Tensor, + ts_mlp_gamma: Tensor | None = None, + ts_mlp_beta: Tensor | None = None) -> Tensor: + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + out = mlp_s * mlp_out + if ts_mlp_beta is not None: + out = out + ts_mlp_beta[None, None, :] + return out + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None, + ts_mlp_beta: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + if ts_mlp_beta is not None: + x = x + ts_mlp_beta[None, None, :] + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + use_timestep_bias: bool = False, + share_attn_only: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + self.share_attn_only = share_attn_only if self.use_recurrence else False + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + if self.share_attn_only: + shared_attn_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + rope_base=rope_base, qk_gain_init=qk_gain_init, + use_birkhoff_mix=use_birkhoff_mix, + ) + unique_mlp_kwargs = dict( + dim=model_dim, mlp_mult=mlp_mult, + use_peri_norm=use_peri_norm, + leaky_relu_slope=leaky_relu_slope, + ) + self.shared_attn_layers = nn.ModuleList([SharedAttnLayer(**shared_attn_kwargs) for _ in range(num_shared)]) + self.unique_mlps = nn.ModuleList([UniqueMLP(**unique_mlp_kwargs) for _ in range(num_shared * self.num_loops)]) + self.shared_blocks = nn.ModuleList() # empty — keeps diagnostics safe + else: + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.shared_attn_layers = nn.ModuleList() + self.unique_mlps = nn.ModuleList() + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None, None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + if self.share_attn_only: + vid = 0 + for _loop in range(self.num_loops): + for attn_layer in self.shared_attn_layers: + ag, mg, ab, mb = self._get_ts(v) + x = attn_layer(x, x0, ag, ab) + x = x + self.unique_mlps[vid](x, mg, mb) + vid += 1 + v += 1 + else: + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + for block in self.coda_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block/layer + if gpt.share_attn_only: + for i, layer in enumerate(gpt.shared_attn_layers): + if hasattr(layer, "resid_mix_logit"): + a = torch.sigmoid(layer.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + else: + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + effective_count = gpt.num_prelude + len(gpt.shared_blocks if not gpt.share_attn_only else gpt.shared_attn_layers) * gpt.num_loops + gpt.num_coda + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + + # Prelude blocks + for block in gpt.prelude_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Shared positions + if gpt.share_attn_only: + vid = 0 + for _loop in range(gpt.num_loops): + for layer in gpt.shared_attn_layers: + asc = layer.attn_scale.norm().item() + ms = gpt.unique_mlps[vid].mlp_scale.norm().item() + d = layer.attn_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + vid += 1 + v += 1 + else: + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Coda blocks + for block in gpt.coda_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + if gpt.timestep_scale is not None and gpt.timestep_scale.attn_beta is not None: + attn_bias_norms: list[str] = [] + mlp_bias_norms: list[str] = [] + for vi in range(effective_count): + ab_rms = gpt.timestep_scale.attn_beta[vi].norm().item() / gpt.timestep_scale.attn_beta[vi].numel() ** 0.5 + mb_rms = gpt.timestep_scale.mlp_beta[vi].norm().item() / gpt.timestep_scale.mlp_beta[vi].numel() ** 0.5 + attn_bias_norms.append(f"v{vi}:{ab_rms:.4f}") + mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") + parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") + parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + use_timestep_bias=args.use_timestep_bias, + share_attn_only=args.share_attn_only, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + if base_model.share_attn_only: + block_named_params.extend(base_model.shared_attn_layers.named_parameters()) + block_named_params.extend(base_model.unique_mlps.named_parameters()) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + num_shared = len(base_model.shared_attn_layers) if base_model.share_attn_only else len(base_model.shared_blocks) + eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda + shared_label = f"shared_attn:{num_shared}" if base_model.share_attn_only else f"shared:{num_shared}" + log0(f"recurrence:enabled prelude:{base_model.num_prelude} {shared_label} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Mon Mar 30 15:57:05 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 44C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 36C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 35C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 43C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 44C P0 124W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 35C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 42C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11572272 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:3 coda:1 effective_layers:14 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:28672 +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9377 train_time:36ms step_avg:36.42ms +step:2/20000 train_loss:9.3717 train_time:90ms step_avg:45.17ms +step:3/20000 train_loss:8.1178 train_time:147ms step_avg:48.92ms +step:4/20000 train_loss:9.5936 train_time:203ms step_avg:50.84ms +step:5/20000 train_loss:8.8191 train_time:265ms step_avg:52.99ms +step:6/20000 train_loss:8.3817 train_time:320ms step_avg:53.37ms +step:7/20000 train_loss:7.4714 train_time:374ms step_avg:53.49ms +step:8/20000 train_loss:6.6967 train_time:435ms step_avg:54.36ms +step:9/20000 train_loss:6.0114 train_time:492ms step_avg:54.69ms +step:10/20000 train_loss:5.6634 train_time:549ms step_avg:54.89ms +step:200/20000 train_loss:2.7813 train_time:11531ms step_avg:57.66ms +step:200 shared0_alpha:mean=0.457,std=0.048 shared1_alpha:mean=0.470,std=0.040 shared2_alpha:mean=0.485,std=0.037 shared3_alpha:mean=0.510,std=0.040 eff_mlp_scale:[v0:35.3722 v1:27.0276 v2:28.3051 v3:30.3633 v4:32.1912 v5:30.9510 v6:28.5910 v7:29.4703 v8:31.2535 v9:30.3697 v10:29.4487 v11:32.0006 v12:34.8477 v13:59.3214] eff_attn_scale:[v0:15.1524 v1:10.5036 v2:11.4879 v3:9.7114 v4:10.1986 v5:9.6956 v6:9.8812 v7:8.4595 v8:9.4758 v9:10.0188 v10:9.6804 v11:8.9148 v12:10.1986 v13:15.3478] eff_attn_bias:[v0:0.1595 v1:0.1202 v2:0.1181 v3:0.1229 v4:0.1347 v5:0.1312 v6:0.1340 v7:0.1264 v8:0.1381 v9:0.1333 v10:0.1319 v11:0.1181 v12:0.1132 v13:0.1167] eff_mlp_bias:[v0:0.1257 v1:0.1132 v2:0.1132 v3:0.1208 v4:0.1229 v5:0.1243 v6:0.1167 v7:0.1181 v8:0.1277 v9:0.1188 v10:0.1112 v11:0.1001 v12:0.1195 v13:0.2003] +step:200/20000 val_loss:2.7669 val_bpb:1.6387 train_time:11595ms step_avg:57.98ms +step:400/20000 train_loss:2.3535 train_time:23196ms step_avg:57.99ms +step:400 shared0_alpha:mean=0.466,std=0.049 shared1_alpha:mean=0.485,std=0.043 shared2_alpha:mean=0.511,std=0.039 shared3_alpha:mean=0.544,std=0.043 eff_mlp_scale:[v0:43.2181 v1:32.4929 v2:35.7139 v3:39.2530 v4:39.3779 v5:43.0994 v6:38.6360 v7:39.2530 v8:38.0430 v9:40.0690 v10:36.8503 v11:38.5848 v12:38.2099 v13:74.8537] eff_attn_scale:[v0:6.5913 v1:5.6074 v2:5.9730 v3:5.3355 v4:5.6246 v5:5.7519 v6:5.6442 v7:4.9507 v8:5.3835 v9:5.8675 v10:5.0414 v11:4.4634 v12:4.9818 v13:9.3774] eff_attn_bias:[v0:0.2044 v1:0.1568 v2:0.1685 v3:0.1782 v4:0.2030 v5:0.1864 v6:0.1961 v7:0.1768 v8:0.1947 v9:0.1678 v10:0.1685 v11:0.1450 v12:0.1374 v13:0.1340] eff_mlp_bias:[v0:0.2486 v1:0.1568 v2:0.1478 v3:0.1616 v4:0.1726 v5:0.1678 v6:0.1450 v7:0.1443 v8:0.1609 v9:0.1547 v10:0.1443 v11:0.1381 v12:0.1471 v13:0.2721] +step:400/20000 val_loss:2.5627 val_bpb:1.5178 train_time:23222ms step_avg:58.05ms +step:600/20000 train_loss:2.5734 train_time:34841ms step_avg:58.07ms +step:600 shared0_alpha:mean=0.470,std=0.049 shared1_alpha:mean=0.492,std=0.045 shared2_alpha:mean=0.529,std=0.039 shared3_alpha:mean=0.569,std=0.044 eff_mlp_scale:[v0:49.4292 v1:36.4139 v2:40.2403 v3:44.1479 v4:42.9119 v5:50.4748 v6:44.8636 v7:45.3791 v8:42.7396 v9:45.7878 v10:40.0690 v11:41.8613 v12:40.1546 v13:89.9028] eff_attn_scale:[v0:3.1026 v1:3.1592 v2:3.4962 v3:3.1990 v4:3.4655 v5:3.4022 v6:3.3997 v7:3.0725 v8:3.4655 v9:3.4832 v10:2.9554 v11:2.5664 v12:2.9815 v13:6.2330] eff_attn_bias:[v0:0.2679 v1:0.2141 v2:0.2224 v3:0.2362 v4:0.2721 v5:0.2555 v6:0.2652 v7:0.2389 v8:0.2596 v9:0.2168 v10:0.2099 v11:0.1733 v12:0.1595 v13:0.1650] eff_mlp_bias:[v0:0.3895 v1:0.2154 v2:0.2016 v3:0.2085 v4:0.2362 v5:0.2265 v6:0.1864 v7:0.1823 v8:0.2154 v9:0.2072 v10:0.1878 v11:0.1864 v12:0.1823 v13:0.2942] +step:600/20000 val_loss:2.4809 val_bpb:1.4694 train_time:34864ms step_avg:58.11ms +step:800/20000 train_loss:2.3380 train_time:46506ms step_avg:58.13ms +step:800 shared0_alpha:mean=0.472,std=0.049 shared1_alpha:mean=0.496,std=0.047 shared2_alpha:mean=0.541,std=0.041 shared3_alpha:mean=0.585,std=0.046 eff_mlp_scale:[v0:55.0148 v1:40.1024 v2:43.7408 v3:48.2729 v4:45.5405 v5:56.3704 v6:49.2305 v7:50.0876 v8:45.8935 v9:50.3172 v10:42.5012 v11:43.9174 v12:42.0102 v13:102.6170] eff_attn_scale:[v0:1.9625 v1:2.2375 v2:2.4457 v3:2.2949 v4:2.5156 v5:2.4306 v6:2.4457 v7:2.2372 v8:2.5781 v9:2.4789 v10:2.0636 v11:1.7825 v12:2.1093 v13:4.7665] eff_attn_bias:[v0:0.3190 v1:0.2679 v2:0.2624 v3:0.2776 v4:0.3246 v5:0.3121 v6:0.3259 v7:0.2914 v8:0.3121 v9:0.2638 v10:0.2472 v11:0.2003 v12:0.1851 v13:0.2016] eff_mlp_bias:[v0:0.5138 v1:0.2665 v2:0.2486 v3:0.2486 v4:0.2817 v5:0.2845 v6:0.2293 v7:0.2154 v8:0.2665 v9:0.2555 v10:0.2265 v11:0.2293 v12:0.2154 v13:0.3080] +step:800/20000 val_loss:2.4221 val_bpb:1.4345 train_time:46530ms step_avg:58.16ms +step:1000/20000 train_loss:2.4156 train_time:58180ms step_avg:58.18ms +step:1000 shared0_alpha:mean=0.472,std=0.050 shared1_alpha:mean=0.499,std=0.050 shared2_alpha:mean=0.549,std=0.041 shared3_alpha:mean=0.597,std=0.046 eff_mlp_scale:[v0:61.1837 v1:43.7341 v2:46.6425 v3:51.2242 v4:47.7570 v5:61.1885 v6:53.2016 v7:53.4514 v8:48.4806 v9:54.1283 v10:44.8205 v11:45.8420 v12:44.3200 v13:113.1708] eff_attn_scale:[v0:1.4818 v1:1.7766 v2:1.9695 v3:1.8842 v4:2.1097 v5:1.9724 v6:1.9561 v7:1.8206 v8:2.1373 v9:2.0004 v10:1.6479 v11:1.4323 v12:1.7029 v13:3.9794] eff_attn_bias:[v0:0.3522 v1:0.3094 v2:0.3052 v3:0.3176 v4:0.3674 v5:0.3618 v6:0.3729 v7:0.3356 v8:0.3618 v9:0.3066 v10:0.2817 v11:0.2251 v12:0.2099 v13:0.2403] eff_mlp_bias:[v0:0.6187 v1:0.3149 v2:0.2928 v3:0.2845 v4:0.3204 v5:0.3342 v6:0.2665 v7:0.2486 v8:0.3121 v9:0.3011 v10:0.2638 v11:0.2652 v12:0.2444 v13:0.3218] +step:1000/20000 val_loss:2.3809 val_bpb:1.4101 train_time:58203ms step_avg:58.20ms +step:1200/20000 train_loss:2.4352 train_time:69943ms step_avg:58.29ms +step:1200 shared0_alpha:mean=0.472,std=0.051 shared1_alpha:mean=0.500,std=0.051 shared2_alpha:mean=0.555,std=0.042 shared3_alpha:mean=0.605,std=0.048 eff_mlp_scale:[v0:66.3767 v1:46.7991 v2:49.4562 v3:54.6259 v4:49.8861 v5:64.9539 v6:56.1495 v7:56.5226 v8:51.3642 v9:57.6919 v10:46.4813 v11:47.7976 v12:46.3756 v13:122.7304] eff_attn_scale:[v0:1.2074 v1:1.5691 v2:1.7080 v3:1.5976 v4:1.7973 v5:1.6977 v6:1.6832 v7:1.5626 v8:1.8350 v9:1.7234 v10:1.4109 v11:1.2128 v12:1.4391 v13:3.4414] eff_attn_bias:[v0:0.3839 v1:0.3494 v2:0.3356 v3:0.3494 v4:0.4088 v5:0.4060 v6:0.4171 v7:0.3729 v8:0.4005 v9:0.3425 v10:0.3135 v11:0.2472 v12:0.2348 v13:0.2748] eff_mlp_bias:[v0:0.7043 v1:0.3563 v2:0.3315 v3:0.3190 v4:0.3563 v5:0.3812 v6:0.2997 v7:0.2790 v8:0.3494 v9:0.3411 v10:0.2928 v11:0.2955 v12:0.2707 v13:0.3301] +step:1200/20000 val_loss:2.3528 val_bpb:1.3935 train_time:69973ms step_avg:58.31ms +step:1400/20000 train_loss:2.4801 train_time:81618ms step_avg:58.30ms +step:1400 shared0_alpha:mean=0.472,std=0.052 shared1_alpha:mean=0.502,std=0.052 shared2_alpha:mean=0.560,std=0.043 shared3_alpha:mean=0.612,std=0.049 eff_mlp_scale:[v0:72.2830 v1:49.7547 v2:52.4096 v3:57.0797 v4:52.0364 v5:68.1289 v6:59.2457 v7:59.0081 v8:53.5447 v9:60.2837 v10:48.4219 v11:49.3662 v12:48.6427 v13:131.3614] eff_attn_scale:[v0:1.0232 v1:1.4309 v2:1.4908 v3:1.4505 v4:1.6296 v5:1.5097 v6:1.4621 v7:1.4115 v8:1.6415 v9:1.5218 v10:1.2213 v11:1.0935 v12:1.2906 v13:3.1542] eff_attn_bias:[v0:0.4005 v1:0.3812 v2:0.3618 v3:0.3784 v4:0.4419 v5:0.4419 v6:0.4530 v7:0.4060 v8:0.4364 v9:0.3729 v10:0.3411 v11:0.2665 v12:0.2555 v13:0.3107] eff_mlp_bias:[v0:0.7844 v1:0.3895 v2:0.3674 v3:0.3522 v4:0.3839 v5:0.4226 v6:0.3328 v7:0.3080 v8:0.3839 v9:0.3757 v10:0.3218 v11:0.3218 v12:0.2942 v13:0.3411] +step:1400/20000 val_loss:2.3323 val_bpb:1.3813 train_time:81644ms step_avg:58.32ms +step:1600/20000 train_loss:2.1539 train_time:93293ms step_avg:58.31ms +step:1600 shared0_alpha:mean=0.472,std=0.052 shared1_alpha:mean=0.502,std=0.054 shared2_alpha:mean=0.564,std=0.043 shared3_alpha:mean=0.616,std=0.049 eff_mlp_scale:[v0:77.2068 v1:52.6197 v2:54.6162 v3:60.0016 v4:54.0761 v5:71.1419 v6:61.1547 v7:61.5703 v8:55.6101 v9:62.7227 v10:49.6161 v11:50.9818 v12:50.6244 v13:137.8789] eff_attn_scale:[v0:0.9117 v1:1.3274 v2:1.3623 v3:1.3267 v4:1.4827 v5:1.3918 v6:1.3184 v7:1.2947 v8:1.4941 v9:1.4093 v10:1.1041 v11:0.9803 v12:1.1602 v13:2.9223] eff_attn_bias:[v0:0.4198 v1:0.4143 v2:0.3867 v3:0.4005 v4:0.4696 v5:0.4751 v6:0.4861 v7:0.4364 v8:0.4640 v9:0.4005 v10:0.3591 v11:0.2831 v12:0.2721 v13:0.3466] eff_mlp_bias:[v0:0.8673 v1:0.4281 v2:0.4005 v3:0.3784 v4:0.4143 v5:0.4613 v6:0.3618 v7:0.3342 v8:0.4116 v9:0.4088 v10:0.3439 v11:0.3453 v12:0.3163 v13:0.3536] +step:1600/20000 val_loss:2.3174 val_bpb:1.3725 train_time:93316ms step_avg:58.32ms +step:1800/20000 train_loss:2.2514 train_time:104964ms step_avg:58.31ms +step:1800 shared0_alpha:mean=0.471,std=0.053 shared1_alpha:mean=0.503,std=0.055 shared2_alpha:mean=0.568,std=0.044 shared3_alpha:mean=0.620,std=0.050 eff_mlp_scale:[v0:82.5675 v1:55.2673 v2:57.4786 v3:62.0375 v4:56.2972 v5:74.1181 v6:63.3438 v7:63.6282 v8:57.8610 v9:65.1211 v10:51.6134 v11:52.4933 v12:52.7786 v13:145.2590] eff_attn_scale:[v0:0.8132 v1:1.2234 v2:1.2578 v3:1.2500 v4:1.4027 v5:1.2907 v6:1.2102 v7:1.1934 v8:1.3917 v9:1.2851 v10:0.9988 v11:0.9105 v12:1.0684 v13:2.7214] eff_attn_bias:[v0:0.4281 v1:0.4447 v2:0.4088 v3:0.4254 v4:0.4944 v5:0.4999 v6:0.5138 v7:0.4613 v8:0.4972 v9:0.4281 v10:0.3757 v11:0.2997 v12:0.2914 v13:0.3784] eff_mlp_bias:[v0:0.9391 v1:0.4640 v2:0.4281 v3:0.4088 v4:0.4419 v5:0.4999 v6:0.3922 v7:0.3591 v8:0.4419 v9:0.4447 v10:0.3646 v11:0.3674 v12:0.3397 v13:0.3674] +step:1800/20000 val_loss:2.3022 val_bpb:1.3635 train_time:104988ms step_avg:58.33ms +step:2000/20000 train_loss:2.3080 train_time:116622ms step_avg:58.31ms +step:2000 shared0_alpha:mean=0.471,std=0.053 shared1_alpha:mean=0.504,std=0.056 shared2_alpha:mean=0.571,std=0.044 shared3_alpha:mean=0.623,std=0.051 eff_mlp_scale:[v0:86.8814 v1:58.4550 v2:59.8295 v3:64.5323 v4:58.3652 v5:76.7768 v6:65.3766 v7:66.1456 v8:59.9533 v9:68.0521 v10:53.0937 v11:54.4491 v12:54.7918 v13:151.7888] eff_attn_scale:[v0:0.7393 v1:1.1545 v2:1.1705 v3:1.1791 v4:1.3250 v5:1.2090 v6:1.1346 v7:1.1341 v8:1.3037 v9:1.2035 v10:0.9292 v11:0.8543 v12:1.0004 v13:2.5772] eff_attn_bias:[v0:0.4364 v1:0.4751 v2:0.4309 v3:0.4530 v4:0.5220 v5:0.5303 v6:0.5359 v7:0.4861 v8:0.5248 v9:0.4502 v10:0.3950 v11:0.3149 v12:0.3094 v13:0.4088] eff_mlp_bias:[v0:1.0054 v1:0.5027 v2:0.4558 v3:0.4364 v4:0.4640 v5:0.5359 v6:0.4143 v7:0.3839 v8:0.4668 v9:0.4751 v10:0.3867 v11:0.3895 v12:0.3591 v13:0.3812] +step:2000/20000 val_loss:2.2873 val_bpb:1.3547 train_time:116650ms step_avg:58.33ms +step:2200/20000 train_loss:2.1336 train_time:128285ms step_avg:58.31ms +step:2200 shared0_alpha:mean=0.471,std=0.054 shared1_alpha:mean=0.505,std=0.057 shared2_alpha:mean=0.574,std=0.044 shared3_alpha:mean=0.626,std=0.051 eff_mlp_scale:[v0:91.7298 v1:60.6538 v2:61.6989 v3:66.7003 v4:60.5595 v5:79.6911 v6:67.3079 v7:67.9279 v8:61.7706 v9:70.3938 v10:54.0867 v11:56.0610 v12:56.9259 v13:158.3064] eff_attn_scale:[v0:0.6852 v1:1.1173 v2:1.1291 v3:1.1335 v4:1.2591 v5:1.1764 v6:1.0839 v7:1.0842 v8:1.2434 v9:1.1442 v10:0.8782 v11:0.8181 v12:0.9469 v13:2.4655] eff_attn_bias:[v0:0.4447 v1:0.5027 v2:0.4502 v3:0.4723 v4:0.5497 v5:0.5552 v6:0.5580 v7:0.5055 v8:0.5469 v9:0.4751 v10:0.4088 v11:0.3287 v12:0.3246 v13:0.4392] eff_mlp_bias:[v0:1.0717 v1:0.5441 v2:0.4834 v3:0.4613 v4:0.4861 v5:0.5745 v6:0.4392 v7:0.4033 v8:0.4889 v9:0.5055 v10:0.4060 v11:0.4088 v12:0.3757 v13:0.3950] +step:2200/20000 val_loss:2.2794 val_bpb:1.3500 train_time:128309ms step_avg:58.32ms +step:2400/20000 train_loss:2.2557 train_time:139928ms step_avg:58.30ms +step:2400 shared0_alpha:mean=0.471,std=0.055 shared1_alpha:mean=0.505,std=0.058 shared2_alpha:mean=0.576,std=0.045 shared3_alpha:mean=0.629,std=0.052 eff_mlp_scale:[v0:95.9495 v1:62.7890 v2:64.1400 v3:69.2629 v4:62.2709 v5:82.0742 v6:69.4174 v7:70.5071 v8:63.9096 v9:72.2074 v10:55.6151 v11:57.6499 v12:58.9935 v13:164.0564] eff_attn_scale:[v0:0.6314 v1:1.0753 v2:1.0769 v3:1.0942 v4:1.2235 v5:1.1067 v6:1.0185 v7:1.0316 v8:1.1827 v9:1.0963 v10:0.8235 v11:0.7809 v12:0.9074 v13:2.4041] eff_attn_bias:[v0:0.4558 v1:0.5331 v2:0.4640 v3:0.4944 v4:0.5690 v5:0.5828 v6:0.5773 v7:0.5276 v8:0.5690 v9:0.4944 v10:0.4198 v11:0.3425 v12:0.3384 v13:0.4668] eff_mlp_bias:[v0:1.1270 v1:0.5800 v2:0.5055 v3:0.4861 v4:0.5082 v5:0.6104 v6:0.4613 v7:0.4254 v8:0.5138 v9:0.5359 v10:0.4198 v11:0.4254 v12:0.3895 v13:0.4116] +step:2400/20000 val_loss:2.2681 val_bpb:1.3433 train_time:139953ms step_avg:58.31ms +step:2600/20000 train_loss:2.4696 train_time:151578ms step_avg:58.30ms +step:2600 shared0_alpha:mean=0.471,std=0.055 shared1_alpha:mean=0.507,std=0.059 shared2_alpha:mean=0.579,std=0.046 shared3_alpha:mean=0.631,std=0.053 eff_mlp_scale:[v0:100.9354 v1:65.5263 v2:66.3584 v3:71.3651 v4:64.1200 v5:84.6382 v6:71.7166 v7:72.2046 v8:66.2019 v9:74.6272 v10:57.2908 v11:59.1910 v12:61.2055 v13:170.1495] eff_attn_scale:[v0:0.6169 v1:1.0437 v2:1.0567 v3:1.0594 v4:1.1669 v5:1.0592 v6:0.9707 v7:0.9884 v8:1.1518 v9:1.0489 v10:0.7842 v11:0.7472 v12:0.8802 v13:2.2889] eff_attn_bias:[v0:0.4585 v1:0.5580 v2:0.4751 v3:0.5110 v4:0.5939 v5:0.6077 v6:0.5939 v7:0.5469 v8:0.5939 v9:0.5193 v10:0.4337 v11:0.3536 v12:0.3536 v13:0.4944] eff_mlp_bias:[v0:1.1767 v1:0.6187 v2:0.5331 v3:0.5138 v4:0.5276 v5:0.6463 v6:0.4806 v7:0.4475 v8:0.5359 v9:0.5662 v10:0.4364 v11:0.4419 v12:0.4060 v13:0.4281] +step:2600/20000 val_loss:2.2826 val_bpb:1.3519 train_time:151602ms step_avg:58.31ms +step:2800/20000 train_loss:2.2895 train_time:163223ms step_avg:58.29ms +step:2800 shared0_alpha:mean=0.470,std=0.055 shared1_alpha:mean=0.508,std=0.059 shared2_alpha:mean=0.580,std=0.045 shared3_alpha:mean=0.632,std=0.053 eff_mlp_scale:[v0:104.8148 v1:68.0631 v2:68.5914 v3:74.0276 v4:66.1570 v5:86.9185 v6:73.5798 v7:74.4531 v8:67.8425 v9:75.8812 v10:58.6145 v11:60.8388 v12:62.7859 v13:175.6817] eff_attn_scale:[v0:0.5619 v1:1.0170 v2:1.0163 v3:1.0339 v4:1.1477 v5:1.0425 v6:0.9414 v7:0.9637 v8:1.1131 v9:1.0170 v10:0.7541 v11:0.7298 v12:0.8410 v13:2.2070] eff_attn_bias:[v0:0.4668 v1:0.5856 v2:0.4944 v3:0.5276 v4:0.6132 v5:0.6298 v6:0.6104 v7:0.5635 v8:0.6132 v9:0.5359 v10:0.4475 v11:0.3646 v12:0.3646 v13:0.5220] eff_mlp_bias:[v0:1.2374 v1:0.6574 v2:0.5607 v3:0.5331 v4:0.5469 v5:0.6795 v6:0.4972 v7:0.4668 v8:0.5524 v9:0.5939 v10:0.4530 v11:0.4640 v12:0.4226 v13:0.4419] +step:2800/20000 val_loss:2.2552 val_bpb:1.3357 train_time:163247ms step_avg:58.30ms +step:3000/20000 train_loss:2.2763 train_time:174869ms step_avg:58.29ms +step:3000 shared0_alpha:mean=0.469,std=0.055 shared1_alpha:mean=0.508,std=0.060 shared2_alpha:mean=0.582,std=0.046 shared3_alpha:mean=0.634,std=0.053 eff_mlp_scale:[v0:109.7113 v1:70.3661 v2:70.4094 v3:76.1981 v4:67.8648 v5:89.0061 v6:75.0195 v7:76.1981 v8:69.5721 v9:78.2881 v10:59.5127 v11:61.9917 v12:64.8770 v13:180.8476] eff_attn_scale:[v0:0.5309 v1:0.9894 v2:0.9810 v3:1.0080 v4:1.1000 v5:1.0094 v6:0.9030 v7:0.9525 v8:1.0709 v9:0.9794 v10:0.7197 v11:0.7074 v12:0.8093 v13:2.1481] eff_attn_bias:[v0:0.4723 v1:0.6104 v2:0.5055 v3:0.5469 v4:0.6381 v5:0.6463 v6:0.6242 v7:0.5800 v8:0.6353 v9:0.5580 v10:0.4585 v11:0.3757 v12:0.3784 v13:0.5497] eff_mlp_bias:[v0:1.2872 v1:0.6905 v2:0.5828 v3:0.5607 v4:0.5690 v5:0.7126 v6:0.5193 v7:0.4861 v8:0.5718 v9:0.6215 v10:0.4668 v11:0.4806 v12:0.4364 v13:0.4613] +step:3000/20000 val_loss:2.2472 val_bpb:1.3309 train_time:174891ms step_avg:58.30ms +step:3200/20000 train_loss:2.2420 train_time:186507ms step_avg:58.28ms +step:3200 shared0_alpha:mean=0.468,std=0.055 shared1_alpha:mean=0.509,std=0.061 shared2_alpha:mean=0.583,std=0.046 shared3_alpha:mean=0.636,std=0.054 eff_mlp_scale:[v0:113.5696 v1:72.0965 v2:72.5024 v3:78.3906 v4:69.1935 v5:91.4165 v6:76.7424 v7:77.9551 v8:71.3558 v9:80.1072 v10:61.0547 v11:63.5835 v12:66.5987 v13:185.1360] eff_attn_scale:[v0:0.5115 v1:0.9724 v2:0.9685 v3:0.9899 v4:1.0614 v5:0.9724 v6:0.8776 v7:0.9214 v8:1.0235 v9:0.9528 v10:0.7003 v11:0.6934 v12:0.7819 v13:2.1071] eff_attn_bias:[v0:0.4751 v1:0.6353 v2:0.5165 v3:0.5607 v4:0.6519 v5:0.6684 v6:0.6408 v7:0.5966 v8:0.6546 v9:0.5718 v10:0.4696 v11:0.3839 v12:0.3895 v13:0.5718] eff_mlp_bias:[v0:1.3479 v1:0.7292 v2:0.5994 v3:0.5828 v4:0.5883 v5:0.7403 v6:0.5386 v7:0.5027 v8:0.5939 v9:0.6463 v10:0.4806 v11:0.4972 v12:0.4475 v13:0.4751] +step:3200/20000 val_loss:2.2433 val_bpb:1.3286 train_time:186530ms step_avg:58.29ms +step:3400/20000 train_loss:2.2098 train_time:198149ms step_avg:58.28ms +step:3400 shared0_alpha:mean=0.468,std=0.056 shared1_alpha:mean=0.509,std=0.062 shared2_alpha:mean=0.584,std=0.046 shared3_alpha:mean=0.637,std=0.055 eff_mlp_scale:[v0:117.8257 v1:74.7840 v2:74.6499 v3:80.5769 v4:71.5085 v5:93.3610 v6:78.9401 v7:80.1366 v8:73.2633 v9:81.9290 v10:62.6373 v11:65.1660 v12:68.8763 v13:190.3966] eff_attn_scale:[v0:0.4906 v1:0.9619 v2:0.9355 v3:0.9632 v4:1.0553 v5:0.9619 v6:0.8513 v7:0.8915 v8:1.0130 v9:0.9280 v10:0.6739 v11:0.6765 v12:0.7739 v13:2.0351] eff_attn_bias:[v0:0.4751 v1:0.6602 v2:0.5303 v3:0.5773 v4:0.6740 v5:0.6850 v6:0.6574 v7:0.6132 v8:0.6740 v9:0.5911 v10:0.4778 v11:0.3950 v12:0.4005 v13:0.5966] eff_mlp_bias:[v0:1.3976 v1:0.7623 v2:0.6215 v3:0.6077 v4:0.6077 v5:0.7734 v6:0.5552 v7:0.5220 v8:0.6132 v9:0.6712 v10:0.4944 v11:0.5165 v12:0.4613 v13:0.4917] +step:3400/20000 val_loss:2.2411 val_bpb:1.3273 train_time:198171ms step_avg:58.29ms +step:3600/20000 train_loss:2.1758 train_time:209773ms step_avg:58.27ms +step:3600 shared0_alpha:mean=0.467,std=0.055 shared1_alpha:mean=0.510,std=0.062 shared2_alpha:mean=0.586,std=0.047 shared3_alpha:mean=0.639,std=0.055 eff_mlp_scale:[v0:121.9862 v1:77.3978 v2:76.6493 v3:82.6708 v4:73.3026 v5:95.1849 v6:80.5467 v7:81.7818 v8:75.0796 v9:83.6473 v10:63.6579 v11:66.6700 v12:70.6370 v13:194.5999] eff_attn_scale:[v0:0.4580 v1:0.9376 v2:0.9304 v3:0.9511 v4:1.0396 v5:0.9424 v6:0.8374 v7:0.8842 v8:1.0068 v9:0.9086 v10:0.6557 v11:0.6564 v12:0.7539 v13:2.0003] eff_attn_bias:[v0:0.4834 v1:0.6822 v2:0.5386 v3:0.5966 v4:0.6933 v5:0.7071 v6:0.6684 v7:0.6270 v8:0.6905 v9:0.6049 v10:0.4889 v11:0.4033 v12:0.4116 v13:0.6187] eff_mlp_bias:[v0:1.4474 v1:0.7955 v2:0.6463 v3:0.6325 v4:0.6270 v5:0.8065 v6:0.5745 v7:0.5386 v8:0.6298 v9:0.6961 v10:0.5082 v11:0.5303 v12:0.4723 v13:0.5082] +step:3600/20000 val_loss:2.2322 val_bpb:1.3220 train_time:209799ms step_avg:58.28ms +step:3800/20000 train_loss:2.2743 train_time:221414ms step_avg:58.27ms +step:3800 shared0_alpha:mean=0.466,std=0.056 shared1_alpha:mean=0.511,std=0.063 shared2_alpha:mean=0.588,std=0.047 shared3_alpha:mean=0.640,std=0.055 eff_mlp_scale:[v0:125.8974 v1:79.2326 v2:78.6812 v3:84.7338 v4:74.9775 v5:97.7040 v6:82.1781 v7:83.3888 v8:76.7734 v9:85.0656 v10:65.1305 v11:67.6974 v12:72.2837 v13:199.0901] eff_attn_scale:[v0:0.4541 v1:0.9260 v2:0.9206 v3:0.9283 v4:1.0213 v5:0.9213 v6:0.8281 v7:0.8750 v8:0.9797 v9:0.8879 v10:0.6475 v11:0.6440 v12:0.7348 v13:1.9766] eff_attn_bias:[v0:0.4861 v1:0.7071 v2:0.5552 v3:0.6104 v4:0.7071 v5:0.7292 v6:0.6822 v7:0.6408 v8:0.7071 v9:0.6187 v10:0.4972 v11:0.4116 v12:0.4254 v13:0.6381] eff_mlp_bias:[v0:1.4916 v1:0.8286 v2:0.6657 v3:0.6519 v4:0.6436 v5:0.8342 v6:0.5911 v7:0.5552 v8:0.6436 v9:0.7182 v10:0.5248 v11:0.5497 v12:0.4861 v13:0.5248] +step:3800/20000 val_loss:2.2279 val_bpb:1.3195 train_time:221439ms step_avg:58.27ms +step:4000/20000 train_loss:2.2130 train_time:233046ms step_avg:58.26ms +step:4000 shared0_alpha:mean=0.465,std=0.056 shared1_alpha:mean=0.511,std=0.064 shared2_alpha:mean=0.589,std=0.048 shared3_alpha:mean=0.641,std=0.055 eff_mlp_scale:[v0:129.6003 v1:81.0736 v2:80.3315 v3:86.6120 v4:77.3645 v5:99.7451 v6:83.8626 v7:85.7050 v8:78.7297 v9:87.4612 v10:66.6487 v11:69.3803 v12:74.1789 v13:204.0163] eff_attn_scale:[v0:0.4408 v1:0.9129 v2:0.8926 v3:0.9195 v4:1.0111 v5:0.9034 v6:0.8103 v7:0.8535 v8:0.9696 v9:0.8656 v10:0.6326 v11:0.6335 v12:0.7295 v13:1.9228] eff_attn_bias:[v0:0.4917 v1:0.7292 v2:0.5662 v3:0.6298 v4:0.7292 v5:0.7458 v6:0.6933 v7:0.6546 v8:0.7237 v9:0.6353 v10:0.5055 v11:0.4198 v12:0.4364 v13:0.6574] eff_mlp_bias:[v0:1.5357 v1:0.8507 v2:0.6905 v3:0.6740 v4:0.6602 v5:0.8563 v6:0.6049 v7:0.5718 v8:0.6629 v9:0.7458 v10:0.5386 v11:0.5635 v12:0.4999 v13:0.5386] +step:4000/20000 val_loss:2.2235 val_bpb:1.3169 train_time:233071ms step_avg:58.27ms +step:4200/20000 train_loss:2.2275 train_time:244731ms step_avg:58.27ms +step:4200 shared0_alpha:mean=0.464,std=0.057 shared1_alpha:mean=0.512,std=0.064 shared2_alpha:mean=0.591,std=0.048 shared3_alpha:mean=0.643,std=0.056 eff_mlp_scale:[v0:133.5793 v1:83.4057 v2:82.3119 v3:88.9017 v4:78.7151 v5:101.7749 v6:85.4264 v7:87.5269 v8:80.0961 v9:89.3633 v10:67.6292 v11:71.0297 v12:76.4135 v13:208.1438] eff_attn_scale:[v0:0.4222 v1:0.8918 v2:0.8931 v3:0.9215 v4:1.0029 v5:0.8824 v6:0.7896 v7:0.8385 v8:0.9484 v9:0.8544 v10:0.6213 v11:0.6245 v12:0.7125 v13:1.9006] eff_attn_bias:[v0:0.4917 v1:0.7513 v2:0.5773 v3:0.6436 v4:0.7458 v5:0.7623 v6:0.7071 v7:0.6712 v8:0.7403 v9:0.6519 v10:0.5138 v11:0.4309 v12:0.4502 v13:0.6767] eff_mlp_bias:[v0:1.5799 v1:0.8839 v2:0.7126 v3:0.6961 v4:0.6767 v5:0.8839 v6:0.6215 v7:0.5883 v8:0.6767 v9:0.7679 v10:0.5497 v11:0.5773 v12:0.5138 v13:0.5552] +step:4200/20000 val_loss:2.2197 val_bpb:1.3146 train_time:244754ms step_avg:58.27ms +step:4400/20000 train_loss:2.1672 train_time:256350ms step_avg:58.26ms +step:4400 shared0_alpha:mean=0.464,std=0.057 shared1_alpha:mean=0.513,std=0.065 shared2_alpha:mean=0.592,std=0.048 shared3_alpha:mean=0.644,std=0.057 eff_mlp_scale:[v0:138.2028 v1:85.1583 v2:84.3040 v3:90.6235 v4:80.9846 v5:103.6927 v6:86.9945 v7:89.2364 v8:81.9154 v9:90.6685 v10:68.6091 v11:72.1289 v12:78.1920 v13:212.3267] eff_attn_scale:[v0:0.4096 v1:0.8973 v2:0.8827 v3:0.9006 v4:0.9955 v5:0.8880 v6:0.7923 v7:0.8266 v8:0.9322 v9:0.8459 v10:0.6157 v11:0.6221 v12:0.7059 v13:1.8692] eff_attn_bias:[v0:0.4972 v1:0.7789 v2:0.5939 v3:0.6574 v4:0.7623 v5:0.7734 v6:0.7182 v7:0.6850 v8:0.7568 v9:0.6684 v10:0.5220 v11:0.4364 v12:0.4613 v13:0.6961] eff_mlp_bias:[v0:1.6241 v1:0.9170 v2:0.7347 v3:0.7126 v4:0.6905 v5:0.9115 v6:0.6353 v7:0.6077 v8:0.6961 v9:0.7900 v10:0.5635 v11:0.5911 v12:0.5248 v13:0.5662] +step:4400/20000 val_loss:2.2203 val_bpb:1.3150 train_time:256373ms step_avg:58.27ms +step:4600/20000 train_loss:2.0329 train_time:267978ms step_avg:58.26ms +step:4600 shared0_alpha:mean=0.464,std=0.057 shared1_alpha:mean=0.513,std=0.065 shared2_alpha:mean=0.593,std=0.049 shared3_alpha:mean=0.644,std=0.057 eff_mlp_scale:[v0:141.6863 v1:88.0264 v2:86.0561 v3:92.8671 v4:82.8886 v5:105.7329 v6:88.7737 v7:91.0004 v8:83.3596 v9:92.5795 v10:70.2037 v11:73.7337 v12:80.0629 v13:216.3305] eff_attn_scale:[v0:0.3943 v1:0.8868 v2:0.8717 v3:0.8974 v4:0.9763 v5:0.8637 v6:0.7696 v7:0.8197 v8:0.9271 v9:0.8268 v10:0.5953 v11:0.6083 v12:0.6897 v13:1.8340] eff_attn_bias:[v0:0.4917 v1:0.8010 v2:0.6021 v3:0.6712 v4:0.7844 v5:0.7844 v6:0.7292 v7:0.6988 v8:0.7679 v9:0.6822 v10:0.5331 v11:0.4475 v12:0.4723 v13:0.7182] eff_mlp_bias:[v0:1.6573 v1:0.9447 v2:0.7568 v3:0.7347 v4:0.7043 v5:0.9336 v6:0.6491 v7:0.6187 v8:0.7126 v9:0.8065 v10:0.5773 v11:0.6104 v12:0.5386 v13:0.5828] +step:4600/20000 val_loss:2.2161 val_bpb:1.3125 train_time:268002ms step_avg:58.26ms +step:4800/20000 train_loss:2.3192 train_time:279599ms step_avg:58.25ms +step:4800 shared0_alpha:mean=0.463,std=0.058 shared1_alpha:mean=0.513,std=0.066 shared2_alpha:mean=0.594,std=0.049 shared3_alpha:mean=0.645,std=0.057 eff_mlp_scale:[v0:146.1482 v1:89.8214 v2:88.1127 v3:95.2277 v4:84.6897 v5:108.1940 v6:90.3955 v7:92.8705 v8:85.1655 v9:94.4146 v10:71.6772 v11:74.9564 v12:81.8350 v13:220.5214] eff_attn_scale:[v0:0.3853 v1:0.8754 v2:0.8745 v3:0.9010 v4:0.9753 v5:0.8435 v6:0.7636 v7:0.8148 v8:0.9216 v9:0.8116 v10:0.6015 v11:0.6079 v12:0.6890 v13:1.8059] eff_attn_bias:[v0:0.4972 v1:0.8176 v2:0.6132 v3:0.6850 v4:0.7955 v5:0.8065 v6:0.7458 v7:0.7071 v8:0.7844 v9:0.6961 v10:0.5386 v11:0.4530 v12:0.4806 v13:0.7347] eff_mlp_bias:[v0:1.7015 v1:0.9723 v2:0.7734 v3:0.7513 v4:0.7182 v5:0.9612 v6:0.6602 v7:0.6325 v8:0.7292 v9:0.8286 v10:0.5883 v11:0.6215 v12:0.5469 v13:0.5966] +step:4800/20000 val_loss:2.2111 val_bpb:1.3095 train_time:279622ms step_avg:58.25ms +step:5000/20000 train_loss:2.0818 train_time:291224ms step_avg:58.24ms +step:5000 shared0_alpha:mean=0.462,std=0.058 shared1_alpha:mean=0.514,std=0.066 shared2_alpha:mean=0.595,std=0.049 shared3_alpha:mean=0.646,std=0.058 eff_mlp_scale:[v0:150.2954 v1:91.7647 v2:89.7718 v3:96.9222 v4:86.5062 v5:110.3238 v6:92.0736 v7:94.0716 v8:86.9868 v9:95.8889 v10:72.7382 v11:76.4925 v12:83.6227 v13:224.0480] eff_attn_scale:[v0:0.3817 v1:0.8782 v2:0.8624 v3:0.9010 v4:0.9747 v5:0.8644 v6:0.7647 v7:0.8105 v8:0.9124 v9:0.8138 v10:0.5905 v11:0.5992 v12:0.6721 v13:1.7829] eff_attn_bias:[v0:0.5027 v1:0.8397 v2:0.6270 v3:0.6988 v4:0.8176 v5:0.8176 v6:0.7513 v7:0.7237 v8:0.8010 v9:0.7126 v10:0.5469 v11:0.4640 v12:0.4889 v13:0.7513] eff_mlp_bias:[v0:1.7346 v1:0.9999 v2:0.7955 v3:0.7734 v4:0.7403 v5:0.9833 v6:0.6767 v7:0.6519 v8:0.7403 v9:0.8452 v10:0.6021 v11:0.6353 v12:0.5607 v13:0.6077] +step:5000/20000 val_loss:2.2067 val_bpb:1.3069 train_time:291248ms step_avg:58.25ms +step:5200/20000 train_loss:2.2266 train_time:302850ms step_avg:58.24ms +step:5200 shared0_alpha:mean=0.462,std=0.058 shared1_alpha:mean=0.515,std=0.066 shared2_alpha:mean=0.596,std=0.049 shared3_alpha:mean=0.647,std=0.058 eff_mlp_scale:[v0:154.6990 v1:94.1810 v2:92.0682 v3:99.3439 v4:87.9857 v5:112.3928 v6:93.4632 v7:95.9844 v8:88.9579 v9:97.8233 v10:73.9336 v11:77.7474 v12:85.5552 v13:228.1575] eff_attn_scale:[v0:0.3698 v1:0.8814 v2:0.8589 v3:0.8960 v4:0.9607 v5:0.8449 v6:0.7536 v7:0.8103 v8:0.8946 v9:0.8038 v10:0.5852 v11:0.5916 v12:0.6786 v13:1.7610] eff_attn_bias:[v0:0.4999 v1:0.8618 v2:0.6353 v3:0.7182 v4:0.8286 v5:0.8342 v6:0.7623 v7:0.7347 v8:0.8121 v9:0.7237 v10:0.5552 v11:0.4723 v12:0.4999 v13:0.7734] eff_mlp_bias:[v0:1.7788 v1:1.0330 v2:0.8176 v3:0.7955 v4:0.7568 v5:1.0054 v6:0.6933 v7:0.6684 v8:0.7568 v9:0.8673 v10:0.6160 v11:0.6519 v12:0.5718 v13:0.6187] +step:5200/20000 val_loss:2.2079 val_bpb:1.3077 train_time:302876ms step_avg:58.25ms +step:5400/20000 train_loss:2.2403 train_time:314473ms step_avg:58.24ms +step:5400 shared0_alpha:mean=0.460,std=0.058 shared1_alpha:mean=0.515,std=0.067 shared2_alpha:mean=0.598,std=0.050 shared3_alpha:mean=0.648,std=0.058 eff_mlp_scale:[v0:157.8568 v1:96.0263 v2:94.3077 v3:101.1196 v4:90.4247 v5:114.3920 v6:95.2461 v7:97.7328 v8:90.9161 v9:99.6995 v10:75.0708 v11:78.8636 v12:87.4760 v13:232.1821] eff_attn_scale:[v0:0.3578 v1:0.8820 v2:0.8590 v3:0.8994 v4:0.9654 v5:0.8409 v6:0.7537 v7:0.8047 v8:0.8990 v9:0.7997 v10:0.5853 v11:0.5939 v12:0.6687 v13:1.7455] eff_attn_bias:[v0:0.5055 v1:0.8839 v2:0.6491 v3:0.7347 v4:0.8507 v5:0.8563 v6:0.7734 v7:0.7458 v8:0.8231 v9:0.7347 v10:0.5635 v11:0.4806 v12:0.5110 v13:0.7844] eff_mlp_bias:[v0:1.8120 v1:1.0607 v2:0.8342 v3:0.8121 v4:0.7734 v5:1.0275 v6:0.7071 v7:0.6822 v8:0.7734 v9:0.8839 v10:0.6270 v11:0.6657 v12:0.5828 v13:0.6325] +step:5400/20000 val_loss:2.2023 val_bpb:1.3043 train_time:314497ms step_avg:58.24ms +step:5600/20000 train_loss:2.2394 train_time:326102ms step_avg:58.23ms +step:5600 shared0_alpha:mean=0.460,std=0.058 shared1_alpha:mean=0.516,std=0.068 shared2_alpha:mean=0.598,std=0.050 shared3_alpha:mean=0.649,std=0.059 eff_mlp_scale:[v0:162.9100 v1:98.4870 v2:95.9774 v3:103.5931 v4:91.8259 v5:115.9605 v6:96.9229 v7:99.6839 v8:92.3222 v9:101.1345 v10:76.1200 v11:80.6267 v12:89.3441 v13:236.1867] eff_attn_scale:[v0:0.3546 v1:0.8712 v2:0.8550 v3:0.8878 v4:0.9622 v5:0.8440 v6:0.7455 v7:0.7982 v8:0.9045 v9:0.7895 v10:0.5812 v11:0.5848 v12:0.6651 v13:1.7124] eff_attn_bias:[v0:0.5082 v1:0.9060 v2:0.6629 v3:0.7458 v4:0.8618 v5:0.8673 v6:0.7844 v7:0.7623 v8:0.8397 v9:0.7513 v10:0.5690 v11:0.4861 v12:0.5193 v13:0.8010] eff_mlp_bias:[v0:1.8562 v1:1.0938 v2:0.8507 v3:0.8286 v4:0.7900 v5:1.0496 v6:0.7237 v7:0.6988 v8:0.7900 v9:0.9005 v10:0.6381 v11:0.6767 v12:0.5966 v13:0.6463] +step:5600/20000 val_loss:2.2029 val_bpb:1.3047 train_time:326122ms step_avg:58.24ms +step:5800/20000 train_loss:2.2014 train_time:337723ms step_avg:58.23ms +step:5800 shared0_alpha:mean=0.459,std=0.058 shared1_alpha:mean=0.516,std=0.068 shared2_alpha:mean=0.599,std=0.051 shared3_alpha:mean=0.650,std=0.059 eff_mlp_scale:[v0:167.3165 v1:101.0337 v2:98.1600 v3:105.4956 v4:93.8135 v5:118.1399 v6:98.6365 v7:102.0448 v8:94.3151 v9:103.1720 v10:77.6703 v11:81.8330 v12:91.3051 v13:240.0715] eff_attn_scale:[v0:0.3462 v1:0.8691 v2:0.8590 v3:0.8740 v4:0.9678 v5:0.8421 v6:0.7459 v7:0.7937 v8:0.8926 v9:0.7880 v10:0.5741 v11:0.5784 v12:0.6673 v13:1.7264] eff_attn_bias:[v0:0.5082 v1:0.9226 v2:0.6712 v3:0.7623 v4:0.8784 v5:0.8784 v6:0.7955 v7:0.7679 v8:0.8507 v9:0.7623 v10:0.5745 v11:0.4917 v12:0.5303 v13:0.8176] eff_mlp_bias:[v0:1.8893 v1:1.1214 v2:0.8673 v3:0.8452 v4:0.8065 v5:1.0662 v6:0.7347 v7:0.7126 v8:0.8010 v9:0.9226 v10:0.6519 v11:0.6905 v12:0.6104 v13:0.6574] +step:5800/20000 val_loss:2.2017 val_bpb:1.3040 train_time:337748ms step_avg:58.23ms +step:6000/20000 train_loss:2.2710 train_time:349346ms step_avg:58.22ms +step:6000 shared0_alpha:mean=0.458,std=0.059 shared1_alpha:mean=0.517,std=0.069 shared2_alpha:mean=0.600,std=0.051 shared3_alpha:mean=0.650,std=0.060 eff_mlp_scale:[v0:171.6976 v1:102.7154 v2:99.8527 v3:107.4257 v4:95.7057 v5:119.9242 v6:100.3327 v7:103.4470 v8:96.2120 v9:104.3287 v10:78.7300 v11:83.5534 v12:93.1738 v13:241.9763] eff_attn_scale:[v0:0.3419 v1:0.8684 v2:0.8307 v3:0.8766 v4:0.9526 v5:0.8234 v6:0.7320 v7:0.7839 v8:0.8908 v9:0.7784 v10:0.5593 v11:0.5816 v12:0.6527 v13:1.7257] eff_attn_bias:[v0:0.5055 v1:0.9447 v2:0.6850 v3:0.7789 v4:0.8949 v5:0.8894 v6:0.8065 v7:0.7844 v8:0.8618 v9:0.7789 v10:0.5828 v11:0.4999 v12:0.5414 v13:0.8342] eff_mlp_bias:[v0:1.9114 v1:1.1435 v2:0.8839 v3:0.8618 v4:0.8231 v5:1.0883 v6:0.7513 v7:0.7292 v8:0.8176 v9:0.9391 v10:0.6629 v11:0.7071 v12:0.6187 v13:0.6657] +step:6000/20000 val_loss:2.1962 val_bpb:1.3007 train_time:349369ms step_avg:58.23ms +step:6200/20000 train_loss:2.1454 train_time:360965ms step_avg:58.22ms +step:6200 shared0_alpha:mean=0.457,std=0.059 shared1_alpha:mean=0.517,std=0.069 shared2_alpha:mean=0.602,std=0.051 shared3_alpha:mean=0.651,std=0.060 eff_mlp_scale:[v0:175.0116 v1:105.2896 v2:101.6530 v3:109.9377 v4:98.0033 v5:122.1142 v6:102.1371 v7:105.4197 v8:97.4929 v9:106.3751 v10:80.3543 v11:84.8377 v12:94.9407 v13:245.9515] eff_attn_scale:[v0:0.3382 v1:0.8550 v2:0.8500 v3:0.8855 v4:0.9613 v5:0.8192 v6:0.7256 v7:0.7923 v8:0.8735 v9:0.7744 v10:0.5597 v11:0.5762 v12:0.6585 v13:1.6425] eff_attn_bias:[v0:0.5082 v1:0.9667 v2:0.6933 v3:0.7900 v4:0.9115 v5:0.9060 v6:0.8176 v7:0.7955 v8:0.8784 v9:0.7900 v10:0.5911 v11:0.5082 v12:0.5497 v13:0.8507] eff_mlp_bias:[v0:1.9445 v1:1.1767 v2:0.9005 v3:0.8784 v4:0.8397 v5:1.1049 v6:0.7679 v7:0.7403 v8:0.8286 v9:0.9557 v10:0.6767 v11:0.7182 v12:0.6298 v13:0.6767] +step:6200/20000 val_loss:2.1965 val_bpb:1.3009 train_time:360988ms step_avg:58.22ms +step:6400/20000 train_loss:2.2175 train_time:372582ms step_avg:58.22ms +step:6400 shared0_alpha:mean=0.456,std=0.058 shared1_alpha:mean=0.518,std=0.069 shared2_alpha:mean=0.603,std=0.051 shared3_alpha:mean=0.652,std=0.060 eff_mlp_scale:[v0:180.4244 v1:107.2598 v2:103.9580 v3:112.3079 v4:100.1625 v5:124.7717 v6:103.4699 v7:107.2490 v8:99.1299 v9:107.8071 v10:81.5070 v11:86.5074 v12:96.5484 v13:249.7608] eff_attn_scale:[v0:0.3334 v1:0.8564 v2:0.8440 v3:0.8799 v4:0.9625 v5:0.8293 v6:0.7205 v7:0.7752 v8:0.8834 v9:0.7752 v10:0.5558 v11:0.5699 v12:0.6592 v13:1.6938] eff_attn_bias:[v0:0.5110 v1:0.9888 v2:0.7043 v3:0.8065 v4:0.9226 v5:0.9226 v6:0.8286 v7:0.8065 v8:0.8894 v9:0.8010 v10:0.5966 v11:0.5138 v12:0.5607 v13:0.8673] eff_mlp_bias:[v0:1.9777 v1:1.1988 v2:0.9170 v3:0.8949 v4:0.8563 v5:1.1270 v6:0.7844 v7:0.7568 v8:0.8452 v9:0.9723 v10:0.6878 v11:0.7292 v12:0.6381 v13:0.6878] +step:6400/20000 val_loss:2.1930 val_bpb:1.2988 train_time:372606ms step_avg:58.22ms +step:6600/20000 train_loss:2.1773 train_time:384207ms step_avg:58.21ms +step:6600 shared0_alpha:mean=0.455,std=0.059 shared1_alpha:mean=0.518,std=0.070 shared2_alpha:mean=0.604,std=0.051 shared3_alpha:mean=0.653,std=0.061 eff_mlp_scale:[v0:184.5248 v1:109.6390 v2:106.1417 v3:114.3012 v4:102.0784 v5:126.1675 v6:105.1589 v7:109.1985 v8:101.5576 v9:109.6390 v10:82.5546 v11:87.7670 v12:98.4328 v13:251.9682] eff_attn_scale:[v0:0.3302 v1:0.8617 v2:0.8455 v3:0.8881 v4:0.9554 v5:0.8123 v6:0.7295 v7:0.7824 v8:0.8852 v9:0.7675 v10:0.5595 v11:0.5794 v12:0.6530 v13:1.6739] eff_attn_bias:[v0:0.5110 v1:1.0109 v2:0.7237 v3:0.8176 v4:0.9391 v5:0.9336 v6:0.8397 v7:0.8176 v8:0.9005 v9:0.8121 v10:0.6077 v11:0.5248 v12:0.5690 v13:0.8839] eff_mlp_bias:[v0:2.0108 v1:1.2319 v2:0.9281 v3:0.9170 v4:0.8673 v5:1.1490 v6:0.7955 v7:0.7679 v8:0.8563 v9:0.9888 v10:0.7016 v11:0.7458 v12:0.6491 v13:0.7016] +step:6600/20000 val_loss:2.1896 val_bpb:1.2968 train_time:384230ms step_avg:58.22ms +step:6800/20000 train_loss:2.2448 train_time:395823ms step_avg:58.21ms +step:6800 shared0_alpha:mean=0.454,std=0.059 shared1_alpha:mean=0.518,std=0.070 shared2_alpha:mean=0.604,std=0.052 shared3_alpha:mean=0.653,std=0.061 eff_mlp_scale:[v0:189.8520 v1:111.0026 v2:107.3284 v3:116.3316 v4:104.1774 v5:128.2080 v6:106.8338 v7:110.6695 v8:103.1251 v9:111.0026 v10:83.5875 v11:89.0503 v12:100.4943 v13:255.8501] eff_attn_scale:[v0:0.3232 v1:0.8452 v2:0.8427 v3:0.8791 v4:0.9610 v5:0.8049 v6:0.7153 v7:0.7740 v8:0.8689 v9:0.7602 v10:0.5467 v11:0.5637 v12:0.6451 v13:1.6261] eff_attn_bias:[v0:0.5082 v1:1.0330 v2:0.7403 v3:0.8397 v4:0.9612 v5:0.9447 v6:0.8507 v7:0.8342 v8:0.9115 v9:0.8286 v10:0.6160 v11:0.5331 v12:0.5800 v13:0.8949] eff_mlp_bias:[v0:2.0440 v1:1.2595 v2:0.9447 v3:0.9336 v4:0.8839 v5:1.1656 v6:0.8121 v7:0.7844 v8:0.8673 v9:1.0054 v10:0.7126 v11:0.7568 v12:0.6602 v13:0.7126] +step:6800/20000 val_loss:2.1883 val_bpb:1.2960 train_time:395848ms step_avg:58.21ms +step:7000/20000 train_loss:2.2819 train_time:407445ms step_avg:58.21ms +step:7000 shared0_alpha:mean=0.453,std=0.059 shared1_alpha:mean=0.519,std=0.071 shared2_alpha:mean=0.605,std=0.052 shared3_alpha:mean=0.654,std=0.061 eff_mlp_scale:[v0:193.2339 v1:112.9361 v2:110.1127 v3:118.3858 v4:106.0918 v5:130.2678 v6:108.1197 v7:112.6742 v8:104.5004 v9:112.3770 v10:84.7021 v11:90.3471 v12:102.3786 v13:259.7267] eff_attn_scale:[v0:0.3151 v1:0.8584 v2:0.8519 v3:0.8886 v4:0.9675 v5:0.8137 v6:0.7161 v7:0.7833 v8:0.8796 v9:0.7556 v10:0.5474 v11:0.5727 v12:0.6509 v13:1.6247] eff_attn_bias:[v0:0.5138 v1:1.0551 v2:0.7458 v3:0.8507 v4:0.9667 v5:0.9612 v6:0.8563 v7:0.8397 v8:0.9226 v9:0.8397 v10:0.6215 v11:0.5414 v12:0.5883 v13:0.9060] eff_mlp_bias:[v0:2.0661 v1:1.2872 v2:0.9612 v3:0.9447 v4:0.9005 v5:1.1767 v6:0.8231 v7:0.8010 v8:0.8839 v9:1.0165 v10:0.7237 v11:0.7679 v12:0.6712 v13:0.7237] +step:7000/20000 val_loss:2.1869 val_bpb:1.2952 train_time:407468ms step_avg:58.21ms +step:7200/20000 train_loss:2.2551 train_time:419064ms step_avg:58.20ms +step:7200 shared0_alpha:mean=0.452,std=0.060 shared1_alpha:mean=0.519,std=0.071 shared2_alpha:mean=0.606,std=0.053 shared3_alpha:mean=0.655,std=0.061 eff_mlp_scale:[v0:196.9550 v1:115.5197 v2:111.8127 v3:120.9402 v4:107.5480 v5:132.4250 v6:109.8071 v7:114.6576 v8:106.4779 v9:114.3927 v10:85.7398 v11:92.1449 v12:104.3377 v13:263.6886] eff_attn_scale:[v0:0.3146 v1:0.8481 v2:0.8424 v3:0.8945 v4:0.9760 v5:0.8170 v6:0.7109 v7:0.7965 v8:0.8836 v9:0.7549 v10:0.5506 v11:0.5793 v12:0.6550 v13:1.6263] eff_attn_bias:[v0:0.5193 v1:1.0717 v2:0.7568 v3:0.8673 v4:0.9833 v5:0.9778 v6:0.8673 v7:0.8507 v8:0.9336 v9:0.8507 v10:0.6298 v11:0.5497 v12:0.5994 v13:0.9226] eff_mlp_bias:[v0:2.0992 v1:1.3093 v2:0.9723 v3:0.9612 v4:0.9115 v5:1.1932 v6:0.8342 v7:0.8121 v8:0.9005 v9:1.0330 v10:0.7347 v11:0.7844 v12:0.6822 v13:0.7347] +step:7200/20000 val_loss:2.1875 val_bpb:1.2956 train_time:419088ms step_avg:58.21ms +step:7400/20000 train_loss:2.1738 train_time:430685ms step_avg:58.20ms +step:7400 shared0_alpha:mean=0.451,std=0.060 shared1_alpha:mean=0.520,std=0.071 shared2_alpha:mean=0.608,std=0.053 shared3_alpha:mean=0.656,std=0.062 eff_mlp_scale:[v0:202.5407 v1:117.4239 v2:113.8027 v3:122.8780 v4:109.5443 v5:134.4418 v6:111.7796 v7:116.5496 v8:107.9254 v9:115.7221 v10:87.5017 v11:93.3451 v12:106.3066 v13:265.6677] eff_attn_scale:[v0:0.3048 v1:0.8437 v2:0.8441 v3:0.8701 v4:0.9609 v5:0.7953 v6:0.7095 v7:0.7702 v8:0.8780 v9:0.7470 v10:0.5382 v11:0.5579 v12:0.6465 v13:1.6149] eff_attn_bias:[v0:0.5138 v1:1.0883 v2:0.7679 v3:0.8784 v4:0.9944 v5:0.9944 v6:0.8784 v7:0.8673 v8:0.9502 v9:0.8618 v10:0.6325 v11:0.5552 v12:0.6077 v13:0.9336] eff_mlp_bias:[v0:2.1213 v1:1.3369 v2:0.9833 v3:0.9778 v4:0.9281 v5:1.2153 v6:0.8507 v7:0.8231 v8:0.9115 v9:1.0496 v10:0.7458 v11:0.7955 v12:0.6933 v13:0.7458] +step:7400/20000 val_loss:2.1835 val_bpb:1.2932 train_time:430709ms step_avg:58.20ms +step:7600/20000 train_loss:2.0562 train_time:442307ms step_avg:58.20ms +step:7600 shared0_alpha:mean=0.449,std=0.060 shared1_alpha:mean=0.520,std=0.072 shared2_alpha:mean=0.608,std=0.053 shared3_alpha:mean=0.656,std=0.063 eff_mlp_scale:[v0:206.5195 v1:119.8887 v2:115.5711 v3:125.0262 v4:111.6028 v5:136.4447 v6:113.0255 v7:118.1098 v8:109.9696 v9:117.0342 v10:88.5876 v11:94.7007 v12:107.7919 v13:269.5365] eff_attn_scale:[v0:0.3014 v1:0.8572 v2:0.8527 v3:0.8816 v4:0.9618 v5:0.8086 v6:0.7209 v7:0.7766 v8:0.8700 v9:0.7467 v10:0.5479 v11:0.5667 v12:0.6427 v13:1.5986] eff_attn_bias:[v0:0.5193 v1:1.1159 v2:0.7844 v3:0.8949 v4:1.0109 v5:0.9999 v6:0.8839 v7:0.8728 v8:0.9612 v9:0.8784 v10:0.6436 v11:0.5635 v12:0.6160 v13:0.9447] eff_mlp_bias:[v0:2.1434 v1:1.3590 v2:0.9999 v3:0.9888 v4:0.9391 v5:1.2264 v6:0.8618 v7:0.8342 v8:0.9226 v9:1.0607 v10:0.7568 v11:0.8065 v12:0.7043 v13:0.7568] +step:7600/20000 val_loss:2.1820 val_bpb:1.2923 train_time:442330ms step_avg:58.20ms +step:7800/20000 train_loss:2.2068 train_time:453933ms step_avg:58.20ms +step:7800 shared0_alpha:mean=0.448,std=0.061 shared1_alpha:mean=0.521,std=0.072 shared2_alpha:mean=0.609,std=0.053 shared3_alpha:mean=0.657,std=0.062 eff_mlp_scale:[v0:211.6009 v1:122.6287 v2:117.5524 v3:127.6484 v4:113.8560 v5:138.1731 v6:114.9858 v7:120.1397 v8:112.2059 v9:119.1743 v10:89.8326 v11:96.5408 v12:110.0058 v13:273.7930] eff_attn_scale:[v0:0.3015 v1:0.8658 v2:0.8442 v3:0.8944 v4:0.9657 v5:0.8212 v6:0.7015 v7:0.7805 v8:0.8744 v9:0.7453 v10:0.5383 v11:0.5695 v12:0.6438 v13:1.6115] eff_attn_bias:[v0:0.5165 v1:1.1325 v2:0.7955 v3:0.9115 v4:1.0220 v5:1.0165 v6:0.9005 v7:0.8894 v8:0.9723 v9:0.8839 v10:0.6491 v11:0.5718 v12:0.6270 v13:0.9612] eff_mlp_bias:[v0:2.1766 v1:1.3866 v2:1.0165 v3:1.0109 v4:0.9502 v5:1.2430 v6:0.8784 v7:0.8452 v8:0.9336 v9:1.0772 v10:0.7623 v11:0.8176 v12:0.7126 v13:0.7679] +step:7800/20000 val_loss:2.1797 val_bpb:1.2910 train_time:453958ms step_avg:58.20ms +step:8000/20000 train_loss:2.1687 train_time:465549ms step_avg:58.19ms +step:8000 shared0_alpha:mean=0.448,std=0.060 shared1_alpha:mean=0.522,std=0.072 shared2_alpha:mean=0.610,std=0.053 shared3_alpha:mean=0.658,std=0.062 eff_mlp_scale:[v0:215.6147 v1:124.8222 v2:120.1046 v3:129.7142 v4:116.0495 v5:140.4976 v6:116.9984 v7:122.1476 v8:113.8285 v9:120.7583 v10:91.1138 v11:97.8261 v12:112.1627 v13:275.4493] eff_attn_scale:[v0:0.2947 v1:0.8576 v2:0.8426 v3:0.8855 v4:0.9739 v5:0.8090 v6:0.7117 v7:0.7758 v8:0.8730 v9:0.7338 v10:0.5440 v11:0.5692 v12:0.6493 v13:1.5922] eff_attn_bias:[v0:0.5193 v1:1.1546 v2:0.8065 v3:0.9281 v4:1.0386 v5:1.0275 v6:0.9060 v7:0.8949 v8:0.9833 v9:0.8949 v10:0.6546 v11:0.5800 v12:0.6325 v13:0.9778] eff_mlp_bias:[v0:2.1987 v1:1.4087 v2:1.0275 v3:1.0220 v4:0.9667 v5:1.2651 v6:0.8894 v7:0.8618 v8:0.9502 v9:1.0883 v10:0.7734 v11:0.8286 v12:0.7292 v13:0.7734] +step:8000/20000 val_loss:2.1774 val_bpb:1.2896 train_time:465573ms step_avg:58.20ms +step:8200/20000 train_loss:2.2357 train_time:477170ms step_avg:58.19ms +step:8200 shared0_alpha:mean=0.447,std=0.060 shared1_alpha:mean=0.522,std=0.073 shared2_alpha:mean=0.611,std=0.053 shared3_alpha:mean=0.658,std=0.063 eff_mlp_scale:[v0:221.3299 v1:126.6775 v2:121.4890 v3:131.8829 v4:118.5773 v5:141.8554 v6:118.3605 v7:124.2533 v8:115.7806 v9:122.5911 v10:92.2899 v11:99.1847 v12:114.1026 v13:279.6976] eff_attn_scale:[v0:0.2930 v1:0.8523 v2:0.8616 v3:0.8877 v4:0.9753 v5:0.8079 v6:0.7173 v7:0.7826 v8:0.8747 v9:0.7369 v10:0.5441 v11:0.5638 v12:0.6517 v13:1.5887] eff_attn_bias:[v0:0.5165 v1:1.1767 v2:0.8176 v3:0.9391 v4:1.0441 v5:1.0441 v6:0.9170 v7:0.9060 v8:0.9944 v9:0.9115 v10:0.6629 v11:0.5856 v12:0.6408 v13:0.9833] eff_mlp_bias:[v0:2.2208 v1:1.4253 v2:1.0441 v3:1.0386 v4:0.9723 v5:1.2706 v6:0.9005 v7:0.8728 v8:0.9612 v9:1.1104 v10:0.7844 v11:0.8452 v12:0.7347 v13:0.7844] +step:8200/20000 val_loss:2.1772 val_bpb:1.2894 train_time:477194ms step_avg:58.19ms +step:8400/20000 train_loss:2.1883 train_time:488852ms step_avg:58.20ms +step:8400 shared0_alpha:mean=0.445,std=0.060 shared1_alpha:mean=0.523,std=0.073 shared2_alpha:mean=0.612,std=0.053 shared3_alpha:mean=0.658,std=0.063 eff_mlp_scale:[v0:224.8670 v1:128.9008 v2:123.9450 v3:134.6189 v4:120.0607 v5:144.2041 v6:120.2686 v7:125.8275 v8:117.2424 v9:124.7807 v10:93.4839 v11:100.5521 v12:116.1150 v13:281.5247] eff_attn_scale:[v0:0.2935 v1:0.8637 v2:0.8595 v3:0.8973 v4:0.9673 v5:0.8234 v6:0.7148 v7:0.7872 v8:0.8714 v9:0.7474 v10:0.5454 v11:0.5671 v12:0.6492 v13:1.5912] eff_attn_bias:[v0:0.5165 v1:1.2043 v2:0.8286 v3:0.9502 v4:1.0607 v5:1.0607 v6:0.9281 v7:0.9170 v8:1.0054 v9:0.9226 v10:0.6712 v11:0.5966 v12:0.6519 v13:0.9944] eff_mlp_bias:[v0:2.2429 v1:1.4474 v2:1.0551 v3:1.0551 v4:0.9888 v5:1.2872 v6:0.9115 v7:0.8894 v8:0.9723 v9:1.1214 v10:0.7955 v11:0.8563 v12:0.7458 v13:0.7955] +step:8400/20000 val_loss:2.1756 val_bpb:1.2885 train_time:488876ms step_avg:58.20ms +step:8600/20000 train_loss:2.1883 train_time:500471ms step_avg:58.19ms +step:8600 shared0_alpha:mean=0.444,std=0.061 shared1_alpha:mean=0.523,std=0.074 shared2_alpha:mean=0.613,std=0.054 shared3_alpha:mean=0.660,std=0.063 eff_mlp_scale:[v0:230.6999 v1:131.4266 v2:125.7292 v3:136.9581 v4:122.1853 v5:146.2269 v6:121.5030 v7:128.0863 v8:119.3438 v9:126.0985 v10:94.5611 v11:102.5800 v12:117.6389 v13:285.3134] eff_attn_scale:[v0:0.2891 v1:0.8740 v2:0.8632 v3:0.8987 v4:0.9849 v5:0.8119 v6:0.7118 v7:0.7728 v8:0.8803 v9:0.7498 v10:0.5400 v11:0.5712 v12:0.6537 v13:1.5910] eff_attn_bias:[v0:0.5138 v1:1.2264 v2:0.8397 v3:0.9667 v4:1.0717 v5:1.0717 v6:0.9336 v7:0.9281 v8:1.0165 v9:0.9336 v10:0.6795 v11:0.6049 v12:0.6629 v13:1.0109] eff_mlp_bias:[v0:2.2650 v1:1.4695 v2:1.0717 v3:1.0717 v4:0.9999 v5:1.3037 v6:0.9226 v7:0.9005 v8:0.9888 v9:1.1325 v10:0.8065 v11:0.8673 v12:0.7568 v13:0.8010] +step:8600/20000 val_loss:2.1731 val_bpb:1.2870 train_time:500498ms step_avg:58.20ms +step:8800/20000 train_loss:2.1608 train_time:512098ms step_avg:58.19ms +step:8800 shared0_alpha:mean=0.444,std=0.060 shared1_alpha:mean=0.524,std=0.074 shared2_alpha:mean=0.614,std=0.054 shared3_alpha:mean=0.660,std=0.064 eff_mlp_scale:[v0:234.8463 v1:133.4327 v2:127.5774 v3:138.9896 v4:124.9512 v5:148.3247 v6:123.3248 v7:130.0586 v8:120.9390 v9:127.4758 v10:96.2146 v11:103.8236 v12:119.7926 v13:288.9956] eff_attn_scale:[v0:0.2804 v1:0.8632 v2:0.8643 v3:0.9012 v4:0.9901 v5:0.8145 v6:0.7086 v7:0.7748 v8:0.8815 v9:0.7304 v10:0.5407 v11:0.5685 v12:0.6383 v13:1.5695] eff_attn_bias:[v0:0.5165 v1:1.2430 v2:0.8507 v3:0.9833 v4:1.0828 v5:1.0883 v6:0.9447 v7:0.9391 v8:1.0330 v9:0.9447 v10:0.6822 v11:0.6104 v12:0.6684 v13:1.0220] eff_mlp_bias:[v0:2.2870 v1:1.4916 v2:1.0938 v3:1.0883 v4:1.0165 v5:1.3203 v6:0.9336 v7:0.9115 v8:1.0054 v9:1.1490 v10:0.8176 v11:0.8784 v12:0.7679 v13:0.8121] +step:8800/20000 val_loss:2.1720 val_bpb:1.2864 train_time:512121ms step_avg:58.20ms +step:9000/20000 train_loss:2.0794 train_time:523713ms step_avg:58.19ms +step:9000 shared0_alpha:mean=0.443,std=0.060 shared1_alpha:mean=0.524,std=0.075 shared2_alpha:mean=0.614,std=0.054 shared3_alpha:mean=0.660,std=0.064 eff_mlp_scale:[v0:240.6299 v1:136.0186 v2:130.2843 v3:141.1545 v4:126.4911 v5:150.3994 v6:124.9228 v7:131.5942 v8:123.0256 v9:129.4274 v10:97.0430 v11:105.1629 v12:121.8704 v13:290.9856] eff_attn_scale:[v0:0.2836 v1:0.8679 v2:0.8595 v3:0.9083 v4:0.9938 v5:0.8100 v6:0.7156 v7:0.7816 v8:0.8927 v9:0.7433 v10:0.5470 v11:0.5746 v12:0.6508 v13:1.5885] eff_attn_bias:[v0:0.5138 v1:1.2651 v2:0.8673 v3:0.9999 v4:1.0993 v5:1.0993 v6:0.9557 v7:0.9502 v8:1.0441 v9:0.9557 v10:0.6905 v11:0.6187 v12:0.6822 v13:1.0330] eff_mlp_bias:[v0:2.3202 v1:1.5137 v2:1.1049 v3:1.1049 v4:1.0330 v5:1.3369 v6:0.9447 v7:0.9281 v8:1.0165 v9:1.1656 v10:0.8286 v11:0.8894 v12:0.7789 v13:0.8176] +step:9000/20000 val_loss:2.1725 val_bpb:1.2867 train_time:523739ms step_avg:58.19ms +step:9200/20000 train_loss:2.1380 train_time:535332ms step_avg:58.19ms +step:9200 shared0_alpha:mean=0.442,std=0.060 shared1_alpha:mean=0.526,std=0.075 shared2_alpha:mean=0.615,std=0.054 shared3_alpha:mean=0.661,std=0.064 eff_mlp_scale:[v0:244.8898 v1:137.6786 v2:131.6315 v3:143.4879 v4:128.5707 v5:152.7750 v6:126.2368 v7:133.8464 v8:125.0801 v9:131.0362 v10:98.1841 v11:106.6234 v12:123.9165 v13:295.1303] eff_attn_scale:[v0:0.2825 v1:0.8834 v2:0.8700 v3:0.9166 v4:1.0091 v5:0.8209 v6:0.7209 v7:0.7929 v8:0.8945 v9:0.7451 v10:0.5469 v11:0.5755 v12:0.6566 v13:1.5747] eff_attn_bias:[v0:0.5165 v1:1.2816 v2:0.8728 v3:1.0109 v4:1.1104 v5:1.1159 v6:0.9667 v7:0.9612 v8:1.0496 v9:0.9723 v10:0.7016 v11:0.6270 v12:0.6933 v13:1.0441] eff_mlp_bias:[v0:2.3312 v1:1.5357 v2:1.1214 v3:1.1159 v4:1.0496 v5:1.3534 v6:0.9557 v7:0.9391 v8:1.0330 v9:1.1822 v10:0.8397 v11:0.8949 v12:0.7844 v13:0.8286] +step:9200/20000 val_loss:2.1674 val_bpb:1.2837 train_time:535357ms step_avg:58.19ms +step:9400/20000 train_loss:2.1829 train_time:546947ms step_avg:58.19ms +step:9400 shared0_alpha:mean=0.440,std=0.060 shared1_alpha:mean=0.525,std=0.076 shared2_alpha:mean=0.615,std=0.054 shared3_alpha:mean=0.661,std=0.064 eff_mlp_scale:[v0:247.9157 v1:139.5834 v2:133.5230 v3:145.7725 v4:130.8546 v5:154.1486 v6:128.0952 v7:134.9110 v8:126.7471 v9:132.3008 v10:99.3281 v11:108.0432 v12:125.5735 v13:297.0215] eff_attn_scale:[v0:0.2732 v1:0.8742 v2:0.8829 v3:0.9242 v4:1.0209 v5:0.8162 v6:0.7246 v7:0.8001 v8:0.9005 v9:0.7404 v10:0.5497 v11:0.5776 v12:0.6642 v13:1.6102] eff_attn_bias:[v0:0.5165 v1:1.2927 v2:0.8894 v3:1.0220 v4:1.1214 v5:1.1270 v6:0.9723 v7:0.9723 v8:1.0607 v9:0.9778 v10:0.7071 v11:0.6353 v12:0.7016 v13:1.0496] eff_mlp_bias:[v0:2.3533 v1:1.5468 v2:1.1270 v3:1.1270 v4:1.0607 v5:1.3645 v6:0.9667 v7:0.9502 v8:1.0386 v9:1.1877 v10:0.8507 v11:0.9060 v12:0.7900 v13:0.8342] +step:9400/20000 val_loss:2.1584 val_bpb:1.2783 train_time:546972ms step_avg:58.19ms +step:9600/20000 train_loss:2.1824 train_time:558559ms step_avg:58.18ms +step:9600 shared0_alpha:mean=0.439,std=0.060 shared1_alpha:mean=0.525,std=0.076 shared2_alpha:mean=0.615,std=0.054 shared3_alpha:mean=0.660,std=0.065 eff_mlp_scale:[v0:251.2592 v1:140.2566 v2:134.2815 v3:147.2460 v4:132.7694 v5:155.5019 v6:128.8229 v7:136.8928 v8:128.6204 v9:132.9389 v10:100.4382 v11:109.2842 v12:127.4349 v13:299.0197] eff_attn_scale:[v0:0.2763 v1:0.8644 v2:0.8721 v3:0.9138 v4:1.0056 v5:0.8154 v6:0.7192 v7:0.7905 v8:0.8948 v9:0.7396 v10:0.5373 v11:0.5823 v12:0.6601 v13:1.6127] eff_attn_bias:[v0:0.5165 v1:1.3037 v2:0.8949 v3:1.0275 v4:1.1270 v5:1.1380 v6:0.9778 v7:0.9778 v8:1.0662 v9:0.9888 v10:0.7126 v11:0.6408 v12:0.7071 v13:1.0607] eff_mlp_bias:[v0:2.3533 v1:1.5578 v2:1.1380 v3:1.1325 v4:1.0662 v5:1.3755 v6:0.9667 v7:0.9612 v8:1.0496 v9:1.1932 v10:0.8563 v11:0.9115 v12:0.7955 v13:0.8452] +step:9600/20000 val_loss:2.1493 val_bpb:1.2729 train_time:558584ms step_avg:58.19ms +step:9800/20000 train_loss:2.1043 train_time:570183ms step_avg:58.18ms +step:9800 shared0_alpha:mean=0.438,std=0.060 shared1_alpha:mean=0.525,std=0.076 shared2_alpha:mean=0.616,std=0.054 shared3_alpha:mean=0.661,std=0.064 eff_mlp_scale:[v0:252.0228 v1:141.4094 v2:135.4686 v3:148.0677 v4:134.3881 v5:156.1013 v6:129.9841 v7:137.6567 v8:129.6098 v9:134.0634 v10:100.9159 v11:109.8940 v12:128.4153 v13:302.9092] eff_attn_scale:[v0:0.2790 v1:0.8832 v2:0.8726 v3:0.9179 v4:1.0180 v5:0.8246 v6:0.7154 v7:0.7984 v8:0.9019 v9:0.7390 v10:0.5376 v11:0.5764 v12:0.6608 v13:1.6386] eff_attn_bias:[v0:0.5138 v1:1.3093 v2:0.9005 v3:1.0330 v4:1.1325 v5:1.1380 v6:0.9778 v7:0.9778 v8:1.0717 v9:0.9888 v10:0.7126 v11:0.6436 v12:0.7126 v13:1.0607] eff_mlp_bias:[v0:2.3533 v1:1.5578 v2:1.1380 v3:1.1380 v4:1.0717 v5:1.3755 v6:0.9723 v7:0.9612 v8:1.0496 v9:1.1988 v10:0.8563 v11:0.9115 v12:0.8010 v13:0.8507] +step:9800/20000 val_loss:2.1406 val_bpb:1.2678 train_time:570207ms step_avg:58.18ms +step:10000/20000 train_loss:2.1384 train_time:581839ms step_avg:58.18ms +step:10000 shared0_alpha:mean=0.438,std=0.060 shared1_alpha:mean=0.525,std=0.076 shared2_alpha:mean=0.615,std=0.054 shared3_alpha:mean=0.660,std=0.064 eff_mlp_scale:[v0:252.3439 v1:141.8799 v2:135.8113 v3:148.6687 v4:135.3221 v5:156.6207 v6:130.3129 v7:138.2154 v8:130.5107 v9:134.5096 v10:101.1712 v11:110.3400 v12:129.3078 v13:304.4849] eff_attn_scale:[v0:0.2740 v1:0.8758 v2:0.8684 v3:0.9213 v4:1.0317 v5:0.8259 v6:0.7188 v7:0.8008 v8:0.9100 v9:0.7442 v10:0.5360 v11:0.5726 v12:0.6622 v13:1.6512] eff_attn_bias:[v0:0.5138 v1:1.3148 v2:0.9005 v3:1.0330 v4:1.1380 v5:1.1380 v6:0.9833 v7:0.9778 v8:1.0717 v9:0.9944 v10:0.7126 v11:0.6436 v12:0.7126 v13:1.0662] eff_mlp_bias:[v0:2.3533 v1:1.5689 v2:1.1435 v3:1.1380 v4:1.0772 v5:1.3811 v6:0.9723 v7:0.9667 v8:1.0551 v9:1.2043 v10:0.8618 v11:0.9115 v12:0.8065 v13:0.8563] +step:10000/20000 val_loss:2.1310 val_bpb:1.2621 train_time:581861ms step_avg:58.19ms +step:10200/20000 train_loss:2.0822 train_time:593454ms step_avg:58.18ms +step:10200 shared0_alpha:mean=0.438,std=0.060 shared1_alpha:mean=0.524,std=0.076 shared2_alpha:mean=0.615,std=0.054 shared3_alpha:mean=0.660,std=0.064 eff_mlp_scale:[v0:252.1265 v1:142.0524 v2:135.9660 v3:148.8238 v4:135.7316 v5:156.8111 v6:130.4613 v7:138.3597 v8:130.9056 v9:134.6730 v10:101.2864 v11:110.4552 v12:129.6991 v13:305.4364] eff_attn_scale:[v0:0.2759 v1:0.8731 v2:0.8651 v3:0.9253 v4:1.0270 v5:0.8276 v6:0.7195 v7:0.7999 v8:0.9099 v9:0.7412 v10:0.5324 v11:0.5751 v12:0.6577 v13:1.6450] eff_attn_bias:[v0:0.5138 v1:1.3148 v2:0.9005 v3:1.0330 v4:1.1380 v5:1.1380 v6:0.9833 v7:0.9778 v8:1.0717 v9:0.9888 v10:0.7126 v11:0.6463 v12:0.7126 v13:1.0662] eff_mlp_bias:[v0:2.3533 v1:1.5578 v2:1.1380 v3:1.1380 v4:1.0772 v5:1.3811 v6:0.9723 v7:0.9667 v8:1.0496 v9:1.2043 v10:0.8618 v11:0.9170 v12:0.8065 v13:0.8563] +step:10200/20000 val_loss:2.1217 val_bpb:1.2566 train_time:593477ms step_avg:58.18ms +step:10313/20000 val_loss:2.1185 val_bpb:1.2547 train_time:600052ms step_avg:58.18ms +stopping_early: wallclock_cap train_time:600052ms step:10313/20000 +peak memory allocated: 13736 MiB reserved: 14154 MiB +Serialized model: 45208380 bytes +Code size: 66161 bytes +Total submission size: 45274541 bytes +Serialized model int8+zlib: 10736208 bytes (payload:11667648 raw_torch:11699443 payload_ratio:3.87x) +Total submission size int8+zlib: 10802369 bytes +final_int8_zlib_roundtrip val_loss:2.1317 val_bpb:1.2625 eval_time:1858ms +final_int8_zlib_roundtrip_exact val_loss:2.13174511 val_bpb:1.26253953 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md index 8ea3b565f4..215c3fce5f 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md @@ -108,3 +108,80 @@ Birkhoff alone hurts (C' > B'). Peri-norm is the main factor (C − C' = −0.00 | **K** | **1+4×3+1 peri+birk+ts(cap4)** | **14** | **1.2583** | **1.2659** | **+0.0076** | **Headline result.** Run K achieves 14 effective layers from 6 unique blocks with Q-gap +0.0076. This is the first viable 3-loop depth recurrence in competition history, vs prior results showing catastrophic failure at 3+ loops. Timestep scaling reduces Q-gap by 26–30% on both 2-loop and 3-loop configurations. It helps quantization, not training. + +## 7. Technique 4: FiLM Bias (Per-Iteration Shift Vectors) + +**Problem.** Capped timestep scaling (§4) provides per-iteration scale vectors $\gamma^{(t)}$ but no shift. Standard FiLM conditioning (Perez et al., 2018) uses both scale and shift: $\text{FiLM}(x) = \gamma \odot x + \beta$. The missing $\beta$ limits per-iteration expressivity — scaling alone cannot shift the operating point of downstream layers. + +**Solution.** Add per-iteration bias vectors $\beta_{\text{attn}}^{(t)}, \beta_{\text{mlp}}^{(t)}$ alongside existing gammas. Initialized to zeros (no effect at initialization), not clamped (unlike gammas which are capped at ±4.0), stored as FP16 passthrough parameters that bypass int8 quantization. Parameter cost: $2 \times \text{eff\_layers} \times 512 \approx 8\text{KB}$ additional. + +**Result.** FiLM bias gives a consistent −0.003 post-Q BPB improvement at both loop counts: + +| Comparison | Without bias | With bias | Delta | +|------------|-------------|-----------|-------| +| 2 loops (s2_I vs s3_N) | 1.2668 | 1.2641 | −0.0027 | +| 3 loops (s2_K vs s3_O) | 1.2659 | 1.2625 | −0.0034 | + +The effect is independent of loop count (~0.003 in both cases), confirming that bias provides additive benefit on top of gammas. No throughput penalty: step_avg is 42.48ms (s3_N) vs 41.92ms (s2_I) at 2 loops, and 58.18ms (s3_O) vs 59.28ms (s2_K) at 3 loops. Negligible artifact overhead (+0.03MB). + +> Perez, E., Strub, F., de Vries, H., Dumoulin, V. & Courville, A. (2018). "FiLM: Visual Reasoning with a General Conditioning Layer." AAAI 2018. [arXiv:1709.07871](https://arxiv.org/abs/1709.07871) + +## 8. Attention-Only Sharing: Validating Per-Iteration MLP Differentiation + +**Hypothesis.** If shared weights are the bottleneck, which component benefits more from being unique per iteration — attention or MLP? ALBERT (Lan et al., 2020, §4.4) found that sharing attention parameters across layers has negligible effect on downstream tasks, while sharing FFN parameters causes most of the degradation. This suggests that attention weights learn position-agnostic patterns, while FFN weights need layer-specific specialization. + +**Experiment.** s3_L uses attention-only sharing: 4 `SharedAttnLayer` modules (shared across loop iterations) paired with 8 `UniqueMLP` modules (one per virtual position per loop). This gives each iteration distinct feedforward capacity while reusing attention weights. + +**Result.** s3_L achieves **1.2406 post-Q BPB** — the best result in the entire ablation series — beating full sharing (s2_I: 1.2668) by −0.026 BPB. This is a massive improvement, larger than any other single technique. + +| Metric | s2_I (full share) | s3_L (attn-only share) | Delta | +|--------|-------------------|------------------------|-------| +| Post-Q BPB | 1.2668 | 1.2406 | −0.0262 | +| Q-gap | 0.0088 | 0.0073 | −0.0015 | +| Params | 11.55M | 15.75M | +4.20M | +| Artifact | 10.77MB | 14.65MB | +3.88MB | +| step_avg | 41.92ms | 42.60ms | +0.68ms | + +**Diagnostics confirm per-iteration specialization.** Unique MLPs develop aggressive per-position scales (148–260 range vs 157–177 for full sharing). Shared attention alphas differentiate more (0.45–0.78 vs 0.43–0.61), suggesting that unique MLPs enable the shared attention to learn more distinct mixing behaviors. + +**Abandoned.** Despite the BPB win, attention-only sharing is impractical for competition use: +1. **Artifact cost:** 14.65MB leaves only ~1.35MB headroom — insufficient for integrating SOTA features. +2. **torch.compile limitation:** The 3-loop variant (s3_M, 12 UniqueMLP modules) crashes `torch.compile(fullgraph=True)` during AOT autograd tracing with `RuntimeError: tensor does not have a device`. The 2-loop variant (8 modules) compiles fine. The model works without compile (verified via smoke test), but the throughput penalty (~3× slower) makes it uncompetitive. + +**Takeaway.** The concept — per-iteration MLP differentiation — is validated. The implementation — full unique MLP copies — is too expensive. A cheaper mechanism is needed. + +> Lan, Z., Chen, M., Goodman, S., Gimpel, K., Sharma, P. & Soricut, R. (2020). "ALBERT: A Lite BERT for Self-supervised Learning of Language Representations." ICLR 2020. [arXiv:1909.11942](https://arxiv.org/abs/1909.11942) + +## 9. Toward Cheap Per-Iteration Specialization + +s3_L proved that per-iteration MLP differentiation is essential (−0.026 BPB). But unique MLPs cost ~4MB in artifact size. The question is: can we achieve most of the differentiation at a fraction of the parameter cost? + +### Key insight: control the input, not the weights + +A unique MLP per iteration gives each loop a distinct feedforward function $f_v(x)$. But the same effect can be approximated by giving the shared MLP a distinct *input* per iteration: $f(\text{transform}_v(x))$. If the per-iteration transform is cheap, the total parameter cost drops dramatically. + +### Evidence from literature + +**MoEUT** (Csordás et al., 2024, §2.4): For Mixture-of-Experts Universal Transformers, "peri-layernorm" (normalization placement around sub-layers) is critical for competitive performance. The paper finds that normalization controls what the shared weights see, which is more important than the weights themselves being unique. This aligns with the Output-LN finding (§3): normalization placement is the key lever for recurrence. + +**BitFit** (Ben-Zaken et al., 2022, §3): When fine-tuning BERT by training only bias terms, LayerNorm parameters change more than any other component — even more than attention or FFN biases. This suggests that normalization parameters have outsized influence on layer behavior, making them an efficient target for per-iteration specialization. + +**Relaxed Recursive Transformers** (Bae et al., 2025, §3.2): Per-iteration LoRA adapters on shared transformer weights recover 99.7% of non-shared performance at 1/3 the parameters. The paper demonstrates that low-rank per-iteration corrections are sufficient — full unique copies are overkill. + +### Planned parameter-efficient stack + +| Component | Per-iteration cost | Total (14 virtual positions) | Role | +|-----------|-------------------|---------------------------|------| +| Unique input norms (attn_in + mlp_in) | 2 × 512 = 1024 params | 14 × 2 × 512 = 14,336 params = 28KB FP16 | Control what shared weights see | +| Depth embeddings | 512 params | 14 × 512 = 14KB FP16 | Positional identity per iteration | +| Timestep gammas | 2 × 512 = 1024 params | 14 × 2 × 512 = 28KB FP16 | Per-iteration scale | +| Timestep betas | 2 × 512 = 1024 params | 14 × 2 × 512 = 28KB FP16 | Per-iteration shift | +| **Total** | | **~110KB FP16 passthrough** | | + +This leaves ~4.8MB headroom (from an estimated ~11.2MB artifact) for SOTA feature integration — vs only ~1.35MB with unique MLPs. + +**Depth embedding subsumes Q/K bias.** Adding a depth embedding $e_v$ to the input before attention: $W_q(x + e_v) = W_q \cdot x + W_q \cdot e_v$. The second term acts as a learned per-iteration query/key bias, providing positional differentiation within the attention mechanism without additional parameters beyond the embedding itself. + +> Csordás, R., Irie, K., Schmidhuber, J., Potts, C. & Manning, C. (2024). "MoEUT: Mixture-of-Experts Universal Transformers." NeurIPS 2024. [arXiv:2405.16039](https://arxiv.org/abs/2405.16039) +> Ben-Zaken, E., Goldberg, Y. & Ravfogel, S. (2022). "BitFit: Simple Parameter-efficient Fine-tuning for Transformer-based Masked Language-models." ACL 2022. [arXiv:2106.10199](https://arxiv.org/abs/2106.10199) +> Bae, S., Ko, J., Song, H. & Yun, S.-Y. (2025). "Relaxed Recursive Transformers: Effective Parameter Sharing with Layer-wise LoRA." ICLR 2025. [arXiv:2410.20672](https://arxiv.org/abs/2410.20672) diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale2.sh b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale2.sh new file mode 100755 index 0000000000..2498c2a39a --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale2.sh @@ -0,0 +1,61 @@ +#!/bin/bash +set -uo pipefail + +SCRIPT="train_gpt.py" +NGPU=${NGPU:-8} +COMMON="SEED=1337 MAX_WALLCLOCK_SECONDS=600 VAL_LOSS_EVERY=200 TRAIN_LOG_EVERY=200" +DATA="DATA_PATH=${DATA_PATH:-./data/datasets/fineweb10B_sp1024} TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model VOCAB_SIZE=1024" + +FAILS=0 +SUMMARY="" + +run_experiment() { + local name="$1"; shift + echo "" + echo "=== $name ===" + if "$@"; then + SUMMARY="${SUMMARY} PASS $name"$'\n' + else + SUMMARY="${SUMMARY} FAIL $name (exit $?)"$'\n' + FAILS=$((FAILS + 1)) + fi +} + +# --- L: 1+4×2+1, attn-only sharing, 2 loops (compare vs Run I) --- + +run_experiment "Run L: 1+4x2+1 attn-only sharing (2 loops)" \ + env $COMMON $DATA RUN_ID=s3_L NUM_LAYERS=10 NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=2 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + SHARE_ATTN_ONLY=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +# --- M: 1+4×3+1, attn-only sharing, 3 loops (compare vs Run K) --- + +run_experiment "Run M: 1+4x3+1 attn-only sharing (3 loops)" \ + env $COMMON $DATA RUN_ID=s3_M NUM_LAYERS=14 NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=3 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + SHARE_ATTN_ONLY=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +# --- N: 1+4×2+1, full sharing + FiLM bias, 2 loops (compare vs Run I) --- + +run_experiment "Run N: 1+4x2+1 full sharing + FiLM bias (2 loops)" \ + env $COMMON $DATA RUN_ID=s3_N NUM_LAYERS=10 NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=2 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + USE_TIMESTEP_BIAS=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +# --- O: 1+4×3+1, full sharing + FiLM bias, 3 loops (compare vs Run K) --- + +run_experiment "Run O: 1+4x3+1 full sharing + FiLM bias (3 loops)" \ + env $COMMON $DATA RUN_ID=s3_O NUM_LAYERS=14 NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=3 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + USE_TIMESTEP_BIAS=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +echo "" +echo "===============================" +echo " FULL-SCALE 2 SUMMARY" +echo "===============================" +echo "$SUMMARY" +echo "$FAILS run(s) failed." diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py index 7a244120d8..4036333b31 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py @@ -90,6 +90,8 @@ class Hyperparameters: use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) + share_attn_only = bool(int(os.environ.get("SHARE_ATTN_ONLY", "0"))) disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) # Optimizer hyperparameters. @@ -311,7 +313,7 @@ def eval_val( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta", ).split(",") if pattern ) @@ -649,19 +651,98 @@ def forward(self, x: Tensor) -> Tensor: class TimestepScaling(nn.Module): """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" - def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0): + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0, use_bias: bool = False): super().__init__() self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) self.gamma_max = gamma_max # 0 = uncapped + if use_bias: + self.attn_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + self.mlp_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + else: + self.attn_beta = None + self.mlp_beta = None - def get(self, v: int) -> tuple[Tensor, Tensor]: + def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: ag = self.attn_gamma[v] mg = self.mlp_gamma[v] if self.gamma_max > 0: ag = ag.clamp(-self.gamma_max, self.gamma_max) mg = mg.clamp(-self.gamma_max, self.gamma_max) - return ag, mg + ab = self.attn_beta[v] if self.attn_beta is not None else None + mb = self.mlp_beta[v] if self.mlp_beta is not None else None + return ag, mg, ab, mb + + +class SharedAttnLayer(nn.Module): + """Shared attention layer (mixing + attention only, no MLP) for attn-only sharing mode.""" + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_birkhoff_mix: bool = False, + ): + super().__init__() + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + return x + + +class UniqueMLP(nn.Module): + """Unique MLP per virtual shared position for attn-only sharing mode.""" + def __init__( + self, + dim: int, + mlp_mult: int, + use_peri_norm: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + if use_peri_norm: + self.mlp_out_norm = RMSNorm() + else: + self.mlp_norm = RMSNorm() + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: Tensor, + ts_mlp_gamma: Tensor | None = None, + ts_mlp_beta: Tensor | None = None) -> Tensor: + mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + out = mlp_s * mlp_out + if ts_mlp_beta is not None: + out = out + ts_mlp_beta[None, None, :] + return out class Block(nn.Module): @@ -690,14 +771,15 @@ def __init__( self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) if use_birkhoff_mix: - # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) else: self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) def forward(self, x: Tensor, x0: Tensor, ts_attn_gamma: Tensor | None = None, - ts_mlp_gamma: Tensor | None = None) -> Tensor: + ts_mlp_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None, + ts_mlp_beta: Tensor | None = None) -> Tensor: if self.use_birkhoff_mix: alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] x = alpha * x + (1 - alpha) * x0 @@ -709,11 +791,15 @@ def forward(self, x: Tensor, x0: Tensor, if ts_attn_gamma is not None: attn_s = attn_s * ts_attn_gamma[None, None, :] x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] if ts_mlp_gamma is not None: mlp_s = mlp_s * ts_mlp_gamma[None, None, :] x = x + mlp_s * mlp_out + if ts_mlp_beta is not None: + x = x + ts_mlp_beta[None, None, :] return x @@ -738,6 +824,8 @@ def __init__( use_peri_norm: bool = False, use_birkhoff_mix: bool = False, use_timestep_scale: bool = False, + use_timestep_bias: bool = False, + share_attn_only: bool = False, timestep_gamma_max: float = 0.0, leaky_relu_slope: float = 0.5, ): @@ -761,12 +849,31 @@ def __init__( leaky_relu_slope=leaky_relu_slope, ) + self.share_attn_only = share_attn_only if self.use_recurrence else False + if self.use_recurrence: self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) - self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + if self.share_attn_only: + shared_attn_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + rope_base=rope_base, qk_gain_init=qk_gain_init, + use_birkhoff_mix=use_birkhoff_mix, + ) + unique_mlp_kwargs = dict( + dim=model_dim, mlp_mult=mlp_mult, + use_peri_norm=use_peri_norm, + leaky_relu_slope=leaky_relu_slope, + ) + self.shared_attn_layers = nn.ModuleList([SharedAttnLayer(**shared_attn_kwargs) for _ in range(num_shared)]) + self.unique_mlps = nn.ModuleList([UniqueMLP(**unique_mlp_kwargs) for _ in range(num_shared * self.num_loops)]) + self.shared_blocks = nn.ModuleList() # empty — keeps diagnostics safe + else: + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + self.shared_attn_layers = nn.ModuleList() + self.unique_mlps = nn.ModuleList() effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda - self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max) if use_timestep_scale else None + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None else: # Standard U-Net path self.num_encoder_layers = num_layers // 2 @@ -789,9 +896,9 @@ def _init_weights(self) -> None: if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) - def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: if self.timestep_scale is None: - return None, None + return None, None, None, None return self.timestep_scale.get(v) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: @@ -802,17 +909,27 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: if self.use_recurrence: v = 0 for block in self.prelude_blocks: - ag, mg = self._get_ts(v) - x = block(x, x0, ag, mg) + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) v += 1 - for _loop in range(self.num_loops): - for block in self.shared_blocks: - ag, mg = self._get_ts(v) - x = block(x, x0, ag, mg) - v += 1 + if self.share_attn_only: + vid = 0 + for _loop in range(self.num_loops): + for attn_layer in self.shared_attn_layers: + ag, mg, ab, mb = self._get_ts(v) + x = attn_layer(x, x0, ag, ab) + x = x + self.unique_mlps[vid](x, mg, mb) + vid += 1 + v += 1 + else: + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) + v += 1 for block in self.coda_blocks: - ag, mg = self._get_ts(v) - x = block(x, x0, ag, mg) + ag, mg, ab, mb = self._get_ts(v) + x = block(x, x0, ag, mg, ab, mb) v += 1 else: skips: list[Tensor] = [] @@ -843,24 +960,77 @@ def recurrence_param_diagnostics(gpt: GPT) -> str: return "" parts: list[str] = [] - # Birkhoff alpha stats per shared block - for i, block in enumerate(gpt.shared_blocks): - if hasattr(block, "resid_mix_logit"): - a = torch.sigmoid(block.resid_mix_logit) - parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + # Birkhoff alpha stats per shared block/layer + if gpt.share_attn_only: + for i, layer in enumerate(gpt.shared_attn_layers): + if hasattr(layer, "resid_mix_logit"): + a = torch.sigmoid(layer.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + else: + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") # Effective MLP/attn contribution scale per virtual layer position: # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. # We report RMS (norm / sqrt(numel)) to make it scale-independent. v = 0 - all_blocks = ( - list(gpt.prelude_blocks) - + list(gpt.shared_blocks) * gpt.num_loops - + list(gpt.coda_blocks) - ) + effective_count = gpt.num_prelude + len(gpt.shared_blocks if not gpt.share_attn_only else gpt.shared_attn_layers) * gpt.num_loops + gpt.num_coda mlp_norms: list[str] = [] attn_norms: list[str] = [] - for block in all_blocks: + + # Prelude blocks + for block in gpt.prelude_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Shared positions + if gpt.share_attn_only: + vid = 0 + for _loop in range(gpt.num_loops): + for layer in gpt.shared_attn_layers: + asc = layer.attn_scale.norm().item() + ms = gpt.unique_mlps[vid].mlp_scale.norm().item() + d = layer.attn_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + vid += 1 + v += 1 + else: + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Coda blocks + for block in gpt.coda_blocks: ms = block.mlp_scale.norm().item() asc = block.attn_scale.norm().item() d = block.mlp_scale.numel() ** 0.5 @@ -873,8 +1043,19 @@ def recurrence_param_diagnostics(gpt: GPT) -> str: mlp_norms.append(f"v{v}:{ms / d:.4f}") attn_norms.append(f"v{v}:{asc / d:.4f}") v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + if gpt.timestep_scale is not None and gpt.timestep_scale.attn_beta is not None: + attn_bias_norms: list[str] = [] + mlp_bias_norms: list[str] = [] + for vi in range(effective_count): + ab_rms = gpt.timestep_scale.attn_beta[vi].norm().item() / gpt.timestep_scale.attn_beta[vi].numel() ** 0.5 + mb_rms = gpt.timestep_scale.mlp_beta[vi].norm().item() / gpt.timestep_scale.mlp_beta[vi].numel() ** 0.5 + attn_bias_norms.append(f"v{vi}:{ab_rms:.4f}") + mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") + parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") + parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") return " ".join(parts) @@ -996,6 +1177,8 @@ def log0(msg: str, console: bool = True) -> None: use_peri_norm=args.use_peri_norm, use_birkhoff_mix=args.use_birkhoff_mix, use_timestep_scale=args.use_timestep_scale, + use_timestep_bias=args.use_timestep_bias, + share_attn_only=args.share_attn_only, timestep_gamma_max=args.timestep_gamma_max, leaky_relu_slope=args.leaky_relu_slope, ).to(device).bfloat16() @@ -1019,6 +1202,9 @@ def log0(msg: str, console: bool = True) -> None: block_named_params = [] for bl in all_block_lists: block_named_params.extend(bl.named_parameters()) + if base_model.share_attn_only: + block_named_params.extend(base_model.shared_attn_layers.named_parameters()) + block_named_params.extend(base_model.unique_mlps.named_parameters()) else: block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ @@ -1085,8 +1271,10 @@ def log0(msg: str, console: bool = True) -> None: ) log0(f"seed:{args.seed}") if base_model.use_recurrence: - eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda - log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + num_shared = len(base_model.shared_attn_layers) if base_model.share_attn_only else len(base_model.shared_blocks) + eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda + shared_label = f"shared_attn:{num_shared}" if base_model.share_attn_only else f"shared:{num_shared}" + log0(f"recurrence:enabled prelude:{base_model.num_prelude} {shared_label} " f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") if base_model.timestep_scale is not None: From 6a9eeaa1ec570d634f42157a8776cbd324d2e6c2 Mon Sep 17 00:00:00 2001 From: Alexandr Azizyan Date: Thu, 2 Apr 2026 20:32:17 +0400 Subject: [PATCH 09/10] feat: add depth embeddings + unique norms ablations (Series 4, negative result) --- .../README.md | 15 +- .../logs/s4_P.txt | 1739 +++++++++++++++++ .../logs/s4_Q.txt | 1688 ++++++++++++++++ .../logs/s4_R.txt | 1697 ++++++++++++++++ .../logs/s4_S.txt | 1700 ++++++++++++++++ .../research_notes.md | 51 + .../scripts/run_fullscale3.sh | 61 + .../train_gpt.py | 266 +-- 8 files changed, 7050 insertions(+), 167 deletions(-) create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_P.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_Q.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_R.txt create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_S.txt create mode 100755 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale3.sh diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md index 1e426cd848..c62ee0c7cc 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md @@ -26,6 +26,8 @@ - **Attention-only sharing (shared attention, unique MLPs per iteration) gives −0.026 post-Q BPB** — the largest improvement found — but costs +3.9MB artifact (14.65MB total), leaving insufficient room for SOTA features. - **FiLM bias adds −0.003 BPB at both 2 and 3 loops** with zero artifact/throughput cost. Additive with timestep scaling gammas. - **ALBERT (Lan et al., 2020) found attention sharing is nearly free while FFN sharing causes most degradation.** s3_L confirms: the model needs per-iteration MLP differentiation, not per-iteration attention differentiation. +- **Learned depth embeddings and unique input norms both hurt BPB despite reducing Q-gap.** The throughput overhead (6–15%) costs training steps that outweigh any specialization benefit. Learned depth embeddings remained near zero even after full training (RMS 0.006–0.010), suggesting they need far more steps to become useful. FiLM bias alone (s3_O: 1.2625) remains the best full-sharing config. +- **The 0.026 BPB gap between full sharing and unique MLPs (s3_L) cannot be closed by cheap per-iteration input controls.** The MLP genuinely needs different weights per iteration, not just different inputs. ## Techniques Applicable to Non-Recurrent Submissions @@ -61,11 +63,20 @@ Output-LN could benefit any submission using quadratic activations (relu², leak Run M (1+4×3+1 attn-only, 3 loops) crashed during torch.compile with 12 UniqueMLP modules. Works without compile (verified via smoke test). +## Series 4: Depth Embeddings + Unique Norms (600s, 8×H100) + +| Run | Config | Eff. Layers | Pre-Q BPB | Post-Q BPB | Q-Gap | +|-----|--------|-------------|-----------|------------|-------| +| P | 1+4×2+1 learned depth+norms+bias | 10 | 1.2579 | 1.2663 | +0.0084 | +| Q | 1+4×3+1 learned depth+norms+bias | 14 | 1.2574 | 1.2643 | +0.0069 | +| R | 1+4×3+1 learned depth only+bias | 14 | 1.2566 | 1.2639 | +0.0073 | +| S | 1+4×3+1 norms only+bias | 14 | 1.2560 | 1.2629 | +0.0069 | + ## Next Direction -Run s3_L validated that per-iteration MLP differentiation is critical (−0.026 BPB), but unique MLPs per loop iteration are too expensive: 12 unique MLP modules add ~4MB to the artifact (14.65MB total), leaving only ~1.35MB headroom — insufficient for integrating SOTA features. The 3-loop variant (s3_M) also crashes torch.compile(fullgraph=True). +The cheap learned specialization approach (Series 4: learned depth embeddings + unique input norms) was tested and did not improve over FiLM bias alone. Learned depth embeddings remained near zero (RMS 0.006–0.010) after full training, and throughput overhead (6–15%) cost more training steps than the specialization recovered. FiLM bias alone (s3_O: 1.2625) remains the best full-sharing configuration. -The planned approach achieves per-iteration differentiation at ~110KB instead of ~12MB: per-iteration **unique input norms** (24KB) control what the shared MLP sees at each iteration, **learned depth embeddings** (14KB) provide positional identity, and **FiLM gammas + betas** (28KB) modulate residual contributions — all stored as FP16 passthrough. This leaves ~4.8MB headroom for SOTA feature integration while preserving the per-iteration specialization that s3_L showed is essential. +Two next steps: (1) Replace learned depth embeddings with **sinusoidal depth encodings** (Universal Transformer style, Dehghani et al., 2019) — zero parameter cost, zero artifact cost, zero throughput overhead, and full-strength iteration identity signal from step 0 instead of slowly-learned near-zero values. (2) **Graft existing techniques** (Output-LN, Birkhoff mixing, FiLM scale+shift) onto the SOTA stack for a competitive submission. ## How to Reproduce diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_P.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_P.txt new file mode 100644 index 0000000000..0fb746c18e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_P.txt @@ -0,0 +1,1739 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) + use_depth_embed = bool(int(os.environ.get("USE_DEPTH_EMBED", "0"))) + use_unique_norms = bool(int(os.environ.get("USE_UNIQUE_NORMS", "0"))) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta,depth_embed,unique_attn_gain,unique_mlp_gain", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0, use_bias: bool = False): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + if use_bias: + self.attn_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + self.mlp_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + else: + self.attn_beta = None + self.mlp_beta = None + + def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + ab = self.attn_beta[v] if self.attn_beta is not None else None + mb = self.mlp_beta[v] if self.mlp_beta is not None else None + return ag, mg, ab, mb + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None, + ts_mlp_beta: Tensor | None = None, + depth_emb: Tensor | None = None, + ext_attn_gain: Tensor | None = None, + ext_mlp_gain: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + if depth_emb is not None: + x = x + depth_emb + attn_normed = self.attn_norm(x) + if ext_attn_gain is not None: + attn_normed = attn_normed * ext_attn_gain[None, None, :] + attn_out = self.attn(attn_normed) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + if self.use_peri_norm: + if ext_mlp_gain is not None: + m_input = F.rms_norm(x, (x.size(-1),)) * ext_mlp_gain[None, None, :] + else: + m_input = x + mlp_out = self.mlp_out_norm(self.mlp(m_input)) + else: + mlp_normed = self.mlp_norm(x) + if ext_mlp_gain is not None: + mlp_normed = mlp_normed * ext_mlp_gain[None, None, :] + mlp_out = self.mlp(mlp_normed) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + if ts_mlp_beta is not None: + x = x + ts_mlp_beta[None, None, :] + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + use_timestep_bias: bool = False, + use_depth_embed: bool = False, + use_unique_norms: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None + self.use_depth_embed = use_depth_embed + if self.use_depth_embed: + self.depth_embeddings = nn.Parameter(torch.zeros(effective_layers, model_dim, dtype=torch.float32)) + self.use_unique_norms = use_unique_norms + if self.use_unique_norms: + num_unique = num_shared * self.num_loops + self.unique_attn_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + self.unique_mlp_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + self.use_depth_embed = False + self.use_unique_norms = False + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None, None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) + v += 1 + uid = 0 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + if self.use_unique_norms: + ag_n = self.unique_attn_gains[uid].to(dtype=x.dtype) + mg_n = self.unique_mlp_gains[uid].to(dtype=x.dtype) + x = block(x, x0, ag, mg, ab, mb, de, ag_n, mg_n) + else: + x = block(x, x0, ag, mg, ab, mb, de) + uid += 1 + v += 1 + for block in self.coda_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + effective_count = gpt.num_prelude + len(gpt.shared_blocks) * gpt.num_loops + gpt.num_coda + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + + # Prelude blocks + for block in gpt.prelude_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Shared positions + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Coda blocks + for block in gpt.coda_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + if gpt.timestep_scale is not None and gpt.timestep_scale.attn_beta is not None: + attn_bias_norms: list[str] = [] + mlp_bias_norms: list[str] = [] + for vi in range(effective_count): + ab_rms = gpt.timestep_scale.attn_beta[vi].norm().item() / gpt.timestep_scale.attn_beta[vi].numel() ** 0.5 + mb_rms = gpt.timestep_scale.mlp_beta[vi].norm().item() / gpt.timestep_scale.mlp_beta[vi].numel() ** 0.5 + attn_bias_norms.append(f"v{vi}:{ab_rms:.4f}") + mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") + parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") + parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") + if gpt.use_unique_norms: + un_attn: list[str] = [] + un_mlp: list[str] = [] + for ui in range(gpt.unique_attn_gains.size(0)): + an_rms = gpt.unique_attn_gains[ui].norm().item() / gpt.unique_attn_gains[ui].numel() ** 0.5 + un_attn.append(f"u{ui}:{an_rms:.4f}") + mn_rms = gpt.unique_mlp_gains[ui].norm().item() / gpt.unique_mlp_gains[ui].numel() ** 0.5 + un_mlp.append(f"u{ui}:{mn_rms:.4f}") + parts.append("unique_attn_gain_rms:[" + " ".join(un_attn) + "]") + parts.append("unique_mlp_gain_rms:[" + " ".join(un_mlp) + "]") + if gpt.use_depth_embed: + de_norms: list[str] = [] + for vi in range(effective_count): + de_rms = gpt.depth_embeddings[vi].norm().item() / gpt.depth_embeddings[vi].numel() ** 0.5 + de_norms.append(f"v{vi}:{de_rms:.4f}") + parts.append("depth_emb_rms:[" + " ".join(de_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + use_timestep_bias=args.use_timestep_bias, + use_depth_embed=args.use_depth_embed, + use_unique_norms=args.use_unique_norms, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + if base_model.use_unique_norms: + block_named_params.extend([("unique_attn_gains", base_model.unique_attn_gains)]) + block_named_params.extend([("unique_mlp_gains", base_model.unique_mlp_gains)]) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + if base_model.use_depth_embed: + scalar_params.append(base_model.depth_embeddings) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + num_shared = len(base_model.shared_blocks) + eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{num_shared} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + log0(f"depth_embed:{'enabled' if base_model.use_depth_embed else 'disabled'}") + log0(f"unique_norms:{'enabled' if base_model.use_unique_norms else 'disabled'}") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Apr 2 14:33:55 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 33C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 31C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 30C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 31C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 29C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11577392 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:2 coda:1 effective_layers:10 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:20480 +depth_embed:enabled +unique_norms:enabled +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9377 train_time:30ms step_avg:29.61ms +step:2/20000 train_loss:9.2669 train_time:82ms step_avg:41.06ms +step:3/20000 train_loss:8.5810 train_time:131ms step_avg:43.65ms +step:4/20000 train_loss:10.0418 train_time:180ms step_avg:45.04ms +step:5/20000 train_loss:9.5696 train_time:228ms step_avg:45.54ms +step:6/20000 train_loss:9.1372 train_time:275ms step_avg:45.91ms +step:7/20000 train_loss:7.5269 train_time:323ms step_avg:46.17ms +step:8/20000 train_loss:6.6249 train_time:371ms step_avg:46.38ms +step:9/20000 train_loss:6.0913 train_time:419ms step_avg:46.52ms +step:10/20000 train_loss:5.7263 train_time:466ms step_avg:46.63ms +step:200/20000 train_loss:2.7949 train_time:9633ms step_avg:48.16ms +step:200 shared0_alpha:mean=0.472,std=0.047 shared1_alpha:mean=0.483,std=0.043 shared2_alpha:mean=0.490,std=0.043 shared3_alpha:mean=0.509,std=0.043 eff_mlp_scale:[v0:40.6555 v1:28.7900 v2:28.5995 v3:29.9926 v4:32.8930 v5:32.6684 v6:30.7882 v7:32.7055 v8:36.7439 v9:59.1495] eff_attn_scale:[v0:14.9938 v1:10.1264 v2:10.8946 v3:10.9638 v4:11.1103 v5:10.4157 v6:10.7679 v7:11.3539 v8:11.1957 v9:15.7552] eff_attn_bias:[v0:0.1188 v1:0.1091 v2:0.1119 v3:0.1126 v4:0.1229 v5:0.1271 v6:0.1229 v7:0.1153 v8:0.1119 v9:0.1112] eff_mlp_bias:[v0:0.1022 v1:0.1036 v2:0.1050 v3:0.1036 v4:0.1153 v5:0.1098 v6:0.1063 v7:0.0974 v8:0.1119 v9:0.1892] unique_attn_gain_rms:[u0:0.8351 u1:0.8589 u2:0.8332 u3:0.8520 u4:0.8881 u5:0.8958 u6:0.8891 u7:0.8850] unique_mlp_gain_rms:[u0:1.0965 u1:1.0875 u2:1.1017 u3:1.0979 u4:1.1109 u5:1.1090 u6:1.1194 u7:1.1381] depth_emb_rms:[v0:0.1108 v1:0.1033 v2:0.1050 v3:0.1060 v4:0.1045 v5:0.1169 v6:0.1116 v7:0.1084 v8:0.0991 v9:0.1142] +step:200/20000 val_loss:2.7822 val_bpb:1.6478 train_time:9685ms step_avg:48.42ms +step:400/20000 train_loss:2.3701 train_time:19343ms step_avg:48.36ms +step:400 shared0_alpha:mean=0.488,std=0.056 shared1_alpha:mean=0.500,std=0.053 shared2_alpha:mean=0.514,std=0.052 shared3_alpha:mean=0.530,std=0.049 eff_mlp_scale:[v0:49.7893 v1:35.9976 v2:38.5890 v3:40.3292 v4:40.7557 v5:44.8666 v6:41.1389 v7:41.0157 v8:41.6228 v9:75.0628] eff_attn_scale:[v0:6.6436 v1:6.6028 v2:6.9594 v3:7.1632 v4:7.1316 v5:7.0236 v6:6.7975 v7:6.7074 v8:6.4374 v9:9.4545] eff_attn_bias:[v0:0.1402 v1:0.1422 v2:0.1436 v3:0.1512 v4:0.1685 v5:0.1609 v6:0.1581 v7:0.1367 v8:0.1250 v9:0.1195] eff_mlp_bias:[v0:0.1492 v1:0.1277 v2:0.1257 v3:0.1208 v4:0.1409 v5:0.1367 v6:0.1326 v7:0.1181 v8:0.1312 v9:0.2514] unique_attn_gain_rms:[u0:0.7840 u1:0.8050 u2:0.7877 u3:0.7950 u4:0.8463 u5:0.8510 u6:0.8306 u7:0.8175] unique_mlp_gain_rms:[u0:1.1658 u1:1.1520 u2:1.1600 u3:1.1521 u4:1.1640 u5:1.1595 u6:1.1638 u7:1.1902] depth_emb_rms:[v0:0.1747 v1:0.1529 v2:0.1311 v3:0.1290 v4:0.1230 v5:0.1445 v6:0.1412 v7:0.1370 v8:0.1211 v9:0.1349] +step:400/20000 val_loss:2.5765 val_bpb:1.5260 train_time:19365ms step_avg:48.41ms +step:600/20000 train_loss:2.5806 train_time:29043ms step_avg:48.41ms +step:600 shared0_alpha:mean=0.497,std=0.063 shared1_alpha:mean=0.510,std=0.059 shared2_alpha:mean=0.530,std=0.055 shared3_alpha:mean=0.545,std=0.052 eff_mlp_scale:[v0:55.9643 v1:41.6480 v2:45.5967 v3:47.7033 v4:45.9031 v5:52.7667 v6:47.8030 v7:46.0520 v8:45.3566 v9:91.0404] eff_attn_scale:[v0:3.2498 v1:4.3918 v2:4.7713 v3:4.9939 v4:5.1825 v5:4.8488 v6:4.6182 v7:4.5048 v8:4.3988 v9:6.4050] eff_attn_bias:[v0:0.1664 v1:0.1851 v2:0.1823 v3:0.1906 v4:0.2113 v5:0.1975 v6:0.1947 v7:0.1602 v8:0.1395 v9:0.1443] eff_mlp_bias:[v0:0.2141 v1:0.1595 v2:0.1554 v3:0.1464 v4:0.1754 v5:0.1733 v6:0.1643 v7:0.1436 v8:0.1574 v9:0.2707] unique_attn_gain_rms:[u0:0.7372 u1:0.7642 u2:0.7505 u3:0.7525 u4:0.8110 u5:0.8147 u6:0.7910 u7:0.7707] unique_mlp_gain_rms:[u0:1.2517 u1:1.2257 u2:1.2288 u3:1.2177 u4:1.2242 u5:1.2177 u6:1.2142 u7:1.2453] depth_emb_rms:[v0:0.2596 v1:0.2200 v2:0.1655 v3:0.1615 v4:0.1494 v5:0.1823 v6:0.1813 v7:0.1709 v8:0.1495 v9:0.1632] +step:600/20000 val_loss:2.4849 val_bpb:1.4717 train_time:29062ms step_avg:48.44ms +step:800/20000 train_loss:2.3424 train_time:38763ms step_avg:48.45ms +step:800 shared0_alpha:mean=0.503,std=0.069 shared1_alpha:mean=0.516,std=0.063 shared2_alpha:mean=0.541,std=0.057 shared3_alpha:mean=0.553,std=0.054 eff_mlp_scale:[v0:62.3674 v1:46.1882 v2:51.1317 v3:52.9606 v4:49.4821 v5:59.3279 v6:53.0685 v7:50.2742 v8:48.7266 v9:104.1231] eff_attn_scale:[v0:2.1113 v1:3.4616 v2:3.9465 v3:4.1283 v4:4.3100 v5:3.8778 v6:3.6804 v7:3.5258 v8:3.5546 v9:4.8749] eff_attn_bias:[v0:0.1837 v1:0.2237 v2:0.2168 v3:0.2279 v4:0.2514 v5:0.2375 v6:0.2279 v7:0.1851 v8:0.1568 v9:0.1795] eff_mlp_bias:[v0:0.2693 v1:0.1892 v2:0.1878 v3:0.1719 v4:0.2085 v5:0.2099 v6:0.1933 v7:0.1685 v8:0.1851 v9:0.2817] unique_attn_gain_rms:[u0:0.7073 u1:0.7441 u2:0.7282 u3:0.7235 u4:0.7819 u5:0.7842 u6:0.7589 u7:0.7346] unique_mlp_gain_rms:[u0:1.3320 u1:1.3017 u2:1.2900 u3:1.2814 u4:1.2846 u5:1.2779 u6:1.2662 u7:1.3038] depth_emb_rms:[v0:0.3300 v1:0.2791 v2:0.1980 v3:0.1959 v4:0.1765 v5:0.2191 v6:0.2203 v7:0.2031 v8:0.1772 v9:0.1925] +step:800/20000 val_loss:2.4287 val_bpb:1.4384 train_time:38782ms step_avg:48.48ms +step:1000/20000 train_loss:2.4233 train_time:48498ms step_avg:48.50ms +step:1000 shared0_alpha:mean=0.505,std=0.073 shared1_alpha:mean=0.519,std=0.065 shared2_alpha:mean=0.549,std=0.059 shared3_alpha:mean=0.558,std=0.056 eff_mlp_scale:[v0:67.6645 v1:50.3798 v2:56.0183 v3:57.6805 v4:53.1302 v5:64.6852 v6:56.8243 v7:53.7025 v8:52.3489 v9:114.9404] eff_attn_scale:[v0:1.6239 v1:3.1796 v2:3.6790 v3:3.7726 v4:4.0154 v5:3.5078 v6:3.2841 v7:3.1681 v8:3.2376 v9:4.0823] eff_attn_bias:[v0:0.1989 v1:0.2596 v2:0.2472 v3:0.2596 v4:0.2831 v5:0.2679 v6:0.2583 v7:0.2044 v8:0.1719 v9:0.2154] eff_mlp_bias:[v0:0.3176 v1:0.2141 v2:0.2182 v3:0.1920 v4:0.2348 v5:0.2403 v6:0.2182 v7:0.1892 v8:0.2085 v9:0.2900] unique_attn_gain_rms:[u0:0.6926 u1:0.7381 u2:0.7177 u3:0.7053 u4:0.7678 u5:0.7612 u6:0.7365 u7:0.7061] unique_mlp_gain_rms:[u0:1.4107 u1:1.3727 u2:1.3516 u3:1.3438 u4:1.3468 u5:1.3384 u6:1.3218 u7:1.3647] depth_emb_rms:[v0:0.3876 v1:0.3306 v2:0.2253 v3:0.2274 v4:0.1993 v5:0.2488 v6:0.2547 v7:0.2307 v8:0.1992 v9:0.2184] +step:1000/20000 val_loss:2.3884 val_bpb:1.4146 train_time:48518ms step_avg:48.52ms +step:1200/20000 train_loss:2.4401 train_time:58251ms step_avg:48.54ms +step:1200 shared0_alpha:mean=0.506,std=0.078 shared1_alpha:mean=0.522,std=0.067 shared2_alpha:mean=0.555,std=0.060 shared3_alpha:mean=0.561,std=0.058 eff_mlp_scale:[v0:73.0825 v1:53.6898 v2:60.2730 v3:61.3909 v4:55.8912 v5:68.8769 v6:60.6887 v7:56.8889 v8:55.4891 v9:124.6020] eff_attn_scale:[v0:1.3227 v1:3.0337 v2:3.7061 v3:3.5933 v4:3.8835 v5:3.3510 v6:3.1942 v7:2.9146 v8:3.0864 v9:3.5750] eff_attn_bias:[v0:0.2113 v1:0.2886 v2:0.2748 v3:0.2831 v4:0.3080 v5:0.2969 v6:0.2859 v7:0.2210 v8:0.1878 v9:0.2500] eff_mlp_bias:[v0:0.3536 v1:0.2362 v2:0.2458 v3:0.2099 v4:0.2569 v5:0.2679 v6:0.2417 v7:0.2044 v8:0.2293 v9:0.2983] unique_attn_gain_rms:[u0:0.6829 u1:0.7375 u2:0.7120 u3:0.6939 u4:0.7578 u5:0.7443 u6:0.7150 u7:0.6819] unique_mlp_gain_rms:[u0:1.4868 u1:1.4417 u2:1.4099 u3:1.4064 u4:1.4091 u5:1.3987 u6:1.3748 u7:1.4260] depth_emb_rms:[v0:0.4390 v1:0.3692 v2:0.2513 v3:0.2583 v4:0.2196 v5:0.2743 v6:0.2848 v7:0.2573 v8:0.2168 v9:0.2425] +step:1200/20000 val_loss:2.3585 val_bpb:1.3968 train_time:58269ms step_avg:48.56ms +step:1400/20000 train_loss:2.4847 train_time:68001ms step_avg:48.57ms +step:1400 shared0_alpha:mean=0.506,std=0.080 shared1_alpha:mean=0.524,std=0.069 shared2_alpha:mean=0.560,std=0.062 shared3_alpha:mean=0.564,std=0.060 eff_mlp_scale:[v0:77.7293 v1:56.6593 v2:63.4124 v3:64.2843 v4:58.4440 v5:72.9104 v6:63.4124 v7:60.0827 v8:58.4440 v9:133.4049] eff_attn_scale:[v0:1.1481 v1:3.0694 v2:4.0140 v3:3.6986 v4:3.9936 v5:3.3883 v6:3.2195 v7:2.8989 v8:3.1084 v9:3.2964] eff_attn_bias:[v0:0.2279 v1:0.3163 v2:0.3011 v3:0.3052 v4:0.3356 v5:0.3232 v6:0.3121 v7:0.2348 v8:0.2016 v9:0.2831] eff_mlp_bias:[v0:0.3867 v1:0.2569 v2:0.2721 v3:0.2279 v4:0.2762 v5:0.2914 v6:0.2652 v7:0.2168 v8:0.2500 v9:0.3094] unique_attn_gain_rms:[u0:0.6821 u1:0.7515 u2:0.7152 u3:0.6879 u4:0.7480 u5:0.7269 u6:0.6968 u7:0.6641] unique_mlp_gain_rms:[u0:1.5577 u1:1.5094 u2:1.4660 u3:1.4670 u4:1.4703 u5:1.4557 u6:1.4269 u7:1.4835] depth_emb_rms:[v0:0.4889 v1:0.4055 v2:0.2738 v3:0.2873 v4:0.2382 v5:0.2958 v6:0.3112 v7:0.2834 v8:0.2312 v9:0.2649] +step:1400/20000 val_loss:2.3375 val_bpb:1.3844 train_time:68020ms step_avg:48.59ms +step:1600/20000 train_loss:2.1526 train_time:77747ms step_avg:48.59ms +step:1600 shared0_alpha:mean=0.506,std=0.082 shared1_alpha:mean=0.524,std=0.071 shared2_alpha:mean=0.565,std=0.063 shared3_alpha:mean=0.566,std=0.061 eff_mlp_scale:[v0:81.8085 v1:59.2769 v2:66.3715 v3:67.1901 v4:60.4867 v5:76.3414 v6:65.9377 v7:62.4825 v8:61.3268 v9:139.9791] eff_attn_scale:[v0:1.0341 v1:3.0735 v2:4.5054 v3:4.0216 v4:4.2100 v5:3.5823 v6:3.4674 v7:3.0058 v8:3.1996 v9:3.0448] eff_attn_bias:[v0:0.2389 v1:0.3439 v2:0.3218 v3:0.3204 v4:0.3536 v5:0.3439 v6:0.3342 v7:0.2458 v8:0.2154 v9:0.3135] eff_mlp_bias:[v0:0.4088 v1:0.2721 v2:0.2983 v3:0.2389 v4:0.2900 v5:0.3107 v6:0.2873 v7:0.2279 v8:0.2707 v9:0.3190] unique_attn_gain_rms:[u0:0.6828 u1:0.7674 u2:0.7285 u3:0.6891 u4:0.7484 u5:0.7198 u6:0.6819 u7:0.6469] unique_mlp_gain_rms:[u0:1.6329 u1:1.5699 u2:1.5180 u3:1.5265 u4:1.5306 u5:1.5147 u6:1.4793 u7:1.5465] depth_emb_rms:[v0:0.5321 v1:0.4388 v2:0.2936 v3:0.3160 v4:0.2529 v5:0.3119 v6:0.3335 v7:0.3076 v8:0.2446 v9:0.2866] +step:1600/20000 val_loss:2.3229 val_bpb:1.3758 train_time:77767ms step_avg:48.60ms +step:1800/20000 train_loss:2.2607 train_time:87494ms step_avg:48.61ms +step:1800 shared0_alpha:mean=0.506,std=0.085 shared1_alpha:mean=0.523,std=0.072 shared2_alpha:mean=0.570,std=0.064 shared3_alpha:mean=0.568,std=0.062 eff_mlp_scale:[v0:85.8113 v1:61.1269 v2:68.2712 v3:69.5868 v4:62.9207 v5:78.9176 v6:68.2712 v7:64.3678 v8:64.2048 v9:147.2664] eff_attn_scale:[v0:0.9300 v1:3.1121 v2:5.1135 v3:4.4506 v4:4.6175 v5:3.8902 v6:3.8881 v7:3.2725 v8:3.4686 v9:2.8574] eff_attn_bias:[v0:0.2569 v1:0.3674 v2:0.3439 v3:0.3356 v4:0.3729 v5:0.3646 v6:0.3591 v7:0.2583 v8:0.2306 v9:0.3466] eff_mlp_bias:[v0:0.4337 v1:0.2859 v2:0.3218 v3:0.2500 v4:0.3011 v5:0.3287 v6:0.3080 v7:0.2375 v8:0.2886 v9:0.3315] unique_attn_gain_rms:[u0:0.6962 u1:0.7913 u2:0.7460 u3:0.6952 u4:0.7552 u5:0.7178 u6:0.6764 u7:0.6373] unique_mlp_gain_rms:[u0:1.7035 u1:1.6324 u2:1.5689 u3:1.5823 u4:1.5887 u5:1.5695 u6:1.5311 u7:1.6039] depth_emb_rms:[v0:0.5786 v1:0.4713 v2:0.3103 v3:0.3422 v4:0.2669 v5:0.3267 v6:0.3553 v7:0.3311 v8:0.2571 v9:0.3079] +step:1800/20000 val_loss:2.3075 val_bpb:1.3666 train_time:87513ms step_avg:48.62ms +step:2000/20000 train_loss:2.3073 train_time:97238ms step_avg:48.62ms +step:2000 shared0_alpha:mean=0.504,std=0.087 shared1_alpha:mean=0.522,std=0.073 shared2_alpha:mean=0.574,std=0.065 shared3_alpha:mean=0.570,std=0.064 eff_mlp_scale:[v0:89.1827 v1:63.5627 v2:70.5116 v3:71.9880 v4:64.2242 v5:82.1211 v6:70.0653 v7:66.6882 v8:66.3940 v9:153.4650] eff_attn_scale:[v0:0.8444 v1:3.2641 v2:5.7345 v3:4.9677 v4:5.1339 v5:4.3375 v6:4.3822 v7:3.6909 v8:3.8914 v9:2.7098] eff_attn_bias:[v0:0.2762 v1:0.3950 v2:0.3674 v3:0.3466 v4:0.3922 v5:0.3812 v6:0.3784 v7:0.2665 v8:0.2458 v9:0.3757] eff_mlp_bias:[v0:0.4585 v1:0.2983 v2:0.3453 v3:0.2596 v4:0.3094 v5:0.3439 v6:0.3218 v7:0.2444 v8:0.3025 v9:0.3453] unique_attn_gain_rms:[u0:0.7203 u1:0.8219 u2:0.7660 u3:0.7045 u4:0.7731 u5:0.7250 u6:0.6763 u7:0.6345] unique_mlp_gain_rms:[u0:1.7764 u1:1.6954 u2:1.6238 u3:1.6443 u4:1.6519 u5:1.6274 u6:1.5840 u7:1.6624] depth_emb_rms:[v0:0.6278 v1:0.5055 v2:0.3269 v3:0.3684 v4:0.2779 v5:0.3389 v6:0.3736 v7:0.3492 v8:0.2670 v9:0.3259] +step:2000/20000 val_loss:2.2932 val_bpb:1.3582 train_time:97256ms step_avg:48.63ms +step:2200/20000 train_loss:2.1337 train_time:106972ms step_avg:48.62ms +step:2200 shared0_alpha:mean=0.504,std=0.088 shared1_alpha:mean=0.522,std=0.073 shared2_alpha:mean=0.578,std=0.066 shared3_alpha:mean=0.572,std=0.065 eff_mlp_scale:[v0:92.2659 v1:65.3028 v2:72.1104 v3:73.8636 v4:66.1712 v5:84.5648 v6:71.6597 v7:68.4917 v8:68.3769 v9:159.9031] eff_attn_scale:[v0:0.7969 v1:3.5001 v2:6.4877 v3:5.5247 v4:5.8189 v5:4.7896 v6:4.9865 v7:4.1435 v8:4.4144 v9:2.5670] eff_attn_bias:[v0:0.2955 v1:0.4116 v2:0.3839 v3:0.3563 v4:0.4060 v5:0.3977 v6:0.3950 v7:0.2776 v8:0.2596 v9:0.4033] eff_mlp_bias:[v0:0.4751 v1:0.3094 v2:0.3618 v3:0.2665 v4:0.3190 v5:0.3563 v6:0.3397 v7:0.2527 v8:0.3190 v9:0.3618] unique_attn_gain_rms:[u0:0.7480 u1:0.8500 u2:0.7922 u3:0.7182 u4:0.7942 u5:0.7336 u6:0.6833 u7:0.6363] unique_mlp_gain_rms:[u0:1.8424 u1:1.7562 u2:1.6765 u3:1.7036 u4:1.7080 u5:1.6813 u6:1.6364 u7:1.7192] depth_emb_rms:[v0:0.6713 v1:0.5351 v2:0.3402 v3:0.3902 v4:0.2885 v5:0.3516 v6:0.3885 v7:0.3691 v8:0.2771 v9:0.3437] +step:2200/20000 val_loss:2.2855 val_bpb:1.3536 train_time:106994ms step_avg:48.63ms +step:2400/20000 train_loss:2.2615 train_time:116707ms step_avg:48.63ms +step:2400 shared0_alpha:mean=0.503,std=0.089 shared1_alpha:mean=0.520,std=0.074 shared2_alpha:mean=0.580,std=0.067 shared3_alpha:mean=0.573,std=0.066 eff_mlp_scale:[v0:95.6171 v1:67.0737 v2:73.8010 v3:75.4959 v4:67.8739 v5:86.5774 v6:72.8899 v7:70.0710 v8:70.5531 v9:166.2090] eff_attn_scale:[v0:0.7542 v1:3.8024 v2:7.1108 v3:6.0789 v4:6.3597 v5:5.3040 v6:5.5024 v7:4.5982 v8:4.8555 v9:2.4572] eff_attn_bias:[v0:0.3094 v1:0.4309 v2:0.4033 v3:0.3674 v4:0.4198 v5:0.4143 v6:0.4116 v7:0.2859 v8:0.2721 v9:0.4309] eff_mlp_bias:[v0:0.4917 v1:0.3135 v2:0.3812 v3:0.2762 v4:0.3246 v5:0.3701 v6:0.3536 v7:0.2583 v8:0.3342 v9:0.3784] unique_attn_gain_rms:[u0:0.7838 u1:0.8750 u2:0.8174 u3:0.7327 u4:0.8160 u5:0.7458 u6:0.6888 u7:0.6447] unique_mlp_gain_rms:[u0:1.9095 u1:1.8134 u2:1.7255 u3:1.7605 u4:1.7640 u5:1.7353 u6:1.6847 u7:1.7774] depth_emb_rms:[v0:0.7178 v1:0.5672 v2:0.3501 v3:0.4117 v4:0.2992 v5:0.3625 v6:0.4056 v7:0.3855 v8:0.2862 v9:0.3617] +step:2400/20000 val_loss:2.2758 val_bpb:1.3479 train_time:116727ms step_avg:48.64ms +step:2600/20000 train_loss:2.4696 train_time:126437ms step_avg:48.63ms +step:2600 shared0_alpha:mean=0.502,std=0.091 shared1_alpha:mean=0.520,std=0.075 shared2_alpha:mean=0.583,std=0.068 shared3_alpha:mean=0.574,std=0.067 eff_mlp_scale:[v0:99.0970 v1:69.3206 v2:75.9790 v3:77.1177 v4:69.8343 v5:89.0577 v6:74.5975 v7:71.6419 v8:73.0086 v9:170.6139] eff_attn_scale:[v0:0.7391 v1:4.0757 v2:7.7027 v3:6.4745 v4:6.9427 v5:5.7614 v6:5.9976 v7:4.9907 v8:5.3384 v9:2.3774] eff_attn_bias:[v0:0.3232 v1:0.4502 v2:0.4198 v3:0.3784 v4:0.4364 v5:0.4281 v6:0.4281 v7:0.2928 v8:0.2845 v9:0.4558] eff_mlp_bias:[v0:0.5082 v1:0.3246 v2:0.3977 v3:0.2845 v4:0.3328 v5:0.3812 v6:0.3646 v7:0.2638 v8:0.3494 v9:0.3922] unique_attn_gain_rms:[u0:0.8201 u1:0.9021 u2:0.8416 u3:0.7505 u4:0.8374 u5:0.7591 u6:0.6981 u7:0.6530] unique_mlp_gain_rms:[u0:1.9754 u1:1.8705 u2:1.7783 u3:1.8148 u4:1.8219 u5:1.7886 u6:1.7341 u7:1.8343] depth_emb_rms:[v0:0.7706 v1:0.5974 v2:0.3659 v3:0.4335 v4:0.3106 v5:0.3739 v6:0.4192 v7:0.4005 v8:0.2954 v9:0.3796] +step:2600/20000 val_loss:2.2836 val_bpb:1.3525 train_time:126455ms step_avg:48.64ms +step:2800/20000 train_loss:2.2907 train_time:136163ms step_avg:48.63ms +step:2800 shared0_alpha:mean=0.501,std=0.091 shared1_alpha:mean=0.519,std=0.074 shared2_alpha:mean=0.584,std=0.068 shared3_alpha:mean=0.575,std=0.068 eff_mlp_scale:[v0:102.3190 v1:70.8252 v2:77.2924 v3:78.7641 v4:71.5285 v5:90.7145 v6:75.4411 v7:72.7762 v8:74.7381 v9:175.6511] eff_attn_scale:[v0:0.6962 v1:4.3764 v2:8.2651 v3:7.0220 v4:7.5496 v5:6.2632 v6:6.5202 v7:5.5213 v8:5.8655 v9:2.2464] eff_attn_bias:[v0:0.3425 v1:0.4696 v2:0.4309 v3:0.3867 v4:0.4475 v5:0.4419 v6:0.4419 v7:0.3011 v8:0.2955 v9:0.4806] eff_mlp_bias:[v0:0.5220 v1:0.3287 v2:0.4143 v3:0.2914 v4:0.3384 v5:0.3895 v6:0.3812 v7:0.2693 v8:0.3646 v9:0.4088] unique_attn_gain_rms:[u0:0.8624 u1:0.9276 u2:0.8638 u3:0.7667 u4:0.8656 u5:0.7732 u6:0.7120 u7:0.6646] unique_mlp_gain_rms:[u0:2.0357 u1:1.9338 u2:1.8288 u3:1.8750 u4:1.8798 u5:1.8454 u6:1.7874 u7:1.8915] depth_emb_rms:[v0:0.8186 v1:0.6299 v2:0.3761 v3:0.4521 v4:0.3191 v5:0.3830 v6:0.4327 v7:0.4179 v8:0.3033 v9:0.3985] +step:2800/20000 val_loss:2.2600 val_bpb:1.3385 train_time:136181ms step_avg:48.64ms +step:3000/20000 train_loss:2.2851 train_time:145881ms step_avg:48.63ms +step:3000 shared0_alpha:mean=0.500,std=0.092 shared1_alpha:mean=0.518,std=0.075 shared2_alpha:mean=0.587,std=0.069 shared3_alpha:mean=0.577,std=0.069 eff_mlp_scale:[v0:105.1503 v1:72.5116 v2:78.9522 v3:80.9580 v4:72.9109 v5:92.5992 v6:77.0835 v7:74.4441 v8:76.6261 v9:180.8038] eff_attn_scale:[v0:0.6785 v1:4.4623 v2:8.7825 v3:7.4196 v4:8.0247 v5:6.6266 v6:6.9502 v7:5.8714 v8:6.2281 v9:2.1613] eff_attn_bias:[v0:0.3591 v1:0.4861 v2:0.4475 v3:0.3950 v4:0.4585 v5:0.4585 v6:0.4530 v7:0.3094 v8:0.3066 v9:0.5027] eff_mlp_bias:[v0:0.5386 v1:0.3356 v2:0.4254 v3:0.2983 v4:0.3453 v5:0.3977 v6:0.3895 v7:0.2748 v8:0.3784 v9:0.4254] unique_attn_gain_rms:[u0:0.8906 u1:0.9499 u2:0.8876 u3:0.7804 u4:0.8890 u5:0.7878 u6:0.7223 u7:0.6783] unique_mlp_gain_rms:[u0:2.1016 u1:1.9865 u2:1.8786 u3:1.9287 u4:1.9332 u5:1.8992 u6:1.8380 u7:1.9472] depth_emb_rms:[v0:0.8685 v1:0.6625 v2:0.3868 v3:0.4689 v4:0.3285 v5:0.3923 v6:0.4443 v7:0.4306 v8:0.3105 v9:0.4147] +step:3000/20000 val_loss:2.2542 val_bpb:1.3351 train_time:145901ms step_avg:48.63ms +step:3200/20000 train_loss:2.2487 train_time:155597ms step_avg:48.62ms +step:3200 shared0_alpha:mean=0.498,std=0.093 shared1_alpha:mean=0.517,std=0.076 shared2_alpha:mean=0.589,std=0.070 shared3_alpha:mean=0.578,std=0.069 eff_mlp_scale:[v0:107.8831 v1:73.6591 v2:80.4970 v3:82.6649 v4:74.3805 v5:94.4220 v6:78.6140 v7:76.0893 v8:78.6174 v9:184.9304] eff_attn_scale:[v0:0.6632 v1:4.4757 v2:9.0789 v3:7.7735 v4:8.3808 v5:6.8483 v6:7.2760 v7:6.1889 v8:6.6245 v9:2.1006] eff_attn_bias:[v0:0.3812 v1:0.4999 v2:0.4613 v3:0.4060 v4:0.4723 v5:0.4696 v6:0.4668 v7:0.3163 v8:0.3163 v9:0.5276] eff_mlp_bias:[v0:0.5497 v1:0.3397 v2:0.4392 v3:0.3038 v4:0.3508 v5:0.4060 v6:0.4005 v7:0.2804 v8:0.3895 v9:0.4419] unique_attn_gain_rms:[u0:0.9182 u1:0.9735 u2:0.9079 u3:0.7999 u4:0.9136 u5:0.8006 u6:0.7339 u7:0.6934] unique_mlp_gain_rms:[u0:2.1645 u1:2.0444 u2:1.9277 u3:1.9819 u4:1.9869 u5:1.9490 u6:1.8872 u7:2.0004] depth_emb_rms:[v0:0.9214 v1:0.6932 v2:0.3956 v3:0.4841 v4:0.3366 v5:0.4011 v6:0.4555 v7:0.4426 v8:0.3181 v9:0.4298] +step:3200/20000 val_loss:2.2483 val_bpb:1.3316 train_time:155615ms step_avg:48.63ms +step:3400/20000 train_loss:2.2170 train_time:165305ms step_avg:48.62ms +step:3400 shared0_alpha:mean=0.497,std=0.094 shared1_alpha:mean=0.516,std=0.076 shared2_alpha:mean=0.591,std=0.070 shared3_alpha:mean=0.578,std=0.070 eff_mlp_scale:[v0:110.7713 v1:75.7765 v2:81.4650 v3:84.3716 v4:76.2595 v5:96.2162 v6:79.5705 v7:77.2616 v8:80.5491 v9:189.8891] eff_attn_scale:[v0:0.6482 v1:4.6367 v2:9.4709 v3:8.1745 v4:8.8643 v5:7.1883 v6:7.6622 v7:6.6072 v8:6.9648 v9:2.0507] eff_attn_bias:[v0:0.4033 v1:0.5193 v2:0.4751 v3:0.4171 v4:0.4834 v5:0.4861 v6:0.4806 v7:0.3246 v8:0.3259 v9:0.5497] eff_mlp_bias:[v0:0.5662 v1:0.3453 v2:0.4502 v3:0.3094 v4:0.3563 v5:0.4116 v6:0.4116 v7:0.2859 v8:0.4033 v9:0.4585] unique_attn_gain_rms:[u0:0.9440 u1:0.9960 u2:0.9305 u3:0.8124 u4:0.9408 u5:0.8159 u6:0.7504 u7:0.7099] unique_mlp_gain_rms:[u0:2.2272 u1:2.1000 u2:1.9789 u3:2.0370 u4:2.0420 u5:2.0011 u6:1.9378 u7:2.0569] depth_emb_rms:[v0:0.9751 v1:0.7289 v2:0.4056 v3:0.5021 v4:0.3446 v5:0.4097 v6:0.4659 v7:0.4560 v8:0.3255 v9:0.4452] +step:3400/20000 val_loss:2.2467 val_bpb:1.3306 train_time:165323ms step_avg:48.62ms +step:3600/20000 train_loss:2.1846 train_time:175011ms step_avg:48.61ms +step:3600 shared0_alpha:mean=0.496,std=0.095 shared1_alpha:mean=0.515,std=0.076 shared2_alpha:mean=0.593,std=0.071 shared3_alpha:mean=0.579,std=0.071 eff_mlp_scale:[v0:113.5676 v1:76.3604 v2:83.0389 v3:85.4906 v4:77.5941 v5:97.9623 v6:80.6528 v7:78.8042 v8:82.4136 v9:194.0409] eff_attn_scale:[v0:0.6422 v1:4.7889 v2:9.6982 v3:8.5331 v4:9.3037 v5:7.5615 v6:7.9592 v7:6.8704 v8:7.2868 v9:2.0110] eff_attn_bias:[v0:0.4226 v1:0.5331 v2:0.4861 v3:0.4254 v4:0.4944 v5:0.4972 v6:0.4944 v7:0.3315 v8:0.3342 v9:0.5690] eff_mlp_bias:[v0:0.5745 v1:0.3522 v2:0.4640 v3:0.3149 v4:0.3646 v5:0.4226 v6:0.4198 v7:0.2886 v8:0.4143 v9:0.4778] unique_attn_gain_rms:[u0:0.9686 u1:1.0176 u2:0.9511 u3:0.8295 u4:0.9637 u5:0.8284 u6:0.7637 u7:0.7255] unique_mlp_gain_rms:[u0:2.2849 u1:2.1547 u2:2.0257 u3:2.0939 u4:2.0968 u5:2.0519 u6:1.9835 u7:2.1101] depth_emb_rms:[v0:1.0286 v1:0.7656 v2:0.4164 v3:0.5176 v4:0.3537 v5:0.4197 v6:0.4769 v7:0.4679 v8:0.3313 v9:0.4608] +step:3600/20000 val_loss:2.2401 val_bpb:1.3267 train_time:175031ms step_avg:48.62ms +step:3800/20000 train_loss:2.2837 train_time:184720ms step_avg:48.61ms +step:3800 shared0_alpha:mean=0.495,std=0.095 shared1_alpha:mean=0.514,std=0.077 shared2_alpha:mean=0.595,std=0.072 shared3_alpha:mean=0.580,std=0.072 eff_mlp_scale:[v0:116.6345 v1:78.5741 v2:84.5674 v3:87.1908 v4:79.1364 v5:99.8651 v6:81.6844 v7:79.9651 v8:84.5099 v9:198.0203] eff_attn_scale:[v0:0.6359 v1:4.8469 v2:10.0927 v3:8.7294 v4:9.6024 v5:7.8230 v6:8.3197 v7:7.1364 v8:7.5826 v9:1.9771] eff_attn_bias:[v0:0.4530 v1:0.5497 v2:0.4944 v3:0.4337 v4:0.5027 v5:0.5082 v6:0.5055 v7:0.3384 v8:0.3411 v9:0.5911] eff_mlp_bias:[v0:0.5828 v1:0.3591 v2:0.4723 v3:0.3218 v4:0.3701 v5:0.4281 v6:0.4281 v7:0.2942 v8:0.4281 v9:0.4917] unique_attn_gain_rms:[u0:0.9972 u1:1.0399 u2:0.9678 u3:0.8414 u4:0.9875 u5:0.8428 u6:0.7766 u7:0.7382] unique_mlp_gain_rms:[u0:2.3435 u1:2.2058 u2:2.0694 u3:2.1460 u4:2.1488 u5:2.1020 u6:2.0312 u7:2.1639] depth_emb_rms:[v0:1.0868 v1:0.8023 v2:0.4272 v3:0.5313 v4:0.3615 v5:0.4289 v6:0.4869 v7:0.4811 v8:0.3380 v9:0.4768] +step:3800/20000 val_loss:2.2360 val_bpb:1.3243 train_time:184740ms step_avg:48.62ms +step:4000/20000 train_loss:2.2226 train_time:194423ms step_avg:48.61ms +step:4000 shared0_alpha:mean=0.493,std=0.096 shared1_alpha:mean=0.514,std=0.077 shared2_alpha:mean=0.597,std=0.072 shared3_alpha:mean=0.581,std=0.073 eff_mlp_scale:[v0:118.8960 v1:79.7477 v2:85.8312 v3:88.9888 v4:81.1315 v5:101.7294 v6:83.4066 v7:81.6946 v8:86.5732 v9:202.1544] eff_attn_scale:[v0:0.6315 v1:4.8965 v2:10.3469 v3:8.9839 v4:9.8916 v5:8.0749 v6:8.5880 v7:7.4004 v8:7.8393 v9:1.9145] eff_attn_bias:[v0:0.4751 v1:0.5607 v2:0.5055 v3:0.4419 v4:0.5138 v5:0.5248 v6:0.5165 v7:0.3466 v8:0.3480 v9:0.6132] eff_mlp_bias:[v0:0.5911 v1:0.3618 v2:0.4806 v3:0.3273 v4:0.3757 v5:0.4309 v6:0.4392 v7:0.3011 v8:0.4392 v9:0.5055] unique_attn_gain_rms:[u0:1.0195 u1:1.0554 u2:0.9898 u3:0.8545 u4:1.0102 u5:0.8581 u6:0.7888 u7:0.7538] unique_mlp_gain_rms:[u0:2.3998 u1:2.2614 u2:2.1152 u3:2.1976 u4:2.1990 u5:2.1497 u6:2.0795 u7:2.2146] depth_emb_rms:[v0:1.1432 v1:0.8398 v2:0.4363 v3:0.5446 v4:0.3695 v5:0.4371 v6:0.4967 v7:0.4943 v8:0.3458 v9:0.4925] +step:4000/20000 val_loss:2.2303 val_bpb:1.3209 train_time:194443ms step_avg:48.61ms +step:4200/20000 train_loss:2.2354 train_time:204194ms step_avg:48.62ms +step:4200 shared0_alpha:mean=0.492,std=0.096 shared1_alpha:mean=0.513,std=0.077 shared2_alpha:mean=0.599,std=0.073 shared3_alpha:mean=0.582,std=0.075 eff_mlp_scale:[v0:121.7812 v1:80.7171 v2:87.2617 v3:90.3511 v4:82.5943 v5:103.3385 v6:84.3368 v7:83.4765 v8:88.6012 v9:206.8924] eff_attn_scale:[v0:0.6261 v1:5.0086 v2:10.6083 v3:9.2195 v4:10.1915 v5:8.1749 v6:8.7937 v7:7.6175 v8:8.0711 v9:1.8665] eff_attn_bias:[v0:0.4999 v1:0.5773 v2:0.5165 v3:0.4502 v4:0.5220 v5:0.5359 v6:0.5248 v7:0.3522 v8:0.3563 v9:0.6325] eff_mlp_bias:[v0:0.5966 v1:0.3701 v2:0.4889 v3:0.3328 v4:0.3839 v5:0.4392 v6:0.4447 v7:0.3038 v8:0.4502 v9:0.5220] unique_attn_gain_rms:[u0:1.0428 u1:1.0777 u2:1.0060 u3:0.8708 u4:1.0280 u5:0.8702 u6:0.8026 u7:0.7695] unique_mlp_gain_rms:[u0:2.4577 u1:2.3096 u2:2.1622 u3:2.2486 u4:2.2506 u5:2.1989 u6:2.1253 u7:2.2676] depth_emb_rms:[v0:1.2031 v1:0.8785 v2:0.4475 v3:0.5581 v4:0.3780 v5:0.4461 v6:0.5063 v7:0.5039 v8:0.3498 v9:0.5064] +step:4200/20000 val_loss:2.2263 val_bpb:1.3185 train_time:204211ms step_avg:48.62ms +step:4400/20000 train_loss:2.1719 train_time:213990ms step_avg:48.63ms +step:4400 shared0_alpha:mean=0.491,std=0.097 shared1_alpha:mean=0.512,std=0.077 shared2_alpha:mean=0.600,std=0.074 shared3_alpha:mean=0.582,std=0.075 eff_mlp_scale:[v0:124.6550 v1:82.4633 v2:88.2715 v3:92.7288 v4:84.4679 v5:105.2834 v6:85.8195 v7:84.7948 v8:90.5375 v9:210.8144] eff_attn_scale:[v0:0.6124 v1:5.0562 v2:10.8728 v3:9.3987 v4:10.5625 v5:8.4270 v6:9.0371 v7:7.8101 v8:8.4290 v9:1.8599] eff_attn_bias:[v0:0.5303 v1:0.5911 v2:0.5248 v3:0.4585 v4:0.5331 v5:0.5469 v6:0.5359 v7:0.3591 v8:0.3646 v9:0.6491] eff_mlp_bias:[v0:0.6077 v1:0.3729 v2:0.4972 v3:0.3397 v4:0.3839 v5:0.4447 v6:0.4502 v7:0.3080 v8:0.4613 v9:0.5359] unique_attn_gain_rms:[u0:1.0665 u1:1.0936 u2:1.0235 u3:0.8814 u4:1.0506 u5:0.8843 u6:0.8129 u7:0.7863] unique_mlp_gain_rms:[u0:2.5119 u1:2.3632 u2:2.2078 u3:2.2989 u4:2.3033 u5:2.2477 u6:2.1737 u7:2.3200] depth_emb_rms:[v0:1.2648 v1:0.9164 v2:0.4550 v3:0.5712 v4:0.3866 v5:0.4518 v6:0.5166 v7:0.5145 v8:0.3575 v9:0.5211] +step:4400/20000 val_loss:2.2267 val_bpb:1.3188 train_time:214020ms step_avg:48.64ms +step:4600/20000 train_loss:2.0387 train_time:223703ms step_avg:48.63ms +step:4600 shared0_alpha:mean=0.490,std=0.097 shared1_alpha:mean=0.511,std=0.078 shared2_alpha:mean=0.602,std=0.074 shared3_alpha:mean=0.582,std=0.076 eff_mlp_scale:[v0:127.9106 v1:84.0413 v2:89.8026 v3:93.8920 v4:85.9572 v5:107.0091 v6:86.8421 v7:85.9011 v8:92.6086 v9:213.8765] eff_attn_scale:[v0:0.6265 v1:5.1240 v2:10.9386 v3:9.6818 v4:10.7573 v5:8.6084 v6:9.2339 v7:8.1017 v8:8.5988 v9:1.8293] eff_attn_bias:[v0:0.5580 v1:0.6077 v2:0.5331 v3:0.4668 v4:0.5414 v5:0.5607 v6:0.5441 v7:0.3674 v8:0.3729 v9:0.6684] eff_mlp_bias:[v0:0.6104 v1:0.3784 v2:0.5082 v3:0.3439 v4:0.3895 v5:0.4502 v6:0.4558 v7:0.3135 v8:0.4696 v9:0.5497] unique_attn_gain_rms:[u0:1.0917 u1:1.1091 u2:1.0450 u3:0.8917 u4:1.0690 u5:0.8986 u6:0.8282 u7:0.8015] unique_mlp_gain_rms:[u0:2.5686 u1:2.4128 u2:2.2518 u3:2.3526 u4:2.3554 u5:2.2981 u6:2.2192 u7:2.3709] depth_emb_rms:[v0:1.3252 v1:0.9544 v2:0.4616 v3:0.5846 v4:0.3934 v5:0.4608 v6:0.5256 v7:0.5225 v8:0.3644 v9:0.5360] +step:4600/20000 val_loss:2.2224 val_bpb:1.3162 train_time:223721ms step_avg:48.64ms +step:4800/20000 train_loss:2.3251 train_time:233402ms step_avg:48.63ms +step:4800 shared0_alpha:mean=0.489,std=0.097 shared1_alpha:mean=0.511,std=0.078 shared2_alpha:mean=0.604,std=0.075 shared3_alpha:mean=0.583,std=0.077 eff_mlp_scale:[v0:130.2163 v1:85.2024 v2:90.9613 v3:95.7443 v4:87.6040 v5:108.8697 v6:87.9789 v7:87.1777 v8:94.8611 v9:218.6697] eff_attn_scale:[v0:0.6116 v1:5.1827 v2:11.2136 v3:9.7725 v4:11.1914 v5:8.8847 v6:9.4165 v7:8.2795 v8:8.9170 v9:1.7932] eff_attn_bias:[v0:0.5911 v1:0.6187 v2:0.5469 v3:0.4778 v4:0.5497 v5:0.5718 v6:0.5552 v7:0.3757 v8:0.3812 v9:0.6822] eff_mlp_bias:[v0:0.6187 v1:0.3839 v2:0.5165 v3:0.3508 v4:0.3950 v5:0.4558 v6:0.4613 v7:0.3190 v8:0.4778 v9:0.5635] unique_attn_gain_rms:[u0:1.1143 u1:1.1234 u2:1.0627 u3:0.9081 u4:1.0885 u5:0.9138 u6:0.8411 u7:0.8176] unique_mlp_gain_rms:[u0:2.6222 u1:2.4650 u2:2.2965 u3:2.4040 u4:2.4002 u5:2.3450 u6:2.2632 u7:2.4238] depth_emb_rms:[v0:1.3804 v1:0.9953 v2:0.4723 v3:0.5976 v4:0.4023 v5:0.4674 v6:0.5351 v7:0.5319 v8:0.3713 v9:0.5494] +step:4800/20000 val_loss:2.2192 val_bpb:1.3143 train_time:233422ms step_avg:48.63ms +step:5000/20000 train_loss:2.0931 train_time:243112ms step_avg:48.62ms +step:5000 shared0_alpha:mean=0.489,std=0.098 shared1_alpha:mean=0.510,std=0.078 shared2_alpha:mean=0.605,std=0.075 shared3_alpha:mean=0.583,std=0.078 eff_mlp_scale:[v0:133.3678 v1:86.1280 v2:92.5903 v3:97.0831 v4:89.0071 v5:109.9056 v6:89.0869 v7:88.9505 v8:96.3371 v9:220.4123] eff_attn_scale:[v0:0.6159 v1:5.2205 v2:11.3658 v3:9.9640 v4:11.4157 v5:9.0687 v6:9.5560 v7:8.4866 v8:9.1838 v9:1.7439] eff_attn_bias:[v0:0.6215 v1:0.6325 v2:0.5524 v3:0.4861 v4:0.5607 v5:0.5800 v6:0.5635 v7:0.3812 v8:0.3867 v9:0.7016] eff_mlp_bias:[v0:0.6270 v1:0.3895 v2:0.5248 v3:0.3563 v4:0.3977 v5:0.4613 v6:0.4696 v7:0.3218 v8:0.4861 v9:0.5745] unique_attn_gain_rms:[u0:1.1381 u1:1.1404 u2:1.0801 u3:0.9197 u4:1.1083 u5:0.9216 u6:0.8547 u7:0.8314] unique_mlp_gain_rms:[u0:2.6776 u1:2.5168 u2:2.3405 u3:2.4510 u4:2.4542 u5:2.3913 u6:2.3100 u7:2.4766] depth_emb_rms:[v0:1.4369 v1:1.0356 v2:0.4798 v3:0.6093 v4:0.4104 v5:0.4747 v6:0.5438 v7:0.5426 v8:0.3773 v9:0.5642] +step:5000/20000 val_loss:2.2146 val_bpb:1.3116 train_time:243130ms step_avg:48.63ms +step:5200/20000 train_loss:2.2350 train_time:252812ms step_avg:48.62ms +step:5200 shared0_alpha:mean=0.487,std=0.100 shared1_alpha:mean=0.509,std=0.078 shared2_alpha:mean=0.607,std=0.076 shared3_alpha:mean=0.583,std=0.078 eff_mlp_scale:[v0:136.4211 v1:87.8183 v2:93.5570 v3:99.0456 v4:90.4863 v5:111.7687 v6:90.0360 v7:90.3214 v8:98.4237 v9:224.4988] eff_attn_scale:[v0:0.6066 v1:5.2665 v2:11.5819 v3:10.2285 v4:11.6304 v5:9.2690 v6:9.8226 v7:8.6977 v8:9.3710 v9:1.7294] eff_attn_bias:[v0:0.6491 v1:0.6463 v2:0.5635 v3:0.4917 v4:0.5690 v5:0.5939 v6:0.5745 v7:0.3867 v8:0.3922 v9:0.7182] eff_mlp_bias:[v0:0.6325 v1:0.3922 v2:0.5331 v3:0.3618 v4:0.4033 v5:0.4640 v6:0.4751 v7:0.3259 v8:0.4917 v9:0.5883] unique_attn_gain_rms:[u0:1.1592 u1:1.1609 u2:1.0965 u3:0.9323 u4:1.1298 u5:0.9390 u6:0.8687 u7:0.8473] unique_mlp_gain_rms:[u0:2.7339 u1:2.5662 u2:2.3849 u3:2.4997 u4:2.5040 u5:2.4424 u6:2.3517 u7:2.5276] depth_emb_rms:[v0:1.4943 v1:1.0738 v2:0.4880 v3:0.6227 v4:0.4179 v5:0.4824 v6:0.5523 v7:0.5520 v8:0.3823 v9:0.5768] +step:5200/20000 val_loss:2.2149 val_bpb:1.3118 train_time:252830ms step_avg:48.62ms +step:5400/20000 train_loss:2.2436 train_time:262506ms step_avg:48.61ms +step:5400 shared0_alpha:mean=0.487,std=0.100 shared1_alpha:mean=0.509,std=0.078 shared2_alpha:mean=0.608,std=0.077 shared3_alpha:mean=0.584,std=0.079 eff_mlp_scale:[v0:139.1265 v1:88.8652 v2:95.2003 v3:99.9668 v4:92.6685 v5:112.9552 v6:91.6556 v7:92.1973 v8:100.7034 v9:228.3549] eff_attn_scale:[v0:0.6113 v1:5.3095 v2:11.7452 v3:10.3237 v4:11.9465 v5:9.4660 v6:9.9723 v7:8.8840 v8:9.6173 v9:1.7381] eff_attn_bias:[v0:0.6795 v1:0.6574 v2:0.5718 v3:0.4999 v4:0.5773 v5:0.6021 v6:0.5856 v7:0.3922 v8:0.3977 v9:0.7347] eff_mlp_bias:[v0:0.6353 v1:0.3977 v2:0.5414 v3:0.3646 v4:0.4060 v5:0.4723 v6:0.4806 v7:0.3301 v8:0.5027 v9:0.6049] unique_attn_gain_rms:[u0:1.1798 u1:1.1754 u2:1.1103 u3:0.9449 u4:1.1457 u5:0.9520 u6:0.8808 u7:0.8639] unique_mlp_gain_rms:[u0:2.7852 u1:2.6156 u2:2.4291 u3:2.5479 u4:2.5537 u5:2.4891 u6:2.3958 u7:2.5798] depth_emb_rms:[v0:1.5523 v1:1.1169 v2:0.4992 v3:0.6337 v4:0.4258 v5:0.4898 v6:0.5629 v7:0.5608 v8:0.3891 v9:0.5915] +step:5400/20000 val_loss:2.2089 val_bpb:1.3082 train_time:262527ms step_avg:48.62ms +step:5600/20000 train_loss:2.2461 train_time:272202ms step_avg:48.61ms +step:5600 shared0_alpha:mean=0.485,std=0.101 shared1_alpha:mean=0.509,std=0.079 shared2_alpha:mean=0.610,std=0.077 shared3_alpha:mean=0.584,std=0.080 eff_mlp_scale:[v0:142.2023 v1:90.0212 v2:95.7173 v3:101.8773 v4:94.1614 v5:114.8174 v6:92.1534 v7:93.5181 v8:102.2787 v9:232.0882] eff_attn_scale:[v0:0.6125 v1:5.3298 v2:11.8853 v3:10.4953 v4:12.3720 v5:9.5632 v6:10.1768 v7:9.0770 v8:9.8518 v9:1.7056] eff_attn_bias:[v0:0.7126 v1:0.6740 v2:0.5800 v3:0.5082 v4:0.5883 v5:0.6160 v6:0.5939 v7:0.3977 v8:0.4060 v9:0.7513] eff_mlp_bias:[v0:0.6408 v1:0.4033 v2:0.5497 v3:0.3701 v4:0.4116 v5:0.4751 v6:0.4889 v7:0.3356 v8:0.5082 v9:0.6187] unique_attn_gain_rms:[u0:1.1985 u1:1.1932 u2:1.1241 u3:0.9587 u4:1.1618 u5:0.9654 u6:0.8949 u7:0.8793] unique_mlp_gain_rms:[u0:2.8405 u1:2.6650 u2:2.4717 u3:2.5970 u4:2.6029 u5:2.5348 u6:2.4401 u7:2.6304] depth_emb_rms:[v0:1.6042 v1:1.1603 v2:0.5055 v3:0.6443 v4:0.4341 v5:0.4972 v6:0.5701 v7:0.5698 v8:0.3970 v9:0.6041] +step:5600/20000 val_loss:2.2120 val_bpb:1.3101 train_time:272220ms step_avg:48.61ms +step:5800/20000 train_loss:2.2122 train_time:281894ms step_avg:48.60ms +step:5800 shared0_alpha:mean=0.485,std=0.101 shared1_alpha:mean=0.509,std=0.079 shared2_alpha:mean=0.611,std=0.078 shared3_alpha:mean=0.585,std=0.081 eff_mlp_scale:[v0:144.8752 v1:91.6951 v2:97.2118 v3:103.7639 v4:95.5435 v5:116.1110 v6:93.1186 v7:94.2829 v8:104.2789 v9:236.3495] eff_attn_scale:[v0:0.6106 v1:5.3505 v2:12.0434 v3:10.6628 v4:12.5296 v5:9.6674 v6:10.3229 v7:9.2316 v8:10.1319 v9:1.6551] eff_attn_bias:[v0:0.7458 v1:0.6878 v2:0.5883 v3:0.5165 v4:0.5966 v5:0.6242 v6:0.6021 v7:0.4033 v8:0.4116 v9:0.7734] eff_mlp_bias:[v0:0.6491 v1:0.4088 v2:0.5552 v3:0.3757 v4:0.4171 v5:0.4806 v6:0.4972 v7:0.3411 v8:0.5193 v9:0.6325] unique_attn_gain_rms:[u0:1.2241 u1:1.2056 u2:1.1434 u3:0.9687 u4:1.1802 u5:0.9814 u6:0.9083 u7:0.8971] unique_mlp_gain_rms:[u0:2.8932 u1:2.7111 u2:2.5144 u3:2.6471 u4:2.6524 u5:2.5830 u6:2.4880 u7:2.6795] depth_emb_rms:[v0:1.6620 v1:1.2013 v2:0.5142 v3:0.6544 v4:0.4424 v5:0.5037 v6:0.5788 v7:0.5808 v8:0.4046 v9:0.6205] +step:5800/20000 val_loss:2.2093 val_bpb:1.3085 train_time:281913ms step_avg:48.61ms +step:6000/20000 train_loss:2.2767 train_time:291589ms step_avg:48.60ms +step:6000 shared0_alpha:mean=0.483,std=0.101 shared1_alpha:mean=0.507,std=0.079 shared2_alpha:mean=0.612,std=0.078 shared3_alpha:mean=0.584,std=0.081 eff_mlp_scale:[v0:147.6123 v1:92.8069 v2:98.2599 v3:105.0680 v4:97.0981 v5:117.9193 v6:94.1443 v7:96.0470 v8:105.9252 v9:238.1409] eff_attn_scale:[v0:0.6067 v1:5.4026 v2:12.1833 v3:10.8911 v4:12.8450 v5:9.8842 v6:10.5942 v7:9.5116 v8:10.3387 v9:1.6774] eff_attn_bias:[v0:0.7789 v1:0.7016 v2:0.5939 v3:0.5220 v4:0.6021 v5:0.6353 v6:0.6104 v7:0.4088 v8:0.4171 v9:0.7844] eff_mlp_bias:[v0:0.6546 v1:0.4116 v2:0.5635 v3:0.3812 v4:0.4198 v5:0.4834 v6:0.4999 v7:0.3439 v8:0.5276 v9:0.6436] unique_attn_gain_rms:[u0:1.2440 u1:1.2179 u2:1.1571 u3:0.9794 u4:1.1992 u5:0.9969 u6:0.9246 u7:0.9143] unique_mlp_gain_rms:[u0:2.9459 u1:2.7617 u2:2.5592 u3:2.6934 u4:2.7012 u5:2.6311 u6:2.5327 u7:2.7278] depth_emb_rms:[v0:1.7169 v1:1.2452 v2:0.5231 v3:0.6683 v4:0.4491 v5:0.5111 v6:0.5859 v7:0.5879 v8:0.4099 v9:0.6351] +step:6000/20000 val_loss:2.2042 val_bpb:1.3054 train_time:291607ms step_avg:48.60ms +step:6200/20000 train_loss:2.1520 train_time:301282ms step_avg:48.59ms +step:6200 shared0_alpha:mean=0.483,std=0.102 shared1_alpha:mean=0.507,std=0.079 shared2_alpha:mean=0.614,std=0.080 shared3_alpha:mean=0.585,std=0.082 eff_mlp_scale:[v0:150.8841 v1:94.5368 v2:100.0450 v3:107.1605 v4:99.3346 v5:119.8199 v6:95.8981 v7:97.5161 v8:108.2635 v9:241.8229] eff_attn_scale:[v0:0.6109 v1:5.3950 v2:12.2795 v3:11.0415 v4:13.0404 v5:10.0502 v6:10.6119 v7:9.6522 v8:10.5113 v9:1.6087] eff_attn_bias:[v0:0.8176 v1:0.7126 v2:0.6077 v3:0.5276 v4:0.6104 v5:0.6436 v6:0.6187 v7:0.4143 v8:0.4226 v9:0.8065] eff_mlp_bias:[v0:0.6602 v1:0.4171 v2:0.5690 v3:0.3895 v4:0.4226 v5:0.4889 v6:0.5055 v7:0.3494 v8:0.5331 v9:0.6574] unique_attn_gain_rms:[u0:1.2587 u1:1.2382 u2:1.1746 u3:0.9913 u4:1.2163 u5:1.0131 u6:0.9349 u7:0.9301] unique_mlp_gain_rms:[u0:2.9992 u1:2.8080 u2:2.5981 u3:2.7402 u4:2.7488 u5:2.6790 u6:2.5784 u7:2.7791] depth_emb_rms:[v0:1.7673 v1:1.2866 v2:0.5314 v3:0.6788 v4:0.4583 v5:0.5169 v6:0.5934 v7:0.5955 v8:0.4169 v9:0.6480] +step:6200/20000 val_loss:2.2046 val_bpb:1.3057 train_time:301300ms step_avg:48.60ms +step:6400/20000 train_loss:2.2252 train_time:310978ms step_avg:48.59ms +step:6400 shared0_alpha:mean=0.481,std=0.102 shared1_alpha:mean=0.507,std=0.080 shared2_alpha:mean=0.615,std=0.080 shared3_alpha:mean=0.585,std=0.083 eff_mlp_scale:[v0:153.7910 v1:95.0154 v2:100.5523 v3:108.5527 v4:100.2802 v5:120.9789 v6:96.3844 v7:98.8316 v8:110.4209 v9:245.4327] eff_attn_scale:[v0:0.6073 v1:5.4276 v2:12.5060 v3:11.1562 v4:13.2383 v5:10.1728 v6:10.8181 v7:9.9002 v8:10.8314 v9:1.6440] eff_attn_bias:[v0:0.8507 v1:0.7237 v2:0.6104 v3:0.5359 v4:0.6215 v5:0.6546 v6:0.6270 v7:0.4198 v8:0.4281 v9:0.8176] eff_mlp_bias:[v0:0.6657 v1:0.4171 v2:0.5745 v3:0.3895 v4:0.4281 v5:0.4972 v6:0.5110 v7:0.3536 v8:0.5441 v9:0.6684] unique_attn_gain_rms:[u0:1.2761 u1:1.2561 u2:1.1888 u3:1.0062 u4:1.2323 u5:1.0230 u6:0.9501 u7:0.9471] unique_mlp_gain_rms:[u0:3.0524 u1:2.8572 u2:2.6430 u3:2.7875 u4:2.7997 u5:2.7270 u6:2.6240 u7:2.8296] depth_emb_rms:[v0:1.8201 v1:1.3318 v2:0.5379 v3:0.6884 v4:0.4639 v5:0.5248 v6:0.6044 v7:0.6041 v8:0.4241 v9:0.6631] +step:6400/20000 val_loss:2.2019 val_bpb:1.3041 train_time:310996ms step_avg:48.59ms +step:6600/20000 train_loss:2.1863 train_time:320676ms step_avg:48.59ms +step:6600 shared0_alpha:mean=0.481,std=0.103 shared1_alpha:mean=0.507,std=0.080 shared2_alpha:mean=0.617,std=0.080 shared3_alpha:mean=0.586,std=0.083 eff_mlp_scale:[v0:156.9544 v1:96.7807 v2:102.0338 v3:109.9688 v4:101.8312 v5:122.3664 v6:97.3246 v7:100.1696 v8:112.0712 v9:247.4256] eff_attn_scale:[v0:0.6120 v1:5.4435 v2:12.5633 v3:11.3035 v4:13.5848 v5:10.3271 v6:10.9447 v7:9.9649 v8:10.9817 v9:1.6133] eff_attn_bias:[v0:0.8894 v1:0.7403 v2:0.6215 v3:0.5414 v4:0.6325 v5:0.6657 v6:0.6353 v7:0.4254 v8:0.4337 v9:0.8342] eff_mlp_bias:[v0:0.6712 v1:0.4226 v2:0.5800 v3:0.3950 v4:0.4309 v5:0.4999 v6:0.5165 v7:0.3591 v8:0.5497 v9:0.6795] unique_attn_gain_rms:[u0:1.2944 u1:1.2657 u2:1.2025 u3:1.0194 u4:1.2461 u5:1.0365 u6:0.9627 u7:0.9626] unique_mlp_gain_rms:[u0:3.1019 u1:2.9019 u2:2.6855 u3:2.8351 u4:2.8489 u5:2.7722 u6:2.6650 u7:2.8757] depth_emb_rms:[v0:1.8713 v1:1.3748 v2:0.5462 v3:0.6989 v4:0.4713 v5:0.5317 v6:0.6114 v7:0.6125 v8:0.4301 v9:0.6762] +step:6600/20000 val_loss:2.1969 val_bpb:1.3011 train_time:320697ms step_avg:48.59ms +step:6800/20000 train_loss:2.2561 train_time:330374ms step_avg:48.58ms +step:6800 shared0_alpha:mean=0.480,std=0.103 shared1_alpha:mean=0.506,std=0.079 shared2_alpha:mean=0.617,std=0.081 shared3_alpha:mean=0.586,std=0.084 eff_mlp_scale:[v0:159.4841 v1:97.1769 v2:103.0853 v3:111.9752 v4:103.4837 v5:123.4258 v6:98.3518 v7:101.5462 v8:114.4069 v9:251.2101] eff_attn_scale:[v0:0.5966 v1:5.4668 v2:12.7210 v3:11.4803 v4:13.8265 v5:10.4963 v6:11.0921 v7:10.1297 v8:11.1928 v9:1.6094] eff_attn_bias:[v0:0.9226 v1:0.7458 v2:0.6298 v3:0.5524 v4:0.6408 v5:0.6767 v6:0.6436 v7:0.4309 v8:0.4419 v9:0.8507] eff_mlp_bias:[v0:0.6822 v1:0.4281 v2:0.5883 v3:0.4005 v4:0.4364 v5:0.5027 v6:0.5220 v7:0.3646 v8:0.5580 v9:0.6905] unique_attn_gain_rms:[u0:1.3115 u1:1.2794 u2:1.2165 u3:1.0281 u4:1.2587 u5:1.0495 u6:0.9720 u7:0.9782] unique_mlp_gain_rms:[u0:3.1537 u1:2.9487 u2:2.7260 u3:2.8767 u4:2.8969 u5:2.8157 u6:2.7097 u7:2.9220] depth_emb_rms:[v0:1.9198 v1:1.4244 v2:0.5544 v3:0.7101 v4:0.4779 v5:0.5386 v6:0.6181 v7:0.6214 v8:0.4372 v9:0.6900] +step:6800/20000 val_loss:2.1974 val_bpb:1.3014 train_time:330393ms step_avg:48.59ms +step:7000/20000 train_loss:2.2867 train_time:340064ms step_avg:48.58ms +step:7000 shared0_alpha:mean=0.480,std=0.103 shared1_alpha:mean=0.506,std=0.080 shared2_alpha:mean=0.619,std=0.082 shared3_alpha:mean=0.587,std=0.085 eff_mlp_scale:[v0:162.1062 v1:99.0626 v2:104.1837 v3:112.9347 v4:105.0102 v5:125.5168 v6:99.4240 v7:103.5235 v8:116.0334 v9:254.9077] eff_attn_scale:[v0:0.6006 v1:5.5304 v2:12.9196 v3:11.7111 v4:14.0801 v5:10.6210 v6:11.2753 v7:10.2755 v8:11.4140 v9:1.5821] eff_attn_bias:[v0:0.9557 v1:0.7623 v2:0.6353 v3:0.5580 v4:0.6491 v5:0.6850 v6:0.6519 v7:0.4337 v8:0.4447 v9:0.8618] eff_mlp_bias:[v0:0.6850 v1:0.4309 v2:0.5939 v3:0.4060 v4:0.4419 v5:0.5110 v6:0.5276 v7:0.3674 v8:0.5635 v9:0.6988] unique_attn_gain_rms:[u0:1.3283 u1:1.2948 u2:1.2321 u3:1.0412 u4:1.2795 u5:1.0585 u6:0.9859 u7:0.9923] unique_mlp_gain_rms:[u0:3.2016 u1:2.9926 u2:2.7689 u3:2.9188 u4:2.9446 u5:2.8608 u6:2.7538 u7:2.9698] depth_emb_rms:[v0:1.9667 v1:1.4696 v2:0.5593 v3:0.7210 v4:0.4852 v5:0.5461 v6:0.6284 v7:0.6295 v8:0.4430 v9:0.7032] +step:7000/20000 val_loss:2.1946 val_bpb:1.2998 train_time:340089ms step_avg:48.58ms +step:7200/20000 train_loss:2.2660 train_time:349764ms step_avg:48.58ms +step:7200 shared0_alpha:mean=0.478,std=0.104 shared1_alpha:mean=0.505,std=0.081 shared2_alpha:mean=0.620,std=0.082 shared3_alpha:mean=0.587,std=0.086 eff_mlp_scale:[v0:164.8772 v1:100.0444 v2:105.2753 v3:115.0734 v4:106.6417 v5:127.1750 v6:100.4900 v7:105.0184 v8:118.3606 v9:258.6734] eff_attn_scale:[v0:0.6021 v1:5.5130 v2:13.0341 v3:11.8122 v4:14.2479 v5:10.7110 v6:11.4542 v7:10.5166 v8:11.6344 v9:1.5773] eff_attn_bias:[v0:0.9944 v1:0.7679 v2:0.6436 v3:0.5662 v4:0.6574 v5:0.6961 v6:0.6629 v7:0.4419 v8:0.4502 v9:0.8728] eff_mlp_bias:[v0:0.6878 v1:0.4364 v2:0.5994 v3:0.4116 v4:0.4475 v5:0.5138 v6:0.5359 v7:0.3701 v8:0.5745 v9:0.7071] unique_attn_gain_rms:[u0:1.3408 u1:1.3060 u2:1.2421 u3:1.0519 u4:1.2977 u5:1.0710 u6:1.0004 u7:1.0083] unique_mlp_gain_rms:[u0:3.2532 u1:3.0388 u2:2.8074 u3:2.9634 u4:2.9884 u5:2.9035 u6:2.7965 u7:3.0147] depth_emb_rms:[v0:2.0159 v1:1.5170 v2:0.5678 v3:0.7316 v4:0.4936 v5:0.5527 v6:0.6358 v7:0.6388 v8:0.4497 v9:0.7169] +step:7200/20000 val_loss:2.1948 val_bpb:1.2999 train_time:349782ms step_avg:48.58ms +step:7400/20000 train_loss:2.1818 train_time:359455ms step_avg:48.58ms +step:7400 shared0_alpha:mean=0.478,std=0.105 shared1_alpha:mean=0.505,std=0.081 shared2_alpha:mean=0.622,std=0.083 shared3_alpha:mean=0.588,std=0.087 eff_mlp_scale:[v0:168.4570 v1:100.5391 v2:106.3334 v3:116.3814 v4:108.2729 v5:128.3719 v6:102.0586 v7:106.2613 v8:120.1060 v9:260.3794] eff_attn_scale:[v0:0.5972 v1:5.4753 v2:13.1562 v3:11.9692 v4:14.5235 v5:10.8873 v6:11.5711 v7:10.6649 v8:11.7207 v9:1.5673] eff_attn_bias:[v0:1.0275 v1:0.7844 v2:0.6519 v3:0.5773 v4:0.6657 v5:0.7043 v6:0.6657 v7:0.4475 v8:0.4558 v9:0.8894] eff_mlp_bias:[v0:0.6933 v1:0.4392 v2:0.6077 v3:0.4143 v4:0.4502 v5:0.5193 v6:0.5386 v7:0.3729 v8:0.5800 v9:0.7182] unique_attn_gain_rms:[u0:1.3540 u1:1.3173 u2:1.2556 u3:1.0585 u4:1.3112 u5:1.0835 u6:1.0141 u7:1.0215] unique_mlp_gain_rms:[u0:3.3008 u1:3.0849 u2:2.8471 u3:3.0069 u4:3.0349 u5:2.9478 u6:2.8384 u7:3.0626] depth_emb_rms:[v0:2.0659 v1:1.5715 v2:0.5764 v3:0.7421 v4:0.4995 v5:0.5595 v6:0.6434 v7:0.6466 v8:0.4556 v9:0.7300] +step:7400/20000 val_loss:2.1922 val_bpb:1.2984 train_time:359475ms step_avg:48.58ms +step:7600/20000 train_loss:2.0613 train_time:369149ms step_avg:48.57ms +step:7600 shared0_alpha:mean=0.477,std=0.105 shared1_alpha:mean=0.505,std=0.080 shared2_alpha:mean=0.622,std=0.083 shared3_alpha:mean=0.587,std=0.087 eff_mlp_scale:[v0:170.9100 v1:101.6276 v2:107.3166 v3:117.9926 v4:110.4106 v5:129.6037 v6:103.0240 v7:107.7817 v8:122.3469 v9:264.1725] eff_attn_scale:[v0:0.6015 v1:5.5675 v2:13.3280 v3:12.1448 v4:14.8009 v5:11.0078 v6:11.6520 v7:10.7524 v8:12.0473 v9:1.5833] eff_attn_bias:[v0:1.0607 v1:0.7955 v2:0.6574 v3:0.5800 v4:0.6740 v5:0.7126 v6:0.6740 v7:0.4558 v8:0.4613 v9:0.9005] eff_mlp_bias:[v0:0.6988 v1:0.4364 v2:0.6104 v3:0.4198 v4:0.4558 v5:0.5276 v6:0.5441 v7:0.3784 v8:0.5856 v9:0.7237] unique_attn_gain_rms:[u0:1.3737 u1:1.3314 u2:1.2690 u3:1.0692 u4:1.3275 u5:1.0939 u6:1.0289 u7:1.0384] unique_mlp_gain_rms:[u0:3.3480 u1:3.1297 u2:2.8854 u3:3.0476 u4:3.0805 u5:2.9908 u6:2.8813 u7:3.1102] depth_emb_rms:[v0:2.1114 v1:1.6164 v2:0.5805 v3:0.7514 v4:0.5071 v5:0.5661 v6:0.6533 v7:0.6551 v8:0.4616 v9:0.7429] +step:7600/20000 val_loss:2.1899 val_bpb:1.2970 train_time:369169ms step_avg:48.57ms +step:7800/20000 train_loss:2.2120 train_time:378848ms step_avg:48.57ms +step:7800 shared0_alpha:mean=0.476,std=0.105 shared1_alpha:mean=0.504,std=0.082 shared2_alpha:mean=0.624,std=0.084 shared3_alpha:mean=0.588,std=0.087 eff_mlp_scale:[v0:173.3739 v1:103.3365 v2:108.3865 v3:119.7336 v4:111.4783 v5:131.4670 v6:104.0726 v7:109.4216 v8:124.1326 v9:267.9849] eff_attn_scale:[v0:0.5967 v1:5.5699 v2:13.4266 v3:12.2567 v4:15.0303 v5:11.1398 v6:11.7382 v7:10.9296 v8:12.1633 v9:1.5843] eff_attn_bias:[v0:1.0993 v1:0.8065 v2:0.6657 v3:0.5883 v4:0.6795 v5:0.7237 v6:0.6822 v7:0.4558 v8:0.4668 v9:0.9170] eff_mlp_bias:[v0:0.7043 v1:0.4419 v2:0.6160 v3:0.4254 v4:0.4585 v5:0.5303 v6:0.5497 v7:0.3812 v8:0.5939 v9:0.7347] unique_attn_gain_rms:[u0:1.3833 u1:1.3420 u2:1.2814 u3:1.0807 u4:1.3386 u5:1.1026 u6:1.0358 u7:1.0537] unique_mlp_gain_rms:[u0:3.3948 u1:3.1735 u2:2.9259 u3:3.0885 u4:3.1253 u5:3.0354 u6:2.9238 u7:3.1516] depth_emb_rms:[v0:2.1583 v1:1.6679 v2:0.5888 v3:0.7612 v4:0.5151 v5:0.5735 v6:0.6590 v7:0.6623 v8:0.4674 v9:0.7545] +step:7800/20000 val_loss:2.1875 val_bpb:1.2956 train_time:378866ms step_avg:48.57ms +step:8000/20000 train_loss:2.1754 train_time:388547ms step_avg:48.57ms +step:8000 shared0_alpha:mean=0.476,std=0.106 shared1_alpha:mean=0.505,std=0.081 shared2_alpha:mean=0.625,std=0.085 shared3_alpha:mean=0.589,std=0.088 eff_mlp_scale:[v0:176.1272 v1:104.5553 v2:109.6001 v3:122.2008 v4:113.8374 v5:133.4379 v6:104.7169 v7:111.1970 v8:126.0125 v9:269.6346] eff_attn_scale:[v0:0.6014 v1:5.5178 v2:13.4949 v3:12.3021 v4:15.1303 v5:11.1631 v6:11.9595 v7:10.9700 v8:12.2442 v9:1.5622] eff_attn_bias:[v0:1.1270 v1:0.8176 v2:0.6712 v3:0.5966 v4:0.6850 v5:0.7292 v6:0.6905 v7:0.4640 v8:0.4696 v9:0.9281] eff_mlp_bias:[v0:0.7071 v1:0.4419 v2:0.6215 v3:0.4337 v4:0.4613 v5:0.5331 v6:0.5552 v7:0.3867 v8:0.6021 v9:0.7458] unique_attn_gain_rms:[u0:1.3963 u1:1.3534 u2:1.2897 u3:1.0938 u4:1.3558 u5:1.1182 u6:1.0538 u7:1.0683] unique_mlp_gain_rms:[u0:3.4391 u1:3.2173 u2:2.9655 u3:3.1377 u4:3.1762 u5:3.0812 u6:2.9654 u7:3.1959] depth_emb_rms:[v0:2.2010 v1:1.7192 v2:0.5954 v3:0.7714 v4:0.5261 v5:0.5783 v6:0.6652 v7:0.6712 v8:0.4728 v9:0.7699] +step:8000/20000 val_loss:2.1862 val_bpb:1.2948 train_time:388565ms step_avg:48.57ms +step:8200/20000 train_loss:2.2452 train_time:398241ms step_avg:48.57ms +step:8200 shared0_alpha:mean=0.475,std=0.106 shared1_alpha:mean=0.504,std=0.081 shared2_alpha:mean=0.626,std=0.085 shared3_alpha:mean=0.588,std=0.089 eff_mlp_scale:[v0:179.6297 v1:105.6286 v2:110.0553 v3:123.6545 v4:114.8300 v5:134.6474 v6:105.6967 v7:113.1555 v8:128.3394 v9:273.5665] eff_attn_scale:[v0:0.5899 v1:5.6054 v2:13.6359 v3:12.4469 v4:15.3908 v5:11.3398 v6:12.0126 v7:11.1077 v8:12.4718 v9:1.5496] eff_attn_bias:[v0:1.1601 v1:0.8231 v2:0.6767 v3:0.6021 v4:0.6961 v5:0.7403 v6:0.6961 v7:0.4696 v8:0.4751 v9:0.9447] eff_mlp_bias:[v0:0.7126 v1:0.4502 v2:0.6270 v3:0.4337 v4:0.4668 v5:0.5386 v6:0.5552 v7:0.3922 v8:0.6077 v9:0.7513] unique_attn_gain_rms:[u0:1.4071 u1:1.3631 u2:1.3055 u3:1.1054 u4:1.3722 u5:1.1268 u6:1.0649 u7:1.0795] unique_mlp_gain_rms:[u0:3.4895 u1:3.2581 u2:3.0057 u3:3.1770 u4:3.2194 u5:3.1221 u6:3.0039 u7:3.2408] depth_emb_rms:[v0:2.2474 v1:1.7723 v2:0.6056 v3:0.7810 v4:0.5293 v5:0.5869 v6:0.6746 v7:0.6786 v8:0.4782 v9:0.7827] +step:8200/20000 val_loss:2.1843 val_bpb:1.2937 train_time:398259ms step_avg:48.57ms +step:8400/20000 train_loss:2.1999 train_time:408002ms step_avg:48.57ms +step:8400 shared0_alpha:mean=0.474,std=0.107 shared1_alpha:mean=0.504,std=0.082 shared2_alpha:mean=0.627,std=0.086 shared3_alpha:mean=0.589,std=0.089 eff_mlp_scale:[v0:181.3307 v1:106.8405 v2:111.6136 v3:125.2924 v4:117.1543 v5:136.0319 v6:106.6895 v7:114.1161 v8:130.1715 v9:277.4233] eff_attn_scale:[v0:0.5969 v1:5.6496 v2:13.7890 v3:12.5582 v4:15.5806 v5:11.4284 v6:12.1572 v7:11.2865 v8:12.7152 v9:1.5518] eff_attn_bias:[v0:1.1932 v1:0.8342 v2:0.6850 v3:0.6049 v4:0.7043 v5:0.7513 v6:0.7043 v7:0.4751 v8:0.4806 v9:0.9557] eff_mlp_bias:[v0:0.7182 v1:0.4530 v2:0.6298 v3:0.4392 v4:0.4696 v5:0.5441 v6:0.5607 v7:0.3950 v8:0.6187 v9:0.7623] unique_attn_gain_rms:[u0:1.4166 u1:1.3752 u2:1.3156 u3:1.1173 u4:1.3846 u5:1.1392 u6:1.0770 u7:1.0935] unique_mlp_gain_rms:[u0:3.5358 u1:3.3041 u2:3.0412 u3:3.2186 u4:3.2625 u5:3.1636 u6:3.0431 u7:3.2857] depth_emb_rms:[v0:2.2884 v1:1.8291 v2:0.6130 v3:0.7893 v4:0.5367 v5:0.5927 v6:0.6815 v7:0.6861 v8:0.4846 v9:0.7969] +step:8400/20000 val_loss:2.1839 val_bpb:1.2934 train_time:408020ms step_avg:48.57ms +step:8600/20000 train_loss:2.1997 train_time:417691ms step_avg:48.57ms +step:8600 shared0_alpha:mean=0.474,std=0.107 shared1_alpha:mean=0.503,std=0.082 shared2_alpha:mean=0.628,std=0.087 shared3_alpha:mean=0.589,std=0.090 eff_mlp_scale:[v0:184.0021 v1:107.8754 v2:112.7912 v3:127.0246 v4:118.7767 v5:137.1894 v6:107.8394 v7:116.3403 v8:131.9046 v9:279.0145] eff_attn_scale:[v0:0.5953 v1:5.6400 v2:13.9253 v3:12.7529 v4:15.8435 v5:11.5393 v6:12.2774 v7:11.4696 v8:12.9464 v9:1.5359] eff_attn_bias:[v0:1.2264 v1:0.8452 v2:0.6905 v3:0.6104 v4:0.7126 v5:0.7568 v6:0.7126 v7:0.4806 v8:0.4834 v9:0.9667] eff_mlp_bias:[v0:0.7237 v1:0.4530 v2:0.6325 v3:0.4447 v4:0.4723 v5:0.5469 v6:0.5635 v7:0.4005 v8:0.6242 v9:0.7679] unique_attn_gain_rms:[u0:1.4253 u1:1.3875 u2:1.3300 u3:1.1282 u4:1.4045 u5:1.1466 u6:1.0879 u7:1.1108] unique_mlp_gain_rms:[u0:3.5838 u1:3.3484 u2:3.0795 u3:3.2588 u4:3.3103 u5:3.2089 u6:3.0856 u7:3.3295] depth_emb_rms:[v0:2.3279 v1:1.8808 v2:0.6195 v3:0.7969 v4:0.5436 v5:0.5981 v6:0.6886 v7:0.6918 v8:0.4911 v9:0.8100] +step:8600/20000 val_loss:2.1816 val_bpb:1.2921 train_time:417708ms step_avg:48.57ms +step:8800/20000 train_loss:2.1668 train_time:427391ms step_avg:48.57ms +step:8800 shared0_alpha:mean=0.472,std=0.107 shared1_alpha:mean=0.504,std=0.082 shared2_alpha:mean=0.629,std=0.087 shared3_alpha:mean=0.589,std=0.091 eff_mlp_scale:[v0:186.7122 v1:109.1324 v2:113.1820 v3:128.5047 v4:120.5025 v5:138.6277 v6:108.7651 v7:117.7462 v8:134.3824 v9:280.7902] eff_attn_scale:[v0:0.5903 v1:5.6978 v2:14.0718 v3:12.9457 v4:16.1940 v5:11.8540 v6:12.4163 v7:11.6511 v8:13.1748 v9:1.5311] eff_attn_bias:[v0:1.2595 v1:0.8507 v2:0.6988 v3:0.6215 v4:0.7182 v5:0.7623 v6:0.7237 v7:0.4861 v8:0.4861 v9:0.9778] eff_mlp_bias:[v0:0.7292 v1:0.4558 v2:0.6408 v3:0.4502 v4:0.4778 v5:0.5497 v6:0.5690 v7:0.4060 v8:0.6298 v9:0.7789] unique_attn_gain_rms:[u0:1.4336 u1:1.3925 u2:1.3428 u3:1.1357 u4:1.4188 u5:1.1566 u6:1.0985 u7:1.1225] unique_mlp_gain_rms:[u0:3.6296 u1:3.3921 u2:3.1197 u3:3.3020 u4:3.3544 u5:3.2517 u6:3.1268 u7:3.3736] depth_emb_rms:[v0:2.3725 v1:1.9385 v2:0.6266 v3:0.8099 v4:0.5514 v5:0.6064 v6:0.6964 v7:0.7006 v8:0.4972 v9:0.8227] +step:8800/20000 val_loss:2.1806 val_bpb:1.2915 train_time:427409ms step_avg:48.57ms +step:9000/20000 train_loss:2.0842 train_time:437077ms step_avg:48.56ms +step:9000 shared0_alpha:mean=0.472,std=0.108 shared1_alpha:mean=0.503,std=0.083 shared2_alpha:mean=0.630,std=0.088 shared3_alpha:mean=0.589,std=0.091 eff_mlp_scale:[v0:189.6999 v1:110.3083 v2:114.3000 v3:130.3690 v4:122.1903 v5:139.9610 v6:109.8612 v7:118.9013 v8:136.8277 v9:284.5742] eff_attn_scale:[v0:0.5968 v1:5.6860 v2:14.1970 v3:13.1279 v4:16.3940 v5:11.9637 v6:12.6102 v7:11.7417 v8:13.3375 v9:1.5499] eff_attn_bias:[v0:1.2927 v1:0.8618 v2:0.7043 v3:0.6270 v4:0.7237 v5:0.7734 v6:0.7237 v7:0.4917 v8:0.4944 v9:0.9833] eff_mlp_bias:[v0:0.7347 v1:0.4613 v2:0.6463 v3:0.4558 v4:0.4834 v5:0.5524 v6:0.5718 v7:0.4060 v8:0.6381 v9:0.7900] unique_attn_gain_rms:[u0:1.4414 u1:1.4047 u2:1.3546 u3:1.1450 u4:1.4352 u5:1.1685 u6:1.1119 u7:1.1360] unique_mlp_gain_rms:[u0:3.6795 u1:3.4351 u2:3.1580 u3:3.3455 u4:3.3965 u5:3.2940 u6:3.1671 u7:3.4186] depth_emb_rms:[v0:2.4110 v1:1.9932 v2:0.6363 v3:0.8192 v4:0.5580 v5:0.6138 v6:0.7014 v7:0.7065 v8:0.5015 v9:0.8353] +step:9000/20000 val_loss:2.1800 val_bpb:1.2911 train_time:437096ms step_avg:48.57ms +step:9200/20000 train_loss:2.1476 train_time:446763ms step_avg:48.56ms +step:9200 shared0_alpha:mean=0.471,std=0.108 shared1_alpha:mean=0.502,std=0.083 shared2_alpha:mean=0.631,std=0.088 shared3_alpha:mean=0.589,std=0.092 eff_mlp_scale:[v0:192.0719 v1:111.3497 v2:115.2168 v3:132.3010 v4:123.2303 v5:141.1224 v6:110.7640 v7:120.7170 v8:137.9922 v9:288.4535] eff_attn_scale:[v0:0.5970 v1:5.7745 v2:14.3457 v3:13.2262 v4:16.7160 v5:12.0769 v6:12.5840 v7:11.9118 v8:13.6169 v9:1.5385] eff_attn_bias:[v0:1.3203 v1:0.8728 v2:0.7126 v3:0.6325 v4:0.7347 v5:0.7844 v6:0.7292 v7:0.4972 v8:0.4972 v9:0.9999] eff_mlp_bias:[v0:0.7403 v1:0.4613 v2:0.6491 v3:0.4585 v4:0.4889 v5:0.5580 v6:0.5800 v7:0.4116 v8:0.6436 v9:0.7955] unique_attn_gain_rms:[u0:1.4519 u1:1.4183 u2:1.3681 u3:1.1572 u4:1.4489 u5:1.1781 u6:1.1211 u7:1.1494] unique_mlp_gain_rms:[u0:3.7255 u1:3.4764 u2:3.1963 u3:3.3864 u4:3.4441 u5:3.3359 u6:3.2115 u7:3.4623] depth_emb_rms:[v0:2.4490 v1:2.0482 v2:0.6417 v3:0.8278 v4:0.5654 v5:0.6217 v6:0.7097 v7:0.7133 v8:0.5086 v9:0.8491] +step:9200/20000 val_loss:2.1793 val_bpb:1.2907 train_time:446783ms step_avg:48.56ms +step:9400/20000 train_loss:2.2040 train_time:456457ms step_avg:48.56ms +step:9400 shared0_alpha:mean=0.471,std=0.108 shared1_alpha:mean=0.502,std=0.083 shared2_alpha:mean=0.632,std=0.089 shared3_alpha:mean=0.589,std=0.092 eff_mlp_scale:[v0:194.6637 v1:111.9235 v2:116.7688 v3:134.0852 v4:124.9494 v5:143.0466 v6:111.7405 v7:122.3988 v8:139.8397 v9:289.9404] eff_attn_scale:[v0:0.5904 v1:5.7451 v2:14.5289 v3:13.4275 v4:16.9732 v5:12.1505 v6:12.7550 v7:12.0185 v8:13.8441 v9:1.5366] eff_attn_bias:[v0:1.3534 v1:0.8839 v2:0.7182 v3:0.6381 v4:0.7458 v5:0.7900 v6:0.7403 v7:0.5027 v8:0.5027 v9:1.0109] eff_mlp_bias:[v0:0.7458 v1:0.4640 v2:0.6546 v3:0.4640 v4:0.4917 v5:0.5607 v6:0.5828 v7:0.4143 v8:0.6491 v9:0.8065] unique_attn_gain_rms:[u0:1.4558 u1:1.4301 u2:1.3785 u3:1.1688 u4:1.4650 u5:1.1880 u6:1.1329 u7:1.1612] unique_mlp_gain_rms:[u0:3.7735 u1:3.5180 u2:3.2365 u3:3.4237 u4:3.4895 u5:3.3776 u6:3.2484 u7:3.5035] depth_emb_rms:[v0:2.4896 v1:2.1035 v2:0.6491 v3:0.8371 v4:0.5719 v5:0.6274 v6:0.7167 v7:0.7204 v8:0.5146 v9:0.8615] +step:9400/20000 val_loss:2.1769 val_bpb:1.2893 train_time:456477ms step_avg:48.56ms +step:9600/20000 train_loss:2.2098 train_time:466149ms step_avg:48.56ms +step:9600 shared0_alpha:mean=0.470,std=0.109 shared1_alpha:mean=0.501,std=0.084 shared2_alpha:mean=0.632,std=0.089 shared3_alpha:mean=0.589,std=0.093 eff_mlp_scale:[v0:197.0683 v1:113.0623 v2:117.6707 v3:136.2467 v4:126.7263 v5:144.3348 v6:112.6277 v7:124.4262 v8:142.4037 v9:291.7816] eff_attn_scale:[v0:0.5904 v1:5.7668 v2:14.6872 v3:13.4815 v4:17.2205 v5:12.3289 v6:12.9794 v7:12.1500 v8:13.9677 v9:1.5067] eff_attn_bias:[v0:1.3811 v1:0.8894 v2:0.7237 v3:0.6436 v4:0.7513 v5:0.7955 v6:0.7458 v7:0.5082 v8:0.5055 v9:1.0165] eff_mlp_bias:[v0:0.7513 v1:0.4668 v2:0.6574 v3:0.4640 v4:0.4944 v5:0.5635 v6:0.5856 v7:0.4171 v8:0.6546 v9:0.8121] unique_attn_gain_rms:[u0:1.4661 u1:1.4369 u2:1.3914 u3:1.1775 u4:1.4759 u5:1.1978 u6:1.1429 u7:1.1750] unique_mlp_gain_rms:[u0:3.8202 u1:3.5614 u2:3.2735 u3:3.4611 u4:3.5316 u5:3.4206 u6:3.2904 u7:3.5471] depth_emb_rms:[v0:2.5295 v1:2.1571 v2:0.6574 v3:0.8440 v4:0.5758 v5:0.6331 v6:0.7253 v7:0.7277 v8:0.5191 v9:0.8741] +step:9600/20000 val_loss:2.1773 val_bpb:1.2895 train_time:466168ms step_avg:48.56ms +step:9800/20000 train_loss:2.1400 train_time:475839ms step_avg:48.56ms +step:9800 shared0_alpha:mean=0.470,std=0.109 shared1_alpha:mean=0.501,std=0.084 shared2_alpha:mean=0.634,std=0.090 shared3_alpha:mean=0.590,std=0.093 eff_mlp_scale:[v0:199.7873 v1:114.0976 v2:118.6960 v3:138.1059 v4:128.3465 v5:145.4896 v6:113.6332 v7:126.1785 v8:144.1430 v9:295.5243] eff_attn_scale:[v0:0.5900 v1:5.7594 v2:14.6687 v3:13.6441 v4:17.3764 v5:12.4509 v6:13.0483 v7:12.2211 v8:14.1907 v9:1.5013] eff_attn_bias:[v0:1.4142 v1:0.8949 v2:0.7292 v3:0.6519 v4:0.7623 v5:0.8065 v6:0.7513 v7:0.5138 v8:0.5110 v9:1.0275] eff_mlp_bias:[v0:0.7513 v1:0.4696 v2:0.6629 v3:0.4668 v4:0.4999 v5:0.5662 v6:0.5911 v7:0.4198 v8:0.6602 v9:0.8176] unique_attn_gain_rms:[u0:1.4724 u1:1.4498 u2:1.4023 u3:1.1854 u4:1.4909 u5:1.2086 u6:1.1503 u7:1.1887] unique_mlp_gain_rms:[u0:3.8681 u1:3.6038 u2:3.3090 u3:3.5040 u4:3.5768 u5:3.4585 u6:3.3272 u7:3.5887] depth_emb_rms:[v0:2.5646 v1:2.2040 v2:0.6624 v3:0.8540 v4:0.5832 v5:0.6402 v6:0.7311 v7:0.7338 v8:0.5241 v9:0.8856] +step:9800/20000 val_loss:2.1781 val_bpb:1.2900 train_time:475858ms step_avg:48.56ms +step:10000/20000 train_loss:2.1775 train_time:485523ms step_avg:48.55ms +step:10000 shared0_alpha:mean=0.468,std=0.109 shared1_alpha:mean=0.500,std=0.084 shared2_alpha:mean=0.635,std=0.090 shared3_alpha:mean=0.590,std=0.094 eff_mlp_scale:[v0:201.8576 v1:115.3079 v2:119.6456 v3:140.0821 v4:130.0175 v5:146.8658 v6:114.5663 v7:128.0389 v8:145.9380 v9:297.3471] eff_attn_scale:[v0:0.5869 v1:5.7606 v2:14.9569 v3:13.7940 v4:17.7521 v5:12.6598 v6:13.3143 v7:12.4400 v8:14.5156 v9:1.5308] eff_attn_bias:[v0:1.4474 v1:0.9060 v2:0.7347 v3:0.6574 v4:0.7679 v5:0.8121 v6:0.7568 v7:0.5165 v8:0.5138 v9:1.0330] eff_mlp_bias:[v0:0.7568 v1:0.4751 v2:0.6712 v3:0.4723 v4:0.5027 v5:0.5718 v6:0.5966 v7:0.4226 v8:0.6657 v9:0.8286] unique_attn_gain_rms:[u0:1.4759 u1:1.4586 u2:1.4148 u3:1.1992 u4:1.5078 u5:1.2169 u6:1.1621 u7:1.2021] unique_mlp_gain_rms:[u0:3.9103 u1:3.6465 u2:3.3458 u3:3.5447 u4:3.6219 u5:3.5002 u6:3.3692 u7:3.6324] depth_emb_rms:[v0:2.6024 v1:2.2466 v2:0.6715 v3:0.8636 v4:0.5907 v5:0.6484 v6:0.7391 v7:0.7422 v8:0.5297 v9:0.8981] +step:10000/20000 val_loss:2.1768 val_bpb:1.2892 train_time:485543ms step_avg:48.55ms +step:10200/20000 train_loss:2.1329 train_time:495216ms step_avg:48.55ms +step:10200 shared0_alpha:mean=0.468,std=0.110 shared1_alpha:mean=0.501,std=0.084 shared2_alpha:mean=0.636,std=0.091 shared3_alpha:mean=0.590,std=0.094 eff_mlp_scale:[v0:204.2620 v1:116.5338 v2:120.6789 v3:142.1342 v4:131.7836 v5:148.2603 v6:115.5798 v7:129.9695 v8:147.8384 v9:301.1838] eff_attn_scale:[v0:0.5810 v1:5.7510 v2:15.1330 v3:13.8611 v4:17.9264 v5:12.7801 v6:13.3936 v7:12.5856 v8:14.7571 v9:1.5172] eff_attn_bias:[v0:1.4695 v1:0.9170 v2:0.7403 v3:0.6629 v4:0.7734 v5:0.8231 v6:0.7679 v7:0.5248 v8:0.5220 v9:1.0441] eff_mlp_bias:[v0:0.7623 v1:0.4778 v2:0.6740 v3:0.4778 v4:0.5082 v5:0.5745 v6:0.5994 v7:0.4254 v8:0.6712 v9:0.8342] unique_attn_gain_rms:[u0:1.4793 u1:1.4665 u2:1.4216 u3:1.2055 u4:1.5193 u5:1.2284 u6:1.1724 u7:1.2153] unique_mlp_gain_rms:[u0:3.9551 u1:3.6901 u2:3.3868 u3:3.5814 u4:3.6672 u5:3.5389 u6:3.4093 u7:3.6736] depth_emb_rms:[v0:2.6418 v1:2.2944 v2:0.6792 v3:0.8726 v4:0.5963 v5:0.6561 v6:0.7466 v7:0.7503 v8:0.5357 v9:0.9098] +step:10200/20000 val_loss:2.1727 val_bpb:1.2868 train_time:495234ms step_avg:48.55ms +step:10400/20000 train_loss:2.1654 train_time:504904ms step_avg:48.55ms +step:10400 shared0_alpha:mean=0.467,std=0.110 shared1_alpha:mean=0.501,std=0.084 shared2_alpha:mean=0.637,std=0.091 shared3_alpha:mean=0.591,std=0.095 eff_mlp_scale:[v0:206.6398 v1:117.8370 v2:121.6206 v3:144.2783 v4:133.6202 v5:150.3649 v6:115.9374 v7:131.9855 v8:149.8166 v9:302.7210] eff_attn_scale:[v0:0.5813 v1:5.8030 v2:15.2212 v3:14.0587 v4:18.3149 v5:12.8206 v6:13.4716 v7:12.6871 v8:14.9121 v9:1.5011] eff_attn_bias:[v0:1.5026 v1:0.9226 v2:0.7403 v3:0.6684 v4:0.7789 v5:0.8286 v6:0.7734 v7:0.5248 v8:0.5248 v9:1.0551] eff_mlp_bias:[v0:0.7623 v1:0.4806 v2:0.6767 v3:0.4834 v4:0.5110 v5:0.5773 v6:0.6021 v7:0.4309 v8:0.6767 v9:0.8397] unique_attn_gain_rms:[u0:1.4839 u1:1.4772 u2:1.4334 u3:1.2127 u4:1.5315 u5:1.2341 u6:1.1790 u7:1.2270] unique_mlp_gain_rms:[u0:4.0018 u1:3.7295 u2:3.4178 u3:3.6214 u4:3.7107 u5:3.5808 u6:3.4483 u7:3.7137] depth_emb_rms:[v0:2.6829 v1:2.3358 v2:0.6857 v3:0.8818 v4:0.6039 v5:0.6618 v6:0.7518 v7:0.7555 v8:0.5412 v9:0.9225] +step:10400/20000 val_loss:2.1724 val_bpb:1.2866 train_time:504925ms step_avg:48.55ms +step:10600/20000 train_loss:2.0390 train_time:514598ms step_avg:48.55ms +step:10600 shared0_alpha:mean=0.466,std=0.111 shared1_alpha:mean=0.501,std=0.084 shared2_alpha:mean=0.638,std=0.092 shared3_alpha:mean=0.590,std=0.095 eff_mlp_scale:[v0:208.1788 v1:118.4041 v2:122.2290 v3:146.0307 v4:135.3981 v5:151.7053 v6:117.0885 v7:134.2434 v8:152.4079 v9:306.5494] eff_attn_scale:[v0:0.5764 v1:5.7648 v2:15.2089 v3:13.9966 v4:18.3527 v5:12.8107 v6:13.4608 v7:12.7945 v8:15.1258 v9:1.5128] eff_attn_bias:[v0:1.5357 v1:0.9281 v2:0.7458 v3:0.6740 v4:0.7900 v5:0.8342 v6:0.7789 v7:0.5331 v8:0.5303 v9:1.0607] eff_mlp_bias:[v0:0.7679 v1:0.4834 v2:0.6795 v3:0.4889 v4:0.5138 v5:0.5828 v6:0.6049 v7:0.4364 v8:0.6822 v9:0.8452] unique_attn_gain_rms:[u0:1.4836 u1:1.4810 u2:1.4488 u3:1.2215 u4:1.5468 u5:1.2471 u6:1.1966 u7:1.2402] unique_mlp_gain_rms:[u0:4.0492 u1:3.7700 u2:3.4563 u3:3.6625 u4:3.7529 u5:3.6179 u6:3.4874 u7:3.7537] depth_emb_rms:[v0:2.7192 v1:2.3789 v2:0.6923 v3:0.8891 v4:0.6116 v5:0.6685 v6:0.7594 v7:0.7621 v8:0.5474 v9:0.9342] +step:10600/20000 val_loss:2.1725 val_bpb:1.2867 train_time:514616ms step_avg:48.55ms +step:10800/20000 train_loss:2.2512 train_time:524287ms step_avg:48.55ms +step:10800 shared0_alpha:mean=0.466,std=0.111 shared1_alpha:mean=0.500,std=0.085 shared2_alpha:mean=0.639,std=0.093 shared3_alpha:mean=0.591,std=0.096 eff_mlp_scale:[v0:210.5479 v1:119.3132 v2:123.2981 v3:148.4347 v4:137.1214 v5:152.6962 v6:118.1367 v7:136.5070 v8:154.2615 v9:307.8055] eff_attn_scale:[v0:0.5901 v1:5.7809 v2:15.2936 v3:14.1461 v4:18.5976 v5:12.9818 v6:13.6236 v7:12.8522 v8:15.2439 v9:1.5189] eff_attn_bias:[v0:1.5578 v1:0.9336 v2:0.7513 v3:0.6795 v4:0.7955 v5:0.8397 v6:0.7844 v7:0.5386 v8:0.5359 v9:1.0717] eff_mlp_bias:[v0:0.7734 v1:0.4889 v2:0.6850 v3:0.4917 v4:0.5193 v5:0.5856 v6:0.6104 v7:0.4337 v8:0.6878 v9:0.8563] unique_attn_gain_rms:[u0:1.4922 u1:1.4916 u2:1.4572 u3:1.2296 u4:1.5629 u5:1.2519 u6:1.2008 u7:1.2478] unique_mlp_gain_rms:[u0:4.0946 u1:3.8076 u2:3.4957 u3:3.6988 u4:3.7981 u5:3.6573 u6:3.5248 u7:3.7955] depth_emb_rms:[v0:2.7598 v1:2.4244 v2:0.7005 v3:0.8988 v4:0.6169 v5:0.6747 v6:0.7668 v7:0.7704 v8:0.5513 v9:0.9462] +step:10800/20000 val_loss:2.1713 val_bpb:1.2860 train_time:524305ms step_avg:48.55ms +step:11000/20000 train_loss:2.1793 train_time:533975ms step_avg:48.54ms +step:11000 shared0_alpha:mean=0.466,std=0.111 shared1_alpha:mean=0.499,std=0.085 shared2_alpha:mean=0.640,std=0.093 shared3_alpha:mean=0.591,std=0.097 eff_mlp_scale:[v0:213.6691 v1:119.9817 v2:124.4433 v3:151.3931 v4:138.7015 v5:154.1734 v6:118.6821 v7:138.6654 v8:156.6430 v9:311.6816] eff_attn_scale:[v0:0.5860 v1:5.8328 v2:15.4418 v3:14.2474 v4:18.9388 v5:13.0220 v6:13.7556 v7:13.0312 v8:15.5422 v9:1.5060] eff_attn_bias:[v0:1.5910 v1:0.9502 v2:0.7568 v3:0.6878 v4:0.8010 v5:0.8507 v6:0.7955 v7:0.5414 v8:0.5386 v9:1.0828] eff_mlp_bias:[v0:0.7734 v1:0.4917 v2:0.6878 v3:0.4972 v4:0.5220 v5:0.5883 v6:0.6160 v7:0.4392 v8:0.6933 v9:0.8673] unique_attn_gain_rms:[u0:1.4925 u1:1.5029 u2:1.4712 u3:1.2383 u4:1.5744 u5:1.2606 u6:1.2145 u7:1.2604] unique_mlp_gain_rms:[u0:4.1407 u1:3.8494 u2:3.5314 u3:3.7378 u4:3.8391 u5:3.6991 u6:3.5615 u7:3.8380] depth_emb_rms:[v0:2.7958 v1:2.4631 v2:0.7075 v3:0.9063 v4:0.6246 v5:0.6803 v6:0.7732 v7:0.7780 v8:0.5576 v9:0.9574] +step:11000/20000 val_loss:2.1689 val_bpb:1.2845 train_time:533994ms step_avg:48.54ms +step:11200/20000 train_loss:2.1349 train_time:543663ms step_avg:48.54ms +step:11200 shared0_alpha:mean=0.465,std=0.111 shared1_alpha:mean=0.500,std=0.085 shared2_alpha:mean=0.641,std=0.094 shared3_alpha:mean=0.591,std=0.097 eff_mlp_scale:[v0:215.8313 v1:121.8494 v2:125.3600 v3:153.7078 v4:139.9292 v5:155.5923 v6:120.1607 v7:140.1653 v8:158.0295 v9:313.1032] eff_attn_scale:[v0:0.5848 v1:5.8484 v2:15.5169 v3:14.3857 v4:19.1142 v5:13.2699 v6:13.9117 v7:13.0780 v8:15.7900 v9:1.5185] eff_attn_bias:[v0:1.6131 v1:0.9557 v2:0.7623 v3:0.6933 v4:0.8065 v5:0.8563 v6:0.8010 v7:0.5469 v8:0.5414 v9:1.0883] eff_mlp_bias:[v0:0.7789 v1:0.4972 v2:0.6961 v3:0.5027 v4:0.5220 v5:0.5883 v6:0.6187 v7:0.4447 v8:0.6988 v9:0.8728] unique_attn_gain_rms:[u0:1.4938 u1:1.5127 u2:1.4804 u3:1.2475 u4:1.5946 u5:1.2709 u6:1.2307 u7:1.2725] unique_mlp_gain_rms:[u0:4.1834 u1:3.8876 u2:3.5680 u3:3.7793 u4:3.8794 u5:3.7377 u6:3.6037 u7:3.8763] depth_emb_rms:[v0:2.8304 v1:2.5025 v2:0.7159 v3:0.9164 v4:0.6314 v5:0.6852 v6:0.7801 v7:0.7839 v8:0.5621 v9:0.9701] +step:11200/20000 val_loss:2.1676 val_bpb:1.2837 train_time:543683ms step_avg:48.54ms +step:11400/20000 train_loss:2.1173 train_time:553357ms step_avg:48.54ms +step:11400 shared0_alpha:mean=0.464,std=0.111 shared1_alpha:mean=0.499,std=0.085 shared2_alpha:mean=0.641,std=0.094 shared3_alpha:mean=0.591,std=0.098 eff_mlp_scale:[v0:217.1286 v1:122.4559 v2:126.4491 v3:155.4205 v4:141.8289 v5:156.9948 v6:121.2287 v7:143.0964 v8:160.0841 v9:317.1282] eff_attn_scale:[v0:0.5803 v1:5.8919 v2:15.6171 v3:14.5270 v4:19.4517 v5:13.4377 v6:14.0913 v7:13.2944 v8:15.9819 v9:1.5541] eff_attn_bias:[v0:1.6352 v1:0.9667 v2:0.7679 v3:0.6961 v4:0.8176 v5:0.8618 v6:0.8065 v7:0.5497 v8:0.5469 v9:1.0938] eff_mlp_bias:[v0:0.7789 v1:0.4999 v2:0.6988 v3:0.5027 v4:0.5303 v5:0.5966 v6:0.6215 v7:0.4475 v8:0.7043 v9:0.8784] unique_attn_gain_rms:[u0:1.4959 u1:1.5173 u2:1.4890 u3:1.2539 u4:1.6013 u5:1.2743 u6:1.2319 u7:1.2834] unique_mlp_gain_rms:[u0:4.2189 u1:3.9232 u2:3.5984 u3:3.8070 u4:3.9136 u5:3.7709 u6:3.6350 u7:3.9107] depth_emb_rms:[v0:2.8607 v1:2.5358 v2:0.7216 v3:0.9229 v4:0.6369 v5:0.6944 v6:0.7883 v7:0.7902 v8:0.5686 v9:0.9807] +step:11400/20000 val_loss:2.1619 val_bpb:1.2804 train_time:553375ms step_avg:48.54ms +step:11600/20000 train_loss:2.1131 train_time:563047ms step_avg:48.54ms +step:11600 shared0_alpha:mean=0.465,std=0.111 shared1_alpha:mean=0.499,std=0.086 shared2_alpha:mean=0.641,std=0.094 shared3_alpha:mean=0.590,std=0.098 eff_mlp_scale:[v0:218.8497 v1:123.8324 v2:127.5133 v3:158.1082 v4:143.9129 v5:157.9495 v6:121.6908 v7:144.9325 v8:162.3451 v9:319.0909] eff_attn_scale:[v0:0.5916 v1:5.8788 v2:15.8021 v3:14.7183 v4:19.7788 v5:13.4866 v6:14.0865 v7:13.3884 v8:16.1633 v9:1.5484] eff_attn_bias:[v0:1.6462 v1:0.9667 v2:0.7734 v3:0.6988 v4:0.8231 v5:0.8673 v6:0.8121 v7:0.5524 v8:0.5497 v9:1.0993] eff_mlp_bias:[v0:0.7789 v1:0.5027 v2:0.6988 v3:0.5082 v4:0.5303 v5:0.5966 v6:0.6242 v7:0.4475 v8:0.7071 v9:0.8894] unique_attn_gain_rms:[u0:1.4974 u1:1.5232 u2:1.4907 u3:1.2610 u4:1.6051 u5:1.2768 u6:1.2376 u7:1.2908] unique_mlp_gain_rms:[u0:4.2442 u1:3.9428 u2:3.6211 u3:3.8307 u4:3.9388 u5:3.7946 u6:3.6573 u7:3.9339] depth_emb_rms:[v0:2.8830 v1:2.5616 v2:0.7272 v3:0.9286 v4:0.6403 v5:0.6994 v6:0.7942 v7:0.7937 v8:0.5736 v9:0.9912] +step:11600/20000 val_loss:2.1526 val_bpb:1.2749 train_time:563066ms step_avg:48.54ms +step:11800/20000 train_loss:2.1363 train_time:572743ms step_avg:48.54ms +step:11800 shared0_alpha:mean=0.465,std=0.111 shared1_alpha:mean=0.498,std=0.086 shared2_alpha:mean=0.641,std=0.094 shared3_alpha:mean=0.590,std=0.098 eff_mlp_scale:[v0:219.8162 v1:125.1367 v2:128.1562 v3:160.7606 v4:145.2331 v5:159.4381 v6:122.8895 v7:147.4224 v8:164.5498 v9:323.2051] eff_attn_scale:[v0:0.5856 v1:5.9239 v2:15.9211 v3:14.7951 v4:19.8998 v5:13.5798 v6:14.2835 v7:13.4582 v8:16.3692 v9:1.5763] eff_attn_bias:[v0:1.6573 v1:0.9723 v2:0.7734 v3:0.7016 v4:0.8176 v5:0.8673 v6:0.8121 v7:0.5552 v8:0.5524 v9:1.1049] eff_mlp_bias:[v0:0.7789 v1:0.5082 v2:0.7016 v3:0.5082 v4:0.5331 v5:0.5994 v6:0.6270 v7:0.4502 v8:0.7126 v9:0.8949] unique_attn_gain_rms:[u0:1.4962 u1:1.5250 u2:1.4942 u3:1.2640 u4:1.6084 u5:1.2772 u6:1.2376 u7:1.2929] unique_mlp_gain_rms:[u0:4.2619 u1:3.9589 u2:3.6342 u3:3.8435 u4:3.9537 u5:3.8102 u6:3.6742 u7:3.9506] depth_emb_rms:[v0:2.8964 v1:2.5813 v2:0.7313 v3:0.9329 v4:0.6444 v5:0.7043 v6:0.7964 v7:0.7982 v8:0.5760 v9:0.9986] +step:11800/20000 val_loss:2.1437 val_bpb:1.2696 train_time:572761ms step_avg:48.54ms +step:12000/20000 train_loss:2.1086 train_time:582432ms step_avg:48.54ms +step:12000 shared0_alpha:mean=0.465,std=0.111 shared1_alpha:mean=0.498,std=0.085 shared2_alpha:mean=0.641,std=0.094 shared3_alpha:mean=0.589,std=0.098 eff_mlp_scale:[v0:220.4730 v1:125.6878 v2:128.6463 v3:161.2644 v4:146.0937 v5:160.1403 v6:123.3595 v7:148.5330 v8:165.5249 v9:325.0007] eff_attn_scale:[v0:0.5850 v1:5.9210 v2:15.9494 v3:14.8360 v4:19.9685 v5:13.5040 v6:14.3088 v7:13.4954 v8:16.3183 v9:1.5809] eff_attn_bias:[v0:1.6573 v1:0.9723 v2:0.7734 v3:0.7016 v4:0.8176 v5:0.8673 v6:0.8121 v7:0.5580 v8:0.5524 v9:1.1104] eff_mlp_bias:[v0:0.7789 v1:0.5082 v2:0.7016 v3:0.5110 v4:0.5359 v5:0.6021 v6:0.6270 v7:0.4530 v8:0.7126 v9:0.8949] unique_attn_gain_rms:[u0:1.4951 u1:1.5228 u2:1.4950 u3:1.2622 u4:1.6080 u5:1.2760 u6:1.2375 u7:1.2949] unique_mlp_gain_rms:[u0:4.2695 u1:3.9685 u2:3.6436 u3:3.8510 u4:3.9624 u5:3.8188 u6:3.6806 u7:3.9584] depth_emb_rms:[v0:2.9029 v1:2.5933 v2:0.7336 v3:0.9358 v4:0.6486 v5:0.7065 v6:0.7999 v7:0.8001 v8:0.5780 v9:1.0044] +step:12000/20000 val_loss:2.1360 val_bpb:1.2651 train_time:582450ms step_avg:48.54ms +step:12200/20000 train_loss:2.2462 train_time:592120ms step_avg:48.53ms +step:12200 shared0_alpha:mean=0.466,std=0.111 shared1_alpha:mean=0.498,std=0.085 shared2_alpha:mean=0.641,std=0.094 shared3_alpha:mean=0.589,std=0.098 eff_mlp_scale:[v0:220.6471 v1:126.0874 v2:129.0463 v3:162.2385 v4:146.6564 v5:160.6495 v6:123.7430 v7:149.4302 v8:166.1624 v9:326.0877] eff_attn_scale:[v0:0.5855 v1:5.9176 v2:15.9139 v3:14.8173 v4:19.9628 v5:13.4962 v6:14.1861 v7:13.4784 v8:16.3137 v9:1.5850] eff_attn_bias:[v0:1.6573 v1:0.9723 v2:0.7734 v3:0.7016 v4:0.8231 v5:0.8728 v6:0.8121 v7:0.5552 v8:0.5524 v9:1.1104] eff_mlp_bias:[v0:0.7789 v1:0.5082 v2:0.7016 v3:0.5110 v4:0.5359 v5:0.6021 v6:0.6270 v7:0.4530 v8:0.7126 v9:0.9005] unique_attn_gain_rms:[u0:1.4927 u1:1.5196 u2:1.4961 u3:1.2622 u4:1.6072 u5:1.2762 u6:1.2361 u7:1.2936] unique_mlp_gain_rms:[u0:4.2725 u1:3.9708 u2:3.6456 u3:3.8541 u4:3.9659 u5:3.8218 u6:3.6839 u7:3.9604] depth_emb_rms:[v0:2.9056 v1:2.5973 v2:0.7353 v3:0.9368 v4:0.6505 v5:0.7085 v6:0.8017 v7:0.8012 v8:0.5785 v9:1.0075] +step:12200/20000 val_loss:2.1285 val_bpb:1.2606 train_time:592138ms step_avg:48.54ms +step:12363/20000 val_loss:2.1239 val_bpb:1.2579 train_time:600044ms step_avg:48.54ms +stopping_early: wallclock_cap train_time:600044ms step:12363/20000 +peak memory allocated: 11709 MiB reserved: 11886 MiB +Serialized model: 45246079 bytes +Code size: 63793 bytes +Total submission size: 45309872 bytes +Serialized model int8+zlib: 10777147 bytes (payload:11704512 raw_torch:11737197 payload_ratio:3.86x) +Total submission size int8+zlib: 10840940 bytes +final_int8_zlib_roundtrip val_loss:2.1382 val_bpb:1.2663 eval_time:1531ms +final_int8_zlib_roundtrip_exact val_loss:2.13815924 val_bpb:1.26633834 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_Q.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_Q.txt new file mode 100644 index 0000000000..6a60e09f70 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_Q.txt @@ -0,0 +1,1688 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) + use_depth_embed = bool(int(os.environ.get("USE_DEPTH_EMBED", "0"))) + use_unique_norms = bool(int(os.environ.get("USE_UNIQUE_NORMS", "0"))) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta,depth_embed,unique_attn_gain,unique_mlp_gain", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0, use_bias: bool = False): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + if use_bias: + self.attn_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + self.mlp_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + else: + self.attn_beta = None + self.mlp_beta = None + + def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + ab = self.attn_beta[v] if self.attn_beta is not None else None + mb = self.mlp_beta[v] if self.mlp_beta is not None else None + return ag, mg, ab, mb + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None, + ts_mlp_beta: Tensor | None = None, + depth_emb: Tensor | None = None, + ext_attn_gain: Tensor | None = None, + ext_mlp_gain: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + if depth_emb is not None: + x = x + depth_emb + attn_normed = self.attn_norm(x) + if ext_attn_gain is not None: + attn_normed = attn_normed * ext_attn_gain[None, None, :] + attn_out = self.attn(attn_normed) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + if self.use_peri_norm: + if ext_mlp_gain is not None: + m_input = F.rms_norm(x, (x.size(-1),)) * ext_mlp_gain[None, None, :] + else: + m_input = x + mlp_out = self.mlp_out_norm(self.mlp(m_input)) + else: + mlp_normed = self.mlp_norm(x) + if ext_mlp_gain is not None: + mlp_normed = mlp_normed * ext_mlp_gain[None, None, :] + mlp_out = self.mlp(mlp_normed) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + if ts_mlp_beta is not None: + x = x + ts_mlp_beta[None, None, :] + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + use_timestep_bias: bool = False, + use_depth_embed: bool = False, + use_unique_norms: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None + self.use_depth_embed = use_depth_embed + if self.use_depth_embed: + self.depth_embeddings = nn.Parameter(torch.zeros(effective_layers, model_dim, dtype=torch.float32)) + self.use_unique_norms = use_unique_norms + if self.use_unique_norms: + num_unique = num_shared * self.num_loops + self.unique_attn_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + self.unique_mlp_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + self.use_depth_embed = False + self.use_unique_norms = False + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None, None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) + v += 1 + uid = 0 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + if self.use_unique_norms: + ag_n = self.unique_attn_gains[uid].to(dtype=x.dtype) + mg_n = self.unique_mlp_gains[uid].to(dtype=x.dtype) + x = block(x, x0, ag, mg, ab, mb, de, ag_n, mg_n) + else: + x = block(x, x0, ag, mg, ab, mb, de) + uid += 1 + v += 1 + for block in self.coda_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + effective_count = gpt.num_prelude + len(gpt.shared_blocks) * gpt.num_loops + gpt.num_coda + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + + # Prelude blocks + for block in gpt.prelude_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Shared positions + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Coda blocks + for block in gpt.coda_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + if gpt.timestep_scale is not None and gpt.timestep_scale.attn_beta is not None: + attn_bias_norms: list[str] = [] + mlp_bias_norms: list[str] = [] + for vi in range(effective_count): + ab_rms = gpt.timestep_scale.attn_beta[vi].norm().item() / gpt.timestep_scale.attn_beta[vi].numel() ** 0.5 + mb_rms = gpt.timestep_scale.mlp_beta[vi].norm().item() / gpt.timestep_scale.mlp_beta[vi].numel() ** 0.5 + attn_bias_norms.append(f"v{vi}:{ab_rms:.4f}") + mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") + parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") + parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") + if gpt.use_unique_norms: + un_attn: list[str] = [] + un_mlp: list[str] = [] + for ui in range(gpt.unique_attn_gains.size(0)): + an_rms = gpt.unique_attn_gains[ui].norm().item() / gpt.unique_attn_gains[ui].numel() ** 0.5 + un_attn.append(f"u{ui}:{an_rms:.4f}") + mn_rms = gpt.unique_mlp_gains[ui].norm().item() / gpt.unique_mlp_gains[ui].numel() ** 0.5 + un_mlp.append(f"u{ui}:{mn_rms:.4f}") + parts.append("unique_attn_gain_rms:[" + " ".join(un_attn) + "]") + parts.append("unique_mlp_gain_rms:[" + " ".join(un_mlp) + "]") + if gpt.use_depth_embed: + de_norms: list[str] = [] + for vi in range(effective_count): + de_rms = gpt.depth_embeddings[vi].norm().item() / gpt.depth_embeddings[vi].numel() ** 0.5 + de_norms.append(f"v{vi}:{de_rms:.4f}") + parts.append("depth_emb_rms:[" + " ".join(de_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + use_timestep_bias=args.use_timestep_bias, + use_depth_embed=args.use_depth_embed, + use_unique_norms=args.use_unique_norms, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + if base_model.use_unique_norms: + block_named_params.extend([("unique_attn_gains", base_model.unique_attn_gains)]) + block_named_params.extend([("unique_mlp_gains", base_model.unique_mlp_gains)]) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + if base_model.use_depth_embed: + scalar_params.append(base_model.depth_embeddings) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + num_shared = len(base_model.shared_blocks) + eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{num_shared} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + log0(f"depth_embed:{'enabled' if base_model.use_depth_embed else 'disabled'}") + log0(f"unique_norms:{'enabled' if base_model.use_unique_norms else 'disabled'}") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Apr 2 14:47:57 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 45C P0 126W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 36C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 34C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 43C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 44C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 43C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11591728 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:3 coda:1 effective_layers:14 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:28672 +depth_embed:enabled +unique_norms:enabled +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9377 train_time:44ms step_avg:43.53ms +step:2/20000 train_loss:9.2622 train_time:108ms step_avg:54.07ms +step:3/20000 train_loss:8.5793 train_time:174ms step_avg:57.88ms +step:4/20000 train_loss:10.0479 train_time:240ms step_avg:59.88ms +step:5/20000 train_loss:9.6318 train_time:305ms step_avg:61.08ms +step:6/20000 train_loss:9.1912 train_time:371ms step_avg:61.89ms +step:7/20000 train_loss:7.6010 train_time:437ms step_avg:62.46ms +step:8/20000 train_loss:6.6930 train_time:503ms step_avg:62.90ms +step:9/20000 train_loss:6.1373 train_time:570ms step_avg:63.32ms +step:10/20000 train_loss:5.7567 train_time:636ms step_avg:63.56ms +step:200/20000 train_loss:2.8021 train_time:13306ms step_avg:66.53ms +step:200 shared0_alpha:mean=0.459,std=0.049 shared1_alpha:mean=0.470,std=0.044 shared2_alpha:mean=0.483,std=0.042 shared3_alpha:mean=0.502,std=0.044 eff_mlp_scale:[v0:37.6316 v1:26.6396 v2:28.4342 v3:29.8879 v4:32.7531 v5:31.8196 v6:28.7229 v7:29.2931 v8:32.7531 v9:31.5236 v10:30.0219 v11:31.8209 v12:36.4279 v13:59.1815] eff_attn_scale:[v0:14.3606 v1:11.1265 v2:11.9308 v3:10.9466 v4:11.7137 v5:10.5827 v6:10.9082 v7:10.1588 v8:10.9553 v9:10.5827 v10:10.9934 v11:10.5734 v12:10.9553 v13:15.3549] eff_attn_bias:[v0:0.1333 v1:0.1119 v2:0.1132 v3:0.1105 v4:0.1229 v5:0.1298 v6:0.1236 v7:0.1146 v8:0.1257 v9:0.1277 v10:0.1243 v11:0.1105 v12:0.1098 v13:0.1132] eff_mlp_bias:[v0:0.1084 v1:0.1029 v2:0.1029 v3:0.1008 v4:0.1181 v5:0.1126 v6:0.1070 v7:0.1063 v8:0.1188 v9:0.1119 v10:0.1043 v11:0.0974 v12:0.1132 v13:0.1906] unique_attn_gain_rms:[u0:0.8607 u1:0.8856 u2:0.8440 u3:0.8609 u4:0.8469 u5:0.8540 u6:0.8223 u7:0.8527 u8:0.8632 u9:0.9022 u10:0.8839 u11:0.8919] unique_mlp_gain_rms:[u0:1.1060 u1:1.0951 u2:1.1218 u3:1.1062 u4:1.1143 u5:1.0984 u6:1.1041 u7:1.1090 u8:1.1171 u9:1.1143 u10:1.1232 u11:1.1406] depth_emb_rms:[v0:0.1282 v1:0.1091 v2:0.1033 v3:0.1032 v4:0.1015 v5:0.1192 v6:0.1137 v7:0.1086 v8:0.1071 v9:0.1195 v10:0.1135 v11:0.1061 v12:0.0988 v13:0.1148] +step:200/20000 val_loss:2.7847 val_bpb:1.6493 train_time:13372ms step_avg:66.86ms +step:400/20000 train_loss:2.3801 train_time:26717ms step_avg:66.79ms +step:400 shared0_alpha:mean=0.470,std=0.058 shared1_alpha:mean=0.487,std=0.053 shared2_alpha:mean=0.507,std=0.049 shared3_alpha:mean=0.526,std=0.050 eff_mlp_scale:[v0:45.9944 v1:32.4993 v2:36.9000 v3:39.4617 v4:40.6860 v5:45.0796 v6:40.2545 v7:40.3085 v8:40.8584 v9:43.5071 v10:39.2482 v11:39.8005 v12:40.8584 v13:74.2397] eff_attn_scale:[v0:6.4965 v1:7.0757 v2:7.1285 v3:6.7871 v4:7.1961 v5:6.9459 v6:7.1285 v7:6.7564 v8:7.0746 v9:7.1082 v10:6.6532 v11:6.1729 v12:6.1030 v13:9.2116] eff_attn_bias:[v0:0.1595 v1:0.1464 v2:0.1533 v3:0.1526 v4:0.1761 v5:0.1761 v6:0.1685 v7:0.1561 v8:0.1706 v9:0.1595 v10:0.1526 v11:0.1277 v12:0.1229 v13:0.1202] eff_mlp_bias:[v0:0.1692 v1:0.1312 v2:0.1298 v3:0.1277 v4:0.1505 v5:0.1409 v6:0.1277 v7:0.1250 v8:0.1416 v9:0.1374 v10:0.1250 v11:0.1174 v12:0.1347 v13:0.2527] unique_attn_gain_rms:[u0:0.7978 u1:0.8138 u2:0.7868 u3:0.7917 u4:0.7973 u5:0.8116 u6:0.7924 u7:0.8073 u8:0.8300 u9:0.8529 u10:0.8307 u11:0.8180] unique_mlp_gain_rms:[u0:1.2085 u1:1.1879 u2:1.2047 u3:1.1837 u4:1.1804 u5:1.1666 u6:1.1662 u7:1.1670 u8:1.1764 u9:1.1718 u10:1.1683 u11:1.1919] depth_emb_rms:[v0:0.1706 v1:0.1730 v2:0.1346 v3:0.1324 v4:0.1298 v5:0.1538 v6:0.1445 v7:0.1309 v8:0.1270 v9:0.1450 v10:0.1417 v11:0.1291 v12:0.1213 v13:0.1392] +step:400/20000 val_loss:2.5793 val_bpb:1.5276 train_time:26748ms step_avg:66.87ms +step:600/20000 train_loss:2.5800 train_time:40125ms step_avg:66.87ms +step:600 shared0_alpha:mean=0.475,std=0.064 shared1_alpha:mean=0.494,std=0.059 shared2_alpha:mean=0.523,std=0.051 shared3_alpha:mean=0.544,std=0.052 eff_mlp_scale:[v0:53.5640 v1:37.3130 v2:43.0464 v3:45.5735 v4:45.1132 v5:54.0658 v6:48.4725 v7:48.1053 v8:46.3714 v9:51.7813 v10:44.8551 v11:44.3075 v12:43.4956 v13:90.7580] eff_attn_scale:[v0:3.1844 v1:4.4703 v2:4.8041 v3:4.7165 v4:5.0756 v5:4.6679 v6:4.8765 v7:4.7400 v8:5.0992 v9:4.7914 v10:4.3454 v11:4.0090 v12:4.0368 v13:6.2970] eff_attn_bias:[v0:0.1892 v1:0.1864 v2:0.1878 v3:0.1961 v4:0.2251 v5:0.2265 v6:0.2210 v7:0.2030 v8:0.2154 v9:0.1961 v10:0.1823 v11:0.1464 v12:0.1374 v13:0.1485] eff_mlp_bias:[v0:0.2458 v1:0.1678 v2:0.1643 v3:0.1602 v4:0.1878 v5:0.1795 v6:0.1561 v7:0.1498 v8:0.1761 v9:0.1719 v10:0.1512 v11:0.1436 v12:0.1637 v13:0.2693] unique_attn_gain_rms:[u0:0.7416 u1:0.7682 u2:0.7542 u3:0.7442 u4:0.7674 u5:0.7787 u6:0.7645 u7:0.7748 u8:0.7989 u9:0.8133 u10:0.7947 u11:0.7634] unique_mlp_gain_rms:[u0:1.3178 u1:1.2970 u2:1.3011 u3:1.2687 u4:1.2576 u5:1.2442 u6:1.2383 u7:1.2309 u8:1.2460 u9:1.2340 u10:1.2209 u11:1.2488] depth_emb_rms:[v0:0.2503 v1:0.2520 v2:0.1731 v3:0.1680 v4:0.1633 v5:0.1930 v6:0.1860 v7:0.1614 v8:0.1540 v9:0.1822 v10:0.1791 v11:0.1580 v12:0.1497 v13:0.1692] +step:600/20000 val_loss:2.4808 val_bpb:1.4693 train_time:40157ms step_avg:66.93ms +step:800/20000 train_loss:2.3381 train_time:53571ms step_avg:66.96ms +step:800 shared0_alpha:mean=0.478,std=0.068 shared1_alpha:mean=0.499,std=0.063 shared2_alpha:mean=0.534,std=0.053 shared3_alpha:mean=0.554,std=0.053 eff_mlp_scale:[v0:60.8383 v1:42.4494 v2:47.8238 v3:50.4777 v4:48.5510 v5:60.7569 v6:55.0353 v7:53.8680 v8:50.7747 v9:57.5380 v10:48.9625 v11:47.6525 v12:45.9566 v13:103.2965] eff_attn_scale:[v0:2.1082 v1:3.5075 v2:3.9126 v3:3.8369 v4:4.2011 v5:3.7822 v6:3.8716 v7:3.8369 v8:4.1808 v9:3.8456 v10:3.3800 v11:3.0936 v12:3.1661 v13:4.8931] eff_attn_bias:[v0:0.2058 v1:0.2251 v2:0.2210 v3:0.2320 v4:0.2583 v5:0.2679 v6:0.2679 v7:0.2458 v8:0.2569 v9:0.2334 v10:0.2085 v11:0.1650 v12:0.1554 v13:0.1837] eff_mlp_bias:[v0:0.3107 v1:0.2058 v2:0.1920 v3:0.1851 v4:0.2182 v5:0.2182 v6:0.1878 v7:0.1754 v8:0.2099 v9:0.2058 v10:0.1761 v11:0.1685 v12:0.1920 v13:0.2804] unique_attn_gain_rms:[u0:0.7149 u1:0.7461 u2:0.7394 u3:0.7258 u4:0.7519 u5:0.7569 u6:0.7469 u7:0.7528 u8:0.7809 u9:0.7875 u10:0.7655 u11:0.7220] unique_mlp_gain_rms:[u0:1.4183 u1:1.3990 u2:1.3872 u3:1.3440 u4:1.3333 u5:1.3172 u6:1.3051 u7:1.2968 u8:1.3132 u9:1.2984 u10:1.2751 u11:1.3091] depth_emb_rms:[v0:0.3159 v1:0.3198 v2:0.2137 v3:0.1982 v4:0.1893 v5:0.2260 v6:0.2270 v7:0.1953 v8:0.1815 v9:0.2190 v10:0.2154 v11:0.1856 v12:0.1772 v13:0.1994] +step:800/20000 val_loss:2.4259 val_bpb:1.4367 train_time:53602ms step_avg:67.00ms +step:1000/20000 train_loss:2.4182 train_time:67033ms step_avg:67.03ms +step:1000 shared0_alpha:mean=0.480,std=0.073 shared1_alpha:mean=0.502,std=0.068 shared2_alpha:mean=0.543,std=0.055 shared3_alpha:mean=0.562,std=0.054 eff_mlp_scale:[v0:69.2018 v1:46.7660 v2:52.3542 v3:54.6228 v4:51.8236 v5:67.4105 v6:60.2270 v7:58.5244 v8:54.1099 v9:63.1973 v10:52.3542 v11:50.7211 v12:48.5846 v13:114.0105] eff_attn_scale:[v0:1.5716 v1:3.1981 v2:3.5650 v3:3.5983 v4:3.8464 v5:3.3725 v6:3.4518 v7:3.4322 v8:3.6963 v9:3.3725 v10:2.9425 v11:2.6757 v12:2.7957 v13:4.1415] eff_attn_bias:[v0:0.2224 v1:0.2541 v2:0.2527 v3:0.2638 v4:0.2900 v5:0.3107 v6:0.3094 v7:0.2831 v8:0.2873 v9:0.2638 v10:0.2293 v11:0.1809 v12:0.1719 v13:0.2210] eff_mlp_bias:[v0:0.3674 v1:0.2375 v2:0.2237 v3:0.2085 v4:0.2444 v5:0.2500 v6:0.2113 v7:0.1975 v8:0.2389 v9:0.2334 v10:0.1975 v11:0.1878 v12:0.2168 v13:0.2873] unique_attn_gain_rms:[u0:0.7045 u1:0.7395 u2:0.7437 u3:0.7168 u4:0.7423 u5:0.7454 u6:0.7377 u7:0.7317 u8:0.7696 u9:0.7626 u10:0.7430 u11:0.6886] unique_mlp_gain_rms:[u0:1.5158 u1:1.4913 u2:1.4718 u3:1.4215 u4:1.4076 u5:1.3906 u6:1.3703 u7:1.3666 u8:1.3803 u9:1.3633 u10:1.3309 u11:1.3715] depth_emb_rms:[v0:0.3761 v1:0.3772 v2:0.2467 v3:0.2317 v4:0.2154 v5:0.2555 v6:0.2627 v7:0.2223 v8:0.2053 v9:0.2501 v10:0.2464 v11:0.2091 v12:0.1991 v13:0.2266] +step:1000/20000 val_loss:2.3857 val_bpb:1.4129 train_time:67064ms step_avg:67.06ms +step:1200/20000 train_loss:2.4346 train_time:80476ms step_avg:67.06ms +step:1200 shared0_alpha:mean=0.481,std=0.076 shared1_alpha:mean=0.505,std=0.070 shared2_alpha:mean=0.549,std=0.057 shared3_alpha:mean=0.568,std=0.056 eff_mlp_scale:[v0:75.6965 v1:50.5703 v2:55.6886 v3:57.9694 v4:54.6312 v5:71.6231 v6:64.1629 v7:61.9673 v8:57.3628 v9:66.8482 v10:54.4779 v11:53.1719 v12:51.1192 v13:123.7605] eff_attn_scale:[v0:1.3316 v1:3.0961 v2:3.6910 v3:3.5712 v4:3.7455 v5:3.2837 v6:3.3588 v7:3.3047 v8:3.5465 v9:3.1711 v10:2.7682 v11:2.4697 v12:2.6056 v13:3.5715] eff_attn_bias:[v0:0.2417 v1:0.2873 v2:0.2804 v3:0.2928 v4:0.3176 v5:0.3411 v6:0.3425 v7:0.3121 v8:0.3135 v9:0.2942 v10:0.2541 v11:0.1961 v12:0.1892 v13:0.2555] eff_mlp_bias:[v0:0.4088 v1:0.2652 v2:0.2527 v3:0.2320 v4:0.2638 v5:0.2762 v6:0.2362 v7:0.2154 v8:0.2624 v9:0.2596 v10:0.2168 v11:0.2058 v12:0.2389 v13:0.2928] unique_attn_gain_rms:[u0:0.7035 u1:0.7565 u2:0.7565 u3:0.7171 u4:0.7426 u5:0.7423 u6:0.7310 u7:0.7251 u8:0.7563 u9:0.7397 u10:0.7183 u11:0.6599] unique_mlp_gain_rms:[u0:1.6082 u1:1.5825 u2:1.5506 u3:1.4951 u4:1.4771 u5:1.4607 u6:1.4310 u7:1.4282 u8:1.4463 u9:1.4273 u10:1.3866 u11:1.4324] depth_emb_rms:[v0:0.4279 v1:0.4263 v2:0.2790 v3:0.2624 v4:0.2401 v5:0.2782 v6:0.2918 v7:0.2480 v8:0.2257 v9:0.2764 v10:0.2743 v11:0.2312 v12:0.2182 v13:0.2516] +step:1200/20000 val_loss:2.3546 val_bpb:1.3945 train_time:80508ms step_avg:67.09ms +step:1400/20000 train_loss:2.4822 train_time:93909ms step_avg:67.08ms +step:1400 shared0_alpha:mean=0.481,std=0.079 shared1_alpha:mean=0.506,std=0.073 shared2_alpha:mean=0.555,std=0.058 shared3_alpha:mean=0.574,std=0.058 eff_mlp_scale:[v0:81.2627 v1:53.6776 v2:58.6477 v3:61.4063 v4:56.6035 v5:75.7278 v6:67.7339 v7:65.0907 v8:59.7924 v9:70.8277 v10:56.9956 v11:56.0845 v12:53.4146 v13:132.7506] eff_attn_scale:[v0:1.1331 v1:3.2252 v2:4.1843 v3:3.7868 v4:3.8709 v5:3.4528 v6:3.5156 v7:3.2890 v8:3.5272 v9:3.1682 v10:2.8469 v11:2.4179 v12:2.5686 v13:3.2790] eff_attn_bias:[v0:0.2624 v1:0.3246 v2:0.3094 v3:0.3107 v4:0.3397 v5:0.3701 v6:0.3729 v7:0.3370 v8:0.3356 v9:0.3176 v10:0.2762 v11:0.2058 v12:0.2044 v13:0.2900] eff_mlp_bias:[v0:0.4419 v1:0.2859 v2:0.2817 v3:0.2555 v4:0.2831 v5:0.3038 v6:0.2610 v7:0.2306 v8:0.2804 v9:0.2817 v10:0.2362 v11:0.2196 v12:0.2583 v13:0.3025] unique_attn_gain_rms:[u0:0.7178 u1:0.7933 u2:0.7759 u3:0.7226 u4:0.7483 u5:0.7405 u6:0.7238 u7:0.7161 u8:0.7464 u9:0.7205 u10:0.6981 u11:0.6377] unique_mlp_gain_rms:[u0:1.7037 u1:1.6699 u2:1.6219 u3:1.5678 u4:1.5414 u5:1.5263 u6:1.4896 u7:1.4897 u8:1.5097 u9:1.4895 u10:1.4424 u11:1.4932] depth_emb_rms:[v0:0.4820 v1:0.4706 v2:0.3102 v3:0.2957 v4:0.2644 v5:0.2997 v6:0.3201 v7:0.2745 v8:0.2424 v9:0.2978 v10:0.2989 v11:0.2531 v12:0.2329 v13:0.2744] +step:1400/20000 val_loss:2.3326 val_bpb:1.3815 train_time:93941ms step_avg:67.10ms +step:1600/20000 train_loss:2.1487 train_time:107339ms step_avg:67.09ms +step:1600 shared0_alpha:mean=0.483,std=0.081 shared1_alpha:mean=0.508,std=0.075 shared2_alpha:mean=0.561,std=0.059 shared3_alpha:mean=0.578,std=0.059 eff_mlp_scale:[v0:85.4559 v1:56.4494 v2:61.3697 v3:63.2395 v4:58.7407 v5:79.6665 v6:70.6172 v7:67.8160 v8:61.9816 v9:73.7485 v10:59.2680 v11:57.8308 v12:55.4999 v13:140.1054] eff_attn_scale:[v0:1.0217 v1:3.4798 v2:4.9081 v3:4.4406 v4:4.3375 v5:3.9198 v6:3.9674 v7:3.6626 v8:3.7529 v9:3.3998 v10:3.0880 v11:2.5999 v12:2.6968 v13:3.0157] eff_attn_bias:[v0:0.2817 v1:0.3536 v2:0.3342 v3:0.3259 v4:0.3591 v5:0.3895 v6:0.4005 v7:0.3563 v8:0.3563 v9:0.3384 v10:0.2969 v11:0.2141 v12:0.2182 v13:0.3218] eff_mlp_bias:[v0:0.4751 v1:0.3038 v2:0.3052 v3:0.2707 v4:0.2928 v5:0.3232 v6:0.2831 v7:0.2431 v8:0.2969 v9:0.2983 v10:0.2527 v11:0.2306 v12:0.2790 v13:0.3135] unique_attn_gain_rms:[u0:0.7455 u1:0.8402 u2:0.8234 u3:0.7424 u4:0.7673 u5:0.7538 u6:0.7336 u7:0.7144 u8:0.7421 u9:0.7118 u10:0.6786 u11:0.6161] unique_mlp_gain_rms:[u0:1.7899 u1:1.7497 u2:1.6931 u3:1.6353 u4:1.6068 u5:1.5900 u6:1.5486 u7:1.5504 u8:1.5713 u9:1.5486 u10:1.4979 u11:1.5537] depth_emb_rms:[v0:0.5266 v1:0.5120 v2:0.3379 v3:0.3243 v4:0.2811 v5:0.3140 v6:0.3427 v7:0.2996 v8:0.2568 v9:0.3170 v10:0.3189 v11:0.2731 v12:0.2462 v13:0.2966] +step:1600/20000 val_loss:2.3195 val_bpb:1.3737 train_time:107371ms step_avg:67.11ms +step:1800/20000 train_loss:2.2499 train_time:120769ms step_avg:67.09ms +step:1800 shared0_alpha:mean=0.482,std=0.083 shared1_alpha:mean=0.509,std=0.076 shared2_alpha:mean=0.566,std=0.060 shared3_alpha:mean=0.583,std=0.060 eff_mlp_scale:[v0:90.6851 v1:58.7722 v2:63.3582 v3:65.3548 v4:60.9047 v5:82.3737 v6:72.7131 v7:70.4146 v8:64.1969 v9:76.8204 v10:60.8069 v11:59.8735 v12:57.6126 v13:147.4395] eff_attn_scale:[v0:0.9086 v1:3.7952 v2:5.8306 v3:5.1421 v4:5.0788 v5:4.6146 v6:4.6645 v7:4.1997 v8:4.2256 v9:3.8383 v10:3.5657 v11:2.9091 v12:3.0067 v13:2.8339] eff_attn_bias:[v0:0.3025 v1:0.3729 v2:0.3591 v3:0.3411 v4:0.3784 v5:0.4143 v6:0.4254 v7:0.3701 v8:0.3757 v9:0.3591 v10:0.3149 v11:0.2224 v12:0.2320 v13:0.3522] eff_mlp_bias:[v0:0.4999 v1:0.3163 v2:0.3259 v3:0.2817 v4:0.3080 v5:0.3384 v6:0.3038 v7:0.2541 v8:0.3107 v9:0.3163 v10:0.2679 v11:0.2403 v12:0.2955 v13:0.3246] unique_attn_gain_rms:[u0:0.7810 u1:0.8889 u2:0.8642 u3:0.7738 u4:0.7992 u5:0.7809 u6:0.7492 u7:0.7237 u8:0.7437 u9:0.7095 u10:0.6694 u11:0.6048] unique_mlp_gain_rms:[u0:1.8723 u1:1.8269 u2:1.7603 u3:1.7023 u4:1.6703 u5:1.6518 u6:1.6020 u7:1.6126 u8:1.6334 u9:1.6100 u10:1.5553 u11:1.6134] depth_emb_rms:[v0:0.5745 v1:0.5499 v2:0.3591 v3:0.3521 v4:0.2964 v5:0.3314 v6:0.3619 v7:0.3233 v8:0.2696 v9:0.3336 v10:0.3388 v11:0.2914 v12:0.2580 v13:0.3175] +step:1800/20000 val_loss:2.3015 val_bpb:1.3631 train_time:120800ms step_avg:67.11ms +step:2000/20000 train_loss:2.3006 train_time:134192ms step_avg:67.10ms +step:2000 shared0_alpha:mean=0.482,std=0.084 shared1_alpha:mean=0.509,std=0.078 shared2_alpha:mean=0.570,std=0.061 shared3_alpha:mean=0.587,std=0.062 eff_mlp_scale:[v0:95.0374 v1:60.9236 v2:65.5670 v3:67.7429 v4:62.4254 v5:84.8244 v6:74.5664 v7:72.0034 v8:65.7548 v9:78.7320 v10:61.7101 v11:60.9260 v12:59.5122 v13:153.9182] eff_attn_scale:[v0:0.8580 v1:4.1187 v2:6.7320 v3:6.0303 v4:5.8183 v5:5.3681 v6:5.5035 v7:4.9278 v8:4.8926 v9:4.3732 v10:4.2259 v11:3.4427 v12:3.4381 v13:2.6700] eff_attn_bias:[v0:0.3218 v1:0.3977 v2:0.3784 v3:0.3536 v4:0.3950 v5:0.4337 v6:0.4475 v7:0.3839 v8:0.3895 v9:0.3757 v10:0.3301 v11:0.2320 v12:0.2472 v13:0.3812] eff_mlp_bias:[v0:0.5303 v1:0.3301 v2:0.3494 v3:0.2914 v4:0.3176 v5:0.3536 v6:0.3232 v7:0.2638 v8:0.3204 v9:0.3273 v10:0.2831 v11:0.2486 v12:0.3121 v13:0.3370] unique_attn_gain_rms:[u0:0.8205 u1:0.9371 u2:0.9082 u3:0.8010 u4:0.8448 u5:0.8088 u6:0.7753 u7:0.7389 u8:0.7637 u9:0.7226 u10:0.6735 u11:0.6068] unique_mlp_gain_rms:[u0:1.9521 u1:1.9059 u2:1.8253 u3:1.7673 u4:1.7312 u5:1.7112 u6:1.6566 u7:1.6707 u8:1.6942 u9:1.6729 u10:1.6092 u11:1.6726] depth_emb_rms:[v0:0.6163 v1:0.5958 v2:0.3793 v3:0.3796 v4:0.3118 v5:0.3450 v6:0.3812 v7:0.3458 v8:0.2813 v9:0.3471 v10:0.3532 v11:0.3083 v12:0.2690 v13:0.3377] +step:2000/20000 val_loss:2.2881 val_bpb:1.3551 train_time:134224ms step_avg:67.11ms +step:2200/20000 train_loss:2.1329 train_time:147608ms step_avg:67.09ms +step:2200 shared0_alpha:mean=0.482,std=0.085 shared1_alpha:mean=0.511,std=0.078 shared2_alpha:mean=0.574,std=0.062 shared3_alpha:mean=0.591,std=0.063 eff_mlp_scale:[v0:99.4654 v1:63.1540 v2:67.0352 v3:69.3479 v4:64.0140 v5:87.3709 v6:76.1174 v7:73.6552 v8:67.3832 v9:81.1980 v10:62.7103 v11:62.4562 v12:61.4871 v13:161.1490] eff_attn_scale:[v0:0.8092 v1:4.4556 v2:7.5533 v3:6.7653 v4:6.6320 v5:6.0893 v6:6.2856 v7:5.6459 v8:5.6062 v9:5.0744 v10:4.9123 v11:4.0397 v12:3.9840 v13:2.5420] eff_attn_bias:[v0:0.3466 v1:0.4171 v2:0.3922 v3:0.3646 v4:0.4116 v5:0.4502 v6:0.4640 v7:0.3950 v8:0.4033 v9:0.3895 v10:0.3466 v11:0.2403 v12:0.2596 v13:0.4088] eff_mlp_bias:[v0:0.5580 v1:0.3384 v2:0.3646 v3:0.2969 v4:0.3259 v5:0.3674 v6:0.3384 v7:0.2693 v8:0.3287 v9:0.3384 v10:0.2969 v11:0.2569 v12:0.3287 v13:0.3522] unique_attn_gain_rms:[u0:0.8631 u1:0.9808 u2:0.9480 u3:0.8267 u4:0.8830 u5:0.8406 u6:0.8022 u7:0.7595 u8:0.7911 u9:0.7402 u10:0.6832 u11:0.6205] unique_mlp_gain_rms:[u0:2.0367 u1:1.9772 u2:1.8912 u3:1.8296 u4:1.7926 u5:1.7680 u6:1.7083 u7:1.7287 u8:1.7545 u9:1.7295 u10:1.6634 u11:1.7322] depth_emb_rms:[v0:0.6600 v1:0.6393 v2:0.3944 v3:0.4021 v4:0.3232 v5:0.3559 v6:0.3981 v7:0.3640 v8:0.2896 v9:0.3594 v10:0.3680 v11:0.3258 v12:0.2798 v13:0.3558] +step:2200/20000 val_loss:2.2789 val_bpb:1.3497 train_time:147639ms step_avg:67.11ms +step:2400/20000 train_loss:2.2474 train_time:161011ms step_avg:67.09ms +step:2400 shared0_alpha:mean=0.481,std=0.087 shared1_alpha:mean=0.510,std=0.079 shared2_alpha:mean=0.577,std=0.062 shared3_alpha:mean=0.593,std=0.064 eff_mlp_scale:[v0:104.2430 v1:65.2321 v2:68.9117 v3:70.7440 v4:65.4826 v5:89.6941 v6:77.6347 v7:75.5181 v8:68.8843 v9:82.9790 v10:64.1141 v11:63.7998 v12:62.9313 v13:166.9092] eff_attn_scale:[v0:0.7620 v1:4.7349 v2:8.2241 v3:7.3591 v4:7.3701 v5:6.8015 v6:6.9933 v7:6.3226 v8:6.3027 v9:5.7028 v10:5.5107 v11:4.5865 v12:4.4475 v13:2.4273] eff_attn_bias:[v0:0.3784 v1:0.4364 v2:0.4060 v3:0.3812 v4:0.4281 v5:0.4668 v6:0.4834 v7:0.4088 v8:0.4198 v9:0.4033 v10:0.3591 v11:0.2458 v12:0.2721 v13:0.4337] eff_mlp_bias:[v0:0.5718 v1:0.3480 v2:0.3867 v3:0.3094 v4:0.3356 v5:0.3784 v6:0.3536 v7:0.2762 v8:0.3370 v9:0.3480 v10:0.3094 v11:0.2652 v12:0.3439 v13:0.3674] unique_attn_gain_rms:[u0:0.9159 u1:1.0170 u2:0.9793 u3:0.8499 u4:0.9263 u5:0.8723 u6:0.8334 u7:0.7771 u8:0.8170 u9:0.7552 u10:0.6942 u11:0.6393] unique_mlp_gain_rms:[u0:2.1155 u1:2.0487 u2:1.9491 u3:1.8955 u4:1.8540 u5:1.8247 u6:1.7631 u7:1.7847 u8:1.8119 u9:1.7913 u10:1.7148 u11:1.7897] depth_emb_rms:[v0:0.7051 v1:0.6846 v2:0.4116 v3:0.4277 v4:0.3386 v5:0.3680 v6:0.4147 v7:0.3815 v8:0.2991 v9:0.3714 v10:0.3799 v11:0.3406 v12:0.2889 v13:0.3747] +step:2400/20000 val_loss:2.2659 val_bpb:1.3420 train_time:161043ms step_avg:67.10ms +step:2600/20000 train_loss:2.4590 train_time:174414ms step_avg:67.08ms +step:2600 shared0_alpha:mean=0.480,std=0.088 shared1_alpha:mean=0.510,std=0.079 shared2_alpha:mean=0.580,std=0.063 shared3_alpha:mean=0.596,std=0.065 eff_mlp_scale:[v0:108.7369 v1:66.8617 v2:70.7068 v3:72.6276 v4:67.2445 v5:91.5714 v6:79.4903 v7:77.0027 v8:71.1240 v9:84.7884 v10:64.9976 v11:64.7523 v12:64.6582 v13:172.9921] eff_attn_scale:[v0:0.7201 v1:4.9883 v2:8.9002 v3:7.9247 v4:8.0545 v5:7.3598 v6:7.7214 v7:6.9205 v8:6.9269 v9:6.2967 v10:6.1594 v11:5.0750 v12:4.9401 v13:2.3883] eff_attn_bias:[v0:0.3977 v1:0.4585 v2:0.4143 v3:0.3950 v4:0.4447 v5:0.4861 v6:0.4999 v7:0.4143 v8:0.4309 v9:0.4198 v10:0.3757 v11:0.2527 v12:0.2845 v13:0.4613] eff_mlp_bias:[v0:0.5911 v1:0.3563 v2:0.4005 v3:0.3163 v4:0.3439 v5:0.3922 v6:0.3674 v7:0.2804 v8:0.3439 v9:0.3563 v10:0.3218 v11:0.2679 v12:0.3563 v13:0.3812] unique_attn_gain_rms:[u0:0.9685 u1:1.0551 u2:1.0132 u3:0.8698 u4:0.9675 u5:0.8997 u6:0.8614 u7:0.7974 u8:0.8502 u9:0.7707 u10:0.7064 u11:0.6637] unique_mlp_gain_rms:[u0:2.1877 u1:2.1225 u2:2.0126 u3:1.9549 u4:1.9144 u5:1.8825 u6:1.8136 u7:1.8436 u8:1.8664 u9:1.8520 u10:1.7664 u11:1.8449] depth_emb_rms:[v0:0.7551 v1:0.7345 v2:0.4249 v3:0.4481 v4:0.3520 v5:0.3811 v6:0.4301 v7:0.3974 v8:0.3058 v9:0.3810 v10:0.3921 v11:0.3546 v12:0.2961 v13:0.3913] +step:2600/20000 val_loss:2.2736 val_bpb:1.3466 train_time:174446ms step_avg:67.09ms +step:2800/20000 train_loss:2.2838 train_time:187816ms step_avg:67.08ms +step:2800 shared0_alpha:mean=0.480,std=0.089 shared1_alpha:mean=0.509,std=0.079 shared2_alpha:mean=0.582,std=0.064 shared3_alpha:mean=0.599,std=0.066 eff_mlp_scale:[v0:112.9033 v1:68.8840 v2:72.6508 v3:74.5388 v4:68.5782 v5:93.3110 v6:80.6247 v7:78.5083 v8:72.0505 v9:85.9829 v10:66.0059 v11:66.1587 v12:66.4080 v13:178.3525] eff_attn_scale:[v0:0.6965 v1:5.1834 v2:9.4315 v3:8.4979 v4:8.8436 v5:7.9875 v6:8.3291 v7:7.5854 v8:7.7665 v9:6.8828 v10:6.7674 v11:5.6463 v12:5.4139 v13:2.2703] eff_attn_bias:[v0:0.4226 v1:0.4751 v2:0.4337 v3:0.4005 v4:0.4613 v5:0.4999 v6:0.5193 v7:0.4254 v8:0.4419 v9:0.4337 v10:0.3867 v11:0.2583 v12:0.2969 v13:0.4834] eff_mlp_bias:[v0:0.6077 v1:0.3646 v2:0.4143 v3:0.3218 v4:0.3494 v5:0.4033 v6:0.3812 v7:0.2873 v8:0.3508 v9:0.3618 v10:0.3342 v11:0.2748 v12:0.3701 v13:0.4005] unique_attn_gain_rms:[u0:1.0169 u1:1.0886 u2:1.0489 u3:0.8961 u4:1.0101 u5:0.9323 u6:0.8916 u7:0.8236 u8:0.8809 u9:0.7924 u10:0.7263 u11:0.6885] unique_mlp_gain_rms:[u0:2.2604 u1:2.1912 u2:2.0731 u3:2.0197 u4:1.9738 u5:1.9361 u6:1.8683 u7:1.8994 u8:1.9264 u9:1.9113 u10:1.8200 u11:1.9038] depth_emb_rms:[v0:0.7988 v1:0.7854 v2:0.4417 v3:0.4668 v4:0.3625 v5:0.3908 v6:0.4438 v7:0.4170 v8:0.3147 v9:0.3929 v10:0.4024 v11:0.3693 v12:0.3031 v13:0.4087] +step:2800/20000 val_loss:2.2519 val_bpb:1.3337 train_time:187847ms step_avg:67.09ms +step:3000/20000 train_loss:2.2761 train_time:201214ms step_avg:67.07ms +step:3000 shared0_alpha:mean=0.479,std=0.089 shared1_alpha:mean=0.508,std=0.080 shared2_alpha:mean=0.585,std=0.064 shared3_alpha:mean=0.601,std=0.066 eff_mlp_scale:[v0:117.8284 v1:70.4042 v2:73.9419 v3:75.9263 v4:70.4076 v5:96.0058 v6:81.9597 v7:79.9225 v8:73.9280 v9:88.1284 v10:66.8149 v11:67.0461 v12:68.2074 v13:183.6767] eff_attn_scale:[v0:0.6799 v1:5.3059 v2:10.0450 v3:8.9872 v4:9.4522 v5:8.5128 v6:8.9642 v7:8.1594 v8:8.4479 v9:7.4341 v10:7.2794 v11:6.1196 v12:5.8190 v13:2.1570] eff_attn_bias:[v0:0.4502 v1:0.4944 v2:0.4447 v3:0.4143 v4:0.4778 v5:0.5165 v6:0.5331 v7:0.4364 v8:0.4530 v9:0.4447 v10:0.4005 v11:0.2652 v12:0.3094 v13:0.5082] eff_mlp_bias:[v0:0.6242 v1:0.3729 v2:0.4281 v3:0.3315 v4:0.3563 v5:0.4088 v6:0.3922 v7:0.2942 v8:0.3563 v9:0.3701 v10:0.3453 v11:0.2804 v12:0.3839 v13:0.4171] unique_attn_gain_rms:[u0:1.0574 u1:1.1225 u2:1.0827 u3:0.9159 u4:1.0491 u5:0.9618 u6:0.9173 u7:0.8466 u8:0.9103 u9:0.8112 u10:0.7415 u11:0.7145] unique_mlp_gain_rms:[u0:2.3298 u1:2.2602 u2:2.1350 u3:2.0779 u4:2.0307 u5:1.9895 u6:1.9161 u7:1.9540 u8:1.9850 u9:1.9668 u10:1.8723 u11:1.9600] depth_emb_rms:[v0:0.8479 v1:0.8331 v2:0.4569 v3:0.4859 v4:0.3767 v5:0.4022 v6:0.4542 v7:0.4311 v8:0.3220 v9:0.4024 v10:0.4149 v11:0.3828 v12:0.3111 v13:0.4256] +step:3000/20000 val_loss:2.2450 val_bpb:1.3296 train_time:201245ms step_avg:67.08ms +step:3200/20000 train_loss:2.2420 train_time:214604ms step_avg:67.06ms +step:3200 shared0_alpha:mean=0.478,std=0.091 shared1_alpha:mean=0.508,std=0.081 shared2_alpha:mean=0.587,std=0.065 shared3_alpha:mean=0.603,std=0.067 eff_mlp_scale:[v0:122.4764 v1:72.4408 v2:75.4662 v3:77.4303 v4:72.2509 v5:97.2493 v6:84.0011 v7:81.4585 v8:75.8188 v9:89.8067 v10:67.8298 v11:68.0312 v12:70.0209 v13:187.8929] eff_attn_scale:[v0:0.6479 v1:5.4616 v2:10.4125 v3:9.3653 v4:9.9548 v5:8.8454 v6:9.4363 v7:8.6355 v8:9.0945 v9:7.8956 v10:7.7768 v11:6.5679 v12:6.2678 v13:2.1350] eff_attn_bias:[v0:0.4806 v1:0.5110 v2:0.4585 v3:0.4254 v4:0.4972 v5:0.5303 v6:0.5469 v7:0.4447 v8:0.4613 v9:0.4585 v10:0.4116 v11:0.2707 v12:0.3176 v13:0.5303] eff_mlp_bias:[v0:0.6381 v1:0.3784 v2:0.4392 v3:0.3411 v4:0.3646 v5:0.4171 v6:0.4033 v7:0.2969 v8:0.3618 v9:0.3757 v10:0.3522 v11:0.2859 v12:0.3977 v13:0.4337] unique_attn_gain_rms:[u0:1.1080 u1:1.1512 u2:1.1083 u3:0.9351 u4:1.0847 u5:0.9881 u6:0.9422 u7:0.8708 u8:0.9426 u9:0.8336 u10:0.7578 u11:0.7457] unique_mlp_gain_rms:[u0:2.3995 u1:2.3307 u2:2.1950 u3:2.1389 u4:2.0860 u5:2.0445 u6:1.9644 u7:2.0096 u8:2.0386 u9:2.0217 u10:1.9229 u11:2.0150] depth_emb_rms:[v0:0.8952 v1:0.8801 v2:0.4690 v3:0.5022 v4:0.3894 v5:0.4138 v6:0.4657 v7:0.4434 v8:0.3280 v9:0.4115 v10:0.4253 v11:0.3932 v12:0.3183 v13:0.4410] +step:3200/20000 val_loss:2.2409 val_bpb:1.3272 train_time:214635ms step_avg:67.07ms +step:3400/20000 train_loss:2.2057 train_time:227985ms step_avg:67.05ms +step:3400 shared0_alpha:mean=0.477,std=0.091 shared1_alpha:mean=0.508,std=0.081 shared2_alpha:mean=0.589,std=0.066 shared3_alpha:mean=0.606,std=0.068 eff_mlp_scale:[v0:126.3864 v1:74.4708 v2:76.7665 v3:79.3538 v4:74.2928 v5:99.4609 v6:84.8948 v7:82.9608 v8:77.9168 v9:90.9643 v10:69.0899 v11:69.4346 v12:71.5748 v13:192.2300] eff_attn_scale:[v0:0.6372 v1:5.6592 v2:11.0233 v3:9.7833 v4:10.5879 v5:9.3711 v6:10.0151 v7:9.1562 v8:9.8225 v9:8.3975 v10:8.3011 v11:7.0553 v12:6.6972 v13:2.0659] eff_attn_bias:[v0:0.5082 v1:0.5276 v2:0.4723 v3:0.4392 v4:0.5110 v5:0.5441 v6:0.5607 v7:0.4558 v8:0.4723 v9:0.4723 v10:0.4226 v11:0.2776 v12:0.3273 v13:0.5524] eff_mlp_bias:[v0:0.6546 v1:0.3839 v2:0.4530 v3:0.3494 v4:0.3701 v5:0.4254 v6:0.4143 v7:0.3011 v8:0.3674 v9:0.3812 v10:0.3646 v11:0.2928 v12:0.4060 v13:0.4475] unique_attn_gain_rms:[u0:1.1526 u1:1.1890 u2:1.1412 u3:0.9577 u4:1.1257 u5:1.0161 u6:0.9722 u7:0.8926 u8:0.9734 u9:0.8533 u10:0.7763 u11:0.7742] unique_mlp_gain_rms:[u0:2.4685 u1:2.3917 u2:2.2459 u3:2.1972 u4:2.1396 u5:2.1002 u6:2.0121 u7:2.0665 u8:2.0949 u9:2.0756 u10:1.9742 u11:2.0713] depth_emb_rms:[v0:0.9443 v1:0.9302 v2:0.4799 v3:0.5219 v4:0.4009 v5:0.4227 v6:0.4780 v7:0.4580 v8:0.3342 v9:0.4225 v10:0.4357 v11:0.4053 v12:0.3260 v13:0.4554] +step:3400/20000 val_loss:2.2366 val_bpb:1.3246 train_time:228017ms step_avg:67.06ms +step:3600/20000 train_loss:2.1742 train_time:241375ms step_avg:67.05ms +step:3600 shared0_alpha:mean=0.476,std=0.092 shared1_alpha:mean=0.507,std=0.082 shared2_alpha:mean=0.591,std=0.066 shared3_alpha:mean=0.607,std=0.068 eff_mlp_scale:[v0:131.4689 v1:75.9506 v2:78.6264 v3:80.2451 v4:76.1025 v5:101.0998 v6:86.3527 v7:83.8719 v8:79.7923 v9:92.5491 v10:69.9911 v11:70.2711 v12:73.7964 v13:197.3706] eff_attn_scale:[v0:0.6218 v1:5.7133 v2:11.3997 v3:10.1052 v4:11.1041 v5:9.7499 v6:10.5707 v7:9.7190 v8:10.4471 v9:8.8805 v10:8.8434 v11:7.4662 v12:7.0633 v13:2.0424] eff_attn_bias:[v0:0.5441 v1:0.5441 v2:0.4806 v3:0.4475 v4:0.5248 v5:0.5552 v6:0.5745 v7:0.4640 v8:0.4806 v9:0.4861 v10:0.4337 v11:0.2817 v12:0.3384 v13:0.5718] eff_mlp_bias:[v0:0.6657 v1:0.3922 v2:0.4613 v3:0.3563 v4:0.3784 v5:0.4337 v6:0.4254 v7:0.3052 v8:0.3729 v9:0.3867 v10:0.3729 v11:0.2969 v12:0.4198 v13:0.4640] unique_attn_gain_rms:[u0:1.1935 u1:1.2158 u2:1.1628 u3:0.9782 u4:1.1609 u5:1.0417 u6:0.9996 u7:0.9137 u8:1.0046 u9:0.8744 u10:0.7930 u11:0.8041] unique_mlp_gain_rms:[u0:2.5351 u1:2.4546 u2:2.3016 u3:2.2592 u4:2.1959 u5:2.1543 u6:2.0600 u7:2.1221 u8:2.1457 u9:2.1272 u10:2.0232 u11:2.1276] depth_emb_rms:[v0:0.9979 v1:0.9827 v2:0.4945 v3:0.5353 v4:0.4103 v5:0.4357 v6:0.4908 v7:0.4741 v8:0.3425 v9:0.4307 v10:0.4464 v11:0.4167 v12:0.3322 v13:0.4697] +step:3600/20000 val_loss:2.2294 val_bpb:1.3204 train_time:241406ms step_avg:67.06ms +step:3800/20000 train_loss:2.2702 train_time:254758ms step_avg:67.04ms +step:3800 shared0_alpha:mean=0.475,std=0.093 shared1_alpha:mean=0.506,std=0.082 shared2_alpha:mean=0.592,std=0.067 shared3_alpha:mean=0.609,std=0.069 eff_mlp_scale:[v0:136.2708 v1:76.9222 v2:80.1373 v3:82.3251 v4:78.5488 v5:102.2256 v6:87.9221 v7:85.5266 v8:82.3117 v9:94.1285 v10:70.9787 v11:71.8058 v12:76.1971 v13:201.7538] eff_attn_scale:[v0:0.6281 v1:5.8432 v2:11.7394 v3:10.4439 v4:11.5610 v5:10.1703 v6:10.9662 v7:10.1155 v8:11.0201 v9:9.3491 v10:9.2791 v11:7.8493 v12:7.4369 v13:1.9870] eff_attn_bias:[v0:0.5718 v1:0.5635 v2:0.4944 v3:0.4585 v4:0.5386 v5:0.5690 v6:0.5856 v7:0.4723 v8:0.4944 v9:0.4972 v10:0.4419 v11:0.2900 v12:0.3480 v13:0.5939] eff_mlp_bias:[v0:0.6740 v1:0.3977 v2:0.4751 v3:0.3674 v4:0.3812 v5:0.4419 v6:0.4337 v7:0.3121 v8:0.3812 v9:0.3922 v10:0.3812 v11:0.3011 v12:0.4337 v13:0.4778] unique_attn_gain_rms:[u0:1.2346 u1:1.2342 u2:1.1911 u3:0.9965 u4:1.1951 u5:1.0672 u6:1.0268 u7:0.9361 u8:1.0328 u9:0.8909 u10:0.8122 u11:0.8301] unique_mlp_gain_rms:[u0:2.6020 u1:2.5162 u2:2.3569 u3:2.3136 u4:2.2521 u5:2.2054 u6:2.1095 u7:2.1758 u8:2.1994 u9:2.1812 u10:2.0718 u11:2.1811] depth_emb_rms:[v0:1.0520 v1:1.0311 v2:0.5073 v3:0.5524 v4:0.4241 v5:0.4422 v6:0.5016 v7:0.4864 v8:0.3522 v9:0.4416 v10:0.4553 v11:0.4294 v12:0.3399 v13:0.4864] +step:3800/20000 val_loss:2.2251 val_bpb:1.3178 train_time:254790ms step_avg:67.05ms +step:4000/20000 train_loss:2.2099 train_time:268136ms step_avg:67.03ms +step:4000 shared0_alpha:mean=0.474,std=0.092 shared1_alpha:mean=0.506,std=0.082 shared2_alpha:mean=0.594,std=0.068 shared3_alpha:mean=0.611,std=0.070 eff_mlp_scale:[v0:140.4284 v1:79.2023 v2:81.4463 v3:83.2872 v4:81.2711 v5:104.2405 v6:88.8087 v7:86.9684 v8:84.6373 v9:96.0647 v10:72.2433 v11:73.1639 v12:78.8666 v13:206.1023] eff_attn_scale:[v0:0.6266 v1:5.8846 v2:12.0569 v3:10.6809 v4:12.1823 v5:10.4900 v6:11.3435 v7:10.4806 v8:11.6254 v9:9.7224 v10:9.5599 v11:8.2109 v12:7.8315 v13:1.9324] eff_attn_bias:[v0:0.6104 v1:0.5773 v2:0.4999 v3:0.4696 v4:0.5497 v5:0.5828 v6:0.5966 v7:0.4806 v8:0.4999 v9:0.5082 v10:0.4530 v11:0.2942 v12:0.3591 v13:0.6160] eff_mlp_bias:[v0:0.6878 v1:0.4060 v2:0.4834 v3:0.3757 v4:0.3867 v5:0.4502 v6:0.4447 v7:0.3135 v8:0.3839 v9:0.3977 v10:0.3867 v11:0.3066 v12:0.4447 v13:0.4944] unique_attn_gain_rms:[u0:1.2720 u1:1.2624 u2:1.2184 u3:1.0158 u4:1.2322 u5:1.0929 u6:1.0504 u7:0.9566 u8:1.0595 u9:0.9111 u10:0.8316 u11:0.8563] unique_mlp_gain_rms:[u0:2.6663 u1:2.5759 u2:2.4125 u3:2.3695 u4:2.3036 u5:2.2572 u6:2.1545 u7:2.2239 u8:2.2568 u9:2.2320 u10:2.1215 u11:2.2314] depth_emb_rms:[v0:1.1087 v1:1.0821 v2:0.5202 v3:0.5688 v4:0.4369 v5:0.4532 v6:0.5137 v7:0.5002 v8:0.3575 v9:0.4503 v10:0.4655 v11:0.4378 v12:0.3466 v13:0.5018] +step:4000/20000 val_loss:2.2189 val_bpb:1.3142 train_time:268167ms step_avg:67.04ms +step:4200/20000 train_loss:2.2242 train_time:281580ms step_avg:67.04ms +step:4200 shared0_alpha:mean=0.474,std=0.094 shared1_alpha:mean=0.506,std=0.083 shared2_alpha:mean=0.595,std=0.069 shared3_alpha:mean=0.613,std=0.071 eff_mlp_scale:[v0:145.2065 v1:80.7775 v2:82.8967 v3:85.2676 v4:84.2773 v5:106.5028 v6:90.3064 v7:88.0481 v8:87.7272 v9:97.2417 v10:73.1714 v11:74.1457 v12:81.8130 v13:210.0611] eff_attn_scale:[v0:0.6126 v1:6.0127 v2:12.4378 v3:10.9874 v4:12.6167 v5:10.8617 v6:11.7104 v7:10.8517 v8:12.1178 v9:10.0859 v10:10.0375 v11:8.5457 v12:8.1617 v13:1.8999] eff_attn_bias:[v0:0.6381 v1:0.5939 v2:0.5110 v3:0.4806 v4:0.5607 v5:0.5911 v6:0.6077 v7:0.4917 v8:0.5110 v9:0.5193 v10:0.4640 v11:0.2983 v12:0.3674 v13:0.6353] eff_mlp_bias:[v0:0.6961 v1:0.4116 v2:0.4944 v3:0.3839 v4:0.3950 v5:0.4558 v6:0.4502 v7:0.3190 v8:0.3895 v9:0.4033 v10:0.3922 v11:0.3107 v12:0.4558 v13:0.5082] unique_attn_gain_rms:[u0:1.3098 u1:1.2868 u2:1.2433 u3:1.0272 u4:1.2643 u5:1.1166 u6:1.0741 u7:0.9782 u8:1.0887 u9:0.9334 u10:0.8507 u11:0.8861] unique_mlp_gain_rms:[u0:2.7277 u1:2.6368 u2:2.4648 u3:2.4232 u4:2.3591 u5:2.3074 u6:2.2030 u7:2.2753 u8:2.3047 u9:2.2826 u10:2.1667 u11:2.2846] depth_emb_rms:[v0:1.1682 v1:1.1325 v2:0.5320 v3:0.5844 v4:0.4479 v5:0.4639 v6:0.5242 v7:0.5123 v8:0.3653 v9:0.4585 v10:0.4757 v11:0.4472 v12:0.3529 v13:0.5163] +step:4200/20000 val_loss:2.2158 val_bpb:1.3123 train_time:281610ms step_avg:67.05ms +step:4400/20000 train_loss:2.1644 train_time:294946ms step_avg:67.03ms +step:4400 shared0_alpha:mean=0.473,std=0.094 shared1_alpha:mean=0.506,std=0.083 shared2_alpha:mean=0.597,std=0.069 shared3_alpha:mean=0.615,std=0.072 eff_mlp_scale:[v0:150.3239 v1:82.3338 v2:84.0132 v3:86.6764 v4:87.4132 v5:108.2249 v6:91.9478 v7:89.4724 v8:90.9501 v9:98.9041 v10:74.2117 v11:75.0263 v12:84.8868 v13:214.0000] eff_attn_scale:[v0:0.6051 v1:6.1025 v2:12.8097 v3:11.3290 v4:13.2078 v5:11.2207 v6:12.1433 v7:11.2599 v8:12.8409 v9:10.4989 v10:10.3662 v11:8.8422 v12:8.5484 v13:1.8815] eff_attn_bias:[v0:0.6712 v1:0.6104 v2:0.5193 v3:0.4889 v4:0.5718 v5:0.6049 v6:0.6132 v7:0.4999 v8:0.5193 v9:0.5303 v10:0.4723 v11:0.3038 v12:0.3757 v13:0.6546] eff_mlp_bias:[v0:0.7016 v1:0.4198 v2:0.5055 v3:0.3895 v4:0.4033 v5:0.4640 v6:0.4613 v7:0.3232 v8:0.3950 v9:0.4088 v10:0.3977 v11:0.3163 v12:0.4640 v13:0.5220] unique_attn_gain_rms:[u0:1.3434 u1:1.3134 u2:1.2692 u3:1.0471 u4:1.2952 u5:1.1455 u6:1.0996 u7:0.9982 u8:1.1206 u9:0.9526 u10:0.8705 u11:0.9136] unique_mlp_gain_rms:[u0:2.7931 u1:2.6967 u2:2.5179 u3:2.4761 u4:2.4135 u5:2.3630 u6:2.2466 u7:2.3242 u8:2.3578 u9:2.3354 u10:2.2156 u11:2.3385] depth_emb_rms:[v0:1.2283 v1:1.1812 v2:0.5458 v3:0.5987 v4:0.4581 v5:0.4737 v6:0.5352 v7:0.5242 v8:0.3724 v9:0.4673 v10:0.4818 v11:0.4566 v12:0.3588 v13:0.5316] +step:4400/20000 val_loss:2.2162 val_bpb:1.3126 train_time:294979ms step_avg:67.04ms +step:4600/20000 train_loss:2.0237 train_time:308326ms step_avg:67.03ms +step:4600 shared0_alpha:mean=0.472,std=0.095 shared1_alpha:mean=0.505,std=0.084 shared2_alpha:mean=0.598,std=0.070 shared3_alpha:mean=0.615,std=0.072 eff_mlp_scale:[v0:154.9539 v1:84.4015 v2:85.8127 v3:88.1340 v4:90.2530 v5:109.4094 v6:93.3154 v7:90.9468 v8:94.4026 v9:100.0314 v10:75.0275 v11:75.9453 v12:88.1783 v13:218.7572] eff_attn_scale:[v0:0.6012 v1:6.1621 v2:13.1273 v3:11.5456 v4:13.7307 v5:11.4629 v6:12.5992 v7:11.6156 v8:13.4289 v9:10.8003 v10:10.7885 v11:9.1665 v12:8.9400 v13:1.8473] eff_attn_bias:[v0:0.7016 v1:0.6242 v2:0.5303 v3:0.5027 v4:0.5856 v5:0.6160 v6:0.6187 v7:0.5055 v8:0.5276 v9:0.5414 v10:0.4806 v11:0.3094 v12:0.3812 v13:0.6712] eff_mlp_bias:[v0:0.7182 v1:0.4254 v2:0.5193 v3:0.4005 v4:0.4060 v5:0.4668 v6:0.4668 v7:0.3259 v8:0.3977 v9:0.4116 v10:0.4033 v11:0.3204 v12:0.4751 v13:0.5359] unique_attn_gain_rms:[u0:1.3796 u1:1.3332 u2:1.2948 u3:1.0645 u4:1.3263 u5:1.1698 u6:1.1229 u7:1.0218 u8:1.1475 u9:0.9748 u10:0.8897 u11:0.9438] unique_mlp_gain_rms:[u0:2.8534 u1:2.7549 u2:2.5681 u3:2.5310 u4:2.4667 u5:2.4113 u6:2.3005 u7:2.3709 u8:2.4097 u9:2.3858 u10:2.2668 u11:2.3911] depth_emb_rms:[v0:1.2916 v1:1.2372 v2:0.5589 v3:0.6156 v4:0.4706 v5:0.4829 v6:0.5416 v7:0.5325 v8:0.3782 v9:0.4753 v10:0.4915 v11:0.4655 v12:0.3661 v13:0.5461] +step:4600/20000 val_loss:2.2123 val_bpb:1.3102 train_time:308357ms step_avg:67.03ms +step:4800/20000 train_loss:2.3092 train_time:321703ms step_avg:67.02ms +step:4800 shared0_alpha:mean=0.471,std=0.095 shared1_alpha:mean=0.505,std=0.084 shared2_alpha:mean=0.599,std=0.070 shared3_alpha:mean=0.617,std=0.073 eff_mlp_scale:[v0:160.8699 v1:85.5185 v2:86.7074 v3:89.6100 v4:93.6992 v5:111.2265 v6:94.2472 v7:91.9682 v8:97.4259 v9:101.7828 v10:75.8690 v11:77.3476 v12:91.0373 v13:222.6999] eff_attn_scale:[v0:0.6054 v1:6.2332 v2:13.4047 v3:11.8250 v4:14.1188 v5:11.7292 v6:12.9451 v7:11.9675 v8:14.0416 v9:11.1930 v10:11.1833 v11:9.5455 v12:9.3354 v13:1.8286] eff_attn_bias:[v0:0.7347 v1:0.6381 v2:0.5386 v3:0.5110 v4:0.5966 v5:0.6298 v6:0.6298 v7:0.5165 v8:0.5359 v9:0.5524 v10:0.4917 v11:0.3135 v12:0.3895 v13:0.6850] eff_mlp_bias:[v0:0.7237 v1:0.4281 v2:0.5303 v3:0.4060 v4:0.4143 v5:0.4751 v6:0.4778 v7:0.3301 v8:0.4033 v9:0.4171 v10:0.4088 v11:0.3259 v12:0.4861 v13:0.5469] unique_attn_gain_rms:[u0:1.4119 u1:1.3531 u2:1.3161 u3:1.0834 u4:1.3574 u5:1.1955 u6:1.1432 u7:1.0428 u8:1.1735 u9:0.9906 u10:0.9088 u11:0.9704] unique_mlp_gain_rms:[u0:2.9156 u1:2.8194 u2:2.6168 u3:2.5803 u4:2.5217 u5:2.4557 u6:2.3404 u7:2.4236 u8:2.4602 u9:2.4363 u10:2.3147 u11:2.4421] depth_emb_rms:[v0:1.3469 v1:1.2946 v2:0.5690 v3:0.6286 v4:0.4798 v5:0.4921 v6:0.5554 v7:0.5469 v8:0.3847 v9:0.4848 v10:0.5008 v11:0.4747 v12:0.3714 v13:0.5596] +step:4800/20000 val_loss:2.2081 val_bpb:1.3078 train_time:321734ms step_avg:67.03ms +step:5000/20000 train_loss:2.0792 train_time:335081ms step_avg:67.02ms +step:5000 shared0_alpha:mean=0.470,std=0.096 shared1_alpha:mean=0.504,std=0.084 shared2_alpha:mean=0.600,std=0.071 shared3_alpha:mean=0.618,std=0.073 eff_mlp_scale:[v0:164.6056 v1:87.0563 v2:88.1845 v3:90.6317 v4:97.3823 v5:112.9094 v6:95.2961 v7:93.4788 v8:100.6648 v9:102.8847 v10:76.8058 v11:78.2944 v12:94.0997 v13:224.6166] eff_attn_scale:[v0:0.6099 v1:6.2109 v2:13.6880 v3:12.0773 v4:14.6677 v5:12.0824 v6:13.2240 v7:12.3666 v8:14.5888 v9:11.5393 v10:11.4453 v11:9.8354 v12:9.6602 v13:1.7871] eff_attn_bias:[v0:0.7679 v1:0.6546 v2:0.5469 v3:0.5220 v4:0.6077 v5:0.6381 v6:0.6408 v7:0.5220 v8:0.5414 v9:0.5607 v10:0.4999 v11:0.3204 v12:0.3950 v13:0.7016] eff_mlp_bias:[v0:0.7292 v1:0.4364 v2:0.5359 v3:0.4116 v4:0.4171 v5:0.4834 v6:0.4834 v7:0.3370 v8:0.4088 v9:0.4198 v10:0.4171 v11:0.3301 v12:0.4944 v13:0.5607] unique_attn_gain_rms:[u0:1.4405 u1:1.3737 u2:1.3397 u3:1.0967 u4:1.3890 u5:1.2152 u6:1.1726 u7:1.0625 u8:1.2028 u9:1.0135 u10:0.9271 u11:0.9970] unique_mlp_gain_rms:[u0:2.9797 u1:2.8726 u2:2.6721 u3:2.6279 u4:2.5726 u5:2.5079 u6:2.3870 u7:2.4716 u8:2.5154 u9:2.4887 u10:2.3614 u11:2.4938] depth_emb_rms:[v0:1.4052 v1:1.3486 v2:0.5817 v3:0.6424 v4:0.4903 v5:0.5010 v6:0.5660 v7:0.5579 v8:0.3936 v9:0.4932 v10:0.5069 v11:0.4848 v12:0.3784 v13:0.5738] +step:5000/20000 val_loss:2.2022 val_bpb:1.3043 train_time:335112ms step_avg:67.02ms +step:5200/20000 train_loss:2.2214 train_time:348457ms step_avg:67.01ms +step:5200 shared0_alpha:mean=0.469,std=0.096 shared1_alpha:mean=0.503,std=0.084 shared2_alpha:mean=0.602,std=0.072 shared3_alpha:mean=0.619,std=0.075 eff_mlp_scale:[v0:170.3900 v1:88.2601 v2:89.6883 v3:92.7666 v4:100.4412 v5:114.8445 v6:96.8443 v7:94.6793 v8:104.3691 v9:104.7424 v10:77.7617 v11:79.3776 v12:97.6356 v13:230.3858] eff_attn_scale:[v0:0.6007 v1:6.3472 v2:13.9224 v3:12.3356 v4:15.1099 v5:12.4200 v6:13.6096 v7:12.5546 v8:15.1907 v9:11.8024 v10:11.7324 v11:10.0729 v12:10.0598 v13:1.7561] eff_attn_bias:[v0:0.8010 v1:0.6712 v2:0.5524 v3:0.5303 v4:0.6187 v5:0.6519 v6:0.6519 v7:0.5331 v8:0.5497 v9:0.5718 v10:0.5082 v11:0.3259 v12:0.4033 v13:0.7126] eff_mlp_bias:[v0:0.7403 v1:0.4392 v2:0.5469 v3:0.4198 v4:0.4254 v5:0.4889 v6:0.4917 v7:0.3411 v8:0.4143 v9:0.4281 v10:0.4226 v11:0.3328 v12:0.5055 v13:0.5718] unique_attn_gain_rms:[u0:1.4740 u1:1.3931 u2:1.3638 u3:1.1115 u4:1.4202 u5:1.2409 u6:1.1946 u7:1.0846 u8:1.2253 u9:1.0333 u10:0.9504 u11:1.0213] unique_mlp_gain_rms:[u0:3.0401 u1:2.9312 u2:2.7226 u3:2.6790 u4:2.6253 u5:2.5560 u6:2.4356 u7:2.5253 u8:2.5715 u9:2.5419 u10:2.4066 u11:2.5443] depth_emb_rms:[v0:1.4637 v1:1.4018 v2:0.5908 v3:0.6570 v4:0.5022 v5:0.5110 v6:0.5758 v7:0.5697 v8:0.4009 v9:0.5020 v10:0.5173 v11:0.4921 v12:0.3829 v13:0.5856] +step:5200/20000 val_loss:2.2041 val_bpb:1.3054 train_time:348488ms step_avg:67.02ms +step:5400/20000 train_loss:2.2374 train_time:361828ms step_avg:67.01ms +step:5400 shared0_alpha:mean=0.468,std=0.097 shared1_alpha:mean=0.503,std=0.084 shared2_alpha:mean=0.603,std=0.072 shared3_alpha:mean=0.621,std=0.075 eff_mlp_scale:[v0:175.1869 v1:89.7173 v2:91.0651 v3:94.1565 v4:104.1415 v5:115.8849 v6:97.7751 v7:95.5977 v8:107.5937 v9:105.7383 v10:78.6035 v11:80.2252 v12:101.2647 v13:234.3855] eff_attn_scale:[v0:0.6055 v1:6.4525 v2:14.3625 v3:12.6056 v4:15.6985 v5:12.5599 v6:13.9658 v7:12.9763 v8:15.6985 v9:12.0768 v10:12.0614 v11:10.4552 v12:10.4106 v13:1.7733] eff_attn_bias:[v0:0.8342 v1:0.6850 v2:0.5607 v3:0.5414 v4:0.6298 v5:0.6602 v6:0.6629 v7:0.5386 v8:0.5552 v9:0.5828 v10:0.5193 v11:0.3301 v12:0.4088 v13:0.7292] eff_mlp_bias:[v0:0.7513 v1:0.4447 v2:0.5552 v3:0.4281 v4:0.4309 v5:0.4917 v6:0.4972 v7:0.3453 v8:0.4171 v9:0.4309 v10:0.4309 v11:0.3370 v12:0.5138 v13:0.5856] unique_attn_gain_rms:[u0:1.4961 u1:1.4252 u2:1.3851 u3:1.1310 u4:1.4454 u5:1.2632 u6:1.2192 u7:1.1025 u8:1.2532 u9:1.0516 u10:0.9689 u11:1.0485] unique_mlp_gain_rms:[u0:3.1008 u1:2.9892 u2:2.7706 u3:2.7314 u4:2.6753 u5:2.6072 u6:2.4825 u7:2.5715 u8:2.6189 u9:2.5909 u10:2.4563 u11:2.5978] depth_emb_rms:[v0:1.5242 v1:1.4599 v2:0.6018 v3:0.6706 v4:0.5131 v5:0.5215 v6:0.5836 v7:0.5809 v8:0.4076 v9:0.5090 v10:0.5262 v11:0.5011 v12:0.3892 v13:0.5984] +step:5400/20000 val_loss:2.1985 val_bpb:1.3021 train_time:361859ms step_avg:67.01ms +step:5600/20000 train_loss:2.2331 train_time:375205ms step_avg:67.00ms +step:5600 shared0_alpha:mean=0.467,std=0.098 shared1_alpha:mean=0.503,std=0.084 shared2_alpha:mean=0.604,std=0.073 shared3_alpha:mean=0.622,std=0.076 eff_mlp_scale:[v0:180.4128 v1:91.2860 v2:92.4975 v3:95.2382 v4:107.9700 v5:117.5978 v6:99.2421 v7:97.1719 v8:111.5100 v9:107.3953 v10:79.4901 v11:81.7018 v12:104.4300 v13:236.3976] eff_attn_scale:[v0:0.5962 v1:6.4495 v2:14.5829 v3:12.7937 v4:16.1981 v5:12.7604 v6:14.1822 v7:13.3174 v8:16.2825 v9:12.3443 v10:12.3393 v11:10.5492 v12:10.7566 v13:1.7377] eff_attn_bias:[v0:0.8673 v1:0.7016 v2:0.5718 v3:0.5524 v4:0.6463 v5:0.6740 v6:0.6712 v7:0.5441 v8:0.5635 v9:0.5911 v10:0.5248 v11:0.3342 v12:0.4143 v13:0.7458] eff_mlp_bias:[v0:0.7568 v1:0.4502 v2:0.5635 v3:0.4337 v4:0.4337 v5:0.4972 v6:0.5055 v7:0.3480 v8:0.4226 v9:0.4364 v10:0.4364 v11:0.3425 v12:0.5248 v13:0.5966] unique_attn_gain_rms:[u0:1.5184 u1:1.4435 u2:1.4043 u3:1.1487 u4:1.4780 u5:1.2822 u6:1.2391 u7:1.1235 u8:1.2805 u9:1.0707 u10:0.9828 u11:1.0718] unique_mlp_gain_rms:[u0:3.1596 u1:3.0472 u2:2.8233 u3:2.7841 u4:2.7305 u5:2.6553 u6:2.5264 u7:2.6209 u8:2.6758 u9:2.6433 u10:2.5029 u11:2.6472] depth_emb_rms:[v0:1.5798 v1:1.5151 v2:0.6135 v3:0.6817 v4:0.5226 v5:0.5296 v6:0.5909 v7:0.5917 v8:0.4130 v9:0.5165 v10:0.5343 v11:0.5096 v12:0.3955 v13:0.6117] +step:5600/20000 val_loss:2.1982 val_bpb:1.3019 train_time:375236ms step_avg:67.01ms +step:5800/20000 train_loss:2.1946 train_time:388579ms step_avg:67.00ms +step:5800 shared0_alpha:mean=0.466,std=0.099 shared1_alpha:mean=0.503,std=0.086 shared2_alpha:mean=0.604,std=0.073 shared3_alpha:mean=0.623,std=0.076 eff_mlp_scale:[v0:186.6605 v1:92.9960 v2:93.9131 v3:96.6201 v4:110.8836 v5:118.9484 v6:100.6903 v7:98.5623 v8:114.4994 v9:109.2162 v10:80.3586 v11:82.0543 v12:107.8705 v13:240.2290] eff_attn_scale:[v0:0.5966 v1:6.5518 v2:14.7781 v3:13.0900 v4:16.7923 v5:13.1036 v6:14.4550 v7:13.5440 v8:16.8784 v9:12.6831 v10:12.5977 v11:10.8201 v12:11.1087 v13:1.7058] eff_attn_bias:[v0:0.9005 v1:0.7182 v2:0.5828 v3:0.5607 v4:0.6546 v5:0.6878 v6:0.6822 v7:0.5524 v8:0.5690 v9:0.6021 v10:0.5331 v11:0.3397 v12:0.4198 v13:0.7568] eff_mlp_bias:[v0:0.7679 v1:0.4585 v2:0.5718 v3:0.4392 v4:0.4364 v5:0.5027 v6:0.5110 v7:0.3536 v8:0.4254 v9:0.4392 v10:0.4419 v11:0.3466 v12:0.5331 v13:0.6132] unique_attn_gain_rms:[u0:1.5496 u1:1.4574 u2:1.4292 u3:1.1614 u4:1.5014 u5:1.3026 u6:1.2651 u7:1.1403 u8:1.3061 u9:1.0903 u10:1.0060 u11:1.0985] unique_mlp_gain_rms:[u0:3.2190 u1:3.1023 u2:2.8708 u3:2.8366 u4:2.7812 u5:2.7029 u6:2.5731 u7:2.6740 u8:2.7230 u9:2.6904 u10:2.5498 u11:2.7009] depth_emb_rms:[v0:1.6376 v1:1.5766 v2:0.6255 v3:0.6953 v4:0.5322 v5:0.5357 v6:0.6010 v7:0.6011 v8:0.4220 v9:0.5248 v10:0.5432 v11:0.5173 v12:0.4015 v13:0.6261] +step:5800/20000 val_loss:2.1953 val_bpb:1.3002 train_time:388611ms step_avg:67.00ms +step:6000/20000 train_loss:2.2669 train_time:401949ms step_avg:66.99ms +step:6000 shared0_alpha:mean=0.465,std=0.099 shared1_alpha:mean=0.502,std=0.085 shared2_alpha:mean=0.605,std=0.074 shared3_alpha:mean=0.624,std=0.077 eff_mlp_scale:[v0:191.6657 v1:94.0286 v2:95.4397 v3:98.1194 v4:114.4309 v5:120.6610 v6:101.7699 v7:99.5839 v8:118.1222 v9:110.3342 v10:81.3185 v11:83.4747 v12:110.7396 v13:244.0273] eff_attn_scale:[v0:0.5916 v1:6.5112 v2:15.0567 v3:13.2640 v4:17.1669 v5:13.3744 v6:14.8112 v7:13.7976 v8:17.3421 v9:12.8816 v10:12.9291 v11:11.0533 v12:11.2987 v13:1.7161] eff_attn_bias:[v0:0.9336 v1:0.7347 v2:0.5911 v3:0.5690 v4:0.6657 v5:0.7016 v6:0.6850 v7:0.5580 v8:0.5773 v9:0.6104 v10:0.5386 v11:0.3425 v12:0.4281 v13:0.7734] eff_mlp_bias:[v0:0.7734 v1:0.4613 v2:0.5745 v3:0.4475 v4:0.4447 v5:0.5082 v6:0.5193 v7:0.3591 v8:0.4309 v9:0.4447 v10:0.4475 v11:0.3494 v12:0.5414 v13:0.6215] unique_attn_gain_rms:[u0:1.5758 u1:1.4753 u2:1.4487 u3:1.1774 u4:1.5345 u5:1.3230 u6:1.2874 u7:1.1579 u8:1.3301 u9:1.1115 u10:1.0269 u11:1.1220] unique_mlp_gain_rms:[u0:3.2767 u1:3.1601 u2:2.9142 u3:2.8844 u4:2.8311 u5:2.7553 u6:2.6127 u7:2.7173 u8:2.7724 u9:2.7385 u10:2.5986 u11:2.7484] depth_emb_rms:[v0:1.6899 v1:1.6357 v2:0.6349 v3:0.7052 v4:0.5441 v5:0.5463 v6:0.6105 v7:0.6126 v8:0.4292 v9:0.5336 v10:0.5509 v11:0.5246 v12:0.4066 v13:0.6375] +step:6000/20000 val_loss:2.1929 val_bpb:1.2987 train_time:401980ms step_avg:67.00ms +step:6200/20000 train_loss:2.1374 train_time:415326ms step_avg:66.99ms +step:6200 shared0_alpha:mean=0.465,std=0.099 shared1_alpha:mean=0.502,std=0.086 shared2_alpha:mean=0.606,std=0.074 shared3_alpha:mean=0.625,std=0.078 eff_mlp_scale:[v0:195.9977 v1:95.6756 v2:96.5008 v3:99.7806 v4:118.4333 v5:122.4647 v6:103.3587 v7:100.7636 v8:121.5831 v9:111.5304 v10:82.2951 v11:84.5432 v12:114.0236 v13:246.1414] eff_attn_scale:[v0:0.6004 v1:6.5547 v2:15.2350 v3:13.5409 v4:17.6516 v5:13.6055 v6:15.0703 v7:14.1599 v8:17.9204 v9:13.1803 v10:13.0939 v11:11.3744 v12:11.7379 v13:1.6740] eff_attn_bias:[v0:0.9667 v1:0.7458 v2:0.5994 v3:0.5773 v4:0.6767 v5:0.7071 v6:0.6933 v7:0.5662 v8:0.5828 v9:0.6215 v10:0.5469 v11:0.3480 v12:0.4337 v13:0.7844] eff_mlp_bias:[v0:0.7844 v1:0.4723 v2:0.5856 v3:0.4530 v4:0.4502 v5:0.5138 v6:0.5248 v7:0.3646 v8:0.4364 v9:0.4475 v10:0.4502 v11:0.3536 v12:0.5497 v13:0.6325] unique_attn_gain_rms:[u0:1.6004 u1:1.4963 u2:1.4647 u3:1.1925 u4:1.5644 u5:1.3435 u6:1.3058 u7:1.1748 u8:1.3546 u9:1.1328 u10:1.0451 u11:1.1492] unique_mlp_gain_rms:[u0:3.3334 u1:3.2166 u2:2.9668 u3:2.9295 u4:2.8819 u5:2.8012 u6:2.6533 u7:2.7653 u8:2.8220 u9:2.7874 u10:2.6459 u11:2.7958] depth_emb_rms:[v0:1.7446 v1:1.6981 v2:0.6485 v3:0.7218 v4:0.5535 v5:0.5568 v6:0.6193 v7:0.6213 v8:0.4354 v9:0.5415 v10:0.5579 v11:0.5309 v12:0.4110 v13:0.6503] +step:6200/20000 val_loss:2.1933 val_bpb:1.2990 train_time:415358ms step_avg:66.99ms +step:6400/20000 train_loss:2.2148 train_time:428695ms step_avg:66.98ms +step:6400 shared0_alpha:mean=0.463,std=0.100 shared1_alpha:mean=0.502,std=0.086 shared2_alpha:mean=0.607,std=0.075 shared3_alpha:mean=0.626,std=0.078 eff_mlp_scale:[v0:201.3470 v1:96.6114 v2:98.0988 v3:100.7911 v4:121.4938 v5:124.0579 v6:104.0143 v7:102.2733 v8:124.7079 v9:112.5304 v10:83.3101 v11:85.4748 v12:117.6369 v13:250.1441] eff_attn_scale:[v0:0.5912 v1:6.5872 v2:15.5507 v3:13.5973 v4:18.0965 v5:13.7442 v6:15.3012 v7:14.2966 v8:18.5535 v9:13.3881 v10:13.3054 v11:11.4995 v12:12.0644 v13:1.6847] eff_attn_bias:[v0:0.9944 v1:0.7568 v2:0.6077 v3:0.5856 v4:0.6878 v5:0.7182 v6:0.7043 v7:0.5745 v8:0.5911 v9:0.6298 v10:0.5552 v11:0.3536 v12:0.4392 v13:0.8010] eff_mlp_bias:[v0:0.7955 v1:0.4751 v2:0.5911 v3:0.4585 v4:0.4558 v5:0.5193 v6:0.5303 v7:0.3674 v8:0.4392 v9:0.4530 v10:0.4558 v11:0.3591 v12:0.5552 v13:0.6436] unique_attn_gain_rms:[u0:1.6224 u1:1.5136 u2:1.4815 u3:1.2067 u4:1.5903 u5:1.3632 u6:1.3283 u7:1.1948 u8:1.3766 u9:1.1499 u10:1.0630 u11:1.1726] unique_mlp_gain_rms:[u0:3.3922 u1:3.2745 u2:3.0112 u3:2.9786 u4:2.9333 u5:2.8475 u6:2.7009 u7:2.8120 u8:2.8740 u9:2.8379 u10:2.6945 u11:2.8487] depth_emb_rms:[v0:1.7970 v1:1.7633 v2:0.6590 v3:0.7339 v4:0.5625 v5:0.5664 v6:0.6280 v7:0.6304 v8:0.4423 v9:0.5491 v10:0.5662 v11:0.5402 v12:0.4174 v13:0.6626] +step:6400/20000 val_loss:2.1887 val_bpb:1.2963 train_time:428727ms step_avg:66.99ms +step:6600/20000 train_loss:2.1765 train_time:442068ms step_avg:66.98ms +step:6600 shared0_alpha:mean=0.463,std=0.101 shared1_alpha:mean=0.502,std=0.087 shared2_alpha:mean=0.609,std=0.076 shared3_alpha:mean=0.628,std=0.079 eff_mlp_scale:[v0:206.4972 v1:98.8238 v2:99.0504 v3:102.2136 v4:124.7575 v5:125.3240 v6:104.9934 v7:103.2060 v8:128.0406 v9:113.7302 v10:84.1928 v11:86.3358 v12:120.8178 v13:253.9664] eff_attn_scale:[v0:0.5917 v1:6.6428 v2:15.7364 v3:13.8038 v4:18.5908 v5:14.0038 v6:15.4853 v7:14.5097 v8:18.9626 v9:13.7165 v10:13.5601 v11:11.6862 v12:12.3629 v13:1.6633] eff_attn_bias:[v0:1.0165 v1:0.7734 v2:0.6160 v3:0.5939 v4:0.6988 v5:0.7292 v6:0.7126 v7:0.5773 v8:0.5966 v9:0.6381 v10:0.5607 v11:0.3591 v12:0.4475 v13:0.8176] eff_mlp_bias:[v0:0.8010 v1:0.4806 v2:0.6021 v3:0.4640 v4:0.4585 v5:0.5248 v6:0.5386 v7:0.3701 v8:0.4419 v9:0.4558 v10:0.4613 v11:0.3646 v12:0.5662 v13:0.6546] unique_attn_gain_rms:[u0:1.6440 u1:1.5326 u2:1.4999 u3:1.2232 u4:1.6204 u5:1.3812 u6:1.3456 u7:1.2150 u8:1.4010 u9:1.1656 u10:1.0811 u11:1.1952] unique_mlp_gain_rms:[u0:3.4453 u1:3.3234 u2:3.0608 u3:3.0286 u4:2.9807 u5:2.8975 u6:2.7448 u7:2.8634 u8:2.9251 u9:2.8866 u10:2.7385 u11:2.8952] depth_emb_rms:[v0:1.8499 v1:1.8320 v2:0.6718 v3:0.7476 v4:0.5712 v5:0.5744 v6:0.6363 v7:0.6409 v8:0.4487 v9:0.5549 v10:0.5730 v11:0.5477 v12:0.4232 v13:0.6752] +step:6600/20000 val_loss:2.1852 val_bpb:1.2942 train_time:442100ms step_avg:66.98ms +step:6800/20000 train_loss:2.2466 train_time:455439ms step_avg:66.98ms +step:6800 shared0_alpha:mean=0.461,std=0.100 shared1_alpha:mean=0.501,std=0.087 shared2_alpha:mean=0.609,std=0.076 shared3_alpha:mean=0.628,std=0.079 eff_mlp_scale:[v0:210.4566 v1:99.8884 v2:100.0387 v3:103.1215 v4:128.4731 v5:126.5253 v6:106.0111 v7:104.1178 v8:131.8187 v9:115.4266 v10:85.1075 v11:87.1800 v12:123.7892 v13:257.9493] eff_attn_scale:[v0:0.5859 v1:6.7024 v2:15.8890 v3:14.1308 v4:19.0840 v5:14.2743 v6:15.8044 v7:14.7659 v8:19.6537 v9:13.9844 v10:13.7761 v11:11.9080 v12:12.7227 v13:1.6468] eff_attn_bias:[v0:1.0496 v1:0.7900 v2:0.6242 v3:0.6049 v4:0.7126 v5:0.7403 v6:0.7182 v7:0.5856 v8:0.6049 v9:0.6463 v10:0.5690 v11:0.3591 v12:0.4530 v13:0.8286] eff_mlp_bias:[v0:0.8176 v1:0.4861 v2:0.6077 v3:0.4723 v4:0.4668 v5:0.5303 v6:0.5469 v7:0.3729 v8:0.4475 v9:0.4613 v10:0.4696 v11:0.3701 v12:0.5745 v13:0.6657] unique_attn_gain_rms:[u0:1.6642 u1:1.5462 u2:1.5201 u3:1.2416 u4:1.6476 u5:1.4024 u6:1.3641 u7:1.2298 u8:1.4226 u9:1.1790 u10:1.0948 u11:1.2157] unique_mlp_gain_rms:[u0:3.5019 u1:3.3759 u2:3.1087 u3:3.0757 u4:3.0320 u5:2.9419 u6:2.7900 u7:2.9067 u8:2.9728 u9:2.9338 u10:2.7833 u11:2.9446] depth_emb_rms:[v0:1.8984 v1:1.9036 v2:0.6806 v3:0.7607 v4:0.5820 v5:0.5845 v6:0.6455 v7:0.6514 v8:0.4560 v9:0.5647 v10:0.5814 v11:0.5566 v12:0.4279 v13:0.6887] +step:6800/20000 val_loss:2.1837 val_bpb:1.2933 train_time:455470ms step_avg:66.98ms +step:7000/20000 train_loss:2.2750 train_time:468808ms step_avg:66.97ms +step:7000 shared0_alpha:mean=0.460,std=0.101 shared1_alpha:mean=0.501,std=0.087 shared2_alpha:mean=0.610,std=0.077 shared3_alpha:mean=0.629,std=0.080 eff_mlp_scale:[v0:216.3601 v1:100.9060 v2:101.6218 v3:104.5856 v4:131.7294 v5:128.2230 v6:107.6290 v7:105.0860 v8:135.1420 v9:116.5157 v10:86.1032 v11:88.0721 v12:127.6341 v13:260.0441] eff_attn_scale:[v0:0.5898 v1:6.7260 v2:16.3439 v3:14.3240 v4:19.5029 v5:14.4700 v6:16.0872 v7:15.0442 v8:20.1787 v9:14.1792 v10:14.0335 v11:12.0834 v12:13.1306 v13:1.6449] eff_attn_bias:[v0:1.0828 v1:0.8010 v2:0.6325 v3:0.6104 v4:0.7182 v5:0.7458 v6:0.7292 v7:0.5939 v8:0.6104 v9:0.6574 v10:0.5745 v11:0.3674 v12:0.4558 v13:0.8397] eff_mlp_bias:[v0:0.8286 v1:0.4889 v2:0.6187 v3:0.4751 v4:0.4696 v5:0.5359 v6:0.5524 v7:0.3784 v8:0.4530 v9:0.4640 v10:0.4723 v11:0.3757 v12:0.5828 v13:0.6740] unique_attn_gain_rms:[u0:1.6811 u1:1.5679 u2:1.5436 u3:1.2544 u4:1.6776 u5:1.4185 u6:1.3866 u7:1.2498 u8:1.4503 u9:1.1959 u10:1.1131 u11:1.2430] unique_mlp_gain_rms:[u0:3.5549 u1:3.4270 u2:3.1550 u3:3.1228 u4:3.0782 u5:2.9876 u6:2.8267 u7:2.9495 u8:3.0212 u9:2.9817 u10:2.8302 u11:2.9915] depth_emb_rms:[v0:1.9507 v1:1.9755 v2:0.6902 v3:0.7737 v4:0.5890 v5:0.5929 v6:0.6539 v7:0.6606 v8:0.4638 v9:0.5716 v10:0.5882 v11:0.5638 v12:0.4348 v13:0.6998] +step:7000/20000 val_loss:2.1816 val_bpb:1.2920 train_time:468839ms step_avg:66.98ms +step:7200/20000 train_loss:2.2503 train_time:482177ms step_avg:66.97ms +step:7200 shared0_alpha:mean=0.460,std=0.101 shared1_alpha:mean=0.501,std=0.087 shared2_alpha:mean=0.610,std=0.078 shared3_alpha:mean=0.631,std=0.081 eff_mlp_scale:[v0:220.1476 v1:101.8805 v2:102.5046 v3:105.7019 v4:135.0623 v5:129.3099 v6:108.5343 v7:106.7086 v8:138.5433 v9:117.5544 v10:86.9279 v11:89.0916 v12:130.8851 v13:264.0561] eff_attn_scale:[v0:0.5853 v1:6.7123 v2:16.4054 v3:14.4472 v4:19.9782 v5:14.6649 v6:16.3196 v7:15.2543 v8:20.6671 v9:14.3001 v10:14.1722 v11:12.2680 v12:13.4828 v13:1.6256] eff_attn_bias:[v0:1.1159 v1:0.8121 v2:0.6381 v3:0.6187 v4:0.7347 v5:0.7568 v6:0.7347 v7:0.6021 v8:0.6160 v9:0.6684 v10:0.5828 v11:0.3701 v12:0.4640 v13:0.8563] eff_mlp_bias:[v0:0.8397 v1:0.4944 v2:0.6242 v3:0.4834 v4:0.4723 v5:0.5414 v6:0.5580 v7:0.3812 v8:0.4558 v9:0.4668 v10:0.4751 v11:0.3784 v12:0.5883 v13:0.6850] unique_attn_gain_rms:[u0:1.6977 u1:1.5777 u2:1.5575 u3:1.2684 u4:1.6996 u5:1.4384 u6:1.4069 u7:1.2664 u8:1.4683 u9:1.2151 u10:1.1297 u11:1.2658] unique_mlp_gain_rms:[u0:3.6118 u1:3.4803 u2:3.1998 u3:3.1694 u4:3.1261 u5:3.0323 u6:2.8740 u7:2.9946 u8:3.0692 u9:3.0243 u10:2.8757 u11:3.0401] depth_emb_rms:[v0:1.9990 v1:2.0508 v2:0.7002 v3:0.7841 v4:0.5997 v5:0.6002 v6:0.6619 v7:0.6696 v8:0.4677 v9:0.5791 v10:0.5948 v11:0.5716 v12:0.4405 v13:0.7109] +step:7200/20000 val_loss:2.1836 val_bpb:1.2933 train_time:482208ms step_avg:66.97ms +step:7400/20000 train_loss:2.1706 train_time:495549ms step_avg:66.97ms +step:7400 shared0_alpha:mean=0.459,std=0.101 shared1_alpha:mean=0.500,std=0.087 shared2_alpha:mean=0.611,std=0.078 shared3_alpha:mean=0.632,std=0.081 eff_mlp_scale:[v0:224.8569 v1:103.6028 v2:104.1825 v3:107.1893 v4:138.3795 v5:131.1926 v6:110.2514 v7:108.2005 v8:142.6374 v9:119.3684 v10:87.9988 v11:90.5042 v12:134.1217 v13:265.7776] eff_attn_scale:[v0:0.5806 v1:6.7193 v2:16.6461 v3:14.6038 v4:20.4912 v5:14.9073 v6:16.4727 v7:15.4152 v8:21.1909 v9:14.6135 v10:14.3919 v11:12.4133 v12:13.6941 v13:1.6149] eff_attn_bias:[v0:1.1435 v1:0.8231 v2:0.6463 v3:0.6270 v4:0.7403 v5:0.7679 v6:0.7458 v7:0.6104 v8:0.6270 v9:0.6740 v10:0.5883 v11:0.3729 v12:0.4696 v13:0.8673] eff_mlp_bias:[v0:0.8563 v1:0.4999 v2:0.6325 v3:0.4917 v4:0.4778 v5:0.5469 v6:0.5662 v7:0.3839 v8:0.4640 v9:0.4723 v10:0.4778 v11:0.3812 v12:0.5966 v13:0.6961] unique_attn_gain_rms:[u0:1.7088 u1:1.5986 u2:1.5735 u3:1.2840 u4:1.7289 u5:1.4561 u6:1.4251 u7:1.2814 u8:1.4991 u9:1.2307 u10:1.1456 u11:1.2842] unique_mlp_gain_rms:[u0:3.6651 u1:3.5302 u2:3.2428 u3:3.2123 u4:3.1753 u5:3.0746 u6:2.9163 u7:3.0371 u8:3.1178 u9:3.0738 u10:2.9201 u11:3.0843] depth_emb_rms:[v0:2.0477 v1:2.1320 v2:0.7103 v3:0.7986 v4:0.6108 v5:0.6100 v6:0.6743 v7:0.6805 v8:0.4751 v9:0.5885 v10:0.6027 v11:0.5772 v12:0.4458 v13:0.7219] +step:7400/20000 val_loss:2.1791 val_bpb:1.2906 train_time:495580ms step_avg:66.97ms +step:7600/20000 train_loss:2.0486 train_time:509017ms step_avg:66.98ms +step:7600 shared0_alpha:mean=0.458,std=0.102 shared1_alpha:mean=0.500,std=0.088 shared2_alpha:mean=0.612,std=0.079 shared3_alpha:mean=0.633,std=0.082 eff_mlp_scale:[v0:229.3473 v1:104.6783 v2:105.1785 v3:108.1413 v4:142.2222 v5:132.4039 v6:111.2758 v7:109.1568 v8:145.8319 v9:120.5215 v10:88.9190 v11:91.3871 v12:137.8905 v13:269.5320] eff_attn_scale:[v0:0.5915 v1:6.6970 v2:16.7678 v3:14.8717 v4:21.0383 v5:15.0130 v6:16.7678 v7:15.6071 v8:21.6482 v9:14.7922 v10:14.5845 v11:12.5838 v12:14.0256 v13:1.6112] eff_attn_bias:[v0:1.1711 v1:0.8397 v2:0.6546 v3:0.6353 v4:0.7568 v5:0.7734 v6:0.7458 v7:0.6160 v8:0.6325 v9:0.6822 v10:0.5966 v11:0.3784 v12:0.4751 v13:0.8839] eff_mlp_bias:[v0:0.8728 v1:0.4999 v2:0.6381 v3:0.4972 v4:0.4834 v5:0.5497 v6:0.5718 v7:0.3867 v8:0.4668 v9:0.4751 v10:0.4834 v11:0.3867 v12:0.6021 v13:0.7043] unique_attn_gain_rms:[u0:1.7290 u1:1.6080 u2:1.5931 u3:1.2947 u4:1.7529 u5:1.4742 u6:1.4454 u7:1.2987 u8:1.5182 u9:1.2484 u10:1.1651 u11:1.3037] unique_mlp_gain_rms:[u0:3.7185 u1:3.5830 u2:3.2893 u3:3.2603 u4:3.2195 u5:3.1163 u6:2.9611 u7:3.0858 u8:3.1701 u9:3.1226 u10:2.9614 u11:3.1333] depth_emb_rms:[v0:2.0973 v1:2.2141 v2:0.7181 v3:0.8077 v4:0.6197 v5:0.6193 v6:0.6797 v7:0.6880 v8:0.4811 v9:0.5937 v10:0.6104 v11:0.5862 v12:0.4515 v13:0.7332] +step:7600/20000 val_loss:2.1773 val_bpb:1.2895 train_time:509048ms step_avg:66.98ms +step:7800/20000 train_loss:2.1997 train_time:522389ms step_avg:66.97ms +step:7800 shared0_alpha:mean=0.458,std=0.102 shared1_alpha:mean=0.499,std=0.088 shared2_alpha:mean=0.612,std=0.079 shared3_alpha:mean=0.633,std=0.083 eff_mlp_scale:[v0:232.5761 v1:106.2553 v2:106.6845 v3:109.7986 v4:145.8485 v5:134.0976 v6:112.2995 v7:110.3093 v8:150.2682 v9:121.5969 v10:89.8396 v11:92.4351 v12:141.4289 v13:273.5926] eff_attn_scale:[v0:0.5816 v1:6.7191 v2:17.1089 v3:15.0060 v4:21.5785 v5:15.2944 v6:17.0207 v7:15.9129 v8:22.3047 v9:15.0717 v10:14.8160 v11:12.7798 v12:14.5240 v13:1.6334] eff_attn_bias:[v0:1.2043 v1:0.8563 v2:0.6657 v3:0.6463 v4:0.7623 v5:0.7789 v6:0.7568 v7:0.6215 v8:0.6381 v9:0.6905 v10:0.6021 v11:0.3812 v12:0.4806 v13:0.8894] eff_mlp_bias:[v0:0.8894 v1:0.5055 v2:0.6491 v3:0.5055 v4:0.4889 v5:0.5524 v6:0.5773 v7:0.3895 v8:0.4723 v9:0.4751 v10:0.4861 v11:0.3895 v12:0.6104 v13:0.7126] unique_attn_gain_rms:[u0:1.7412 u1:1.6236 u2:1.6057 u3:1.3111 u4:1.7812 u5:1.4953 u6:1.4626 u7:1.3141 u8:1.5416 u9:1.2633 u10:1.1782 u11:1.3275] unique_mlp_gain_rms:[u0:3.7653 u1:3.6280 u2:3.3328 u3:3.3074 u4:3.2691 u5:3.1612 u6:3.0030 u7:3.1272 u8:3.2189 u9:3.1688 u10:3.0065 u11:3.1762] depth_emb_rms:[v0:2.1412 v1:2.2946 v2:0.7297 v3:0.8216 v4:0.6302 v5:0.6267 v6:0.6862 v7:0.6976 v8:0.4884 v9:0.6023 v10:0.6160 v11:0.5913 v12:0.4560 v13:0.7450] +step:7800/20000 val_loss:2.1747 val_bpb:1.2880 train_time:522419ms step_avg:66.98ms +step:8000/20000 train_loss:2.1583 train_time:535763ms step_avg:66.97ms +step:8000 shared0_alpha:mean=0.457,std=0.102 shared1_alpha:mean=0.499,std=0.087 shared2_alpha:mean=0.614,std=0.079 shared3_alpha:mean=0.635,std=0.083 eff_mlp_scale:[v0:236.5387 v1:108.0674 v2:107.8843 v3:111.0637 v4:150.5066 v5:135.5131 v6:114.0492 v7:111.5778 v8:154.2692 v9:122.9338 v10:90.9311 v11:93.5814 v12:145.2389 v13:277.8133] eff_attn_scale:[v0:0.5794 v1:6.7329 v2:17.3548 v3:15.2788 v4:22.1015 v5:15.5604 v6:17.1768 v7:16.1972 v8:22.9475 v9:15.3360 v10:15.1298 v11:13.0246 v12:14.9106 v13:1.6255] eff_attn_bias:[v0:1.2319 v1:0.8673 v2:0.6712 v3:0.6574 v4:0.7734 v5:0.7900 v6:0.7623 v7:0.6270 v8:0.6436 v9:0.6961 v10:0.6049 v11:0.3867 v12:0.4889 v13:0.9005] eff_mlp_bias:[v0:0.8949 v1:0.5110 v2:0.6519 v3:0.5165 v4:0.4917 v5:0.5580 v6:0.5800 v7:0.3950 v8:0.4751 v9:0.4778 v10:0.4889 v11:0.3895 v12:0.6187 v13:0.7237] unique_attn_gain_rms:[u0:1.7512 u1:1.6381 u2:1.6204 u3:1.3228 u4:1.7981 u5:1.5086 u6:1.4773 u7:1.3292 u8:1.5611 u9:1.2805 u10:1.1871 u11:1.3465] unique_mlp_gain_rms:[u0:3.8082 u1:3.6679 u2:3.3692 u3:3.3457 u4:3.3121 u5:3.1970 u6:3.0384 u7:3.1673 u8:3.2595 u9:3.2099 u10:3.0428 u11:3.2157] depth_emb_rms:[v0:2.1791 v1:2.3638 v2:0.7388 v3:0.8303 v4:0.6424 v5:0.6353 v6:0.6953 v7:0.7039 v8:0.4945 v9:0.6110 v10:0.6202 v11:0.5975 v12:0.4602 v13:0.7554] +step:8000/20000 val_loss:2.1646 val_bpb:1.2820 train_time:535795ms step_avg:66.97ms +step:8200/20000 train_loss:2.2183 train_time:549142ms step_avg:66.97ms +step:8200 shared0_alpha:mean=0.457,std=0.103 shared1_alpha:mean=0.499,std=0.088 shared2_alpha:mean=0.614,std=0.080 shared3_alpha:mean=0.634,std=0.084 eff_mlp_scale:[v0:239.6668 v1:109.3551 v2:109.5404 v3:112.0673 v4:153.6738 v5:137.5572 v6:114.7074 v7:112.5837 v8:156.7320 v9:124.3195 v10:91.4559 v11:93.9919 v12:148.3220 v13:280.3171] eff_attn_scale:[v0:0.5764 v1:6.7191 v2:17.5320 v3:15.4470 v4:22.5962 v5:15.8540 v6:17.5320 v7:16.4600 v8:23.5646 v9:15.6275 v10:15.3742 v11:13.3368 v12:15.1717 v13:1.6557] eff_attn_bias:[v0:1.2430 v1:0.8673 v2:0.6767 v3:0.6602 v4:0.7789 v5:0.8010 v6:0.7623 v7:0.6325 v8:0.6491 v9:0.7016 v10:0.6077 v11:0.3895 v12:0.4917 v13:0.9060] eff_mlp_bias:[v0:0.9005 v1:0.5110 v2:0.6519 v3:0.5165 v4:0.4944 v5:0.5635 v6:0.5856 v7:0.3977 v8:0.4806 v9:0.4806 v10:0.4917 v11:0.3950 v12:0.6242 v13:0.7292] unique_attn_gain_rms:[u0:1.7576 u1:1.6468 u2:1.6268 u3:1.3316 u4:1.8117 u5:1.5171 u6:1.4843 u7:1.3384 u8:1.5708 u9:1.2876 u10:1.1934 u11:1.3598] unique_mlp_gain_rms:[u0:3.8403 u1:3.6946 u2:3.3972 u3:3.3735 u4:3.3396 u5:3.2258 u6:3.0609 u7:3.1910 u8:3.2877 u9:3.2366 u10:3.0693 u11:3.2441] depth_emb_rms:[v0:2.2065 v1:2.4172 v2:0.7451 v3:0.8367 v4:0.6471 v5:0.6401 v6:0.7035 v7:0.7136 v8:0.5002 v9:0.6198 v10:0.6263 v11:0.6017 v12:0.4661 v13:0.7650] +step:8200/20000 val_loss:2.1557 val_bpb:1.2767 train_time:549173ms step_avg:66.97ms +step:8400/20000 train_loss:2.1625 train_time:562577ms step_avg:66.97ms +step:8400 shared0_alpha:mean=0.457,std=0.103 shared1_alpha:mean=0.499,std=0.088 shared2_alpha:mean=0.614,std=0.080 shared3_alpha:mean=0.635,std=0.084 eff_mlp_scale:[v0:240.8477 v1:109.9236 v2:110.1829 v3:113.2649 v4:155.7668 v5:138.2723 v6:115.8999 v7:113.2649 v8:158.8667 v9:124.9658 v10:91.9923 v11:94.5606 v12:151.1171 v13:282.6439] eff_attn_scale:[v0:0.5841 v1:6.7621 v2:17.7993 v3:15.6660 v4:22.9443 v5:16.1075 v6:17.7993 v7:16.6025 v8:23.9276 v9:15.7276 v10:15.5289 v11:13.5374 v12:15.4054 v13:1.6799] eff_attn_bias:[v0:1.2540 v1:0.8728 v2:0.6767 v3:0.6602 v4:0.7789 v5:0.8010 v6:0.7679 v7:0.6353 v8:0.6519 v9:0.7071 v10:0.6104 v11:0.3922 v12:0.4944 v13:0.9115] eff_mlp_bias:[v0:0.9005 v1:0.5165 v2:0.6574 v3:0.5193 v4:0.4944 v5:0.5662 v6:0.5911 v7:0.4005 v8:0.4806 v9:0.4834 v10:0.4944 v11:0.3977 v12:0.6298 v13:0.7347] unique_attn_gain_rms:[u0:1.7620 u1:1.6473 u2:1.6333 u3:1.3377 u4:1.8209 u5:1.5264 u6:1.4902 u7:1.3448 u8:1.5778 u9:1.2902 u10:1.1972 u11:1.3708] unique_mlp_gain_rms:[u0:3.8581 u1:3.7131 u2:3.4148 u3:3.3913 u4:3.3594 u5:3.2455 u6:3.0787 u7:3.2079 u8:3.3091 u9:3.2549 u10:3.0851 u11:3.2597] depth_emb_rms:[v0:2.2219 v1:2.4470 v2:0.7514 v3:0.8452 v4:0.6521 v5:0.6455 v6:0.7081 v7:0.7180 v8:0.5046 v9:0.6222 v10:0.6298 v11:0.6061 v12:0.4697 v13:0.7720] +step:8400/20000 val_loss:2.1462 val_bpb:1.2711 train_time:562608ms step_avg:66.98ms +step:8600/20000 train_loss:2.1546 train_time:575948ms step_avg:66.97ms +step:8600 shared0_alpha:mean=0.457,std=0.102 shared1_alpha:mean=0.499,std=0.087 shared2_alpha:mean=0.614,std=0.080 shared3_alpha:mean=0.635,std=0.084 eff_mlp_scale:[v0:241.7603 v1:110.4992 v2:111.3030 v3:114.0114 v4:158.2322 v5:138.9963 v6:116.5285 v7:114.0114 v8:161.3656 v9:125.6201 v10:93.0138 v11:95.7068 v12:152.7489 v13:284.5168] eff_attn_scale:[v0:0.5914 v1:6.7739 v2:17.8120 v3:15.7398 v4:23.1425 v5:16.1355 v6:17.8120 v7:16.7663 v8:24.2445 v9:15.7550 v10:15.5400 v11:13.5157 v12:15.5385 v13:1.7088] eff_attn_bias:[v0:1.2595 v1:0.8784 v2:0.6795 v3:0.6629 v4:0.7789 v5:0.8010 v6:0.7679 v7:0.6381 v8:0.6546 v9:0.7071 v10:0.6132 v11:0.3922 v12:0.4972 v13:0.9170] eff_mlp_bias:[v0:0.9060 v1:0.5193 v2:0.6602 v3:0.5193 v4:0.4972 v5:0.5690 v6:0.5911 v7:0.4005 v8:0.4834 v9:0.4834 v10:0.4944 v11:0.4005 v12:0.6298 v13:0.7403] unique_attn_gain_rms:[u0:1.7612 u1:1.6497 u2:1.6340 u3:1.3404 u4:1.8216 u5:1.5275 u6:1.4939 u7:1.3485 u8:1.5807 u9:1.2904 u10:1.1977 u11:1.3780] unique_mlp_gain_rms:[u0:3.8672 u1:3.7229 u2:3.4217 u3:3.3996 u4:3.3651 u5:3.2561 u6:3.0864 u7:3.2186 u8:3.3175 u9:3.2639 u10:3.0962 u11:3.2702] depth_emb_rms:[v0:2.2282 v1:2.4656 v2:0.7550 v3:0.8480 v4:0.6541 v5:0.6465 v6:0.7116 v7:0.7221 v8:0.5076 v9:0.6248 v10:0.6347 v11:0.6098 v12:0.4712 v13:0.7771] +step:8600/20000 val_loss:2.1362 val_bpb:1.2652 train_time:575978ms step_avg:66.97ms +step:8800/20000 train_loss:2.1182 train_time:589318ms step_avg:66.97ms +step:8800 shared0_alpha:mean=0.457,std=0.102 shared1_alpha:mean=0.499,std=0.087 shared2_alpha:mean=0.614,std=0.080 shared3_alpha:mean=0.635,std=0.084 eff_mlp_scale:[v0:242.0745 v1:110.9861 v2:111.7642 v3:114.4149 v4:158.9466 v5:139.6089 v6:117.0113 v7:114.4149 v8:162.0941 v9:126.1737 v10:93.3992 v11:96.0456 v12:153.4386 v13:285.8616] eff_attn_scale:[v0:0.5859 v1:6.7636 v2:17.7491 v3:15.7356 v4:23.2204 v5:16.1109 v6:17.8401 v7:16.7618 v8:24.3262 v9:15.7310 v10:15.5646 v11:13.5121 v12:15.4803 v13:1.7189] eff_attn_bias:[v0:1.2595 v1:0.8784 v2:0.6795 v3:0.6629 v4:0.7789 v5:0.8010 v6:0.7679 v7:0.6381 v8:0.6546 v9:0.7071 v10:0.6104 v11:0.3950 v12:0.4972 v13:0.9170] eff_mlp_bias:[v0:0.9005 v1:0.5193 v2:0.6602 v3:0.5220 v4:0.4972 v5:0.5690 v6:0.5911 v7:0.4033 v8:0.4834 v9:0.4834 v10:0.4972 v11:0.4005 v12:0.6325 v13:0.7403] unique_attn_gain_rms:[u0:1.7584 u1:1.6487 u2:1.6362 u3:1.3410 u4:1.8240 u5:1.5271 u6:1.4930 u7:1.3489 u8:1.5805 u9:1.2916 u10:1.1962 u11:1.3792] unique_mlp_gain_rms:[u0:3.8699 u1:3.7243 u2:3.4252 u3:3.4020 u4:3.3674 u5:3.2594 u6:3.0892 u7:3.2230 u8:3.3219 u9:3.2680 u10:3.0998 u11:3.2744] depth_emb_rms:[v0:2.2305 v1:2.4721 v2:0.7561 v3:0.8506 v4:0.6559 v5:0.6497 v6:0.7141 v7:0.7224 v8:0.5093 v9:0.6276 v10:0.6371 v11:0.6115 v12:0.4726 v13:0.7789] +step:8800/20000 val_loss:2.1285 val_bpb:1.2606 train_time:589348ms step_avg:66.97ms +step:8961/20000 val_loss:2.1231 val_bpb:1.2574 train_time:600056ms step_avg:66.96ms +stopping_early: wallclock_cap train_time:600056ms step:8961/20000 +peak memory allocated: 16289 MiB reserved: 16472 MiB +Serialized model: 45287039 bytes +Code size: 63793 bytes +Total submission size: 45350832 bytes +Serialized model int8+zlib: 10806903 bytes (payload:11745472 raw_torch:11778157 payload_ratio:3.85x) +Total submission size int8+zlib: 10870696 bytes +final_int8_zlib_roundtrip val_loss:2.1348 val_bpb:1.2643 eval_time:2126ms +final_int8_zlib_roundtrip_exact val_loss:2.13479541 val_bpb:1.26434609 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_R.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_R.txt new file mode 100644 index 0000000000..91f1738b4f --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_R.txt @@ -0,0 +1,1697 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) + use_depth_embed = bool(int(os.environ.get("USE_DEPTH_EMBED", "0"))) + use_unique_norms = bool(int(os.environ.get("USE_UNIQUE_NORMS", "0"))) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta,depth_embed,unique_attn_gain,unique_mlp_gain", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0, use_bias: bool = False): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + if use_bias: + self.attn_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + self.mlp_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + else: + self.attn_beta = None + self.mlp_beta = None + + def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + ab = self.attn_beta[v] if self.attn_beta is not None else None + mb = self.mlp_beta[v] if self.mlp_beta is not None else None + return ag, mg, ab, mb + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None, + ts_mlp_beta: Tensor | None = None, + depth_emb: Tensor | None = None, + ext_attn_gain: Tensor | None = None, + ext_mlp_gain: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + if depth_emb is not None: + x = x + depth_emb + attn_normed = self.attn_norm(x) + if ext_attn_gain is not None: + attn_normed = attn_normed * ext_attn_gain[None, None, :] + attn_out = self.attn(attn_normed) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + if self.use_peri_norm: + if ext_mlp_gain is not None: + m_input = F.rms_norm(x, (x.size(-1),)) * ext_mlp_gain[None, None, :] + else: + m_input = x + mlp_out = self.mlp_out_norm(self.mlp(m_input)) + else: + mlp_normed = self.mlp_norm(x) + if ext_mlp_gain is not None: + mlp_normed = mlp_normed * ext_mlp_gain[None, None, :] + mlp_out = self.mlp(mlp_normed) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + if ts_mlp_beta is not None: + x = x + ts_mlp_beta[None, None, :] + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + use_timestep_bias: bool = False, + use_depth_embed: bool = False, + use_unique_norms: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None + self.use_depth_embed = use_depth_embed + if self.use_depth_embed: + self.depth_embeddings = nn.Parameter(torch.zeros(effective_layers, model_dim, dtype=torch.float32)) + self.use_unique_norms = use_unique_norms + if self.use_unique_norms: + num_unique = num_shared * self.num_loops + self.unique_attn_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + self.unique_mlp_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + self.use_depth_embed = False + self.use_unique_norms = False + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None, None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) + v += 1 + uid = 0 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + if self.use_unique_norms: + ag_n = self.unique_attn_gains[uid].to(dtype=x.dtype) + mg_n = self.unique_mlp_gains[uid].to(dtype=x.dtype) + x = block(x, x0, ag, mg, ab, mb, de, ag_n, mg_n) + else: + x = block(x, x0, ag, mg, ab, mb, de) + uid += 1 + v += 1 + for block in self.coda_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + effective_count = gpt.num_prelude + len(gpt.shared_blocks) * gpt.num_loops + gpt.num_coda + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + + # Prelude blocks + for block in gpt.prelude_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Shared positions + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Coda blocks + for block in gpt.coda_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + if gpt.timestep_scale is not None and gpt.timestep_scale.attn_beta is not None: + attn_bias_norms: list[str] = [] + mlp_bias_norms: list[str] = [] + for vi in range(effective_count): + ab_rms = gpt.timestep_scale.attn_beta[vi].norm().item() / gpt.timestep_scale.attn_beta[vi].numel() ** 0.5 + mb_rms = gpt.timestep_scale.mlp_beta[vi].norm().item() / gpt.timestep_scale.mlp_beta[vi].numel() ** 0.5 + attn_bias_norms.append(f"v{vi}:{ab_rms:.4f}") + mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") + parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") + parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") + if gpt.use_unique_norms: + un_attn: list[str] = [] + un_mlp: list[str] = [] + for ui in range(gpt.unique_attn_gains.size(0)): + an_rms = gpt.unique_attn_gains[ui].norm().item() / gpt.unique_attn_gains[ui].numel() ** 0.5 + un_attn.append(f"u{ui}:{an_rms:.4f}") + mn_rms = gpt.unique_mlp_gains[ui].norm().item() / gpt.unique_mlp_gains[ui].numel() ** 0.5 + un_mlp.append(f"u{ui}:{mn_rms:.4f}") + parts.append("unique_attn_gain_rms:[" + " ".join(un_attn) + "]") + parts.append("unique_mlp_gain_rms:[" + " ".join(un_mlp) + "]") + if gpt.use_depth_embed: + de_norms: list[str] = [] + for vi in range(effective_count): + de_rms = gpt.depth_embeddings[vi].norm().item() / gpt.depth_embeddings[vi].numel() ** 0.5 + de_norms.append(f"v{vi}:{de_rms:.4f}") + parts.append("depth_emb_rms:[" + " ".join(de_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + use_timestep_bias=args.use_timestep_bias, + use_depth_embed=args.use_depth_embed, + use_unique_norms=args.use_unique_norms, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + if base_model.use_unique_norms: + block_named_params.extend([("unique_attn_gains", base_model.unique_attn_gains)]) + block_named_params.extend([("unique_mlp_gains", base_model.unique_mlp_gains)]) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + if base_model.use_depth_embed: + scalar_params.append(base_model.depth_embeddings) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + num_shared = len(base_model.shared_blocks) + eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{num_shared} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + log0(f"depth_embed:{'enabled' if base_model.use_depth_embed else 'disabled'}") + log0(f"unique_norms:{'enabled' if base_model.use_unique_norms else 'disabled'}") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Apr 2 15:02:21 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 45C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 36C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 34C P0 120W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 43C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 44C P0 127W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 35C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 43C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11579440 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:3 coda:1 effective_layers:14 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:28672 +depth_embed:enabled +unique_norms:disabled +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9377 train_time:39ms step_avg:39.49ms +step:2/20000 train_loss:9.2613 train_time:99ms step_avg:49.29ms +step:3/20000 train_loss:8.5831 train_time:160ms step_avg:53.19ms +step:4/20000 train_loss:10.0218 train_time:221ms step_avg:55.26ms +step:5/20000 train_loss:9.5785 train_time:283ms step_avg:56.57ms +step:6/20000 train_loss:9.1687 train_time:344ms step_avg:57.33ms +step:7/20000 train_loss:7.6094 train_time:407ms step_avg:58.12ms +step:8/20000 train_loss:6.6694 train_time:468ms step_avg:58.50ms +step:9/20000 train_loss:6.0951 train_time:530ms step_avg:58.84ms +step:10/20000 train_loss:5.7431 train_time:591ms step_avg:59.05ms +step:200/20000 train_loss:2.7843 train_time:12424ms step_avg:62.12ms +step:200 shared0_alpha:mean=0.457,std=0.046 shared1_alpha:mean=0.468,std=0.038 shared2_alpha:mean=0.480,std=0.036 shared3_alpha:mean=0.505,std=0.039 eff_mlp_scale:[v0:38.6956 v1:26.8032 v2:28.2976 v3:30.6838 v4:33.4271 v5:32.1342 v6:28.5849 v7:29.7857 v8:32.6312 v9:31.5419 v10:29.7340 v11:32.1805 v12:36.1331 v13:58.6540] eff_attn_scale:[v0:15.3561 v1:9.1967 v2:10.7374 v3:9.7527 v4:10.2931 v5:8.4148 v6:9.3658 v7:8.6098 v8:9.4288 v9:8.5265 v10:9.3266 v11:9.0288 v12:9.7038 v13:15.2747] eff_attn_bias:[v0:0.1340 v1:0.1098 v2:0.1091 v3:0.1112 v4:0.1250 v5:0.1264 v6:0.1174 v7:0.1146 v8:0.1236 v9:0.1257 v10:0.1229 v11:0.1139 v12:0.1119 v13:0.1126] eff_mlp_bias:[v0:0.1057 v1:0.1036 v2:0.0994 v3:0.1029 v4:0.1167 v5:0.1084 v6:0.1050 v7:0.1043 v8:0.1174 v9:0.1098 v10:0.1050 v11:0.1001 v12:0.1160 v13:0.1864] depth_emb_rms:[v0:0.1296 v1:0.1069 v2:0.1041 v3:0.1004 v4:0.1032 v5:0.1178 v6:0.1098 v7:0.1060 v8:0.1048 v9:0.1180 v10:0.1120 v11:0.1072 v12:0.1020 v13:0.1179] +step:200/20000 val_loss:2.7761 val_bpb:1.6441 train_time:12485ms step_avg:62.42ms +step:400/20000 train_loss:2.3659 train_time:24958ms step_avg:62.39ms +step:400 shared0_alpha:mean=0.468,std=0.048 shared1_alpha:mean=0.482,std=0.041 shared2_alpha:mean=0.505,std=0.039 shared3_alpha:mean=0.536,std=0.045 eff_mlp_scale:[v0:46.8796 v1:31.5451 v2:36.4508 v3:40.1127 v4:41.0147 v5:44.4734 v6:39.7495 v7:40.1127 v8:39.8283 v9:42.4049 v10:38.1001 v11:39.2700 v12:39.4894 v13:74.2800] eff_attn_scale:[v0:6.7533 v1:5.3036 v2:5.8570 v3:5.3215 v4:5.6249 v5:5.2481 v6:5.5821 v7:5.0373 v8:5.4426 v9:5.4147 v10:5.1421 v11:4.5207 v12:4.6874 v13:8.8097] eff_attn_bias:[v0:0.1568 v1:0.1422 v2:0.1478 v3:0.1692 v4:0.1823 v5:0.1706 v6:0.1650 v7:0.1623 v8:0.1733 v9:0.1588 v10:0.1512 v11:0.1291 v12:0.1257 v13:0.1222] eff_mlp_bias:[v0:0.1809 v1:0.1291 v2:0.1298 v3:0.1333 v4:0.1478 v5:0.1409 v6:0.1264 v7:0.1257 v8:0.1422 v9:0.1388 v10:0.1298 v11:0.1250 v12:0.1402 v13:0.2486] depth_emb_rms:[v0:0.1700 v1:0.1864 v2:0.1324 v3:0.1326 v4:0.1344 v5:0.1508 v6:0.1445 v7:0.1297 v8:0.1271 v9:0.1448 v10:0.1429 v11:0.1340 v12:0.1300 v13:0.1443] +step:400/20000 val_loss:2.5659 val_bpb:1.5197 train_time:24985ms step_avg:62.46ms +step:600/20000 train_loss:2.5846 train_time:37480ms step_avg:62.47ms +step:600 shared0_alpha:mean=0.474,std=0.047 shared1_alpha:mean=0.491,std=0.043 shared2_alpha:mean=0.524,std=0.040 shared3_alpha:mean=0.558,std=0.047 eff_mlp_scale:[v0:53.4933 v1:35.7252 v2:41.4469 v3:45.3333 v4:44.7265 v5:52.4714 v6:47.0668 v7:46.5778 v8:44.3770 v9:49.4943 v10:42.3250 v11:42.6667 v12:41.4069 v13:90.0951] eff_attn_scale:[v0:3.3362 v1:3.0608 v2:3.4289 v3:3.2435 v4:3.5282 v5:3.2209 v6:3.4289 v7:3.1508 v8:3.5472 v9:3.3809 v10:3.0197 v11:2.6689 v12:2.8074 v13:5.9203] eff_attn_bias:[v0:0.1851 v1:0.1823 v2:0.1892 v3:0.2141 v4:0.2348 v5:0.2210 v6:0.2237 v7:0.2141 v8:0.2210 v9:0.1961 v10:0.1823 v11:0.1464 v12:0.1395 v13:0.1533] eff_mlp_bias:[v0:0.2679 v1:0.1609 v2:0.1637 v3:0.1650 v4:0.1892 v5:0.1864 v6:0.1588 v7:0.1526 v8:0.1795 v9:0.1795 v10:0.1609 v11:0.1581 v12:0.1685 v13:0.2638] depth_emb_rms:[v0:0.2441 v1:0.2752 v2:0.1673 v3:0.1687 v4:0.1680 v5:0.1957 v6:0.1924 v7:0.1643 v8:0.1550 v9:0.1847 v10:0.1863 v11:0.1680 v12:0.1652 v13:0.1739] +step:600/20000 val_loss:2.4892 val_bpb:1.4742 train_time:37507ms step_avg:62.51ms +step:800/20000 train_loss:2.3434 train_time:50029ms step_avg:62.54ms +step:800 shared0_alpha:mean=0.475,std=0.047 shared1_alpha:mean=0.497,std=0.046 shared2_alpha:mean=0.537,std=0.042 shared3_alpha:mean=0.572,std=0.050 eff_mlp_scale:[v0:59.1090 v1:40.0983 v2:45.5213 v3:49.6694 v4:47.7353 v5:58.6804 v6:51.7122 v7:51.1411 v8:48.0942 v9:54.3772 v10:45.1571 v11:45.2544 v12:43.4283 v13:103.2616] eff_attn_scale:[v0:2.1123 v1:2.2542 v2:2.4265 v3:2.4114 v4:2.6283 v5:2.4187 v6:2.4732 v7:2.3366 v8:2.7060 v9:2.5504 v10:2.1310 v11:1.8947 v12:2.0373 v13:4.5444] eff_attn_bias:[v0:0.2044 v1:0.2182 v2:0.2168 v3:0.2514 v4:0.2762 v5:0.2679 v6:0.2748 v7:0.2583 v8:0.2610 v9:0.2320 v10:0.2099 v11:0.1650 v12:0.1574 v13:0.1920] eff_mlp_bias:[v0:0.3439 v1:0.1892 v2:0.1961 v3:0.1920 v4:0.2251 v5:0.2306 v6:0.1878 v7:0.1823 v8:0.2127 v9:0.2154 v10:0.1892 v11:0.1892 v12:0.1933 v13:0.2735] depth_emb_rms:[v0:0.3053 v1:0.3542 v2:0.1983 v3:0.2030 v4:0.1962 v5:0.2334 v6:0.2390 v7:0.1964 v8:0.1863 v9:0.2212 v10:0.2260 v11:0.1986 v12:0.1982 v13:0.2001] +step:800/20000 val_loss:2.4265 val_bpb:1.4371 train_time:50056ms step_avg:62.57ms +step:1000/20000 train_loss:2.4221 train_time:62570ms step_avg:62.57ms +step:1000 shared0_alpha:mean=0.475,std=0.048 shared1_alpha:mean=0.500,std=0.046 shared2_alpha:mean=0.547,std=0.043 shared3_alpha:mean=0.582,std=0.051 eff_mlp_scale:[v0:65.7357 v1:44.3223 v2:49.2417 v3:53.4368 v4:50.8610 v5:63.8404 v6:56.3836 v7:54.9527 v8:51.2296 v9:58.9608 v10:47.9261 v11:47.9415 v12:45.7012 v13:114.1758] eff_attn_scale:[v0:1.6259 v1:1.8768 v2:2.0141 v3:1.9709 v4:2.1924 v5:1.9931 v6:1.9869 v7:1.9444 v8:2.2476 v9:2.1241 v10:1.6943 v11:1.5476 v12:1.6547 v13:3.7794] eff_attn_bias:[v0:0.2196 v1:0.2444 v2:0.2458 v3:0.2859 v4:0.3080 v5:0.3025 v6:0.3176 v7:0.2942 v8:0.2914 v9:0.2638 v10:0.2306 v11:0.1809 v12:0.1754 v13:0.2348] eff_mlp_bias:[v0:0.4088 v1:0.2141 v2:0.2237 v3:0.2210 v4:0.2458 v5:0.2665 v6:0.2168 v7:0.2058 v8:0.2389 v9:0.2486 v10:0.2127 v11:0.2141 v12:0.2141 v13:0.2859] depth_emb_rms:[v0:0.3618 v1:0.4252 v2:0.2262 v3:0.2343 v4:0.2272 v5:0.2584 v6:0.2792 v7:0.2280 v8:0.2133 v9:0.2506 v10:0.2611 v11:0.2244 v12:0.2261 v13:0.2230] +step:1000/20000 val_loss:2.3861 val_bpb:1.4132 train_time:62598ms step_avg:62.60ms +step:1200/20000 train_loss:2.4363 train_time:75121ms step_avg:62.60ms +step:1200 shared0_alpha:mean=0.476,std=0.048 shared1_alpha:mean=0.504,std=0.047 shared2_alpha:mean=0.555,std=0.044 shared3_alpha:mean=0.590,std=0.051 eff_mlp_scale:[v0:72.5459 v1:47.6534 v2:52.8417 v3:57.1090 v4:53.1888 v5:68.4362 v6:59.7844 v7:58.2745 v8:53.9433 v9:62.9781 v10:50.1418 v11:50.1160 v12:48.2849 v13:124.6918] eff_attn_scale:[v0:1.3351 v1:1.6068 v2:1.7141 v3:1.7237 v4:1.8943 v5:1.7122 v6:1.7017 v7:1.6870 v8:1.9448 v9:1.7912 v10:1.4346 v11:1.3203 v12:1.4081 v13:3.2801] eff_attn_bias:[v0:0.2348 v1:0.2721 v2:0.2679 v3:0.3121 v4:0.3384 v5:0.3301 v6:0.3536 v7:0.3246 v8:0.3190 v9:0.2928 v10:0.2486 v11:0.1961 v12:0.1933 v13:0.2735] eff_mlp_bias:[v0:0.4696 v1:0.2403 v2:0.2514 v3:0.2444 v4:0.2652 v5:0.2983 v6:0.2403 v7:0.2279 v8:0.2596 v9:0.2762 v10:0.2320 v11:0.2362 v12:0.2362 v13:0.2942] depth_emb_rms:[v0:0.4156 v1:0.4889 v2:0.2548 v3:0.2632 v4:0.2522 v5:0.2798 v6:0.3129 v7:0.2540 v8:0.2371 v9:0.2745 v10:0.2931 v11:0.2478 v12:0.2509 v13:0.2461] +step:1200/20000 val_loss:2.3528 val_bpb:1.3934 train_time:75148ms step_avg:62.62ms +step:1400/20000 train_loss:2.4833 train_time:87655ms step_avg:62.61ms +step:1400 shared0_alpha:mean=0.474,std=0.048 shared1_alpha:mean=0.507,std=0.049 shared2_alpha:mean=0.560,std=0.045 shared3_alpha:mean=0.595,std=0.053 eff_mlp_scale:[v0:78.2452 v1:50.8255 v2:56.0025 v3:59.9916 v4:55.9421 v5:72.3618 v6:63.1014 v7:61.5808 v8:56.3279 v9:66.3316 v10:52.0587 v11:52.4430 v12:50.9266 v13:133.7384] eff_attn_scale:[v0:1.1608 v1:1.4695 v2:1.5332 v3:1.5607 v4:1.7436 v5:1.5380 v6:1.5215 v7:1.5145 v8:1.7678 v9:1.5878 v10:1.2698 v11:1.1792 v12:1.2774 v13:2.9704] eff_attn_bias:[v0:0.2472 v1:0.2942 v2:0.2873 v3:0.3356 v4:0.3646 v5:0.3536 v6:0.3839 v7:0.3494 v8:0.3439 v9:0.3121 v10:0.2652 v11:0.2085 v12:0.2085 v13:0.3107] eff_mlp_bias:[v0:0.5220 v1:0.2610 v2:0.2748 v3:0.2679 v4:0.2817 v5:0.3273 v6:0.2624 v7:0.2486 v8:0.2804 v9:0.2997 v10:0.2486 v11:0.2555 v12:0.2514 v13:0.3052] depth_emb_rms:[v0:0.4640 v1:0.5449 v2:0.2785 v3:0.2898 v4:0.2772 v5:0.3006 v6:0.3452 v7:0.2793 v8:0.2599 v9:0.2985 v10:0.3193 v11:0.2665 v12:0.2715 v13:0.2644] +step:1400/20000 val_loss:2.3306 val_bpb:1.3803 train_time:87683ms step_avg:62.63ms +step:1600/20000 train_loss:2.1520 train_time:100185ms step_avg:62.62ms +step:1600 shared0_alpha:mean=0.473,std=0.049 shared1_alpha:mean=0.509,std=0.049 shared2_alpha:mean=0.565,std=0.045 shared3_alpha:mean=0.598,std=0.054 eff_mlp_scale:[v0:84.0843 v1:54.3045 v2:59.1122 v3:63.1090 v4:58.1901 v5:75.6306 v6:65.9483 v7:64.3227 v8:58.5833 v9:69.0349 v10:53.8846 v11:54.6136 v12:53.0788 v13:141.4585] eff_attn_scale:[v0:1.0353 v1:1.3492 v2:1.4371 v3:1.4582 v4:1.6103 v5:1.3846 v6:1.4258 v7:1.4192 v8:1.6219 v9:1.4494 v10:1.1645 v11:1.0742 v12:1.1643 v13:2.7388] eff_attn_bias:[v0:0.2583 v1:0.3107 v2:0.3066 v3:0.3591 v4:0.3922 v5:0.3729 v6:0.4060 v7:0.3729 v8:0.3646 v9:0.3315 v10:0.2776 v11:0.2196 v12:0.2237 v13:0.3466] eff_mlp_bias:[v0:0.5773 v1:0.2776 v2:0.2969 v3:0.2914 v4:0.3011 v5:0.3522 v6:0.2831 v7:0.2679 v8:0.3011 v9:0.3204 v10:0.2638 v11:0.2721 v12:0.2679 v13:0.3163] depth_emb_rms:[v0:0.5115 v1:0.6054 v2:0.2977 v3:0.3153 v4:0.3025 v5:0.3233 v6:0.3730 v7:0.3026 v8:0.2821 v9:0.3234 v10:0.3441 v11:0.2850 v12:0.2904 v13:0.2823] +step:1600/20000 val_loss:2.3182 val_bpb:1.3729 train_time:100212ms step_avg:62.63ms +step:1800/20000 train_loss:2.2545 train_time:112713ms step_avg:62.62ms +step:1800 shared0_alpha:mean=0.472,std=0.050 shared1_alpha:mean=0.513,std=0.050 shared2_alpha:mean=0.569,std=0.046 shared3_alpha:mean=0.602,std=0.054 eff_mlp_scale:[v0:89.2811 v1:57.5202 v2:62.1651 v3:65.9594 v4:60.9152 v5:79.5397 v6:68.7088 v7:66.7839 v8:60.9152 v9:71.9003 v10:56.0304 v11:56.8900 v12:55.7054 v13:149.2826] eff_attn_scale:[v0:0.9440 v1:1.2637 v2:1.3389 v3:1.3677 v4:1.5134 v5:1.2981 v6:1.2953 v7:1.3193 v8:1.5246 v9:1.3498 v10:1.0667 v11:0.9962 v12:1.0874 v13:2.5630] eff_attn_bias:[v0:0.2790 v1:0.3301 v2:0.3246 v3:0.3784 v4:0.4143 v5:0.3922 v6:0.4309 v7:0.3977 v8:0.3867 v9:0.3466 v10:0.2900 v11:0.2306 v12:0.2389 v13:0.3839] eff_mlp_bias:[v0:0.6242 v1:0.2983 v2:0.3218 v3:0.3121 v4:0.3190 v5:0.3757 v6:0.3011 v7:0.2845 v8:0.3204 v9:0.3411 v10:0.2776 v11:0.2886 v12:0.2845 v13:0.3301] depth_emb_rms:[v0:0.5580 v1:0.6585 v2:0.3220 v3:0.3412 v4:0.3254 v5:0.3448 v6:0.4009 v7:0.3239 v8:0.3016 v9:0.3456 v10:0.3672 v11:0.3032 v12:0.3103 v13:0.2998] +step:1800/20000 val_loss:2.3031 val_bpb:1.3640 train_time:112741ms step_avg:62.63ms +step:2000/20000 train_loss:2.3071 train_time:125239ms step_avg:62.62ms +step:2000 shared0_alpha:mean=0.471,std=0.050 shared1_alpha:mean=0.515,std=0.051 shared2_alpha:mean=0.573,std=0.048 shared3_alpha:mean=0.605,std=0.055 eff_mlp_scale:[v0:95.4208 v1:60.3451 v2:64.8785 v3:69.1411 v4:63.2601 v5:82.7459 v6:71.1168 v7:69.5602 v8:63.2601 v9:74.5171 v10:57.8084 v11:58.6652 v12:57.9544 v13:156.8431] eff_attn_scale:[v0:0.8723 v1:1.1896 v2:1.2594 v3:1.3078 v4:1.4275 v5:1.2228 v6:1.2222 v7:1.2395 v8:1.4384 v9:1.2394 v10:1.0043 v11:0.9349 v12:1.0189 v13:2.3864] eff_attn_bias:[v0:0.2983 v1:0.3522 v2:0.3384 v3:0.3950 v4:0.4392 v5:0.4060 v6:0.4502 v7:0.4143 v8:0.4088 v9:0.3591 v10:0.3025 v11:0.2431 v12:0.2527 v13:0.4171] eff_mlp_bias:[v0:0.6795 v1:0.3204 v2:0.3439 v3:0.3301 v4:0.3384 v5:0.3977 v6:0.3190 v7:0.3025 v8:0.3384 v9:0.3591 v10:0.2928 v11:0.3038 v12:0.2997 v13:0.3439] depth_emb_rms:[v0:0.6046 v1:0.7163 v2:0.3458 v3:0.3670 v4:0.3466 v5:0.3674 v6:0.4270 v7:0.3453 v8:0.3218 v9:0.3684 v10:0.3895 v11:0.3198 v12:0.3295 v13:0.3174] +step:2000/20000 val_loss:2.2886 val_bpb:1.3554 train_time:125266ms step_avg:62.63ms +step:2200/20000 train_loss:2.1332 train_time:137757ms step_avg:62.62ms +step:2200 shared0_alpha:mean=0.470,std=0.051 shared1_alpha:mean=0.518,std=0.051 shared2_alpha:mean=0.577,std=0.048 shared3_alpha:mean=0.607,std=0.056 eff_mlp_scale:[v0:100.7329 v1:63.2257 v2:67.7605 v3:72.0593 v4:65.7137 v5:86.0055 v6:73.6895 v7:72.0593 v8:65.2978 v9:77.1725 v10:60.1374 v11:60.9733 v12:60.3069 v13:163.4882] eff_attn_scale:[v0:0.8225 v1:1.1794 v2:1.2198 v3:1.2571 v4:1.3683 v5:1.1685 v6:1.1727 v7:1.1850 v8:1.3790 v9:1.1848 v10:0.9633 v11:0.9016 v12:0.9781 v13:2.2630] eff_attn_bias:[v0:0.3176 v1:0.3701 v2:0.3536 v3:0.4143 v4:0.4585 v5:0.4171 v6:0.4696 v7:0.4337 v8:0.4281 v9:0.3729 v10:0.3135 v11:0.2527 v12:0.2665 v13:0.4502] eff_mlp_bias:[v0:0.7237 v1:0.3411 v2:0.3646 v3:0.3494 v4:0.3563 v5:0.4171 v6:0.3370 v7:0.3218 v8:0.3536 v9:0.3757 v10:0.3066 v11:0.3176 v12:0.3135 v13:0.3618] depth_emb_rms:[v0:0.6512 v1:0.7673 v2:0.3699 v3:0.3907 v4:0.3698 v5:0.3897 v6:0.4498 v7:0.3646 v8:0.3426 v9:0.3895 v10:0.4111 v11:0.3369 v12:0.3462 v13:0.3335] +step:2200/20000 val_loss:2.2810 val_bpb:1.3509 train_time:137785ms step_avg:62.63ms +step:2400/20000 train_loss:2.2567 train_time:150273ms step_avg:62.61ms +step:2400 shared0_alpha:mean=0.467,std=0.051 shared1_alpha:mean=0.520,std=0.052 shared2_alpha:mean=0.580,std=0.049 shared3_alpha:mean=0.608,std=0.057 eff_mlp_scale:[v0:105.5002 v1:66.0419 v2:70.7784 v3:74.4072 v4:67.6635 v5:88.6848 v6:76.3549 v7:74.4072 v8:67.6635 v9:79.7220 v10:61.3413 v11:62.7270 v12:62.5888 v13:170.3139] eff_attn_scale:[v0:0.7755 v1:1.1510 v2:1.1752 v3:1.2043 v4:1.3166 v5:1.1135 v6:1.1242 v7:1.1391 v8:1.3218 v9:1.1456 v10:0.9044 v11:0.8531 v12:0.9300 v13:2.1861] eff_attn_bias:[v0:0.3480 v1:0.3867 v2:0.3646 v3:0.4309 v4:0.4778 v5:0.4309 v6:0.4889 v7:0.4502 v8:0.4447 v9:0.3839 v10:0.3246 v11:0.2624 v12:0.2790 v13:0.4806] eff_mlp_bias:[v0:0.7679 v1:0.3591 v2:0.3839 v3:0.3674 v4:0.3729 v5:0.4364 v6:0.3536 v7:0.3384 v8:0.3729 v9:0.3922 v10:0.3204 v11:0.3328 v12:0.3287 v13:0.3784] depth_emb_rms:[v0:0.6938 v1:0.8176 v2:0.3903 v3:0.4120 v4:0.3882 v5:0.4111 v6:0.4713 v7:0.3849 v8:0.3611 v9:0.4099 v10:0.4313 v11:0.3533 v12:0.3634 v13:0.3494] +step:2400/20000 val_loss:2.2693 val_bpb:1.3440 train_time:150301ms step_avg:62.63ms +step:2600/20000 train_loss:2.4716 train_time:162787ms step_avg:62.61ms +step:2600 shared0_alpha:mean=0.466,std=0.052 shared1_alpha:mean=0.523,std=0.053 shared2_alpha:mean=0.583,std=0.050 shared3_alpha:mean=0.610,std=0.058 eff_mlp_scale:[v0:111.3693 v1:68.9932 v2:73.0569 v3:77.2129 v4:70.5239 v5:91.9910 v6:78.2753 v7:76.7742 v8:70.0939 v9:81.9295 v10:63.0551 v11:64.4904 v12:64.9336 v13:176.0035] eff_attn_scale:[v0:0.7408 v1:1.1179 v2:1.1509 v3:1.1923 v4:1.2776 v5:1.0865 v6:1.0856 v7:1.1174 v8:1.2776 v9:1.0604 v10:0.8745 v11:0.8281 v12:0.9155 v13:2.1231] eff_attn_bias:[v0:0.3757 v1:0.4060 v2:0.3812 v3:0.4475 v4:0.4999 v5:0.4447 v6:0.5027 v7:0.4668 v8:0.4613 v9:0.3922 v10:0.3315 v11:0.2707 v12:0.2900 v13:0.5055] eff_mlp_bias:[v0:0.8065 v1:0.3812 v2:0.4060 v3:0.3867 v4:0.3950 v5:0.4530 v6:0.3674 v7:0.3536 v8:0.3895 v9:0.4088 v10:0.3315 v11:0.3453 v12:0.3397 v13:0.3950] depth_emb_rms:[v0:0.7415 v1:0.8657 v2:0.4137 v3:0.4372 v4:0.4100 v5:0.4353 v6:0.4938 v7:0.4016 v8:0.3792 v9:0.4320 v10:0.4507 v11:0.3676 v12:0.3792 v13:0.3647] +step:2600/20000 val_loss:2.2869 val_bpb:1.3545 train_time:162815ms step_avg:62.62ms +step:2800/20000 train_loss:2.2910 train_time:175299ms step_avg:62.61ms +step:2800 shared0_alpha:mean=0.464,std=0.051 shared1_alpha:mean=0.525,std=0.054 shared2_alpha:mean=0.585,std=0.050 shared3_alpha:mean=0.612,std=0.057 eff_mlp_scale:[v0:117.0871 v1:71.8433 v2:76.4383 v3:80.0945 v4:72.3902 v5:94.6584 v6:81.2985 v7:79.2045 v8:71.9541 v9:83.9790 v10:64.9504 v11:66.7454 v12:67.1572 v13:181.6837] eff_attn_scale:[v0:0.7093 v1:1.0688 v2:1.1146 v3:1.1460 v4:1.2418 v5:1.0533 v6:1.0453 v7:1.0634 v8:1.2367 v9:1.0430 v10:0.8372 v11:0.7915 v12:0.8754 v13:2.0156] eff_attn_bias:[v0:0.3977 v1:0.4226 v2:0.3895 v3:0.4640 v4:0.5193 v5:0.4558 v6:0.5193 v7:0.4778 v8:0.4778 v9:0.4033 v10:0.3411 v11:0.2790 v12:0.3011 v13:0.5331] eff_mlp_bias:[v0:0.8507 v1:0.3950 v2:0.4254 v3:0.4060 v4:0.4143 v5:0.4696 v6:0.3839 v7:0.3674 v8:0.4060 v9:0.4254 v10:0.3439 v11:0.3591 v12:0.3522 v13:0.4116] depth_emb_rms:[v0:0.7855 v1:0.9215 v2:0.4320 v3:0.4580 v4:0.4309 v5:0.4601 v6:0.5141 v7:0.4191 v8:0.3942 v9:0.4512 v10:0.4715 v11:0.3835 v12:0.3962 v13:0.3807] +step:2800/20000 val_loss:2.2570 val_bpb:1.3367 train_time:175328ms step_avg:62.62ms +step:3000/20000 train_loss:2.2778 train_time:187805ms step_avg:62.60ms +step:3000 shared0_alpha:mean=0.462,std=0.053 shared1_alpha:mean=0.527,std=0.055 shared2_alpha:mean=0.588,std=0.051 shared3_alpha:mean=0.613,std=0.058 eff_mlp_scale:[v0:122.1320 v1:74.2075 v2:78.5717 v3:82.4292 v4:74.8204 v5:97.3052 v6:83.0360 v7:81.0779 v8:73.9349 v9:86.4935 v10:66.5181 v11:68.0153 v12:69.5077 v13:187.2736] eff_attn_scale:[v0:0.6847 v1:1.0415 v2:1.0908 v3:1.1294 v4:1.2240 v5:1.0061 v6:1.0084 v7:1.0380 v8:1.2140 v9:1.0061 v10:0.8096 v11:0.7689 v12:0.8513 v13:1.9485] eff_attn_bias:[v0:0.4309 v1:0.4447 v2:0.4033 v3:0.4806 v4:0.5386 v5:0.4640 v6:0.5359 v7:0.4944 v8:0.4944 v9:0.4143 v10:0.3508 v11:0.2873 v12:0.3135 v13:0.5580] eff_mlp_bias:[v0:0.8894 v1:0.4143 v2:0.4419 v3:0.4226 v4:0.4309 v5:0.4861 v6:0.3977 v7:0.3812 v8:0.4198 v9:0.4392 v10:0.3591 v11:0.3729 v12:0.3646 v13:0.4337] depth_emb_rms:[v0:0.8377 v1:0.9719 v2:0.4541 v3:0.4801 v4:0.4517 v5:0.4816 v6:0.5349 v7:0.4369 v8:0.4118 v9:0.4705 v10:0.4902 v11:0.4000 v12:0.4133 v13:0.3966] +step:3000/20000 val_loss:2.2498 val_bpb:1.3324 train_time:187833ms step_avg:62.61ms +step:3200/20000 train_loss:2.2457 train_time:200308ms step_avg:62.60ms +step:3200 shared0_alpha:mean=0.461,std=0.052 shared1_alpha:mean=0.529,std=0.055 shared2_alpha:mean=0.590,std=0.051 shared3_alpha:mean=0.614,std=0.058 eff_mlp_scale:[v0:127.3951 v1:76.9729 v2:81.4804 v3:84.9515 v4:76.8976 v5:99.3198 v6:85.5544 v7:83.5813 v8:75.9982 v9:88.3946 v10:67.9003 v11:70.3362 v12:71.5013 v13:191.7119] eff_attn_scale:[v0:0.6622 v1:1.0377 v2:1.0601 v3:1.1232 v4:1.1892 v5:0.9974 v6:0.9785 v7:1.0224 v8:1.1743 v9:0.9874 v10:0.7867 v11:0.7584 v12:0.8275 v13:1.9033] eff_attn_bias:[v0:0.4585 v1:0.4558 v2:0.4143 v3:0.4972 v4:0.5552 v5:0.4751 v6:0.5469 v7:0.5055 v8:0.5138 v9:0.4226 v10:0.3563 v11:0.2955 v12:0.3232 v13:0.5856] eff_mlp_bias:[v0:0.9281 v1:0.4281 v2:0.4640 v3:0.4392 v4:0.4502 v5:0.5055 v6:0.4116 v7:0.3977 v8:0.4337 v9:0.4558 v10:0.3701 v11:0.3839 v12:0.3729 v13:0.4530] depth_emb_rms:[v0:0.8830 v1:1.0186 v2:0.4723 v3:0.5014 v4:0.4711 v5:0.5040 v6:0.5581 v7:0.4543 v8:0.4299 v9:0.4892 v10:0.5086 v11:0.4138 v12:0.4281 v13:0.4118] +step:3200/20000 val_loss:2.2449 val_bpb:1.3295 train_time:200336ms step_avg:62.61ms +step:3400/20000 train_loss:2.2130 train_time:212818ms step_avg:62.59ms +step:3400 shared0_alpha:mean=0.459,std=0.053 shared1_alpha:mean=0.531,std=0.055 shared2_alpha:mean=0.592,std=0.052 shared3_alpha:mean=0.615,std=0.059 eff_mlp_scale:[v0:132.3164 v1:79.6543 v2:84.3633 v3:87.8727 v4:79.4248 v5:102.3407 v6:88.0313 v7:86.0228 v8:78.0554 v9:90.7454 v10:69.6914 v11:72.1481 v12:73.9472 v13:196.3649] eff_attn_scale:[v0:0.6533 v1:1.0262 v2:1.0440 v3:1.0930 v4:1.1898 v5:0.9816 v6:0.9585 v7:0.9937 v8:1.1603 v9:0.9617 v10:0.7735 v11:0.7382 v12:0.8162 v13:1.8439] eff_attn_bias:[v0:0.4889 v1:0.4751 v2:0.4254 v3:0.5082 v4:0.5745 v5:0.4834 v6:0.5635 v7:0.5220 v8:0.5303 v9:0.4337 v10:0.3646 v11:0.3052 v12:0.3356 v13:0.6104] eff_mlp_bias:[v0:0.9667 v1:0.4475 v2:0.4834 v3:0.4558 v4:0.4640 v5:0.5220 v6:0.4254 v7:0.4116 v8:0.4530 v9:0.4723 v10:0.3812 v11:0.3977 v12:0.3812 v13:0.4696] depth_emb_rms:[v0:0.9387 v1:1.0701 v2:0.4942 v3:0.5256 v4:0.4898 v5:0.5243 v6:0.5790 v7:0.4726 v8:0.4473 v9:0.5096 v10:0.5260 v11:0.4285 v12:0.4443 v13:0.4251] +step:3400/20000 val_loss:2.2411 val_bpb:1.3273 train_time:212845ms step_avg:62.60ms +step:3600/20000 train_loss:2.1748 train_time:225320ms step_avg:62.59ms +step:3600 shared0_alpha:mean=0.457,std=0.053 shared1_alpha:mean=0.533,std=0.056 shared2_alpha:mean=0.594,std=0.052 shared3_alpha:mean=0.615,std=0.059 eff_mlp_scale:[v0:137.3354 v1:81.9493 v2:86.7331 v3:90.7998 v4:81.4154 v5:104.8544 v6:90.4436 v7:87.9916 v8:80.0277 v9:92.1293 v10:71.4273 v11:73.9504 v12:76.3270 v13:201.7179] eff_attn_scale:[v0:0.6323 v1:1.0033 v2:1.0301 v3:1.0917 v4:1.1811 v5:0.9495 v6:0.9501 v7:0.9830 v8:1.1421 v9:0.9348 v10:0.7526 v11:0.7325 v12:0.8004 v13:1.8331] eff_attn_bias:[v0:0.5165 v1:0.4889 v2:0.4309 v3:0.5248 v4:0.5939 v5:0.4972 v6:0.5800 v7:0.5359 v8:0.5441 v9:0.4419 v10:0.3729 v11:0.3135 v12:0.3466 v13:0.6325] eff_mlp_bias:[v0:0.9999 v1:0.4640 v2:0.4999 v3:0.4751 v4:0.4806 v5:0.5386 v6:0.4392 v7:0.4254 v8:0.4640 v9:0.4889 v10:0.3922 v11:0.4088 v12:0.3922 v13:0.4861] depth_emb_rms:[v0:0.9887 v1:1.1188 v2:0.5122 v3:0.5446 v4:0.5110 v5:0.5456 v6:0.6005 v7:0.4903 v8:0.4634 v9:0.5286 v10:0.5462 v11:0.4431 v12:0.4586 v13:0.4401] +step:3600/20000 val_loss:2.2333 val_bpb:1.3227 train_time:225348ms step_avg:62.60ms +step:3800/20000 train_loss:2.2758 train_time:237816ms step_avg:62.58ms +step:3800 shared0_alpha:mean=0.455,std=0.053 shared1_alpha:mean=0.535,std=0.057 shared2_alpha:mean=0.595,std=0.052 shared3_alpha:mean=0.616,std=0.060 eff_mlp_scale:[v0:143.1202 v1:84.5629 v2:89.6118 v3:93.3521 v4:84.1420 v5:107.7661 v6:92.4268 v7:90.5089 v8:82.7318 v9:94.3598 v10:72.7216 v11:75.8190 v12:78.5012 v13:206.3245] eff_attn_scale:[v0:0.6269 v1:1.0028 v2:1.0230 v3:1.0729 v4:1.1562 v5:0.9493 v6:0.9249 v7:0.9745 v8:1.1271 v9:0.9152 v10:0.7380 v11:0.7121 v12:0.7919 v13:1.7924] eff_attn_bias:[v0:0.5469 v1:0.5055 v2:0.4475 v3:0.5386 v4:0.6132 v5:0.5027 v6:0.5911 v7:0.5524 v8:0.5607 v9:0.4502 v10:0.3812 v11:0.3204 v12:0.3591 v13:0.6574] eff_mlp_bias:[v0:1.0386 v1:0.4806 v2:0.5165 v3:0.4917 v4:0.4944 v5:0.5552 v6:0.4530 v7:0.4392 v8:0.4806 v9:0.5027 v10:0.4033 v11:0.4226 v12:0.4033 v13:0.5027] depth_emb_rms:[v0:1.0427 v1:1.1664 v2:0.5338 v3:0.5663 v4:0.5321 v5:0.5646 v6:0.6217 v7:0.5074 v8:0.4810 v9:0.5477 v10:0.5631 v11:0.4569 v12:0.4747 v13:0.4553] +step:3800/20000 val_loss:2.2301 val_bpb:1.3208 train_time:237844ms step_avg:62.59ms +step:4000/20000 train_loss:2.2174 train_time:250316ms step_avg:62.58ms +step:4000 shared0_alpha:mean=0.453,std=0.054 shared1_alpha:mean=0.537,std=0.057 shared2_alpha:mean=0.598,std=0.053 shared3_alpha:mean=0.617,std=0.061 eff_mlp_scale:[v0:147.4537 v1:87.0037 v2:92.0857 v3:95.8297 v4:86.1765 v5:109.9268 v6:94.9337 v7:92.4756 v8:84.2721 v9:96.3813 v10:74.5229 v11:77.6220 v12:80.9393 v13:211.6888] eff_attn_scale:[v0:0.6304 v1:0.9980 v2:1.0020 v3:1.0551 v4:1.1576 v5:0.9253 v6:0.9143 v7:0.9487 v8:1.1138 v9:0.9011 v10:0.7296 v11:0.6988 v12:0.7733 v13:1.7277] eff_attn_bias:[v0:0.5800 v1:0.5220 v2:0.4585 v3:0.5524 v4:0.6298 v5:0.5138 v6:0.6049 v7:0.5635 v8:0.5773 v9:0.4585 v10:0.3867 v11:0.3287 v12:0.3674 v13:0.6767] eff_mlp_bias:[v0:1.0662 v1:0.4999 v2:0.5331 v3:0.5082 v4:0.5110 v5:0.5718 v6:0.4640 v7:0.4502 v8:0.4944 v9:0.5138 v10:0.4143 v11:0.4364 v12:0.4116 v13:0.5193] depth_emb_rms:[v0:1.0984 v1:1.2109 v2:0.5543 v3:0.5877 v4:0.5499 v5:0.5843 v6:0.6410 v7:0.5219 v8:0.4954 v9:0.5657 v10:0.5795 v11:0.4688 v12:0.4887 v13:0.4703] +step:4000/20000 val_loss:2.2243 val_bpb:1.3174 train_time:250344ms step_avg:62.59ms +step:4200/20000 train_loss:2.2266 train_time:262885ms step_avg:62.59ms +step:4200 shared0_alpha:mean=0.451,std=0.054 shared1_alpha:mean=0.539,std=0.058 shared2_alpha:mean=0.599,std=0.053 shared3_alpha:mean=0.617,std=0.061 eff_mlp_scale:[v0:152.3157 v1:89.5307 v2:94.5620 v3:98.4519 v4:88.8245 v5:112.1767 v6:96.9621 v7:95.0570 v8:86.4108 v9:98.4838 v10:76.3216 v11:79.5375 v12:83.0316 v13:216.0071] eff_attn_scale:[v0:0.6110 v1:0.9647 v2:1.0075 v3:1.0635 v4:1.1355 v5:0.9219 v6:0.9017 v7:0.9525 v8:1.0782 v9:0.8744 v10:0.7131 v11:0.6936 v12:0.7538 v13:1.7109] eff_attn_bias:[v0:0.6077 v1:0.5359 v2:0.4668 v3:0.5662 v4:0.6463 v5:0.5220 v6:0.6160 v7:0.5745 v8:0.5939 v9:0.4668 v10:0.3922 v11:0.3370 v12:0.3812 v13:0.6988] eff_mlp_bias:[v0:1.0993 v1:0.5165 v2:0.5497 v3:0.5248 v4:0.5248 v5:0.5856 v6:0.4806 v7:0.4640 v8:0.5110 v9:0.5276 v10:0.4254 v11:0.4475 v12:0.4171 v13:0.5359] depth_emb_rms:[v0:1.1546 v1:1.2589 v2:0.5723 v3:0.6051 v4:0.5694 v5:0.6042 v6:0.6611 v7:0.5398 v8:0.5136 v9:0.5875 v10:0.5969 v11:0.4837 v12:0.5024 v13:0.4815] +step:4200/20000 val_loss:2.2205 val_bpb:1.3151 train_time:262912ms step_avg:62.60ms +step:4400/20000 train_loss:2.1706 train_time:275384ms step_avg:62.59ms +step:4400 shared0_alpha:mean=0.450,std=0.054 shared1_alpha:mean=0.541,std=0.058 shared2_alpha:mean=0.601,std=0.053 shared3_alpha:mean=0.618,std=0.061 eff_mlp_scale:[v0:158.5937 v1:92.0819 v2:97.6452 v3:101.6867 v4:90.9773 v5:114.4371 v6:99.1026 v7:97.2656 v8:88.5317 v9:100.5982 v10:77.7275 v11:81.5459 v12:85.5969 v13:221.2793] eff_attn_scale:[v0:0.6025 v1:0.9820 v2:0.9919 v3:1.0600 v4:1.1288 v5:0.8962 v6:0.8777 v7:0.9447 v8:1.0809 v9:0.8724 v10:0.7085 v11:0.6913 v12:0.7557 v13:1.6816] eff_attn_bias:[v0:0.6408 v1:0.5580 v2:0.4778 v3:0.5800 v4:0.6629 v5:0.5359 v6:0.6270 v7:0.5856 v8:0.6077 v9:0.4751 v10:0.4005 v11:0.3439 v12:0.3895 v13:0.7182] eff_mlp_bias:[v0:1.1325 v1:0.5359 v2:0.5690 v3:0.5441 v4:0.5414 v5:0.6021 v6:0.4944 v7:0.4751 v8:0.5220 v9:0.5441 v10:0.4337 v11:0.4558 v12:0.4281 v13:0.5497] depth_emb_rms:[v0:1.2117 v1:1.3061 v2:0.5941 v3:0.6252 v4:0.5907 v5:0.6258 v6:0.6825 v7:0.5573 v8:0.5271 v9:0.6051 v10:0.6145 v11:0.4972 v12:0.5158 v13:0.4947] +step:4400/20000 val_loss:2.2211 val_bpb:1.3155 train_time:275412ms step_avg:62.59ms +step:4600/20000 train_loss:2.0328 train_time:287982ms step_avg:62.60ms +step:4600 shared0_alpha:mean=0.447,std=0.054 shared1_alpha:mean=0.543,std=0.059 shared2_alpha:mean=0.602,std=0.054 shared3_alpha:mean=0.618,std=0.062 eff_mlp_scale:[v0:163.2717 v1:95.1636 v2:100.3145 v3:104.2890 v4:93.0246 v5:117.2071 v6:101.7897 v7:99.3229 v8:90.5505 v9:102.1530 v10:79.1698 v11:83.4312 v12:87.5816 v13:225.7745] eff_attn_scale:[v0:0.6026 v1:0.9734 v2:1.0004 v3:1.0610 v4:1.1338 v5:0.8884 v6:0.8856 v7:0.9370 v8:1.0910 v9:0.8553 v10:0.6975 v11:0.6844 v12:0.7575 v13:1.6593] eff_attn_bias:[v0:0.6740 v1:0.5718 v2:0.4861 v3:0.5939 v4:0.6767 v5:0.5441 v6:0.6381 v7:0.5966 v8:0.6187 v9:0.4806 v10:0.4060 v11:0.3494 v12:0.3977 v13:0.7347] eff_mlp_bias:[v0:1.1656 v1:0.5524 v2:0.5856 v3:0.5607 v4:0.5580 v5:0.6160 v6:0.5055 v7:0.4889 v8:0.5331 v9:0.5580 v10:0.4447 v11:0.4668 v12:0.4364 v13:0.5662] depth_emb_rms:[v0:1.2670 v1:1.3527 v2:0.6128 v3:0.6487 v4:0.6102 v5:0.6470 v6:0.7015 v7:0.5712 v8:0.5424 v9:0.6214 v10:0.6321 v11:0.5108 v12:0.5306 v13:0.5071] +step:4600/20000 val_loss:2.2169 val_bpb:1.3129 train_time:288011ms step_avg:62.61ms +step:4800/20000 train_loss:2.3154 train_time:300485ms step_avg:62.60ms +step:4800 shared0_alpha:mean=0.445,std=0.054 shared1_alpha:mean=0.544,std=0.060 shared2_alpha:mean=0.604,std=0.054 shared3_alpha:mean=0.618,std=0.062 eff_mlp_scale:[v0:168.9203 v1:97.0230 v2:102.8233 v3:106.3246 v4:95.8284 v5:119.2462 v6:103.8168 v7:101.3093 v8:92.8181 v9:104.0694 v10:80.9672 v11:84.7588 v12:89.8078 v13:230.1413] eff_attn_scale:[v0:0.5949 v1:0.9596 v2:0.9924 v3:1.0478 v4:1.1378 v5:0.8847 v6:0.8701 v7:0.9289 v8:1.0855 v9:0.8332 v10:0.6888 v11:0.6726 v12:0.7474 v13:1.6303] eff_attn_bias:[v0:0.7016 v1:0.5911 v2:0.4972 v3:0.6104 v4:0.6933 v5:0.5524 v6:0.6491 v7:0.6132 v8:0.6381 v9:0.4917 v10:0.4143 v11:0.3591 v12:0.4088 v13:0.7513] eff_mlp_bias:[v0:1.1932 v1:0.5690 v2:0.6021 v3:0.5773 v4:0.5718 v5:0.6298 v6:0.5220 v7:0.5027 v8:0.5469 v9:0.5718 v10:0.4558 v11:0.4778 v12:0.4447 v13:0.5773] depth_emb_rms:[v0:1.3260 v1:1.3973 v2:0.6347 v3:0.6698 v4:0.6298 v5:0.6652 v6:0.7194 v7:0.5897 v8:0.5593 v9:0.6416 v10:0.6489 v11:0.5230 v12:0.5454 v13:0.5194] +step:4800/20000 val_loss:2.2122 val_bpb:1.3102 train_time:300513ms step_avg:62.61ms +step:5000/20000 train_loss:2.0839 train_time:312991ms step_avg:62.60ms +step:5000 shared0_alpha:mean=0.444,std=0.054 shared1_alpha:mean=0.547,std=0.060 shared2_alpha:mean=0.605,std=0.054 shared3_alpha:mean=0.619,std=0.062 eff_mlp_scale:[v0:175.2470 v1:100.2023 v2:105.7626 v3:109.6332 v4:97.8573 v5:122.1044 v6:106.2638 v7:103.5425 v8:94.8151 v9:105.6778 v10:82.2041 v11:86.7929 v12:91.7729 v13:232.2790] eff_attn_scale:[v0:0.5950 v1:0.9806 v2:0.9808 v3:1.0553 v4:1.1349 v5:0.8763 v6:0.8678 v7:0.9189 v8:1.0827 v9:0.8479 v10:0.6825 v11:0.6550 v12:0.7408 v13:1.6238] eff_attn_bias:[v0:0.7347 v1:0.6049 v2:0.5110 v3:0.6242 v4:0.7071 v5:0.5580 v6:0.6602 v7:0.6242 v8:0.6519 v9:0.4999 v10:0.4198 v11:0.3646 v12:0.4171 v13:0.7734] eff_mlp_bias:[v0:1.2209 v1:0.5883 v2:0.6215 v3:0.5939 v4:0.5856 v5:0.6436 v6:0.5331 v7:0.5138 v8:0.5607 v9:0.5856 v10:0.4668 v11:0.4861 v12:0.4530 v13:0.5911] depth_emb_rms:[v0:1.3832 v1:1.4395 v2:0.6578 v3:0.6919 v4:0.6492 v5:0.6844 v6:0.7412 v7:0.6046 v8:0.5751 v9:0.6599 v10:0.6648 v11:0.5369 v12:0.5585 v13:0.5315] +step:5000/20000 val_loss:2.2073 val_bpb:1.3073 train_time:313018ms step_avg:62.60ms +step:5200/20000 train_loss:2.2277 train_time:325497ms step_avg:62.60ms +step:5200 shared0_alpha:mean=0.441,std=0.055 shared1_alpha:mean=0.549,std=0.060 shared2_alpha:mean=0.607,std=0.055 shared3_alpha:mean=0.619,std=0.063 eff_mlp_scale:[v0:180.0119 v1:102.9014 v2:108.4747 v3:112.5330 v4:100.7819 v5:124.4775 v6:108.4747 v7:106.3668 v8:97.1825 v9:107.8805 v10:84.1439 v11:88.8959 v12:94.6116 v13:238.4944] eff_attn_scale:[v0:0.5901 v1:0.9639 v2:0.9968 v3:1.0615 v4:1.1320 v5:0.8750 v6:0.8570 v7:0.9197 v8:1.0751 v9:0.8282 v10:0.6765 v11:0.6589 v12:0.7389 v13:1.5955] eff_attn_bias:[v0:0.7679 v1:0.6187 v2:0.5220 v3:0.6325 v4:0.7237 v5:0.5718 v6:0.6740 v7:0.6353 v8:0.6629 v9:0.5082 v10:0.4281 v11:0.3729 v12:0.4254 v13:0.7900] eff_mlp_bias:[v0:1.2540 v1:0.6077 v2:0.6381 v3:0.6077 v4:0.5966 v5:0.6602 v6:0.5441 v7:0.5248 v8:0.5745 v9:0.5966 v10:0.4751 v11:0.4999 v12:0.4613 v13:0.6049] depth_emb_rms:[v0:1.4399 v1:1.4885 v2:0.6779 v3:0.7114 v4:0.6665 v5:0.7046 v6:0.7626 v7:0.6196 v8:0.5887 v9:0.6780 v10:0.6824 v11:0.5509 v12:0.5734 v13:0.5435] +step:5200/20000 val_loss:2.2091 val_bpb:1.3084 train_time:325526ms step_avg:62.60ms +step:5400/20000 train_loss:2.2396 train_time:337993ms step_avg:62.59ms +step:5400 shared0_alpha:mean=0.440,std=0.054 shared1_alpha:mean=0.551,std=0.061 shared2_alpha:mean=0.608,std=0.055 shared3_alpha:mean=0.619,std=0.063 eff_mlp_scale:[v0:185.7888 v1:105.5934 v2:111.7939 v3:115.2201 v4:103.4276 v5:127.3825 v6:110.7682 v7:108.4730 v8:99.2697 v9:110.0630 v10:85.6403 v11:90.3077 v12:96.6710 v13:241.0050] eff_attn_scale:[v0:0.5952 v1:0.9744 v2:0.9926 v3:1.0586 v4:1.1375 v5:0.8751 v6:0.8657 v7:0.9212 v8:1.0709 v9:0.8325 v10:0.6662 v11:0.6553 v12:0.7330 v13:1.5946] eff_attn_bias:[v0:0.7955 v1:0.6353 v2:0.5331 v3:0.6436 v4:0.7403 v5:0.5800 v6:0.6822 v7:0.6463 v8:0.6795 v9:0.5138 v10:0.4337 v11:0.3784 v12:0.4364 v13:0.8065] eff_mlp_bias:[v0:1.2816 v1:0.6215 v2:0.6519 v3:0.6242 v4:0.6132 v5:0.6740 v6:0.5552 v7:0.5386 v8:0.5856 v9:0.6077 v10:0.4861 v11:0.5082 v12:0.4696 v13:0.6187] depth_emb_rms:[v0:1.4954 v1:1.5339 v2:0.6992 v3:0.7306 v4:0.6869 v5:0.7230 v6:0.7804 v7:0.6356 v8:0.6043 v9:0.6947 v10:0.6995 v11:0.5633 v12:0.5864 v13:0.5542] +step:5400/20000 val_loss:2.2036 val_bpb:1.3051 train_time:338021ms step_avg:62.60ms +step:5600/20000 train_loss:2.2397 train_time:350499ms step_avg:62.59ms +step:5600 shared0_alpha:mean=0.437,std=0.054 shared1_alpha:mean=0.553,std=0.061 shared2_alpha:mean=0.609,std=0.055 shared3_alpha:mean=0.620,std=0.064 eff_mlp_scale:[v0:191.4395 v1:108.0368 v2:113.7932 v3:118.0972 v4:105.5219 v5:129.4191 v6:112.7587 v7:111.2738 v8:100.7971 v9:111.4130 v10:86.8966 v11:92.3783 v12:98.6971 v13:245.1215] eff_attn_scale:[v0:0.5859 v1:0.9708 v2:0.9895 v3:1.0668 v4:1.1324 v5:0.8671 v6:0.8585 v7:0.9118 v8:1.0708 v9:0.8341 v10:0.6778 v11:0.6519 v12:0.7344 v13:1.5592] eff_attn_bias:[v0:0.8286 v1:0.6463 v2:0.5441 v3:0.6602 v4:0.7568 v5:0.5911 v6:0.6905 v7:0.6602 v8:0.6961 v9:0.5248 v10:0.4419 v11:0.3867 v12:0.4447 v13:0.8286] eff_mlp_bias:[v0:1.3037 v1:0.6408 v2:0.6684 v3:0.6463 v4:0.6270 v5:0.6905 v6:0.5690 v7:0.5469 v8:0.5966 v9:0.6187 v10:0.4944 v11:0.5165 v12:0.4751 v13:0.6325] depth_emb_rms:[v0:1.5485 v1:1.5785 v2:0.7215 v3:0.7500 v4:0.7092 v5:0.7435 v6:0.8008 v7:0.6529 v8:0.6193 v9:0.7135 v10:0.7148 v11:0.5765 v12:0.5991 v13:0.5655] +step:5600/20000 val_loss:2.2040 val_bpb:1.3053 train_time:350527ms step_avg:62.59ms +step:5800/20000 train_loss:2.2021 train_time:362996ms step_avg:62.59ms +step:5800 shared0_alpha:mean=0.436,std=0.055 shared1_alpha:mean=0.556,std=0.062 shared2_alpha:mean=0.611,std=0.056 shared3_alpha:mean=0.620,std=0.064 eff_mlp_scale:[v0:197.5922 v1:111.3737 v2:117.2165 v3:121.4985 v4:107.7518 v5:131.2619 v6:115.1233 v7:113.0095 v8:103.5054 v9:113.0784 v10:88.9589 v11:93.9093 v12:101.3822 v13:249.5258] eff_attn_scale:[v0:0.5814 v1:0.9795 v2:0.9996 v3:1.0681 v4:1.1339 v5:0.8681 v6:0.8561 v7:0.9142 v8:1.0727 v9:0.8170 v10:0.6589 v11:0.6517 v12:0.7245 v13:1.5489] eff_attn_bias:[v0:0.8618 v1:0.6657 v2:0.5524 v3:0.6740 v4:0.7679 v5:0.5966 v6:0.7016 v7:0.6740 v8:0.7126 v9:0.5331 v10:0.4502 v11:0.3950 v12:0.4558 v13:0.8452] eff_mlp_bias:[v0:1.3369 v1:0.6574 v2:0.6850 v3:0.6574 v4:0.6381 v5:0.7071 v6:0.5856 v7:0.5607 v8:0.6132 v9:0.6298 v10:0.5055 v11:0.5248 v12:0.4834 v13:0.6436] depth_emb_rms:[v0:1.6081 v1:1.6288 v2:0.7439 v3:0.7715 v4:0.7248 v5:0.7585 v6:0.8214 v7:0.6724 v8:0.6371 v9:0.7339 v10:0.7327 v11:0.5915 v12:0.6143 v13:0.5763] +step:5800/20000 val_loss:2.2019 val_bpb:1.3041 train_time:363024ms step_avg:62.59ms +step:6000/20000 train_loss:2.2710 train_time:375503ms step_avg:62.58ms +step:6000 shared0_alpha:mean=0.433,std=0.055 shared1_alpha:mean=0.558,std=0.063 shared2_alpha:mean=0.612,std=0.056 shared3_alpha:mean=0.619,std=0.064 eff_mlp_scale:[v0:203.7929 v1:113.4618 v2:120.1349 v3:123.7735 v4:110.3890 v5:134.0913 v6:118.0179 v7:115.2005 v8:105.0303 v9:115.1810 v10:90.4981 v11:95.9111 v12:102.8869 v13:253.4706] eff_attn_scale:[v0:0.5857 v1:0.9739 v2:0.9982 v3:1.0653 v4:1.1325 v5:0.8544 v6:0.8505 v7:0.9118 v8:1.0573 v9:0.8039 v10:0.6580 v11:0.6455 v12:0.7284 v13:1.5626] eff_attn_bias:[v0:0.8949 v1:0.6795 v2:0.5607 v3:0.6878 v4:0.7844 v5:0.6021 v6:0.7126 v7:0.6822 v8:0.7237 v9:0.5386 v10:0.4558 v11:0.4005 v12:0.4668 v13:0.8618] eff_mlp_bias:[v0:1.3590 v1:0.6740 v2:0.7043 v3:0.6740 v4:0.6519 v5:0.7237 v6:0.5966 v7:0.5745 v8:0.6270 v9:0.6408 v10:0.5138 v11:0.5359 v12:0.4917 v13:0.6519] depth_emb_rms:[v0:1.6597 v1:1.6741 v2:0.7658 v3:0.7919 v4:0.7433 v5:0.7805 v6:0.8420 v7:0.6882 v8:0.6515 v9:0.7507 v10:0.7483 v11:0.6033 v12:0.6268 v13:0.5866] +step:6000/20000 val_loss:2.1975 val_bpb:1.3015 train_time:375530ms step_avg:62.59ms +step:6200/20000 train_loss:2.1443 train_time:388020ms step_avg:62.58ms +step:6200 shared0_alpha:mean=0.431,std=0.055 shared1_alpha:mean=0.560,std=0.062 shared2_alpha:mean=0.613,std=0.056 shared3_alpha:mean=0.620,std=0.064 eff_mlp_scale:[v0:209.6806 v1:115.6477 v2:122.8893 v3:126.9099 v4:112.8256 v5:136.4643 v6:120.2178 v7:118.2323 v8:107.9438 v9:117.3824 v10:92.4341 v11:98.1654 v12:105.7740 v13:257.6370] eff_attn_scale:[v0:0.5755 v1:0.9676 v2:1.0033 v3:1.0681 v4:1.1378 v5:0.8611 v6:0.8472 v7:0.9006 v8:1.0670 v9:0.8055 v10:0.6510 v11:0.6472 v12:0.7365 v13:1.5257] eff_attn_bias:[v0:0.9281 v1:0.6988 v2:0.5690 v3:0.7043 v4:0.8010 v5:0.6132 v6:0.7237 v7:0.6905 v8:0.7347 v9:0.5441 v10:0.4613 v11:0.4088 v12:0.4751 v13:0.8784] eff_mlp_bias:[v0:1.3921 v1:0.6905 v2:0.7237 v3:0.6878 v4:0.6684 v5:0.7347 v6:0.6077 v7:0.5856 v8:0.6381 v9:0.6519 v10:0.5248 v11:0.5441 v12:0.4999 v13:0.6629] depth_emb_rms:[v0:1.7115 v1:1.7211 v2:0.7882 v3:0.8140 v4:0.7647 v5:0.8025 v6:0.8583 v7:0.7030 v8:0.6659 v9:0.7670 v10:0.7648 v11:0.6158 v12:0.6402 v13:0.5967] +step:6200/20000 val_loss:2.1970 val_bpb:1.3012 train_time:388048ms step_avg:62.59ms +step:6400/20000 train_loss:2.2179 train_time:400513ms step_avg:62.58ms +step:6400 shared0_alpha:mean=0.429,std=0.055 shared1_alpha:mean=0.563,std=0.063 shared2_alpha:mean=0.615,std=0.057 shared3_alpha:mean=0.620,std=0.065 eff_mlp_scale:[v0:216.0285 v1:118.9685 v2:126.4943 v3:130.3977 v4:115.7399 v5:138.7966 v6:122.7103 v7:119.9878 v8:109.7060 v9:118.9685 v10:94.0599 v11:100.2637 v12:108.0605 v13:259.8743] eff_attn_scale:[v0:0.5912 v1:0.9818 v2:0.9974 v3:1.0673 v4:1.1582 v5:0.8696 v6:0.8447 v7:0.8961 v8:1.0688 v9:0.7994 v10:0.6605 v11:0.6440 v12:0.7298 v13:1.5294] eff_attn_bias:[v0:0.9502 v1:0.7126 v2:0.5800 v3:0.7182 v4:0.8176 v5:0.6242 v6:0.7292 v7:0.7043 v8:0.7513 v9:0.5524 v10:0.4723 v11:0.4171 v12:0.4861 v13:0.8949] eff_mlp_bias:[v0:1.4253 v1:0.7043 v2:0.7403 v3:0.7043 v4:0.6822 v5:0.7513 v6:0.6215 v7:0.5966 v8:0.6546 v9:0.6657 v10:0.5359 v11:0.5524 v12:0.5082 v13:0.6740] depth_emb_rms:[v0:1.7647 v1:1.7697 v2:0.8103 v3:0.8333 v4:0.7828 v5:0.8215 v6:0.8805 v7:0.7209 v8:0.6817 v9:0.7868 v10:0.7846 v11:0.6302 v12:0.6545 v13:0.6068] +step:6400/20000 val_loss:2.1941 val_bpb:1.2995 train_time:400541ms step_avg:62.58ms +step:6600/20000 train_loss:2.1835 train_time:413011ms step_avg:62.58ms +step:6600 shared0_alpha:mean=0.426,std=0.055 shared1_alpha:mean=0.565,std=0.064 shared2_alpha:mean=0.616,std=0.057 shared3_alpha:mean=0.620,std=0.064 eff_mlp_scale:[v0:220.1345 v1:121.1145 v2:128.7675 v3:133.4154 v4:118.1146 v5:141.1043 v6:124.9481 v7:122.3436 v8:112.0148 v9:121.1145 v10:95.4843 v11:101.8607 v12:110.3512 v13:264.0340] eff_attn_scale:[v0:0.5849 v1:0.9785 v2:1.0001 v3:1.0848 v4:1.1516 v5:0.8620 v6:0.8483 v7:0.9070 v8:1.0808 v9:0.8014 v10:0.6519 v11:0.6518 v12:0.7221 v13:1.5002] eff_attn_bias:[v0:0.9833 v1:0.7347 v2:0.5939 v3:0.7347 v4:0.8342 v5:0.6325 v6:0.7458 v7:0.7182 v8:0.7623 v9:0.5580 v10:0.4778 v11:0.4198 v12:0.4944 v13:0.9060] eff_mlp_bias:[v0:1.4474 v1:0.7292 v2:0.7568 v3:0.7182 v4:0.6905 v5:0.7679 v6:0.6325 v7:0.6104 v8:0.6657 v9:0.6740 v10:0.5441 v11:0.5607 v12:0.5138 v13:0.6850] depth_emb_rms:[v0:1.8148 v1:1.8170 v2:0.8340 v3:0.8525 v4:0.8022 v5:0.8381 v6:0.8985 v7:0.7366 v8:0.6937 v9:0.8028 v10:0.8006 v11:0.6423 v12:0.6673 v13:0.6168] +step:6600/20000 val_loss:2.1908 val_bpb:1.2975 train_time:413039ms step_avg:62.58ms +step:6800/20000 train_loss:2.2509 train_time:425515ms step_avg:62.58ms +step:6800 shared0_alpha:mean=0.424,std=0.055 shared1_alpha:mean=0.567,std=0.064 shared2_alpha:mean=0.617,std=0.058 shared3_alpha:mean=0.619,std=0.065 eff_mlp_scale:[v0:226.6244 v1:122.9771 v2:131.5042 v3:135.8294 v4:120.4627 v5:143.0792 v6:127.1023 v7:124.6500 v8:114.2994 v9:122.3859 v10:96.8399 v11:103.9681 v12:112.6186 v13:268.2435] eff_attn_scale:[v0:0.5770 v1:0.9725 v2:1.0062 v3:1.0841 v4:1.1490 v5:0.8526 v6:0.8497 v7:0.9034 v8:1.0639 v9:0.7743 v10:0.6440 v11:0.6414 v12:0.7235 v13:1.4641] eff_attn_bias:[v0:1.0165 v1:0.7458 v2:0.6049 v3:0.7458 v4:0.8507 v5:0.6463 v6:0.7568 v7:0.7292 v8:0.7734 v9:0.5662 v10:0.4834 v11:0.4281 v12:0.5082 v13:0.9226] eff_mlp_bias:[v0:1.4695 v1:0.7403 v2:0.7734 v3:0.7347 v4:0.7016 v5:0.7844 v6:0.6463 v7:0.6187 v8:0.6795 v9:0.6878 v10:0.5580 v11:0.5690 v12:0.5220 v13:0.6961] depth_emb_rms:[v0:1.8648 v1:1.8608 v2:0.8544 v3:0.8729 v4:0.8200 v5:0.8557 v6:0.9186 v7:0.7524 v8:0.7074 v9:0.8207 v10:0.8180 v11:0.6565 v12:0.6806 v13:0.6270] +step:6800/20000 val_loss:2.1890 val_bpb:1.2965 train_time:425543ms step_avg:62.58ms +step:7000/20000 train_loss:2.2846 train_time:438031ms step_avg:62.58ms +step:7000 shared0_alpha:mean=0.423,std=0.055 shared1_alpha:mean=0.569,std=0.064 shared2_alpha:mean=0.618,std=0.057 shared3_alpha:mean=0.620,std=0.065 eff_mlp_scale:[v0:232.6423 v1:125.5857 v2:134.9927 v3:139.0073 v4:123.4228 v5:145.2270 v6:129.4375 v7:127.1408 v8:116.6289 v9:123.8001 v10:98.3280 v11:105.6682 v12:114.9304 v13:272.4831] eff_attn_scale:[v0:0.5800 v1:0.9999 v2:1.0102 v3:1.1039 v4:1.1746 v5:0.8550 v6:0.8456 v7:0.9154 v8:1.0752 v9:0.7990 v10:0.6453 v11:0.6486 v12:0.7247 v13:1.4643] eff_attn_bias:[v0:1.0496 v1:0.7623 v2:0.6160 v3:0.7568 v4:0.8618 v5:0.6546 v6:0.7623 v7:0.7403 v8:0.7900 v9:0.5718 v10:0.4917 v11:0.4364 v12:0.5165 v13:0.9336] eff_mlp_bias:[v0:1.5026 v1:0.7568 v2:0.7900 v3:0.7513 v4:0.7126 v5:0.7955 v6:0.6574 v7:0.6298 v8:0.6878 v9:0.7016 v10:0.5662 v11:0.5773 v12:0.5276 v13:0.7043] depth_emb_rms:[v0:1.9140 v1:1.9098 v2:0.8759 v3:0.8919 v4:0.8389 v5:0.8739 v6:0.9394 v7:0.7699 v8:0.7219 v9:0.8362 v10:0.8351 v11:0.6700 v12:0.6928 v13:0.6362] +step:7000/20000 val_loss:2.1872 val_bpb:1.2954 train_time:438058ms step_avg:62.58ms +step:7200/20000 train_loss:2.2597 train_time:450528ms step_avg:62.57ms +step:7200 shared0_alpha:mean=0.420,std=0.054 shared1_alpha:mean=0.571,std=0.065 shared2_alpha:mean=0.619,std=0.058 shared3_alpha:mean=0.620,std=0.066 eff_mlp_scale:[v0:239.2463 v1:128.4012 v2:137.9497 v3:142.0183 v4:126.4626 v5:147.6014 v6:131.7813 v7:129.4705 v8:119.0236 v9:125.4012 v10:99.8173 v11:107.7970 v12:117.3069 v13:274.8089] eff_attn_scale:[v0:0.5763 v1:0.9879 v2:1.0198 v3:1.1097 v4:1.1805 v5:0.8501 v6:0.8581 v7:0.9179 v8:1.0853 v9:0.7766 v10:0.6559 v11:0.6530 v12:0.7378 v13:1.4582] eff_attn_bias:[v0:1.0772 v1:0.7789 v2:0.6242 v3:0.7734 v4:0.8784 v5:0.6629 v6:0.7734 v7:0.7458 v8:0.7955 v9:0.5800 v10:0.4972 v11:0.4419 v12:0.5276 v13:0.9447] eff_mlp_bias:[v0:1.5247 v1:0.7734 v2:0.8065 v3:0.7679 v4:0.7237 v5:0.8121 v6:0.6657 v7:0.6408 v8:0.7016 v9:0.7126 v10:0.5745 v11:0.5856 v12:0.5359 v13:0.7182] depth_emb_rms:[v0:1.9638 v1:1.9481 v2:0.8949 v3:0.9129 v4:0.8601 v5:0.8949 v6:0.9611 v7:0.7834 v8:0.7359 v9:0.8542 v10:0.8508 v11:0.6825 v12:0.7060 v13:0.6455] +step:7200/20000 val_loss:2.1884 val_bpb:1.2961 train_time:450556ms step_avg:62.58ms +step:7400/20000 train_loss:2.1758 train_time:463025ms step_avg:62.57ms +step:7400 shared0_alpha:mean=0.418,std=0.054 shared1_alpha:mean=0.574,std=0.065 shared2_alpha:mean=0.620,std=0.058 shared3_alpha:mean=0.620,std=0.066 eff_mlp_scale:[v0:245.6747 v1:130.5944 v2:140.7573 v3:145.2399 v4:128.8931 v5:149.9417 v6:133.9739 v7:131.9839 v8:121.3792 v9:126.9668 v10:101.7523 v11:109.5063 v12:119.6452 v13:278.7200] eff_attn_scale:[v0:0.5664 v1:0.9734 v2:1.0129 v3:1.1182 v4:1.1687 v5:0.8409 v6:0.8485 v7:0.9174 v8:1.0659 v9:0.7769 v10:0.6442 v11:0.6481 v12:0.7153 v13:1.4477] eff_attn_bias:[v0:1.1159 v1:0.7955 v2:0.6381 v3:0.7900 v4:0.8949 v5:0.6712 v6:0.7844 v7:0.7568 v8:0.8121 v9:0.5856 v10:0.5027 v11:0.4475 v12:0.5359 v13:0.9612] eff_mlp_bias:[v0:1.5578 v1:0.7955 v2:0.8176 v3:0.7789 v4:0.7347 v5:0.8231 v6:0.6795 v7:0.6519 v8:0.7126 v9:0.7237 v10:0.5856 v11:0.5939 v12:0.5441 v13:0.7237] depth_emb_rms:[v0:2.0112 v1:1.9994 v2:0.9204 v3:0.9329 v4:0.8782 v5:0.9100 v6:0.9798 v7:0.8012 v8:0.7511 v9:0.8726 v10:0.8668 v11:0.6964 v12:0.7192 v13:0.6558] +step:7400/20000 val_loss:2.1843 val_bpb:1.2936 train_time:463052ms step_avg:62.57ms +step:7600/20000 train_loss:2.0532 train_time:475517ms step_avg:62.57ms +step:7600 shared0_alpha:mean=0.416,std=0.055 shared1_alpha:mean=0.576,std=0.066 shared2_alpha:mean=0.621,std=0.058 shared3_alpha:mean=0.619,std=0.066 eff_mlp_scale:[v0:251.9731 v1:133.4637 v2:143.8218 v3:148.2004 v4:131.2767 v5:152.3558 v6:136.4024 v7:134.2522 v8:123.1084 v9:129.1977 v10:103.3006 v11:111.5862 v12:121.3580 v13:282.8690] eff_attn_scale:[v0:0.5705 v1:0.9969 v2:1.0278 v3:1.1022 v4:1.1912 v5:0.8492 v6:0.8401 v7:0.9079 v8:1.0816 v9:0.7846 v10:0.6480 v11:0.6414 v12:0.7290 v13:1.4345] eff_attn_bias:[v0:1.1435 v1:0.8121 v2:0.6519 v3:0.8010 v4:0.9060 v5:0.6850 v6:0.7900 v7:0.7623 v8:0.8231 v9:0.5911 v10:0.5082 v11:0.4558 v12:0.5441 v13:0.9778] eff_mlp_bias:[v0:1.5799 v1:0.8121 v2:0.8342 v3:0.7955 v4:0.7513 v5:0.8342 v6:0.6878 v7:0.6602 v8:0.7182 v9:0.7347 v10:0.5939 v11:0.6021 v12:0.5552 v13:0.7347] depth_emb_rms:[v0:2.0587 v1:2.0476 v2:0.9410 v3:0.9539 v4:0.8977 v5:0.9341 v6:1.0021 v7:0.8167 v8:0.7654 v9:0.8871 v10:0.8834 v11:0.7085 v12:0.7316 v13:0.6644] +step:7600/20000 val_loss:2.1833 val_bpb:1.2931 train_time:475544ms step_avg:62.57ms +step:7800/20000 train_loss:2.2018 train_time:488012ms step_avg:62.57ms +step:7800 shared0_alpha:mean=0.414,std=0.055 shared1_alpha:mean=0.578,std=0.066 shared2_alpha:mean=0.622,std=0.059 shared3_alpha:mean=0.619,std=0.066 eff_mlp_scale:[v0:258.5976 v1:135.4757 v2:146.7667 v3:151.6017 v4:134.4523 v5:154.4791 v6:139.2844 v7:136.9116 v8:126.1965 v9:130.5716 v10:104.7511 v11:113.4075 v12:123.8376 v13:285.0793] eff_attn_scale:[v0:0.5705 v1:0.9883 v2:1.0291 v3:1.1107 v4:1.2098 v5:0.8497 v6:0.8412 v7:0.9158 v8:1.0907 v9:0.7758 v10:0.6443 v11:0.6483 v12:0.7287 v13:1.4480] eff_attn_bias:[v0:1.1767 v1:0.8286 v2:0.6602 v3:0.8176 v4:0.9226 v5:0.6933 v6:0.8010 v7:0.7789 v8:0.8342 v9:0.6021 v10:0.5165 v11:0.4640 v12:0.5552 v13:0.9888] eff_mlp_bias:[v0:1.6020 v1:0.8286 v2:0.8507 v3:0.8065 v4:0.7623 v5:0.8507 v6:0.6988 v7:0.6712 v8:0.7347 v9:0.7458 v10:0.6049 v11:0.6104 v12:0.5607 v13:0.7458] depth_emb_rms:[v0:2.1065 v1:2.0928 v2:0.9658 v3:0.9752 v4:0.9171 v5:0.9554 v6:1.0232 v7:0.8345 v8:0.7802 v9:0.9069 v10:0.9005 v11:0.7207 v12:0.7464 v13:0.6741] +step:7800/20000 val_loss:2.1814 val_bpb:1.2920 train_time:488040ms step_avg:62.57ms +step:8000/20000 train_loss:2.1666 train_time:500515ms step_avg:62.56ms +step:8000 shared0_alpha:mean=0.412,std=0.054 shared1_alpha:mean=0.581,std=0.066 shared2_alpha:mean=0.623,std=0.059 shared3_alpha:mean=0.620,std=0.067 eff_mlp_scale:[v0:264.2507 v1:139.0632 v2:150.2956 v3:154.5905 v4:137.3695 v5:156.9869 v6:142.1400 v7:140.3206 v8:127.8548 v9:132.2646 v10:107.1876 v11:115.9429 v12:126.0707 v13:288.9572] eff_attn_scale:[v0:0.5661 v1:0.9901 v2:1.0399 v3:1.1315 v4:1.2092 v5:0.8473 v6:0.8554 v7:0.9299 v8:1.0807 v9:0.7782 v10:0.6393 v11:0.6551 v12:0.7331 v13:1.4277] eff_attn_bias:[v0:1.2098 v1:0.8452 v2:0.6684 v3:0.8286 v4:0.9336 v5:0.7071 v6:0.8121 v7:0.7844 v8:0.8452 v9:0.6049 v10:0.5220 v11:0.4696 v12:0.5662 v13:1.0054] eff_mlp_bias:[v0:1.6241 v1:0.8397 v2:0.8618 v3:0.8231 v4:0.7734 v5:0.8618 v6:0.7126 v7:0.6822 v8:0.7458 v9:0.7568 v10:0.6104 v11:0.6160 v12:0.5662 v13:0.7568] depth_emb_rms:[v0:2.1502 v1:2.1357 v2:0.9814 v3:0.9941 v4:0.9363 v5:0.9719 v6:1.0441 v7:0.8517 v8:0.7953 v9:0.9248 v10:0.9162 v11:0.7306 v12:0.7568 v13:0.6839] +step:8000/20000 val_loss:2.1783 val_bpb:1.2901 train_time:500542ms step_avg:62.57ms +step:8200/20000 train_loss:2.2367 train_time:513008ms step_avg:62.56ms +step:8200 shared0_alpha:mean=0.410,std=0.054 shared1_alpha:mean=0.583,std=0.067 shared2_alpha:mean=0.624,std=0.059 shared3_alpha:mean=0.619,std=0.067 eff_mlp_scale:[v0:271.2469 v1:141.2655 v2:152.5379 v3:157.0669 v4:140.1793 v5:159.3126 v6:143.7376 v7:142.0796 v8:130.5533 v9:133.7977 v10:108.5366 v11:117.5005 v12:128.7484 v13:293.3033] eff_attn_scale:[v0:0.5698 v1:1.0034 v2:1.0381 v3:1.1392 v4:1.2115 v5:0.8515 v6:0.8538 v7:0.9287 v8:1.1066 v9:0.7825 v10:0.6471 v11:0.6496 v12:0.7298 v13:1.4365] eff_attn_bias:[v0:1.2430 v1:0.8563 v2:0.6795 v3:0.8452 v4:0.9447 v5:0.7182 v6:0.8176 v7:0.8010 v8:0.8563 v9:0.6132 v10:0.5303 v11:0.4778 v12:0.5745 v13:1.0165] eff_mlp_bias:[v0:1.6462 v1:0.8563 v2:0.8784 v3:0.8397 v4:0.7844 v5:0.8728 v6:0.7237 v7:0.6905 v8:0.7568 v9:0.7679 v10:0.6187 v11:0.6242 v12:0.5745 v13:0.7623] depth_emb_rms:[v0:2.1939 v1:2.1835 v2:1.0049 v3:1.0144 v4:0.9547 v5:0.9885 v6:1.0638 v7:0.8653 v8:0.8067 v9:0.9411 v10:0.9323 v11:0.7432 v12:0.7701 v13:0.6923] +step:8200/20000 val_loss:2.1771 val_bpb:1.2894 train_time:513035ms step_avg:62.57ms +step:8400/20000 train_loss:2.1849 train_time:525569ms step_avg:62.57ms +step:8400 shared0_alpha:mean=0.407,std=0.054 shared1_alpha:mean=0.586,std=0.068 shared2_alpha:mean=0.625,std=0.059 shared3_alpha:mean=0.619,std=0.067 eff_mlp_scale:[v0:278.4039 v1:143.5037 v2:156.2263 v3:160.9950 v4:143.2888 v5:161.6766 v6:146.1662 v7:144.6534 v8:132.9672 v9:135.9838 v10:110.0685 v11:119.8384 v12:131.1457 v13:295.3743] eff_attn_scale:[v0:0.5653 v1:1.0246 v2:1.0474 v3:1.1578 v4:1.2199 v5:0.8630 v6:0.8541 v7:0.9447 v8:1.1056 v9:0.7661 v10:0.6473 v11:0.6622 v12:0.7291 v13:1.4306] eff_attn_bias:[v0:1.2761 v1:0.8728 v2:0.6933 v3:0.8563 v4:0.9612 v5:0.7292 v6:0.8231 v7:0.8065 v8:0.8728 v9:0.6187 v10:0.5359 v11:0.4834 v12:0.5883 v13:1.0275] eff_mlp_bias:[v0:1.6683 v1:0.8728 v2:0.8949 v3:0.8507 v4:0.7955 v5:0.8894 v6:0.7347 v7:0.7016 v8:0.7679 v9:0.7789 v10:0.6270 v11:0.6353 v12:0.5800 v13:0.7734] depth_emb_rms:[v0:2.2366 v1:2.2287 v2:1.0298 v3:1.0380 v4:0.9752 v5:1.0105 v6:1.0839 v7:0.8810 v8:0.8212 v9:0.9569 v10:0.9492 v11:0.7564 v12:0.7840 v13:0.7006] +step:8400/20000 val_loss:2.1755 val_bpb:1.2885 train_time:525596ms step_avg:62.57ms +step:8600/20000 train_loss:2.1818 train_time:538071ms step_avg:62.57ms +step:8600 shared0_alpha:mean=0.406,std=0.054 shared1_alpha:mean=0.588,std=0.068 shared2_alpha:mean=0.626,std=0.059 shared3_alpha:mean=0.618,std=0.067 eff_mlp_scale:[v0:282.9090 v1:146.2481 v2:158.7466 v3:165.1429 v4:145.7237 v5:162.6379 v6:148.6012 v7:147.4053 v8:134.7026 v9:137.4227 v10:111.6001 v11:121.7164 v12:132.8657 v13:299.3512] eff_attn_scale:[v0:0.5703 v1:1.0259 v2:1.0501 v3:1.1511 v4:1.2382 v5:0.8549 v6:0.8645 v7:0.9347 v8:1.1124 v9:0.7809 v10:0.6472 v11:0.6538 v12:0.7448 v13:1.4479] eff_attn_bias:[v0:1.2982 v1:0.8894 v2:0.7043 v3:0.8728 v4:0.9723 v5:0.7347 v6:0.8342 v7:0.8176 v8:0.8784 v9:0.6270 v10:0.5414 v11:0.4917 v12:0.5966 v13:1.0386] eff_mlp_bias:[v0:1.6904 v1:0.8894 v2:0.9060 v3:0.8618 v4:0.8065 v5:0.9005 v6:0.7458 v7:0.7126 v8:0.7789 v9:0.7900 v10:0.6353 v11:0.6408 v12:0.5883 v13:0.7844] depth_emb_rms:[v0:2.2754 v1:2.2711 v2:1.0474 v3:1.0558 v4:0.9904 v5:1.0256 v6:1.1031 v7:0.8965 v8:0.8350 v9:0.9738 v10:0.9660 v11:0.7701 v12:0.7961 v13:0.7110] +step:8600/20000 val_loss:2.1656 val_bpb:1.2826 train_time:538099ms step_avg:62.57ms +step:8800/20000 train_loss:2.1448 train_time:550568ms step_avg:62.56ms +step:8800 shared0_alpha:mean=0.403,std=0.054 shared1_alpha:mean=0.589,std=0.068 shared2_alpha:mean=0.626,std=0.059 shared3_alpha:mean=0.618,std=0.067 eff_mlp_scale:[v0:287.0613 v1:147.6006 v2:161.2275 v3:166.2868 v4:147.6644 v5:164.7045 v6:150.3987 v7:149.0423 v8:137.1610 v9:138.0984 v10:113.0999 v11:123.1754 v12:135.3075 v13:301.7915] eff_attn_scale:[v0:0.5667 v1:1.0330 v2:1.0487 v3:1.1666 v4:1.2449 v5:0.8748 v6:0.8552 v7:0.9351 v8:1.1195 v9:0.7818 v10:0.6436 v11:0.6574 v12:0.7479 v13:1.4613] eff_attn_bias:[v0:1.3203 v1:0.9005 v2:0.7126 v3:0.8784 v4:0.9833 v5:0.7458 v6:0.8397 v7:0.8231 v8:0.8894 v9:0.6325 v10:0.5469 v11:0.4999 v12:0.6021 v13:1.0441] eff_mlp_bias:[v0:1.7015 v1:0.8949 v2:0.9170 v3:0.8728 v4:0.8121 v5:0.9060 v6:0.7513 v7:0.7182 v8:0.7844 v9:0.7955 v10:0.6408 v11:0.6463 v12:0.5911 v13:0.7900] depth_emb_rms:[v0:2.3014 v1:2.3008 v2:1.0632 v3:1.0694 v4:1.0057 v5:1.0393 v6:1.1214 v7:0.9083 v8:0.8475 v9:0.9886 v10:0.9793 v11:0.7798 v12:0.8064 v13:0.7176] +step:8800/20000 val_loss:2.1567 val_bpb:1.2773 train_time:550595ms step_avg:62.57ms +step:9000/20000 train_loss:2.0542 train_time:563066ms step_avg:62.56ms +step:9000 shared0_alpha:mean=0.402,std=0.054 shared1_alpha:mean=0.590,std=0.069 shared2_alpha:mean=0.626,std=0.059 shared3_alpha:mean=0.617,std=0.067 eff_mlp_scale:[v0:288.7072 v1:148.9598 v2:162.0365 v3:167.4318 v4:149.7376 v5:165.5109 v6:151.7581 v7:150.0685 v8:138.5073 v9:138.7745 v10:113.6674 v11:124.0235 v12:136.6356 v13:306.3378] eff_attn_scale:[v0:0.5659 v1:1.0453 v2:1.0571 v3:1.1683 v4:1.2781 v5:0.8892 v6:0.8711 v7:0.9580 v8:1.1183 v9:0.7899 v10:0.6443 v11:0.6636 v12:0.7504 v13:1.4916] eff_attn_bias:[v0:1.3313 v1:0.9060 v2:0.7182 v3:0.8839 v4:0.9888 v5:0.7513 v6:0.8397 v7:0.8286 v8:0.8949 v9:0.6353 v10:0.5497 v11:0.4999 v12:0.6077 v13:1.0496] eff_mlp_bias:[v0:1.7125 v1:0.9060 v2:0.9226 v3:0.8784 v4:0.8176 v5:0.9115 v6:0.7568 v7:0.7182 v8:0.7900 v9:0.8065 v10:0.6463 v11:0.6491 v12:0.5939 v13:0.7955] depth_emb_rms:[v0:2.3175 v1:2.3198 v2:1.0739 v3:1.0821 v4:1.0179 v5:1.0518 v6:1.1315 v7:0.9176 v8:0.8554 v9:0.9992 v10:0.9903 v11:0.7874 v12:0.8134 v13:0.7231] +step:9000/20000 val_loss:2.1480 val_bpb:1.2722 train_time:563094ms step_avg:62.57ms +step:9200/20000 train_loss:2.1100 train_time:575559ms step_avg:62.56ms +step:9200 shared0_alpha:mean=0.402,std=0.054 shared1_alpha:mean=0.590,std=0.068 shared2_alpha:mean=0.626,std=0.059 shared3_alpha:mean=0.617,std=0.067 eff_mlp_scale:[v0:289.1200 v1:149.5231 v2:163.9326 v3:168.3982 v4:151.4951 v5:166.1368 v6:152.3966 v7:150.9347 v8:139.5515 v9:139.2993 v10:114.1457 v11:125.3631 v12:137.6656 v13:308.3846] eff_attn_scale:[v0:0.5701 v1:1.0448 v2:1.0659 v3:1.1788 v4:1.2762 v5:0.8928 v6:0.8700 v7:0.9496 v8:1.1349 v9:0.7931 v10:0.6514 v11:0.6642 v12:0.7550 v13:1.4971] eff_attn_bias:[v0:1.3313 v1:0.9115 v2:0.7182 v3:0.8839 v4:0.9944 v5:0.7568 v6:0.8452 v7:0.8342 v8:0.8949 v9:0.6381 v10:0.5497 v11:0.4999 v12:0.6104 v13:1.0496] eff_mlp_bias:[v0:1.7125 v1:0.9060 v2:0.9281 v3:0.8839 v4:0.8231 v5:0.9170 v6:0.7623 v7:0.7237 v8:0.7955 v9:0.8065 v10:0.6491 v11:0.6519 v12:0.5966 v13:0.8010] depth_emb_rms:[v0:2.3253 v1:2.3284 v2:1.0779 v3:1.0906 v4:1.0235 v5:1.0597 v6:1.1400 v7:0.9257 v8:0.8611 v9:1.0093 v10:0.9971 v11:0.7926 v12:0.8184 v13:0.7257] +step:9200/20000 val_loss:2.1378 val_bpb:1.2661 train_time:575587ms step_avg:62.56ms +step:9400/20000 train_loss:2.1479 train_time:588055ms step_avg:62.56ms +step:9400 shared0_alpha:mean=0.401,std=0.054 shared1_alpha:mean=0.589,std=0.069 shared2_alpha:mean=0.625,std=0.059 shared3_alpha:mean=0.616,std=0.067 eff_mlp_scale:[v0:288.9544 v1:149.7192 v2:164.2695 v3:168.6555 v4:152.1174 v5:166.3547 v6:152.7098 v7:151.1653 v8:140.1248 v9:139.4820 v10:114.3802 v11:125.5546 v12:138.2312 v13:309.5997] eff_attn_scale:[v0:0.5781 v1:1.0362 v2:1.0576 v3:1.1804 v4:1.2722 v5:0.8889 v6:0.8662 v7:0.9594 v8:1.1401 v9:0.7938 v10:0.6473 v11:0.6631 v12:0.7584 v13:1.4976] eff_attn_bias:[v0:1.3313 v1:0.9115 v2:0.7237 v3:0.8839 v4:0.9944 v5:0.7568 v6:0.8452 v7:0.8286 v8:0.8949 v9:0.6381 v10:0.5524 v11:0.4999 v12:0.6104 v13:1.0496] eff_mlp_bias:[v0:1.7125 v1:0.9060 v2:0.9281 v3:0.8839 v4:0.8231 v5:0.9170 v6:0.7623 v7:0.7237 v8:0.7955 v9:0.8065 v10:0.6491 v11:0.6519 v12:0.5966 v13:0.8010] depth_emb_rms:[v0:2.3291 v1:2.3318 v2:1.0790 v3:1.0928 v4:1.0269 v5:1.0627 v6:1.1446 v7:0.9279 v8:0.8637 v9:1.0117 v10:0.9998 v11:0.7941 v12:0.8207 v13:0.7295] +step:9400/20000 val_loss:2.1284 val_bpb:1.2606 train_time:588083ms step_avg:62.56ms +step:9592/20000 val_loss:2.1217 val_bpb:1.2566 train_time:600079ms step_avg:62.56ms +stopping_early: wallclock_cap train_time:600079ms step:9592/20000 +peak memory allocated: 15080 MiB reserved: 15356 MiB +Serialized model: 45237245 bytes +Code size: 63793 bytes +Total submission size: 45301038 bytes +Serialized model int8+zlib: 10765827 bytes (payload:11696320 raw_torch:11728369 payload_ratio:3.87x) +Total submission size int8+zlib: 10829620 bytes +final_int8_zlib_roundtrip val_loss:2.1340 val_bpb:1.2639 eval_time:2121ms +final_int8_zlib_roundtrip_exact val_loss:2.13400958 val_bpb:1.26388068 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_S.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_S.txt new file mode 100644 index 0000000000..5708479602 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s4_S.txt @@ -0,0 +1,1700 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) + use_depth_embed = bool(int(os.environ.get("USE_DEPTH_EMBED", "0"))) + use_unique_norms = bool(int(os.environ.get("USE_UNIQUE_NORMS", "0"))) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta,depth_embed,unique_attn_gain,unique_mlp_gain", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0, use_bias: bool = False): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + if use_bias: + self.attn_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + self.mlp_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + else: + self.attn_beta = None + self.mlp_beta = None + + def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + ab = self.attn_beta[v] if self.attn_beta is not None else None + mb = self.mlp_beta[v] if self.mlp_beta is not None else None + return ag, mg, ab, mb + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None, + ts_mlp_beta: Tensor | None = None, + depth_emb: Tensor | None = None, + ext_attn_gain: Tensor | None = None, + ext_mlp_gain: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + if depth_emb is not None: + x = x + depth_emb + attn_normed = self.attn_norm(x) + if ext_attn_gain is not None: + attn_normed = attn_normed * ext_attn_gain[None, None, :] + attn_out = self.attn(attn_normed) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + if self.use_peri_norm: + if ext_mlp_gain is not None: + m_input = F.rms_norm(x, (x.size(-1),)) * ext_mlp_gain[None, None, :] + else: + m_input = x + mlp_out = self.mlp_out_norm(self.mlp(m_input)) + else: + mlp_normed = self.mlp_norm(x) + if ext_mlp_gain is not None: + mlp_normed = mlp_normed * ext_mlp_gain[None, None, :] + mlp_out = self.mlp(mlp_normed) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + if ts_mlp_beta is not None: + x = x + ts_mlp_beta[None, None, :] + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + use_timestep_bias: bool = False, + use_depth_embed: bool = False, + use_unique_norms: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None + self.use_depth_embed = use_depth_embed + if self.use_depth_embed: + self.depth_embeddings = nn.Parameter(torch.zeros(effective_layers, model_dim, dtype=torch.float32)) + self.use_unique_norms = use_unique_norms + if self.use_unique_norms: + num_unique = num_shared * self.num_loops + self.unique_attn_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + self.unique_mlp_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + self.use_depth_embed = False + self.use_unique_norms = False + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None, None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) + v += 1 + uid = 0 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + if self.use_unique_norms: + ag_n = self.unique_attn_gains[uid].to(dtype=x.dtype) + mg_n = self.unique_mlp_gains[uid].to(dtype=x.dtype) + x = block(x, x0, ag, mg, ab, mb, de, ag_n, mg_n) + else: + x = block(x, x0, ag, mg, ab, mb, de) + uid += 1 + v += 1 + for block in self.coda_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + effective_count = gpt.num_prelude + len(gpt.shared_blocks) * gpt.num_loops + gpt.num_coda + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + + # Prelude blocks + for block in gpt.prelude_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Shared positions + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Coda blocks + for block in gpt.coda_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + if gpt.timestep_scale is not None and gpt.timestep_scale.attn_beta is not None: + attn_bias_norms: list[str] = [] + mlp_bias_norms: list[str] = [] + for vi in range(effective_count): + ab_rms = gpt.timestep_scale.attn_beta[vi].norm().item() / gpt.timestep_scale.attn_beta[vi].numel() ** 0.5 + mb_rms = gpt.timestep_scale.mlp_beta[vi].norm().item() / gpt.timestep_scale.mlp_beta[vi].numel() ** 0.5 + attn_bias_norms.append(f"v{vi}:{ab_rms:.4f}") + mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") + parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") + parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") + if gpt.use_unique_norms: + un_attn: list[str] = [] + un_mlp: list[str] = [] + for ui in range(gpt.unique_attn_gains.size(0)): + an_rms = gpt.unique_attn_gains[ui].norm().item() / gpt.unique_attn_gains[ui].numel() ** 0.5 + un_attn.append(f"u{ui}:{an_rms:.4f}") + mn_rms = gpt.unique_mlp_gains[ui].norm().item() / gpt.unique_mlp_gains[ui].numel() ** 0.5 + un_mlp.append(f"u{ui}:{mn_rms:.4f}") + parts.append("unique_attn_gain_rms:[" + " ".join(un_attn) + "]") + parts.append("unique_mlp_gain_rms:[" + " ".join(un_mlp) + "]") + if gpt.use_depth_embed: + de_norms: list[str] = [] + for vi in range(effective_count): + de_rms = gpt.depth_embeddings[vi].norm().item() / gpt.depth_embeddings[vi].numel() ** 0.5 + de_norms.append(f"v{vi}:{de_rms:.4f}") + parts.append("depth_emb_rms:[" + " ".join(de_norms) + "]") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + use_timestep_bias=args.use_timestep_bias, + use_depth_embed=args.use_depth_embed, + use_unique_norms=args.use_unique_norms, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + if base_model.use_unique_norms: + block_named_params.extend([("unique_attn_gains", base_model.unique_attn_gains)]) + block_named_params.extend([("unique_mlp_gains", base_model.unique_mlp_gains)]) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + if base_model.use_depth_embed: + scalar_params.append(base_model.depth_embeddings) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + num_shared = len(base_model.shared_blocks) + eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{num_shared} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + log0(f"depth_embed:{'enabled' if base_model.use_depth_embed else 'disabled'}") + log0(f"unique_norms:{'enabled' if base_model.use_unique_norms else 'disabled'}") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Apr 2 15:16:42 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 45C P0 125W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 36C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 35C P0 123W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 44C P0 126W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 45C P0 126W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 35C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 43C P0 122W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 34C P0 119W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11584560 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:3 coda:1 effective_layers:14 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:28672 +depth_embed:disabled +unique_norms:enabled +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9377 train_time:44ms step_avg:44.22ms +step:2/20000 train_loss:9.3723 train_time:101ms step_avg:50.72ms +step:3/20000 train_loss:8.1273 train_time:162ms step_avg:54.12ms +step:4/20000 train_loss:9.5869 train_time:227ms step_avg:56.74ms +step:5/20000 train_loss:8.8318 train_time:287ms step_avg:57.46ms +step:6/20000 train_loss:8.4175 train_time:347ms step_avg:57.86ms +step:7/20000 train_loss:7.5439 train_time:409ms step_avg:58.48ms +step:8/20000 train_loss:6.6940 train_time:472ms step_avg:59.01ms +step:9/20000 train_loss:6.0419 train_time:531ms step_avg:58.99ms +step:10/20000 train_loss:5.7016 train_time:593ms step_avg:59.27ms +step:200/20000 train_loss:2.7785 train_time:12293ms step_avg:61.46ms +step:200 shared0_alpha:mean=0.459,std=0.052 shared1_alpha:mean=0.473,std=0.043 shared2_alpha:mean=0.483,std=0.043 shared3_alpha:mean=0.507,std=0.044 eff_mlp_scale:[v0:35.6677 v1:26.8995 v2:27.6057 v3:30.0749 v4:32.3163 v5:31.1391 v6:28.0304 v7:29.3341 v8:31.5281 v9:30.7005 v10:28.8798 v11:31.5564 v12:35.4692 v13:59.2673] eff_attn_scale:[v0:14.5629 v1:11.6750 v2:11.8347 v3:11.1855 v4:11.5008 v5:11.0696 v6:11.0629 v7:10.4652 v8:10.9429 v9:11.3290 v10:11.0629 v11:11.0160 v12:11.4150 v13:15.7221] eff_attn_bias:[v0:0.1643 v1:0.1215 v2:0.1222 v3:0.1195 v4:0.1353 v5:0.1367 v6:0.1340 v7:0.1264 v8:0.1374 v9:0.1367 v10:0.1298 v11:0.1160 v12:0.1132 v13:0.1188] eff_mlp_bias:[v0:0.1222 v1:0.1139 v2:0.1139 v3:0.1250 v4:0.1271 v5:0.1250 v6:0.1181 v7:0.1243 v8:0.1284 v9:0.1188 v10:0.1077 v11:0.1029 v12:0.1195 v13:0.1961] unique_attn_gain_rms:[u0:0.8801 u1:0.8928 u2:0.8597 u3:0.8440 u4:0.8574 u5:0.8582 u6:0.8285 u7:0.8435 u8:0.8930 u9:0.9009 u10:0.8857 u11:0.8788] unique_mlp_gain_rms:[u0:1.1011 u1:1.0895 u2:1.1097 u3:1.1019 u4:1.1015 u5:1.0941 u6:1.1026 u7:1.0987 u8:1.1185 u9:1.1145 u10:1.1245 u11:1.1394] +step:200/20000 val_loss:2.7660 val_bpb:1.6382 train_time:12354ms step_avg:61.77ms +step:400/20000 train_loss:2.3635 train_time:24693ms step_avg:61.73ms +step:400 shared0_alpha:mean=0.469,std=0.061 shared1_alpha:mean=0.489,std=0.052 shared2_alpha:mean=0.507,std=0.050 shared3_alpha:mean=0.535,std=0.049 eff_mlp_scale:[v0:43.8384 v1:32.7964 v2:35.2771 v3:39.0784 v4:39.3739 v5:43.9575 v6:38.8702 v7:39.5816 v8:39.2042 v9:41.8970 v10:37.4003 v11:38.9107 v12:39.7133 v13:75.0169] eff_attn_scale:[v0:6.5474 v1:7.2248 v2:7.1589 v3:6.7523 v4:7.0217 v5:7.3912 v6:7.1272 v7:6.6909 v8:6.9910 v9:7.4245 v10:6.5570 v11:6.1998 v12:6.2551 v13:9.4516] eff_attn_bias:[v0:0.2072 v1:0.1664 v2:0.1699 v3:0.1754 v4:0.1947 v5:0.1892 v6:0.1878 v7:0.1719 v8:0.1864 v9:0.1678 v10:0.1630 v11:0.1416 v12:0.1360 v13:0.1340] eff_mlp_bias:[v0:0.2279 v1:0.1512 v2:0.1485 v3:0.1623 v4:0.1699 v5:0.1574 v6:0.1422 v7:0.1457 v8:0.1581 v9:0.1464 v10:0.1367 v11:0.1312 v12:0.1443 v13:0.2693] unique_attn_gain_rms:[u0:0.7977 u1:0.8251 u2:0.7979 u3:0.7747 u4:0.8084 u5:0.8172 u6:0.7886 u7:0.8044 u8:0.8473 u9:0.8523 u10:0.8311 u11:0.8070] unique_mlp_gain_rms:[u0:1.1993 u1:1.1854 u2:1.1919 u3:1.1783 u4:1.1719 u5:1.1597 u6:1.1575 u7:1.1569 u8:1.1780 u9:1.1666 u10:1.1685 u11:1.1880] +step:400/20000 val_loss:2.5650 val_bpb:1.5192 train_time:24719ms step_avg:61.80ms +step:600/20000 train_loss:2.5820 train_time:37086ms step_avg:61.81ms +step:600 shared0_alpha:mean=0.476,std=0.066 shared1_alpha:mean=0.497,std=0.058 shared2_alpha:mean=0.523,std=0.053 shared3_alpha:mean=0.554,std=0.050 eff_mlp_scale:[v0:50.4095 v1:37.5619 v2:40.2017 v3:44.4413 v4:43.3103 v5:51.8096 v6:46.1445 v7:46.5745 v8:44.1906 v9:48.4790 v10:42.1244 v11:42.8414 v12:41.9018 v13:90.8821] eff_attn_scale:[v0:3.0899 v1:4.6705 v2:4.8356 v3:4.8769 v4:5.0717 v5:5.0811 v6:5.0565 v7:4.7585 v8:5.0477 v9:5.0811 v10:4.4183 v11:4.0009 v12:4.1387 v13:6.4956] eff_attn_bias:[v0:0.2610 v1:0.2237 v2:0.2141 v3:0.2320 v4:0.2486 v5:0.2472 v6:0.2472 v7:0.2293 v8:0.2444 v9:0.2141 v10:0.2016 v11:0.1713 v12:0.1595 v13:0.1637] eff_mlp_bias:[v0:0.3466 v1:0.2099 v2:0.1947 v3:0.2003 v4:0.2224 v5:0.2058 v6:0.1795 v7:0.1837 v8:0.2072 v9:0.1878 v10:0.1754 v11:0.1678 v12:0.1795 v13:0.2817] unique_attn_gain_rms:[u0:0.7361 u1:0.7815 u2:0.7676 u3:0.7433 u4:0.7740 u5:0.7936 u6:0.7625 u7:0.7731 u8:0.8108 u9:0.8209 u10:0.7913 u11:0.7569] unique_mlp_gain_rms:[u0:1.3095 u1:1.2884 u2:1.2755 u3:1.2586 u4:1.2493 u5:1.2322 u6:1.2188 u7:1.2190 u8:1.2468 u9:1.2269 u10:1.2203 u11:1.2402] +step:600/20000 val_loss:2.4815 val_bpb:1.4697 train_time:37113ms step_avg:61.85ms +step:800/20000 train_loss:2.3365 train_time:49510ms step_avg:61.89ms +step:800 shared0_alpha:mean=0.480,std=0.070 shared1_alpha:mean=0.502,std=0.060 shared2_alpha:mean=0.533,std=0.055 shared3_alpha:mean=0.567,std=0.051 eff_mlp_scale:[v0:57.1354 v1:42.2276 v2:44.0571 v3:48.2342 v4:46.0004 v5:57.9899 v6:51.9179 v7:51.5480 v8:48.1736 v9:53.7088 v10:45.7024 v11:45.4727 v12:44.3705 v13:103.6657] eff_attn_scale:[v0:2.0493 v1:3.6577 v2:4.1966 v3:4.2121 v4:4.3366 v5:4.1367 v6:4.2397 v7:3.9463 v8:4.2121 v9:3.9843 v10:3.5725 v11:3.1284 v12:3.2577 v13:5.0813] eff_attn_bias:[v0:0.3052 v1:0.2762 v2:0.2583 v3:0.2790 v4:0.2955 v5:0.3011 v6:0.2983 v7:0.2817 v8:0.2955 v9:0.2541 v10:0.2389 v11:0.2003 v12:0.1851 v13:0.1975] eff_mlp_bias:[v0:0.4502 v1:0.2638 v2:0.2375 v3:0.2389 v4:0.2707 v5:0.2527 v6:0.2168 v7:0.2224 v8:0.2569 v9:0.2306 v10:0.2141 v11:0.2030 v12:0.2154 v13:0.2928] unique_attn_gain_rms:[u0:0.7073 u1:0.7766 u2:0.7686 u3:0.7249 u4:0.7584 u5:0.7768 u6:0.7486 u7:0.7511 u8:0.7831 u9:0.7967 u10:0.7603 u11:0.7180] unique_mlp_gain_rms:[u0:1.4130 u1:1.3814 u2:1.3605 u3:1.3367 u4:1.3210 u5:1.3006 u6:1.2828 u7:1.2842 u8:1.3080 u9:1.2901 u10:1.2758 u11:1.2962] +step:800/20000 val_loss:2.4235 val_bpb:1.4354 train_time:49537ms step_avg:61.92ms +step:1000/20000 train_loss:2.4168 train_time:61939ms step_avg:61.94ms +step:1000 shared0_alpha:mean=0.482,std=0.074 shared1_alpha:mean=0.505,std=0.064 shared2_alpha:mean=0.540,std=0.056 shared3_alpha:mean=0.577,std=0.053 eff_mlp_scale:[v0:62.3978 v1:45.6620 v2:47.2734 v3:51.1610 v4:48.7781 v5:62.6337 v6:55.7488 v7:55.7086 v8:51.3846 v9:57.7846 v10:48.5918 v11:48.1292 v12:46.7302 v13:114.1700] eff_attn_scale:[v0:1.5138 v1:3.4767 v2:4.3609 v3:4.2604 v4:4.2271 v5:3.8653 v6:4.1512 v7:3.7523 v8:3.8944 v9:3.5994 v10:3.3545 v11:2.8728 v12:2.9551 v13:4.2356] eff_attn_bias:[v0:0.3411 v1:0.3176 v2:0.2914 v3:0.3149 v4:0.3342 v5:0.3466 v6:0.3397 v7:0.3259 v8:0.3411 v9:0.2886 v10:0.2721 v11:0.2279 v12:0.2099 v13:0.2306] eff_mlp_bias:[v0:0.5303 v1:0.3025 v2:0.2762 v3:0.2721 v4:0.3066 v5:0.2942 v6:0.2472 v7:0.2527 v8:0.2983 v9:0.2665 v10:0.2444 v11:0.2334 v12:0.2486 v13:0.3038] unique_attn_gain_rms:[u0:0.7041 u1:0.7984 u2:0.7866 u3:0.7261 u4:0.7514 u5:0.7745 u6:0.7436 u7:0.7377 u8:0.7675 u9:0.7786 u10:0.7349 u11:0.6882] unique_mlp_gain_rms:[u0:1.5083 u1:1.4752 u2:1.4415 u3:1.4065 u4:1.3884 u5:1.3694 u6:1.3428 u7:1.3508 u8:1.3718 u9:1.3538 u10:1.3339 u11:1.3551] +step:1000/20000 val_loss:2.3836 val_bpb:1.4117 train_time:61966ms step_avg:61.97ms +step:1200/20000 train_loss:2.4374 train_time:74481ms step_avg:62.07ms +step:1200 shared0_alpha:mean=0.483,std=0.077 shared1_alpha:mean=0.506,std=0.066 shared2_alpha:mean=0.545,std=0.058 shared3_alpha:mean=0.585,std=0.055 eff_mlp_scale:[v0:68.1810 v1:48.3162 v2:49.6499 v3:53.2952 v4:50.5253 v5:66.3570 v6:59.2719 v7:58.3157 v8:53.9443 v9:61.3802 v10:50.8045 v11:50.2056 v12:49.0058 v13:123.5359] eff_attn_scale:[v0:1.2055 v1:3.5762 v2:4.9226 v3:4.7608 v4:4.4796 v5:4.0667 v6:4.4282 v7:3.8880 v8:3.9148 v9:3.5353 v10:3.3534 v11:2.8565 v12:2.8825 v13:3.6667] eff_attn_bias:[v0:0.3674 v1:0.3591 v2:0.3149 v3:0.3453 v4:0.3701 v5:0.3812 v6:0.3784 v7:0.3646 v8:0.3757 v9:0.3204 v10:0.2969 v11:0.2527 v12:0.2348 v13:0.2638] eff_mlp_bias:[v0:0.5966 v1:0.3370 v2:0.3025 v3:0.2997 v4:0.3342 v5:0.3287 v6:0.2748 v7:0.2790 v8:0.3342 v9:0.2969 v10:0.2693 v11:0.2596 v12:0.2804 v13:0.3121] unique_attn_gain_rms:[u0:0.7287 u1:0.8440 u2:0.8298 u3:0.7493 u4:0.7672 u5:0.7838 u6:0.7453 u7:0.7327 u8:0.7539 u9:0.7588 u10:0.7138 u11:0.6614] unique_mlp_gain_rms:[u0:1.6000 u1:1.5610 u2:1.5149 u3:1.4798 u4:1.4519 u5:1.4360 u6:1.4034 u7:1.4133 u8:1.4335 u9:1.4155 u10:1.3906 u11:1.4127] +step:1200/20000 val_loss:2.3514 val_bpb:1.3926 train_time:74508ms step_avg:62.09ms +step:1400/20000 train_loss:2.4817 train_time:86908ms step_avg:62.08ms +step:1400 shared0_alpha:mean=0.485,std=0.080 shared1_alpha:mean=0.509,std=0.067 shared2_alpha:mean=0.550,std=0.059 shared3_alpha:mean=0.592,std=0.056 eff_mlp_scale:[v0:72.8331 v1:51.1460 v2:52.1633 v3:55.3544 v4:52.2239 v5:69.3218 v6:61.5762 v7:60.8506 v8:56.0923 v9:63.8268 v10:52.9477 v11:52.2138 v12:51.0634 v13:132.0159] eff_attn_scale:[v0:1.0523 v1:3.9309 v2:5.6888 v3:5.5653 v4:5.0608 v5:4.6693 v6:4.9806 v7:4.3970 v8:4.1903 v9:3.8441 v10:3.6326 v11:3.1438 v12:2.9960 v13:3.3811] eff_attn_bias:[v0:0.3812 v1:0.3895 v2:0.3411 v3:0.3701 v4:0.3950 v5:0.4088 v6:0.4088 v7:0.3950 v8:0.4060 v9:0.3480 v10:0.3190 v11:0.2721 v12:0.2569 v13:0.2955] eff_mlp_bias:[v0:0.6602 v1:0.3674 v2:0.3246 v3:0.3218 v4:0.3591 v5:0.3591 v6:0.2997 v7:0.3025 v8:0.3591 v9:0.3218 v10:0.2873 v11:0.2817 v12:0.3080 v13:0.3232] unique_attn_gain_rms:[u0:0.7655 u1:0.9014 u2:0.8747 u3:0.7805 u4:0.8005 u5:0.8040 u6:0.7555 u7:0.7366 u8:0.7513 u9:0.7490 u10:0.6988 u11:0.6431] unique_mlp_gain_rms:[u0:1.6861 u1:1.6453 u2:1.5851 u3:1.5494 u4:1.5147 u5:1.4981 u6:1.4601 u7:1.4731 u8:1.4936 u9:1.4742 u10:1.4446 u11:1.4699] +step:1400/20000 val_loss:2.3313 val_bpb:1.3807 train_time:86936ms step_avg:62.10ms +step:1600/20000 train_loss:2.1507 train_time:99326ms step_avg:62.08ms +step:1600 shared0_alpha:mean=0.486,std=0.081 shared1_alpha:mean=0.510,std=0.068 shared2_alpha:mean=0.553,std=0.060 shared3_alpha:mean=0.597,std=0.058 eff_mlp_scale:[v0:77.0572 v1:53.2558 v2:54.2990 v3:57.1017 v4:54.0702 v5:71.8633 v6:63.8112 v7:63.0498 v8:57.5965 v9:65.8747 v10:53.9027 v11:53.1363 v12:52.8948 v13:138.4714] eff_attn_scale:[v0:0.9192 v1:4.2253 v2:6.5284 v3:6.3522 v4:5.7784 v5:5.3165 v6:5.7124 v7:5.0863 v8:4.7357 v9:4.3878 v10:4.2286 v11:3.6134 v12:3.3454 v13:3.1579] eff_attn_bias:[v0:0.3977 v1:0.4198 v2:0.3591 v3:0.3922 v4:0.4198 v5:0.4337 v6:0.4281 v7:0.4171 v8:0.4309 v9:0.3701 v10:0.3356 v11:0.2900 v12:0.2790 v13:0.3259] eff_mlp_bias:[v0:0.7043 v1:0.3922 v2:0.3439 v3:0.3439 v4:0.3784 v5:0.3895 v6:0.3176 v7:0.3204 v8:0.3812 v9:0.3480 v10:0.3025 v11:0.3025 v12:0.3384 v13:0.3356] unique_attn_gain_rms:[u0:0.8040 u1:0.9544 u2:0.9131 u3:0.8067 u4:0.8387 u5:0.8298 u6:0.7767 u7:0.7542 u8:0.7644 u9:0.7530 u10:0.6935 u11:0.6336] unique_mlp_gain_rms:[u0:1.7707 u1:1.7290 u2:1.6575 u3:1.6151 u4:1.5768 u5:1.5639 u6:1.5175 u7:1.5335 u8:1.5540 u9:1.5322 u10:1.5008 u11:1.5301] +step:1600/20000 val_loss:2.3169 val_bpb:1.3722 train_time:99353ms step_avg:62.10ms +step:1800/20000 train_loss:2.2570 train_time:111730ms step_avg:62.07ms +step:1800 shared0_alpha:mean=0.486,std=0.083 shared1_alpha:mean=0.510,std=0.069 shared2_alpha:mean=0.555,std=0.061 shared3_alpha:mean=0.601,std=0.059 eff_mlp_scale:[v0:81.3569 v1:55.4127 v2:56.0650 v3:58.7604 v4:55.4168 v5:73.5950 v6:65.2757 v7:63.9569 v8:59.4037 v9:67.5342 v10:55.2641 v11:54.3634 v12:55.0182 v13:145.6469] eff_attn_scale:[v0:0.8010 v1:4.5068 v2:7.0850 v3:7.0971 v4:6.4064 v5:5.9598 v6:6.2775 v7:5.7664 v8:5.3619 v9:4.9747 v10:4.6886 v11:4.1400 v12:3.7603 v13:2.9589] eff_attn_bias:[v0:0.4060 v1:0.4447 v2:0.3757 v3:0.4171 v4:0.4447 v5:0.4558 v6:0.4502 v7:0.4392 v8:0.4558 v9:0.3922 v10:0.3508 v11:0.3066 v12:0.2997 v13:0.3563] eff_mlp_bias:[v0:0.7458 v1:0.4171 v2:0.3591 v3:0.3618 v4:0.3922 v5:0.4171 v6:0.3342 v7:0.3384 v8:0.4005 v9:0.3674 v10:0.3163 v11:0.3218 v12:0.3646 v13:0.3480] unique_attn_gain_rms:[u0:0.8530 u1:0.9992 u2:0.9464 u3:0.8319 u4:0.8777 u5:0.8586 u6:0.8002 u7:0.7701 u8:0.7864 u9:0.7621 u10:0.6961 u11:0.6356] unique_mlp_gain_rms:[u0:1.8499 u1:1.8065 u2:1.7254 u3:1.6817 u4:1.6410 u5:1.6197 u6:1.5707 u7:1.5950 u8:1.6088 u9:1.5871 u10:1.5574 u11:1.5866] +step:1800/20000 val_loss:2.3022 val_bpb:1.3635 train_time:111758ms step_avg:62.09ms +step:2000/20000 train_loss:2.3032 train_time:124129ms step_avg:62.06ms +step:2000 shared0_alpha:mean=0.486,std=0.084 shared1_alpha:mean=0.511,std=0.070 shared2_alpha:mean=0.557,std=0.062 shared3_alpha:mean=0.604,std=0.061 eff_mlp_scale:[v0:85.7770 v1:57.2593 v2:57.9329 v3:60.5938 v4:57.4779 v5:75.6173 v6:67.2508 v7:65.8453 v8:61.5256 v9:69.4980 v10:56.7175 v11:55.7463 v12:56.6683 v13:152.5784] eff_attn_scale:[v0:0.7492 v1:4.7465 v2:7.6457 v3:7.7071 v4:7.1487 v5:6.4142 v6:6.8812 v7:6.5012 v8:6.0317 v9:5.5419 v10:5.2701 v11:4.7711 v12:4.2942 v13:2.7855] eff_attn_bias:[v0:0.4171 v1:0.4696 v2:0.3922 v3:0.4309 v4:0.4640 v5:0.4751 v6:0.4723 v7:0.4613 v8:0.4751 v9:0.4143 v10:0.3618 v11:0.3218 v12:0.3190 v13:0.3867] eff_mlp_bias:[v0:0.7844 v1:0.4419 v2:0.3784 v3:0.3757 v4:0.4033 v5:0.4447 v6:0.3494 v7:0.3536 v8:0.4198 v9:0.3922 v10:0.3273 v11:0.3397 v12:0.3922 v13:0.3618] unique_attn_gain_rms:[u0:0.9076 u1:1.0399 u2:0.9754 u3:0.8571 u4:0.9188 u5:0.8889 u6:0.8261 u7:0.7896 u8:0.8094 u9:0.7728 u10:0.7024 u11:0.6469] unique_mlp_gain_rms:[u0:1.9346 u1:1.8826 u2:1.7939 u3:1.7454 u4:1.7007 u5:1.6760 u6:1.6275 u7:1.6513 u8:1.6653 u9:1.6449 u10:1.6075 u11:1.6419] +step:2000/20000 val_loss:2.2850 val_bpb:1.3533 train_time:124158ms step_avg:62.08ms +step:2200/20000 train_loss:2.1260 train_time:136519ms step_avg:62.05ms +step:2200 shared0_alpha:mean=0.487,std=0.085 shared1_alpha:mean=0.512,std=0.070 shared2_alpha:mean=0.559,std=0.062 shared3_alpha:mean=0.608,std=0.062 eff_mlp_scale:[v0:90.0882 v1:59.2065 v2:59.2162 v3:61.4091 v4:59.2311 v5:77.7638 v6:69.0175 v7:67.1027 v8:63.3444 v9:71.1362 v10:57.5827 v11:56.9356 v12:58.4085 v13:158.0963] eff_attn_scale:[v0:0.6839 v1:5.0144 v2:8.2573 v3:8.4018 v4:7.7752 v5:7.0093 v6:7.4545 v7:7.1777 v8:6.6983 v9:6.2005 v10:5.8489 v11:5.3972 v12:4.8070 v13:2.6813] eff_attn_bias:[v0:0.4171 v1:0.4917 v2:0.4088 v3:0.4475 v4:0.4834 v5:0.4972 v6:0.4889 v7:0.4778 v8:0.4944 v9:0.4337 v10:0.3757 v11:0.3342 v12:0.3370 v13:0.4116] eff_mlp_bias:[v0:0.8231 v1:0.4640 v2:0.3895 v3:0.3950 v4:0.4171 v5:0.4723 v6:0.3618 v7:0.3674 v8:0.4337 v9:0.4116 v10:0.3384 v11:0.3536 v12:0.4143 v13:0.3784] unique_attn_gain_rms:[u0:0.9487 u1:1.0803 u2:1.0052 u3:0.8770 u4:0.9562 u5:0.9182 u6:0.8486 u7:0.8082 u8:0.8368 u9:0.7903 u10:0.7128 u11:0.6652] unique_mlp_gain_rms:[u0:2.0127 u1:1.9539 u2:1.8585 u3:1.8063 u4:1.7567 u5:1.7340 u6:1.6809 u7:1.7076 u8:1.7270 u9:1.6991 u10:1.6627 u11:1.6974] +step:2200/20000 val_loss:2.2774 val_bpb:1.3488 train_time:136547ms step_avg:62.07ms +step:2400/20000 train_loss:2.2503 train_time:148901ms step_avg:62.04ms +step:2400 shared0_alpha:mean=0.487,std=0.086 shared1_alpha:mean=0.512,std=0.071 shared2_alpha:mean=0.560,std=0.064 shared3_alpha:mean=0.610,std=0.063 eff_mlp_scale:[v0:94.0930 v1:61.0268 v2:60.9092 v3:63.6205 v4:60.3936 v5:79.7358 v6:70.7864 v7:68.5460 v8:64.9752 v9:72.6086 v10:58.8514 v11:57.8741 v12:59.9771 v13:164.1867] eff_attn_scale:[v0:0.6263 v1:5.2008 v2:8.6560 v3:8.8990 v4:8.2428 v5:7.4536 v6:7.9446 v7:7.7433 v8:7.2056 v9:6.7583 v10:6.3438 v11:5.8942 v12:5.2131 v13:2.5763] eff_attn_bias:[v0:0.4254 v1:0.5138 v2:0.4171 v3:0.4640 v4:0.4972 v5:0.5165 v6:0.5027 v7:0.4972 v8:0.5110 v9:0.4502 v10:0.3867 v11:0.3466 v12:0.3536 v13:0.4392] eff_mlp_bias:[v0:0.8507 v1:0.4861 v2:0.4033 v3:0.4116 v4:0.4309 v5:0.4944 v6:0.3757 v7:0.3784 v8:0.4447 v9:0.4281 v10:0.3480 v11:0.3701 v12:0.4392 v13:0.3922] unique_attn_gain_rms:[u0:0.9886 u1:1.1125 u2:1.0329 u3:0.8992 u4:0.9883 u5:0.9437 u6:0.8695 u7:0.8290 u8:0.8626 u9:0.8057 u10:0.7216 u11:0.6828] unique_mlp_gain_rms:[u0:2.0867 u1:2.0197 u2:1.9218 u3:1.8687 u4:1.8154 u5:1.7856 u6:1.7310 u7:1.7613 u8:1.7809 u9:1.7533 u10:1.7163 u11:1.7513] +step:2400/20000 val_loss:2.2660 val_bpb:1.3420 train_time:148928ms step_avg:62.05ms +step:2600/20000 train_loss:2.4554 train_time:161277ms step_avg:62.03ms +step:2600 shared0_alpha:mean=0.487,std=0.088 shared1_alpha:mean=0.512,std=0.071 shared2_alpha:mean=0.562,std=0.065 shared3_alpha:mean=0.611,std=0.064 eff_mlp_scale:[v0:97.7932 v1:62.5160 v2:62.6442 v3:64.9160 v4:62.0411 v5:81.8555 v6:72.1860 v7:69.8777 v8:66.6836 v9:73.7599 v10:60.1550 v11:59.1273 v12:61.6190 v13:169.5521] eff_attn_scale:[v0:0.5965 v1:5.3171 v2:8.9521 v3:9.4340 v4:8.6700 v5:7.8477 v6:8.2867 v7:8.2923 v8:7.6566 v9:7.1938 v10:6.7141 v11:6.4296 v12:5.5454 v13:2.4726] eff_attn_bias:[v0:0.4337 v1:0.5386 v2:0.4281 v3:0.4778 v4:0.5138 v5:0.5359 v6:0.5220 v7:0.5110 v8:0.5248 v9:0.4640 v10:0.3977 v11:0.3563 v12:0.3646 v13:0.4640] eff_mlp_bias:[v0:0.8839 v1:0.5138 v2:0.4143 v3:0.4281 v4:0.4447 v5:0.5165 v6:0.3867 v7:0.3867 v8:0.4585 v9:0.4475 v10:0.3563 v11:0.3812 v12:0.4558 v13:0.4060] unique_attn_gain_rms:[u0:1.0354 u1:1.1471 u2:1.0583 u3:0.9144 u4:1.0261 u5:0.9670 u6:0.8905 u7:0.8464 u8:0.8884 u9:0.8220 u10:0.7343 u11:0.7010] unique_mlp_gain_rms:[u0:2.1571 u1:2.0896 u2:1.9861 u3:1.9318 u4:1.8710 u5:1.8404 u6:1.7853 u7:1.8185 u8:1.8374 u9:1.8085 u10:1.7676 u11:1.8066] +step:2600/20000 val_loss:2.2718 val_bpb:1.3455 train_time:161304ms step_avg:62.04ms +step:2800/20000 train_loss:2.2854 train_time:173647ms step_avg:62.02ms +step:2800 shared0_alpha:mean=0.487,std=0.088 shared1_alpha:mean=0.512,std=0.072 shared2_alpha:mean=0.562,std=0.065 shared3_alpha:mean=0.614,std=0.065 eff_mlp_scale:[v0:101.7621 v1:64.3826 v2:64.2009 v3:66.2011 v4:63.5100 v5:83.4253 v6:73.3724 v7:71.6138 v8:68.1987 v9:75.2642 v10:60.8658 v11:59.9557 v12:63.0838 v13:174.3746] eff_attn_scale:[v0:0.5567 v1:5.4599 v2:9.4096 v3:9.7780 v4:9.0818 v5:8.2336 v6:8.7242 v7:8.7878 v8:8.2086 v9:7.6497 v10:7.1974 v11:6.9003 v12:5.9090 v13:2.3951] eff_attn_bias:[v0:0.4364 v1:0.5607 v2:0.4364 v3:0.4889 v4:0.5303 v5:0.5497 v6:0.5359 v7:0.5276 v8:0.5414 v9:0.4778 v10:0.4088 v11:0.3674 v12:0.3812 v13:0.4889] eff_mlp_bias:[v0:0.9115 v1:0.5386 v2:0.4254 v3:0.4419 v4:0.4530 v5:0.5359 v6:0.3950 v7:0.3950 v8:0.4723 v9:0.4640 v10:0.3646 v11:0.3950 v12:0.4778 v13:0.4226] unique_attn_gain_rms:[u0:1.0817 u1:1.1744 u2:1.0809 u3:0.9305 u4:1.0617 u5:0.9931 u6:0.9117 u7:0.8659 u8:0.9176 u9:0.8393 u10:0.7491 u11:0.7244] unique_mlp_gain_rms:[u0:2.2243 u1:2.1586 u2:2.0484 u3:1.9918 u4:1.9278 u5:1.8956 u6:1.8357 u7:1.8745 u8:1.8953 u9:1.8635 u10:1.8236 u11:1.8627] +step:2800/20000 val_loss:2.2531 val_bpb:1.3344 train_time:173674ms step_avg:62.03ms +step:3000/20000 train_loss:2.2739 train_time:186010ms step_avg:62.00ms +step:3000 shared0_alpha:mean=0.486,std=0.089 shared1_alpha:mean=0.512,std=0.073 shared2_alpha:mean=0.564,std=0.065 shared3_alpha:mean=0.615,std=0.065 eff_mlp_scale:[v0:104.8985 v1:65.8132 v2:65.3638 v3:67.6042 v4:64.8209 v5:84.5516 v6:74.5817 v7:73.0629 v8:70.0065 v9:76.7820 v10:61.5928 v11:61.3057 v12:64.8209 v13:179.4815] eff_attn_scale:[v0:0.5227 v1:5.5853 v2:9.6196 v3:10.2131 v4:9.4947 v5:8.5561 v6:9.0500 v7:9.1981 v8:8.5990 v9:8.0808 v10:7.5312 v11:7.3268 v12:6.2104 v13:2.3091] eff_attn_bias:[v0:0.4419 v1:0.5828 v2:0.4475 v3:0.5055 v4:0.5414 v5:0.5635 v6:0.5524 v7:0.5359 v8:0.5524 v9:0.4917 v10:0.4198 v11:0.3784 v12:0.3950 v13:0.5138] eff_mlp_bias:[v0:0.9391 v1:0.5552 v2:0.4364 v3:0.4558 v4:0.4640 v5:0.5524 v6:0.4060 v7:0.4033 v8:0.4778 v9:0.4778 v10:0.3729 v11:0.4060 v12:0.4972 v13:0.4364] unique_attn_gain_rms:[u0:1.1173 u1:1.2029 u2:1.1004 u3:0.9536 u4:1.0901 u5:1.0130 u6:0.9312 u7:0.8851 u8:0.9452 u9:0.8543 u10:0.7614 u11:0.7456] unique_mlp_gain_rms:[u0:2.2940 u1:2.2214 u2:2.1026 u3:2.0504 u4:1.9812 u5:1.9483 u6:1.8882 u7:1.9294 u8:1.9513 u9:1.9162 u10:1.8747 u11:1.9153] +step:3000/20000 val_loss:2.2440 val_bpb:1.3290 train_time:186039ms step_avg:62.01ms +step:3200/20000 train_loss:2.2387 train_time:198381ms step_avg:61.99ms +step:3200 shared0_alpha:mean=0.485,std=0.091 shared1_alpha:mean=0.512,std=0.073 shared2_alpha:mean=0.565,std=0.066 shared3_alpha:mean=0.618,std=0.066 eff_mlp_scale:[v0:108.8118 v1:67.1556 v2:66.7870 v3:68.9604 v4:66.4066 v5:86.0143 v6:76.0865 v7:74.0373 v8:71.2124 v9:78.1949 v10:62.9827 v11:62.1913 v12:66.4066 v13:183.5314] eff_attn_scale:[v0:0.5080 v1:5.6459 v2:9.9098 v3:10.4958 v4:9.8397 v5:8.7091 v6:9.3307 v7:9.5888 v8:8.9840 v9:8.4087 v10:7.8184 v11:7.7099 v12:6.5088 v13:2.2206] eff_attn_bias:[v0:0.4447 v1:0.5994 v2:0.4585 v3:0.5165 v4:0.5552 v5:0.5800 v6:0.5662 v7:0.5497 v8:0.5662 v9:0.5055 v10:0.4281 v11:0.3867 v12:0.4088 v13:0.5359] eff_mlp_bias:[v0:0.9667 v1:0.5800 v2:0.4447 v3:0.4668 v4:0.4723 v5:0.5690 v6:0.4143 v7:0.4143 v8:0.4861 v9:0.4917 v10:0.3812 v11:0.4143 v12:0.5138 v13:0.4502] unique_attn_gain_rms:[u0:1.1544 u1:1.2301 u2:1.1244 u3:0.9710 u4:1.1134 u5:1.0353 u6:0.9486 u7:0.9038 u8:0.9693 u9:0.8711 u10:0.7772 u11:0.7708] unique_mlp_gain_rms:[u0:2.3587 u1:2.2814 u2:2.1639 u3:2.1080 u4:2.0353 u5:1.9993 u6:1.9348 u7:1.9831 u8:2.0043 u9:1.9695 u10:1.9256 u11:1.9684] +step:3200/20000 val_loss:2.2392 val_bpb:1.3262 train_time:198406ms step_avg:62.00ms +step:3400/20000 train_loss:2.2054 train_time:210745ms step_avg:61.98ms +step:3400 shared0_alpha:mean=0.485,std=0.091 shared1_alpha:mean=0.512,std=0.074 shared2_alpha:mean=0.565,std=0.066 shared3_alpha:mean=0.620,std=0.067 eff_mlp_scale:[v0:112.7855 v1:68.4877 v2:68.0438 v3:70.1929 v4:68.1581 v5:87.9234 v6:77.3998 v7:75.2978 v8:73.0266 v9:79.1310 v10:63.7910 v11:63.3863 v12:68.1581 v13:188.5526] eff_attn_scale:[v0:0.4852 v1:5.7560 v2:10.1801 v3:10.7700 v4:10.0916 v5:9.0146 v6:9.5927 v7:9.9771 v8:9.4064 v9:8.7710 v10:8.1571 v11:8.0609 v12:6.7900 v13:2.1644] eff_attn_bias:[v0:0.4530 v1:0.6215 v2:0.4723 v3:0.5303 v4:0.5690 v5:0.5939 v6:0.5773 v7:0.5635 v8:0.5773 v9:0.5165 v10:0.4392 v11:0.3977 v12:0.4226 v13:0.5552] eff_mlp_bias:[v0:0.9888 v1:0.5994 v2:0.4558 v3:0.4861 v4:0.4834 v5:0.5828 v6:0.4226 v7:0.4254 v8:0.4972 v9:0.5055 v10:0.3895 v11:0.4254 v12:0.5359 v13:0.4640] unique_attn_gain_rms:[u0:1.1898 u1:1.2562 u2:1.1465 u3:0.9859 u4:1.1447 u5:1.0607 u6:0.9694 u7:0.9229 u8:0.9921 u9:0.8909 u10:0.7922 u11:0.7959] unique_mlp_gain_rms:[u0:2.4219 u1:2.3438 u2:2.2231 u3:2.1642 u4:2.0935 u5:2.0492 u6:1.9875 u7:2.0374 u8:2.0603 u9:2.0223 u10:1.9769 u11:2.0203] +step:3400/20000 val_loss:2.2358 val_bpb:1.3242 train_time:210772ms step_avg:61.99ms +step:3600/20000 train_loss:2.1751 train_time:223104ms step_avg:61.97ms +step:3600 shared0_alpha:mean=0.484,std=0.093 shared1_alpha:mean=0.511,std=0.074 shared2_alpha:mean=0.567,std=0.067 shared3_alpha:mean=0.621,std=0.068 eff_mlp_scale:[v0:116.2885 v1:69.8590 v2:69.3857 v3:71.4204 v4:69.9204 v5:88.9538 v6:78.3802 v7:76.1247 v8:74.8507 v9:80.5708 v10:65.1027 v11:64.1500 v12:69.9204 v13:193.4515] eff_attn_scale:[v0:0.4617 v1:5.8362 v2:10.3651 v3:10.9825 v4:10.4872 v5:9.3256 v6:9.9000 v7:10.2458 v8:9.7245 v9:9.1403 v10:8.5047 v11:8.4043 v12:6.9915 v13:2.1283] eff_attn_bias:[v0:0.4530 v1:0.6408 v2:0.4806 v3:0.5441 v4:0.5773 v5:0.6077 v6:0.5911 v7:0.5745 v8:0.5883 v9:0.5331 v10:0.4475 v11:0.4088 v12:0.4337 v13:0.5745] eff_mlp_bias:[v0:1.0165 v1:0.6132 v2:0.4668 v3:0.4944 v4:0.4917 v5:0.5994 v6:0.4309 v7:0.4337 v8:0.5082 v9:0.5193 v10:0.3950 v11:0.4337 v12:0.5497 v13:0.4806] unique_attn_gain_rms:[u0:1.2256 u1:1.2771 u2:1.1634 u3:1.0022 u4:1.1740 u5:1.0826 u6:0.9857 u7:0.9422 u8:1.0158 u9:0.9085 u10:0.8078 u11:0.8167] unique_mlp_gain_rms:[u0:2.4880 u1:2.4071 u2:2.2780 u3:2.2178 u4:2.1461 u5:2.1009 u6:2.0390 u7:2.0892 u8:2.1173 u9:2.0740 u10:2.0286 u11:2.0725] +step:3600/20000 val_loss:2.2310 val_bpb:1.3213 train_time:223131ms step_avg:61.98ms +step:3800/20000 train_loss:2.2713 train_time:235471ms step_avg:61.97ms +step:3800 shared0_alpha:mean=0.483,std=0.094 shared1_alpha:mean=0.512,std=0.075 shared2_alpha:mean=0.568,std=0.068 shared3_alpha:mean=0.622,std=0.069 eff_mlp_scale:[v0:119.6374 v1:71.2338 v2:70.6466 v3:72.7448 v4:71.3041 v5:90.4482 v6:79.6928 v7:77.4797 v8:76.2999 v9:81.5440 v10:65.9081 v11:64.9968 v12:71.3041 v13:197.5608] eff_attn_scale:[v0:0.4573 v1:5.8685 v2:10.6585 v3:11.2210 v4:10.7260 v5:9.4392 v6:10.1222 v7:10.5409 v8:10.0799 v9:9.3771 v10:8.7145 v11:8.7048 v12:7.2369 v13:2.1076] eff_attn_bias:[v0:0.4585 v1:0.6546 v2:0.4944 v3:0.5524 v4:0.5883 v5:0.6242 v6:0.6021 v7:0.5856 v8:0.5994 v9:0.5441 v10:0.4585 v11:0.4171 v12:0.4447 v13:0.5939] eff_mlp_bias:[v0:1.0330 v1:0.6298 v2:0.4751 v3:0.5082 v4:0.4999 v5:0.6160 v6:0.4364 v7:0.4392 v8:0.5165 v9:0.5331 v10:0.4033 v11:0.4447 v12:0.5662 v13:0.4944] unique_attn_gain_rms:[u0:1.2575 u1:1.3013 u2:1.1863 u3:1.0218 u4:1.1966 u5:1.1013 u6:1.0036 u7:0.9562 u8:1.0383 u9:0.9227 u10:0.8204 u11:0.8397] unique_mlp_gain_rms:[u0:2.5491 u1:2.4677 u2:2.3339 u3:2.2705 u4:2.1970 u5:2.1486 u6:2.0880 u7:2.1413 u8:2.1672 u9:2.1242 u10:2.0778 u11:2.1209] +step:3800/20000 val_loss:2.2252 val_bpb:1.3179 train_time:235498ms step_avg:61.97ms +step:4000/20000 train_loss:2.2127 train_time:247833ms step_avg:61.96ms +step:4000 shared0_alpha:mean=0.484,std=0.094 shared1_alpha:mean=0.513,std=0.075 shared2_alpha:mean=0.569,std=0.068 shared3_alpha:mean=0.623,std=0.069 eff_mlp_scale:[v0:123.2553 v1:72.5711 v2:71.8998 v3:73.7564 v4:73.1016 v5:91.8920 v6:80.9955 v7:78.9628 v8:77.6992 v9:82.9384 v10:66.7022 v11:66.3808 v12:73.1016 v13:200.7818] eff_attn_scale:[v0:0.4372 v1:5.9225 v2:10.7673 v3:11.5070 v4:11.0148 v5:9.6514 v6:10.3610 v7:10.8179 v8:10.4212 v9:9.6514 v10:8.9389 v11:8.9575 v12:7.4861 v13:2.0649] eff_attn_bias:[v0:0.4613 v1:0.6712 v2:0.5055 v3:0.5662 v4:0.5966 v5:0.6353 v6:0.6104 v7:0.5939 v8:0.6104 v9:0.5552 v10:0.4696 v11:0.4254 v12:0.4530 v13:0.6132] eff_mlp_bias:[v0:1.0551 v1:0.6491 v2:0.4834 v3:0.5248 v4:0.5110 v5:0.6298 v6:0.4475 v7:0.4502 v8:0.5276 v9:0.5469 v10:0.4116 v11:0.4530 v12:0.5800 v13:0.5082] unique_attn_gain_rms:[u0:1.2890 u1:1.3136 u2:1.2034 u3:1.0347 u4:1.2183 u5:1.1190 u6:1.0230 u7:0.9755 u8:1.0563 u9:0.9378 u10:0.8341 u11:0.8621] unique_mlp_gain_rms:[u0:2.6144 u1:2.5246 u2:2.3881 u3:2.3255 u4:2.2479 u5:2.2002 u6:2.1378 u7:2.1910 u8:2.2210 u9:2.1762 u10:2.1294 u11:2.1726] +step:4000/20000 val_loss:2.2203 val_bpb:1.3150 train_time:247860ms step_avg:61.96ms +step:4200/20000 train_loss:2.2212 train_time:260258ms step_avg:61.97ms +step:4200 shared0_alpha:mean=0.483,std=0.095 shared1_alpha:mean=0.513,std=0.075 shared2_alpha:mean=0.570,std=0.069 shared3_alpha:mean=0.625,std=0.070 eff_mlp_scale:[v0:126.8470 v1:73.9694 v2:73.1358 v3:75.6450 v4:74.3710 v5:93.4100 v6:81.8424 v7:80.0176 v8:79.4840 v9:84.4010 v10:67.9118 v11:67.3372 v12:74.8358 v13:205.5421] eff_attn_scale:[v0:0.4230 v1:5.9103 v2:10.9037 v3:11.6246 v4:11.2429 v5:9.8086 v6:10.5629 v7:11.0019 v8:10.6442 v9:9.8715 v10:9.1318 v11:9.2028 v12:7.6172 v13:2.0086] eff_attn_bias:[v0:0.4613 v1:0.6878 v2:0.5193 v3:0.5800 v4:0.6077 v5:0.6491 v6:0.6242 v7:0.6049 v8:0.6187 v9:0.5690 v10:0.4778 v11:0.4337 v12:0.4613 v13:0.6325] eff_mlp_bias:[v0:1.0772 v1:0.6657 v2:0.4917 v3:0.5359 v4:0.5193 v5:0.6436 v6:0.4530 v7:0.4613 v8:0.5359 v9:0.5580 v10:0.4171 v11:0.4613 v12:0.5939 v13:0.5193] unique_attn_gain_rms:[u0:1.3109 u1:1.3308 u2:1.2238 u3:1.0545 u4:1.2473 u5:1.1368 u6:1.0420 u7:0.9888 u8:1.0768 u9:0.9543 u10:0.8522 u11:0.8818] unique_mlp_gain_rms:[u0:2.6730 u1:2.5838 u2:2.4405 u3:2.3798 u4:2.3015 u5:2.2488 u6:2.1866 u7:2.2436 u8:2.2740 u9:2.2260 u10:2.1802 u11:2.2217] +step:4200/20000 val_loss:2.2155 val_bpb:1.3122 train_time:260289ms step_avg:61.97ms +step:4400/20000 train_loss:2.1625 train_time:272608ms step_avg:61.96ms +step:4400 shared0_alpha:mean=0.482,std=0.096 shared1_alpha:mean=0.512,std=0.077 shared2_alpha:mean=0.570,std=0.069 shared3_alpha:mean=0.626,std=0.071 eff_mlp_scale:[v0:130.3133 v1:74.9624 v2:73.9965 v3:76.1634 v4:76.1794 v5:95.0161 v6:83.6292 v7:81.4464 v8:81.3520 v9:85.4667 v10:68.7423 v11:68.6792 v12:76.6496 v13:209.4651] eff_attn_scale:[v0:0.4094 v1:5.9594 v2:11.1800 v3:11.7255 v4:11.5631 v5:10.0168 v6:10.7659 v7:11.2370 v8:10.9545 v9:10.1436 v10:9.3857 v11:9.4223 v12:7.8440 v13:1.9880] eff_attn_bias:[v0:0.4668 v1:0.7016 v2:0.5248 v3:0.5883 v4:0.6187 v5:0.6602 v6:0.6353 v7:0.6160 v8:0.6298 v9:0.5800 v10:0.4889 v11:0.4419 v12:0.4696 v13:0.6491] eff_mlp_bias:[v0:1.0993 v1:0.6822 v2:0.4999 v3:0.5497 v4:0.5303 v5:0.6574 v6:0.4613 v7:0.4723 v8:0.5441 v9:0.5718 v10:0.4226 v11:0.4723 v12:0.6077 v13:0.5331] unique_attn_gain_rms:[u0:1.3405 u1:1.3510 u2:1.2398 u3:1.0679 u4:1.2677 u5:1.1566 u6:1.0568 u7:0.9991 u8:1.0974 u9:0.9695 u10:0.8660 u11:0.9043] unique_mlp_gain_rms:[u0:2.7358 u1:2.6393 u2:2.4963 u3:2.4307 u4:2.3506 u5:2.2969 u6:2.2361 u7:2.2898 u8:2.3276 u9:2.2776 u10:2.2299 u11:2.2690] +step:4400/20000 val_loss:2.2180 val_bpb:1.3136 train_time:272635ms step_avg:61.96ms +step:4600/20000 train_loss:2.0294 train_time:284966ms step_avg:61.95ms +step:4600 shared0_alpha:mean=0.481,std=0.095 shared1_alpha:mean=0.511,std=0.077 shared2_alpha:mean=0.571,std=0.069 shared3_alpha:mean=0.626,std=0.072 eff_mlp_scale:[v0:134.2302 v1:76.3127 v2:75.2388 v3:77.3208 v4:77.5399 v5:95.9909 v6:84.4787 v7:82.1810 v8:82.7727 v9:86.8717 v10:69.5189 v11:69.3678 v12:78.0156 v13:213.3284] eff_attn_scale:[v0:0.3989 v1:6.0712 v2:11.3593 v3:12.0547 v4:11.7640 v5:10.1612 v6:10.9412 v7:11.5583 v8:11.3512 v9:10.3529 v10:9.6171 v11:9.7147 v12:8.0835 v13:1.9556] eff_attn_bias:[v0:0.4696 v1:0.7237 v2:0.5331 v3:0.5994 v4:0.6298 v5:0.6740 v6:0.6436 v7:0.6270 v8:0.6381 v9:0.5911 v10:0.4972 v11:0.4502 v12:0.4806 v13:0.6657] eff_mlp_bias:[v0:1.1214 v1:0.7016 v2:0.5082 v3:0.5580 v4:0.5441 v5:0.6712 v6:0.4668 v7:0.4806 v8:0.5524 v9:0.5800 v10:0.4281 v11:0.4806 v12:0.6215 v13:0.5469] unique_attn_gain_rms:[u0:1.3717 u1:1.3705 u2:1.2671 u3:1.0829 u4:1.2937 u5:1.1748 u6:1.0799 u7:1.0235 u8:1.1141 u9:0.9850 u10:0.8832 u11:0.9246] unique_mlp_gain_rms:[u0:2.7944 u1:2.6928 u2:2.5489 u3:2.4825 u4:2.4040 u5:2.3458 u6:2.2842 u7:2.3386 u8:2.3766 u9:2.3288 u10:2.2772 u11:2.3199] +step:4600/20000 val_loss:2.2132 val_bpb:1.3108 train_time:284993ms step_avg:61.95ms +step:4800/20000 train_loss:2.3166 train_time:297316ms step_avg:61.94ms +step:4800 shared0_alpha:mean=0.480,std=0.096 shared1_alpha:mean=0.511,std=0.077 shared2_alpha:mean=0.571,std=0.070 shared3_alpha:mean=0.627,std=0.072 eff_mlp_scale:[v0:137.8871 v1:77.6688 v2:76.1969 v3:78.9299 v4:78.8753 v5:97.4478 v6:85.5001 v7:83.8352 v8:84.6466 v9:87.7995 v10:70.8809 v11:70.4572 v12:79.8372 v13:217.3005] eff_attn_scale:[v0:0.3863 v1:6.0970 v2:11.5036 v3:12.2153 v4:12.0765 v5:10.3328 v6:11.1529 v7:11.7153 v8:11.5879 v9:10.5895 v10:9.7500 v11:9.9294 v12:8.2372 v13:1.8977] eff_attn_bias:[v0:0.4696 v1:0.7292 v2:0.5441 v3:0.6049 v4:0.6408 v5:0.6850 v6:0.6546 v7:0.6408 v8:0.6491 v9:0.5994 v10:0.5055 v11:0.4585 v12:0.4889 v13:0.6822] eff_mlp_bias:[v0:1.1380 v1:0.7126 v2:0.5165 v3:0.5718 v4:0.5497 v5:0.6850 v6:0.4751 v7:0.4889 v8:0.5607 v9:0.5911 v10:0.4309 v11:0.4861 v12:0.6325 v13:0.5607] unique_attn_gain_rms:[u0:1.4009 u1:1.3858 u2:1.2862 u3:1.0979 u4:1.3185 u5:1.1915 u6:1.0946 u7:1.0362 u8:1.1346 u9:1.0005 u10:0.8987 u11:0.9478] unique_mlp_gain_rms:[u0:2.8565 u1:2.7512 u2:2.5991 u3:2.5357 u4:2.4530 u5:2.3905 u6:2.3283 u7:2.3849 u8:2.4287 u9:2.3774 u10:2.3271 u11:2.3677] +step:4800/20000 val_loss:2.2084 val_bpb:1.3080 train_time:297344ms step_avg:61.95ms +step:5000/20000 train_loss:2.0787 train_time:309662ms step_avg:61.93ms +step:5000 shared0_alpha:mean=0.480,std=0.097 shared1_alpha:mean=0.511,std=0.078 shared2_alpha:mean=0.572,std=0.070 shared3_alpha:mean=0.628,std=0.073 eff_mlp_scale:[v0:141.6967 v1:79.1546 v2:77.4670 v3:80.1774 v4:80.7125 v5:99.0647 v6:86.8164 v7:84.6566 v8:86.0609 v9:88.8669 v10:71.6792 v11:71.2190 v12:81.1987 v13:219.2956] eff_attn_scale:[v0:0.3852 v1:6.1282 v2:11.7371 v3:12.3515 v4:12.3996 v5:10.5148 v6:11.3129 v7:11.8488 v8:11.9037 v9:10.7728 v10:9.8281 v11:9.9817 v12:8.4318 v13:1.8831] eff_attn_bias:[v0:0.4696 v1:0.7513 v2:0.5552 v3:0.6160 v4:0.6519 v5:0.6961 v6:0.6602 v7:0.6491 v8:0.6546 v9:0.6132 v10:0.5110 v11:0.4668 v12:0.4944 v13:0.6988] eff_mlp_bias:[v0:1.1546 v1:0.7292 v2:0.5276 v3:0.5828 v4:0.5580 v5:0.6988 v6:0.4806 v7:0.4944 v8:0.5690 v9:0.6021 v10:0.4364 v11:0.4944 v12:0.6463 v13:0.5718] unique_attn_gain_rms:[u0:1.4199 u1:1.4008 u2:1.3045 u3:1.1149 u4:1.3403 u5:1.2110 u6:1.1140 u7:1.0530 u8:1.1559 u9:1.0148 u10:0.9110 u11:0.9678] unique_mlp_gain_rms:[u0:2.9158 u1:2.8039 u2:2.6532 u3:2.5809 u4:2.5035 u5:2.4372 u6:2.3799 u7:2.4356 u8:2.4785 u9:2.4280 u10:2.3736 u11:2.4180] +step:5000/20000 val_loss:2.2030 val_bpb:1.3047 train_time:309689ms step_avg:61.94ms +step:5200/20000 train_loss:2.2238 train_time:322022ms step_avg:61.93ms +step:5200 shared0_alpha:mean=0.478,std=0.098 shared1_alpha:mean=0.511,std=0.079 shared2_alpha:mean=0.573,std=0.071 shared3_alpha:mean=0.629,std=0.073 eff_mlp_scale:[v0:145.6966 v1:81.3124 v2:79.5202 v3:81.8880 v4:83.2467 v5:100.9058 v6:88.5055 v7:86.8646 v8:88.1725 v9:90.6193 v10:72.7812 v11:72.3872 v12:83.2467 v13:223.3147] eff_attn_scale:[v0:0.3740 v1:6.1749 v2:11.3039 v3:12.0549 v4:11.8619 v5:10.3661 v6:10.9571 v7:11.7004 v8:11.4433 v9:10.6861 v10:9.7088 v11:9.9276 v12:8.4080 v13:1.8482] eff_attn_bias:[v0:0.4723 v1:0.7623 v2:0.5635 v3:0.6298 v4:0.6657 v5:0.7071 v6:0.6684 v7:0.6574 v8:0.6684 v9:0.6242 v10:0.5220 v11:0.4751 v12:0.5027 v13:0.7126] eff_mlp_bias:[v0:1.1656 v1:0.7458 v2:0.5359 v3:0.5994 v4:0.5690 v5:0.7126 v6:0.4861 v7:0.5055 v8:0.5800 v9:0.6132 v10:0.4419 v11:0.5027 v12:0.6629 v13:0.5856] unique_attn_gain_rms:[u0:1.4373 u1:1.4296 u2:1.3219 u3:1.1425 u4:1.3558 u5:1.2361 u6:1.1312 u7:1.0849 u8:1.1715 u9:1.0401 u10:0.9319 u11:0.9887] unique_mlp_gain_rms:[u0:2.9715 u1:2.8623 u2:2.7058 u3:2.6423 u4:2.5561 u5:2.4881 u6:2.4305 u7:2.4941 u8:2.5364 u9:2.4813 u10:2.4253 u11:2.4686] +step:5200/20000 val_loss:2.2053 val_bpb:1.3061 train_time:322049ms step_avg:61.93ms +step:5400/20000 train_loss:2.2366 train_time:334380ms step_avg:61.92ms +step:5400 shared0_alpha:mean=0.478,std=0.099 shared1_alpha:mean=0.510,std=0.079 shared2_alpha:mean=0.573,std=0.072 shared3_alpha:mean=0.630,std=0.074 eff_mlp_scale:[v0:149.1297 v1:82.1096 v2:80.2128 v3:82.6239 v4:84.0222 v5:101.7766 v6:89.6762 v7:87.1636 v8:89.4912 v9:91.9431 v10:73.4533 v11:73.5443 v12:84.5194 v13:227.2429] eff_attn_scale:[v0:0.3687 v1:6.2892 v2:11.7455 v3:12.3596 v4:12.2988 v5:10.6333 v6:11.3209 v7:12.0704 v8:12.0128 v9:10.9575 v10:9.9766 v11:10.2635 v12:8.6163 v13:1.8325] eff_attn_bias:[v0:0.4751 v1:0.7789 v2:0.5745 v3:0.6381 v4:0.6712 v5:0.7182 v6:0.6795 v7:0.6684 v8:0.6767 v9:0.6353 v10:0.5303 v11:0.4806 v12:0.5110 v13:0.7292] eff_mlp_bias:[v0:1.1822 v1:0.7568 v2:0.5441 v3:0.6021 v4:0.5745 v5:0.7237 v6:0.4944 v7:0.5110 v8:0.5856 v9:0.6215 v10:0.4475 v11:0.5110 v12:0.6740 v13:0.5966] unique_attn_gain_rms:[u0:1.4576 u1:1.4435 u2:1.3385 u3:1.1618 u4:1.3778 u5:1.2532 u6:1.1501 u7:1.1034 u8:1.1909 u9:1.0525 u10:0.9465 u11:1.0065] unique_mlp_gain_rms:[u0:3.0255 u1:2.9185 u2:2.7557 u3:2.6931 u4:2.6037 u5:2.5332 u6:2.4789 u7:2.5429 u8:2.5850 u9:2.5313 u10:2.4731 u11:2.5166] +step:5400/20000 val_loss:2.1990 val_bpb:1.3023 train_time:334407ms step_avg:61.93ms +step:5600/20000 train_loss:2.2341 train_time:346731ms step_avg:61.92ms +step:5600 shared0_alpha:mean=0.477,std=0.099 shared1_alpha:mean=0.510,std=0.079 shared2_alpha:mean=0.574,std=0.073 shared3_alpha:mean=0.631,std=0.075 eff_mlp_scale:[v0:152.1452 v1:83.0236 v2:81.0667 v3:83.7987 v4:85.3349 v5:103.2853 v6:90.1244 v7:88.8358 v8:91.3586 v9:92.9073 v10:74.2734 v11:74.6404 v12:85.8369 v13:231.1595] eff_attn_scale:[v0:0.3572 v1:6.2418 v2:11.9137 v3:12.6021 v4:12.6455 v5:10.7281 v6:11.4856 v7:12.1651 v8:12.2822 v9:11.0532 v10:10.1302 v11:10.4168 v12:8.8301 v13:1.8288] eff_attn_bias:[v0:0.4751 v1:0.7900 v2:0.5856 v3:0.6491 v4:0.6767 v5:0.7347 v6:0.6878 v7:0.6767 v8:0.6850 v9:0.6463 v10:0.5386 v11:0.4889 v12:0.5165 v13:0.7513] eff_mlp_bias:[v0:1.2043 v1:0.7679 v2:0.5552 v3:0.6104 v4:0.5856 v5:0.7403 v6:0.4972 v7:0.5193 v8:0.5966 v9:0.6298 v10:0.4530 v11:0.5165 v12:0.6850 v13:0.6104] unique_attn_gain_rms:[u0:1.4772 u1:1.4635 u2:1.3610 u3:1.1762 u4:1.4024 u5:1.2676 u6:1.1690 u7:1.1181 u8:1.2071 u9:1.0684 u10:0.9591 u11:1.0268] unique_mlp_gain_rms:[u0:3.0811 u1:2.9725 u2:2.8074 u3:2.7391 u4:2.6543 u5:2.5747 u6:2.5222 u7:2.5844 u8:2.6345 u9:2.5765 u10:2.5220 u11:2.5625] +step:5600/20000 val_loss:2.1988 val_bpb:1.3023 train_time:346758ms step_avg:61.92ms +step:5800/20000 train_loss:2.1938 train_time:359078ms step_avg:61.91ms +step:5800 shared0_alpha:mean=0.476,std=0.100 shared1_alpha:mean=0.510,std=0.079 shared2_alpha:mean=0.574,std=0.073 shared3_alpha:mean=0.632,std=0.075 eff_mlp_scale:[v0:156.6907 v1:84.2915 v2:82.3189 v3:84.5036 v4:86.6795 v5:104.1248 v6:91.4149 v7:89.0962 v8:92.7622 v9:93.7123 v10:75.0421 v11:75.3184 v12:87.6933 v13:235.0589] eff_attn_scale:[v0:0.3544 v1:6.3447 v2:12.2031 v3:12.7438 v4:12.9624 v5:10.9800 v6:11.7699 v7:12.3755 v8:12.6661 v9:11.3745 v10:10.3257 v11:10.6812 v12:9.0737 v13:1.8250] eff_attn_bias:[v0:0.4806 v1:0.8065 v2:0.5939 v3:0.6574 v4:0.6878 v5:0.7403 v6:0.6933 v7:0.6822 v8:0.6933 v9:0.6546 v10:0.5441 v11:0.4917 v12:0.5220 v13:0.7679] eff_mlp_bias:[v0:1.2209 v1:0.7900 v2:0.5607 v3:0.6187 v4:0.5911 v5:0.7458 v6:0.5055 v7:0.5276 v8:0.6049 v9:0.6408 v10:0.4558 v11:0.5248 v12:0.6961 v13:0.6242] unique_attn_gain_rms:[u0:1.4981 u1:1.4749 u2:1.3758 u3:1.1877 u4:1.4253 u5:1.2824 u6:1.1852 u7:1.1320 u8:1.2292 u9:1.0856 u10:0.9781 u11:1.0493] unique_mlp_gain_rms:[u0:3.1413 u1:3.0286 u2:2.8593 u3:2.7881 u4:2.7010 u5:2.6224 u6:2.5673 u7:2.6332 u8:2.6853 u9:2.6242 u10:2.5689 u11:2.6112] +step:5800/20000 val_loss:2.1983 val_bpb:1.3020 train_time:359105ms step_avg:61.91ms +step:6000/20000 train_loss:2.2687 train_time:371433ms step_avg:61.91ms +step:6000 shared0_alpha:mean=0.476,std=0.100 shared1_alpha:mean=0.509,std=0.080 shared2_alpha:mean=0.574,std=0.074 shared3_alpha:mean=0.632,std=0.076 eff_mlp_scale:[v0:159.8100 v1:85.6167 v2:83.0541 v3:86.1219 v4:88.4901 v5:105.5276 v6:92.1810 v7:90.7521 v8:94.1166 v9:95.0744 v10:75.7527 v11:76.3985 v12:89.0016 v13:236.6819] eff_attn_scale:[v0:0.3456 v1:6.3618 v2:12.2533 v3:12.9017 v4:13.2976 v5:11.1413 v6:11.8908 v7:12.6052 v8:12.9220 v9:11.5369 v10:10.4407 v11:10.8256 v12:9.2407 v13:1.7937] eff_attn_bias:[v0:0.4806 v1:0.8231 v2:0.6021 v3:0.6657 v4:0.6961 v5:0.7513 v6:0.7016 v7:0.6905 v8:0.6988 v9:0.6629 v10:0.5497 v11:0.4972 v12:0.5331 v13:0.7789] eff_mlp_bias:[v0:1.2374 v1:0.8010 v2:0.5690 v3:0.6353 v4:0.6021 v5:0.7623 v6:0.5110 v7:0.5359 v8:0.6104 v9:0.6519 v10:0.4585 v11:0.5359 v12:0.7071 v13:0.6325] unique_attn_gain_rms:[u0:1.5151 u1:1.4887 u2:1.3938 u3:1.2038 u4:1.4516 u5:1.2958 u6:1.2047 u7:1.1455 u8:1.2432 u9:1.1023 u10:0.9914 u11:1.0676] unique_mlp_gain_rms:[u0:3.1931 u1:3.0809 u2:2.9068 u3:2.8374 u4:2.7484 u5:2.6655 u6:2.6152 u7:2.6796 u8:2.7365 u9:2.6706 u10:2.6161 u11:2.6590] +step:6000/20000 val_loss:2.1936 val_bpb:1.2992 train_time:371460ms step_avg:61.91ms +step:6200/20000 train_loss:2.1392 train_time:383784ms step_avg:61.90ms +step:6200 shared0_alpha:mean=0.475,std=0.100 shared1_alpha:mean=0.510,std=0.081 shared2_alpha:mean=0.576,std=0.074 shared3_alpha:mean=0.634,std=0.076 eff_mlp_scale:[v0:164.1944 v1:86.6975 v2:84.4264 v3:87.1937 v4:89.9855 v5:107.2443 v6:93.1443 v7:91.8564 v8:95.6743 v9:96.2192 v10:76.6261 v11:77.4019 v12:91.0199 v13:240.4917] eff_attn_scale:[v0:0.3448 v1:6.3944 v2:12.4395 v3:13.1625 v4:13.5897 v5:11.1323 v6:12.0736 v7:12.7885 v8:13.2080 v9:11.7287 v10:10.6102 v11:10.9937 v12:9.4670 v13:1.7426] eff_attn_bias:[v0:0.4834 v1:0.8286 v2:0.6077 v3:0.6740 v4:0.7043 v5:0.7623 v6:0.7071 v7:0.6988 v8:0.7043 v9:0.6740 v10:0.5607 v11:0.5082 v12:0.5414 v13:0.7955] eff_mlp_bias:[v0:1.2540 v1:0.8121 v2:0.5773 v3:0.6408 v4:0.6104 v5:0.7679 v6:0.5165 v7:0.5414 v8:0.6187 v9:0.6629 v10:0.4668 v11:0.5441 v12:0.7237 v13:0.6408] unique_attn_gain_rms:[u0:1.5398 u1:1.5061 u2:1.4085 u3:1.2153 u4:1.4738 u5:1.3090 u6:1.2202 u7:1.1589 u8:1.2614 u9:1.1121 u10:1.0058 u11:1.0896] unique_mlp_gain_rms:[u0:3.2494 u1:3.1277 u2:2.9577 u3:2.8821 u4:2.7982 u5:2.7074 u6:2.6625 u7:2.7265 u8:2.7854 u9:2.7182 u10:2.6633 u11:2.7033] +step:6200/20000 val_loss:2.1926 val_bpb:1.2986 train_time:383809ms step_avg:61.90ms +step:6400/20000 train_loss:2.2146 train_time:396138ms step_avg:61.90ms +step:6400 shared0_alpha:mean=0.474,std=0.101 shared1_alpha:mean=0.509,std=0.081 shared2_alpha:mean=0.576,std=0.075 shared3_alpha:mean=0.634,std=0.077 eff_mlp_scale:[v0:168.4754 v1:87.4411 v2:85.3122 v3:88.7401 v4:91.4439 v5:108.0450 v6:94.5352 v7:92.9658 v8:97.1918 v9:96.9893 v10:77.4727 v11:78.4106 v12:92.4890 v13:244.1227] eff_attn_scale:[v0:0.3373 v1:6.4145 v2:12.6493 v3:13.4060 v4:13.7854 v5:11.3591 v6:12.3534 v7:13.1030 v8:13.5531 v9:11.8936 v10:10.8000 v11:11.2853 v12:9.6421 v13:1.7888] eff_attn_bias:[v0:0.4806 v1:0.8452 v2:0.6215 v3:0.6795 v4:0.7126 v5:0.7734 v6:0.7182 v7:0.7071 v8:0.7126 v9:0.6822 v10:0.5635 v11:0.5138 v12:0.5469 v13:0.8121] eff_mlp_bias:[v0:1.2706 v1:0.8231 v2:0.5828 v3:0.6491 v4:0.6215 v5:0.7789 v6:0.5220 v7:0.5497 v8:0.6215 v9:0.6712 v10:0.4668 v11:0.5497 v12:0.7347 v13:0.6519] unique_attn_gain_rms:[u0:1.5643 u1:1.5210 u2:1.4240 u3:1.2310 u4:1.4963 u5:1.3244 u6:1.2365 u7:1.1714 u8:1.2799 u9:1.1277 u10:1.0217 u11:1.1039] unique_mlp_gain_rms:[u0:3.3013 u1:3.1841 u2:3.0065 u3:2.9308 u4:2.8447 u5:2.7518 u6:2.7072 u7:2.7717 u8:2.8369 u9:2.7683 u10:2.7098 u11:2.7512] +step:6400/20000 val_loss:2.1891 val_bpb:1.2965 train_time:396167ms step_avg:61.90ms +step:6600/20000 train_loss:2.1743 train_time:408498ms step_avg:61.89ms +step:6600 shared0_alpha:mean=0.473,std=0.101 shared1_alpha:mean=0.509,std=0.082 shared2_alpha:mean=0.577,std=0.074 shared3_alpha:mean=0.635,std=0.078 eff_mlp_scale:[v0:172.1954 v1:88.9605 v2:86.1402 v3:89.6068 v4:93.3375 v5:109.1788 v6:95.4026 v7:93.8514 v8:99.1382 v9:98.0587 v10:78.2672 v11:79.2313 v12:94.3922 v13:248.1197] eff_attn_scale:[v0:0.3346 v1:6.4810 v2:12.7923 v3:13.4097 v4:14.1682 v5:11.4845 v6:12.3460 v7:13.1811 v8:13.9320 v9:12.0218 v10:10.9329 v11:11.4287 v12:9.8390 v13:1.7762] eff_attn_bias:[v0:0.4834 v1:0.8563 v2:0.6325 v3:0.6878 v4:0.7182 v5:0.7844 v6:0.7237 v7:0.7126 v8:0.7237 v9:0.6933 v10:0.5718 v11:0.5220 v12:0.5552 v13:0.8286] eff_mlp_bias:[v0:1.2816 v1:0.8342 v2:0.5856 v3:0.6602 v4:0.6298 v5:0.7900 v6:0.5276 v7:0.5580 v8:0.6298 v9:0.6795 v10:0.4723 v11:0.5580 v12:0.7403 v13:0.6629] unique_attn_gain_rms:[u0:1.5795 u1:1.5348 u2:1.4393 u3:1.2462 u4:1.5124 u5:1.3393 u6:1.2494 u7:1.1897 u8:1.2918 u9:1.1413 u10:1.0360 u11:1.1267] unique_mlp_gain_rms:[u0:3.3556 u1:3.2387 u2:3.0583 u3:2.9810 u4:2.8922 u5:2.7976 u6:2.7566 u7:2.8165 u8:2.8828 u9:2.8118 u10:2.7573 u11:2.7965] +step:6600/20000 val_loss:2.1859 val_bpb:1.2946 train_time:408522ms step_avg:61.90ms +step:6800/20000 train_loss:2.2460 train_time:420849ms step_avg:61.89ms +step:6800 shared0_alpha:mean=0.473,std=0.101 shared1_alpha:mean=0.509,std=0.082 shared2_alpha:mean=0.577,std=0.075 shared3_alpha:mean=0.636,std=0.078 eff_mlp_scale:[v0:175.7576 v1:89.2385 v2:87.6380 v3:90.7489 v4:94.7037 v5:110.5340 v6:96.9612 v7:95.5002 v8:100.5562 v9:98.8722 v10:79.2471 v11:80.2962 v12:95.7678 v13:251.8846] eff_attn_scale:[v0:0.3236 v1:6.5613 v2:12.8235 v3:13.6790 v4:14.3835 v5:11.6345 v6:12.5253 v7:13.3716 v8:14.2237 v9:12.2433 v10:11.0342 v11:11.6041 v12:10.0684 v13:1.7567] eff_attn_bias:[v0:0.4861 v1:0.8673 v2:0.6381 v3:0.6988 v4:0.7292 v5:0.7900 v6:0.7347 v7:0.7237 v8:0.7237 v9:0.7043 v10:0.5800 v11:0.5276 v12:0.5635 v13:0.8397] eff_mlp_bias:[v0:1.2982 v1:0.8452 v2:0.5966 v3:0.6712 v4:0.6381 v5:0.8065 v6:0.5359 v7:0.5635 v8:0.6381 v9:0.6850 v10:0.4778 v11:0.5635 v12:0.7568 v13:0.6712] unique_attn_gain_rms:[u0:1.5950 u1:1.5456 u2:1.4570 u3:1.2573 u4:1.5334 u5:1.3510 u6:1.2657 u7:1.2014 u8:1.3119 u9:1.1569 u10:1.0482 u11:1.1406] unique_mlp_gain_rms:[u0:3.4106 u1:3.2862 u2:3.1047 u3:3.0241 u4:2.9368 u5:2.8393 u6:2.7989 u7:2.8607 u8:2.9268 u9:2.8573 u10:2.8017 u11:2.8419] +step:6800/20000 val_loss:2.1845 val_bpb:1.2938 train_time:420876ms step_avg:61.89ms +step:7000/20000 train_loss:2.2725 train_time:433207ms step_avg:61.89ms +step:7000 shared0_alpha:mean=0.472,std=0.102 shared1_alpha:mean=0.509,std=0.082 shared2_alpha:mean=0.577,std=0.076 shared3_alpha:mean=0.637,std=0.079 eff_mlp_scale:[v0:179.5586 v1:90.6377 v2:88.4449 v3:91.8288 v4:96.2682 v5:111.5150 v6:97.3362 v7:96.6116 v8:102.1841 v9:100.3125 v10:80.0216 v11:81.3068 v12:97.3438 v13:255.6471] eff_attn_scale:[v0:0.3125 v1:6.6138 v2:13.1198 v3:13.7300 v4:14.6850 v5:11.8640 v6:12.7428 v7:13.4986 v8:14.5228 v9:12.4095 v10:11.2348 v11:11.7245 v12:10.2633 v13:1.7393] eff_attn_bias:[v0:0.4889 v1:0.8784 v2:0.6463 v3:0.7071 v4:0.7347 v5:0.8010 v6:0.7403 v7:0.7292 v8:0.7347 v9:0.7071 v10:0.5856 v11:0.5359 v12:0.5690 v13:0.8507] eff_mlp_bias:[v0:1.3148 v1:0.8563 v2:0.5994 v3:0.6822 v4:0.6463 v5:0.8121 v6:0.5414 v7:0.5690 v8:0.6463 v9:0.6905 v10:0.4806 v11:0.5718 v12:0.7623 v13:0.6850] unique_attn_gain_rms:[u0:1.6117 u1:1.5626 u2:1.4747 u3:1.2698 u4:1.5506 u5:1.3692 u6:1.2782 u7:1.2178 u8:1.3277 u9:1.1745 u10:1.0625 u11:1.1591] unique_mlp_gain_rms:[u0:3.4631 u1:3.3378 u2:3.1530 u3:3.0719 u4:2.9796 u5:2.8845 u6:2.8426 u7:2.9050 u8:2.9752 u9:2.9029 u10:2.8464 u11:2.8879] +step:7000/20000 val_loss:2.1825 val_bpb:1.2926 train_time:433234ms step_avg:61.89ms +step:7200/20000 train_loss:2.2487 train_time:445564ms step_avg:61.88ms +step:7200 shared0_alpha:mean=0.472,std=0.102 shared1_alpha:mean=0.509,std=0.082 shared2_alpha:mean=0.578,std=0.076 shared3_alpha:mean=0.638,std=0.079 eff_mlp_scale:[v0:183.9828 v1:91.6182 v2:89.7220 v3:93.5406 v4:97.7726 v5:112.6034 v6:98.1775 v7:97.8801 v8:104.2907 v9:101.3431 v10:80.7968 v11:82.4508 v12:98.8589 v13:257.7057] eff_attn_scale:[v0:0.3197 v1:6.6603 v2:13.2257 v3:13.9266 v4:14.9696 v5:11.9473 v6:12.7723 v7:13.6154 v8:14.8051 v9:12.6339 v10:11.3363 v11:11.9037 v12:10.4458 v13:1.7230] eff_attn_bias:[v0:0.4889 v1:0.8894 v2:0.6519 v3:0.7126 v4:0.7458 v5:0.8065 v6:0.7513 v7:0.7403 v8:0.7458 v9:0.7182 v10:0.5939 v11:0.5414 v12:0.5745 v13:0.8618] eff_mlp_bias:[v0:1.3313 v1:0.8673 v2:0.6049 v3:0.6905 v4:0.6519 v5:0.8231 v6:0.5441 v7:0.5745 v8:0.6546 v9:0.6961 v10:0.4861 v11:0.5773 v12:0.7734 v13:0.6961] unique_attn_gain_rms:[u0:1.6319 u1:1.5758 u2:1.4859 u3:1.2811 u4:1.5661 u5:1.3845 u6:1.2903 u7:1.2283 u8:1.3423 u9:1.1874 u10:1.0780 u11:1.1767] unique_mlp_gain_rms:[u0:3.5196 u1:3.3851 u2:3.1983 u3:3.1180 u4:3.0253 u5:2.9268 u6:2.8831 u7:2.9505 u8:3.0211 u9:2.9476 u10:2.8899 u11:2.9325] +step:7200/20000 val_loss:2.1832 val_bpb:1.2930 train_time:445592ms step_avg:61.89ms +step:7400/20000 train_loss:2.1698 train_time:457919ms step_avg:61.88ms +step:7400 shared0_alpha:mean=0.471,std=0.102 shared1_alpha:mean=0.509,std=0.083 shared2_alpha:mean=0.579,std=0.076 shared3_alpha:mean=0.638,std=0.079 eff_mlp_scale:[v0:187.7943 v1:92.9347 v2:90.5776 v3:94.5608 v4:99.8069 v5:113.4728 v6:99.5411 v7:98.9251 v8:105.8391 v9:102.1768 v10:81.6142 v11:83.4074 v12:100.9036 v13:261.3064] eff_attn_scale:[v0:0.3086 v1:6.7327 v2:13.3446 v3:14.0796 v4:15.1177 v5:12.0844 v6:12.8897 v7:13.8449 v8:15.0347 v9:12.7750 v10:11.3732 v11:11.9677 v12:10.6322 v13:1.7028] eff_attn_bias:[v0:0.4889 v1:0.9060 v2:0.6602 v3:0.7237 v4:0.7568 v5:0.8176 v6:0.7568 v7:0.7458 v8:0.7513 v9:0.7292 v10:0.6021 v11:0.5497 v12:0.5856 v13:0.8784] eff_mlp_bias:[v0:1.3424 v1:0.8784 v2:0.6132 v3:0.7016 v4:0.6629 v5:0.8286 v6:0.5524 v7:0.5828 v8:0.6629 v9:0.7071 v10:0.4889 v11:0.5883 v12:0.7844 v13:0.7043] unique_attn_gain_rms:[u0:1.6479 u1:1.5901 u2:1.5021 u3:1.2935 u4:1.5854 u5:1.3947 u6:1.3048 u7:1.2410 u8:1.3560 u9:1.2033 u10:1.0880 u11:1.1950] unique_mlp_gain_rms:[u0:3.5719 u1:3.4327 u2:3.2459 u3:3.1618 u4:3.0700 u5:2.9709 u6:2.9283 u7:2.9952 u8:3.0658 u9:2.9931 u10:2.9350 u11:2.9783] +step:7400/20000 val_loss:2.1797 val_bpb:1.2910 train_time:457946ms step_avg:61.88ms +step:7600/20000 train_loss:2.0529 train_time:470266ms step_avg:61.88ms +step:7600 shared0_alpha:mean=0.471,std=0.103 shared1_alpha:mean=0.509,std=0.083 shared2_alpha:mean=0.579,std=0.077 shared3_alpha:mean=0.639,std=0.080 eff_mlp_scale:[v0:192.4914 v1:93.9263 v2:91.3707 v3:95.6566 v4:101.2036 v5:115.0855 v6:100.3657 v7:100.5370 v8:107.2869 v9:103.2157 v10:82.3757 v11:84.4316 v12:102.3097 v13:265.1308] eff_attn_scale:[v0:0.3119 v1:6.7562 v2:13.4544 v3:14.2901 v4:15.4300 v5:12.1957 v6:13.0721 v7:14.0533 v8:15.3457 v9:12.9580 v10:11.5432 v11:12.1585 v12:10.8769 v13:1.7004] eff_attn_bias:[v0:0.4917 v1:0.9170 v2:0.6657 v3:0.7292 v4:0.7679 v5:0.8286 v6:0.7623 v7:0.7568 v8:0.7568 v9:0.7347 v10:0.6077 v11:0.5552 v12:0.5883 v13:0.8894] eff_mlp_bias:[v0:1.3590 v1:0.8894 v2:0.6187 v3:0.7126 v4:0.6740 v5:0.8397 v6:0.5552 v7:0.5911 v8:0.6712 v9:0.7182 v10:0.4917 v11:0.5939 v12:0.7955 v13:0.7126] unique_attn_gain_rms:[u0:1.6619 u1:1.6026 u2:1.5186 u3:1.3121 u4:1.6028 u5:1.4079 u6:1.3219 u7:1.2535 u8:1.3715 u9:1.2160 u10:1.1040 u11:1.2105] unique_mlp_gain_rms:[u0:3.6205 u1:3.4839 u2:3.2931 u3:3.2055 u4:3.1132 u5:3.0137 u6:2.9695 u7:3.0407 u8:3.1141 u9:3.0357 u10:2.9793 u11:3.0222] +step:7600/20000 val_loss:2.1795 val_bpb:1.2908 train_time:470294ms step_avg:61.88ms +step:7800/20000 train_loss:2.2005 train_time:482617ms step_avg:61.87ms +step:7800 shared0_alpha:mean=0.470,std=0.104 shared1_alpha:mean=0.509,std=0.084 shared2_alpha:mean=0.580,std=0.077 shared3_alpha:mean=0.640,std=0.081 eff_mlp_scale:[v0:196.2220 v1:94.8053 v2:92.8046 v3:96.8025 v4:102.7437 v5:116.0458 v6:101.3712 v7:101.7164 v8:108.8859 v9:104.1304 v10:83.2862 v11:85.9921 v12:103.8604 v13:267.0926] eff_attn_scale:[v0:0.3047 v1:6.8068 v2:13.7498 v3:14.4345 v4:15.7633 v5:12.3569 v6:13.2863 v7:14.2750 v8:15.6777 v9:13.1248 v10:11.7414 v11:12.4408 v12:11.1371 v13:1.7278] eff_attn_bias:[v0:0.4944 v1:0.9281 v2:0.6712 v3:0.7347 v4:0.7734 v5:0.8397 v6:0.7679 v7:0.7623 v8:0.7623 v9:0.7458 v10:0.6132 v11:0.5607 v12:0.5966 v13:0.9005] eff_mlp_bias:[v0:1.3755 v1:0.9005 v2:0.6270 v3:0.7182 v4:0.6795 v5:0.8452 v6:0.5635 v7:0.5966 v8:0.6767 v9:0.7237 v10:0.4944 v11:0.5994 v12:0.8065 v13:0.7237] unique_attn_gain_rms:[u0:1.6815 u1:1.6140 u2:1.5338 u3:1.3236 u4:1.6178 u5:1.4204 u6:1.3365 u7:1.2672 u8:1.3845 u9:1.2255 u10:1.1181 u11:1.2271] unique_mlp_gain_rms:[u0:3.6749 u1:3.5302 u2:3.3378 u3:3.2452 u4:3.1569 u5:3.0550 u6:3.0151 u7:3.0836 u8:3.1582 u9:3.0799 u10:3.0229 u11:3.0662] +step:7800/20000 val_loss:2.1749 val_bpb:1.2881 train_time:482641ms step_avg:61.88ms +step:8000/20000 train_loss:2.1628 train_time:494965ms step_avg:61.87ms +step:8000 shared0_alpha:mean=0.469,std=0.104 shared1_alpha:mean=0.509,std=0.084 shared2_alpha:mean=0.580,std=0.078 shared3_alpha:mean=0.640,std=0.082 eff_mlp_scale:[v0:200.9914 v1:96.2641 v2:93.7161 v3:98.6042 v4:104.3431 v5:117.5983 v6:102.8008 v7:103.0637 v8:110.5473 v9:105.6303 v10:84.1532 v11:86.7123 v12:105.4711 v13:270.7445] eff_attn_scale:[v0:0.3076 v1:6.8133 v2:13.6782 v3:14.5951 v4:16.0399 v5:12.5085 v6:13.3691 v7:14.4347 v8:15.9532 v9:13.2073 v10:11.8235 v11:12.5101 v12:11.2713 v13:1.6949] eff_attn_bias:[v0:0.4944 v1:0.9391 v2:0.6822 v3:0.7458 v4:0.7844 v5:0.8452 v6:0.7734 v7:0.7679 v8:0.7679 v9:0.7513 v10:0.6160 v11:0.5662 v12:0.6021 v13:0.9115] eff_mlp_bias:[v0:1.3866 v1:0.9060 v2:0.6325 v3:0.7292 v4:0.6850 v5:0.8563 v6:0.5662 v7:0.6049 v8:0.6822 v9:0.7347 v10:0.4999 v11:0.6077 v12:0.8176 v13:0.7347] unique_attn_gain_rms:[u0:1.7004 u1:1.6233 u2:1.5463 u3:1.3391 u4:1.6370 u5:1.4320 u6:1.3502 u7:1.2816 u8:1.4021 u9:1.2370 u10:1.1296 u11:1.2467] unique_mlp_gain_rms:[u0:3.7251 u1:3.5767 u2:3.3832 u3:3.2946 u4:3.2056 u5:3.0994 u6:3.0573 u7:3.1303 u8:3.2057 u9:3.1216 u10:3.0647 u11:3.1120] +step:8000/20000 val_loss:2.1739 val_bpb:1.2875 train_time:494992ms step_avg:61.87ms +step:8200/20000 train_loss:2.2298 train_time:507316ms step_avg:61.87ms +step:8200 shared0_alpha:mean=0.468,std=0.104 shared1_alpha:mean=0.509,std=0.084 shared2_alpha:mean=0.581,std=0.078 shared3_alpha:mean=0.641,std=0.082 eff_mlp_scale:[v0:205.4538 v1:97.2557 v2:94.6590 v3:99.7208 v4:105.8586 v5:118.6938 v6:103.3081 v7:104.7068 v8:112.6882 v9:106.1447 v10:85.0490 v11:87.7543 v12:106.9968 v13:274.8280] eff_attn_scale:[v0:0.2969 v1:6.9014 v2:13.9446 v3:14.6836 v4:16.3361 v5:12.6056 v6:13.5551 v7:14.5223 v8:16.2483 v9:13.3803 v10:11.9970 v11:12.6667 v12:11.4177 v13:1.6807] eff_attn_bias:[v0:0.4999 v1:0.9502 v2:0.6933 v3:0.7568 v4:0.7900 v5:0.8563 v6:0.7789 v7:0.7734 v8:0.7789 v9:0.7623 v10:0.6215 v11:0.5745 v12:0.6104 v13:0.9281] eff_mlp_bias:[v0:1.3976 v1:0.9170 v2:0.6353 v3:0.7347 v4:0.6933 v5:0.8618 v6:0.5718 v7:0.6104 v8:0.6878 v9:0.7403 v10:0.5055 v11:0.6132 v12:0.8286 v13:0.7458] unique_attn_gain_rms:[u0:1.7164 u1:1.6340 u2:1.5565 u3:1.3485 u4:1.6526 u5:1.4448 u6:1.3613 u7:1.2925 u8:1.4144 u9:1.2497 u10:1.1410 u11:1.2605] unique_mlp_gain_rms:[u0:3.7774 u1:3.6214 u2:3.4318 u3:3.3357 u4:3.2537 u5:3.1386 u6:3.0972 u7:3.1722 u8:3.2516 u9:3.1654 u10:3.1066 u11:3.1525] +step:8200/20000 val_loss:2.1715 val_bpb:1.2861 train_time:507343ms step_avg:61.87ms +step:8400/20000 train_loss:2.1888 train_time:519732ms step_avg:61.87ms +step:8400 shared0_alpha:mean=0.467,std=0.105 shared1_alpha:mean=0.509,std=0.084 shared2_alpha:mean=0.581,std=0.078 shared3_alpha:mean=0.641,std=0.083 eff_mlp_scale:[v0:209.8242 v1:98.2429 v2:95.5665 v3:100.9536 v4:107.4673 v5:119.7828 v6:104.7370 v7:105.4739 v8:114.3636 v9:107.1741 v10:85.9133 v11:88.8994 v12:109.1914 v13:276.6565] eff_attn_scale:[v0:0.2981 v1:6.9949 v2:14.0907 v3:14.9323 v4:16.5971 v5:12.8536 v6:13.7758 v7:14.7691 v8:16.5079 v9:13.5637 v10:12.2014 v11:12.8924 v12:11.6894 v13:1.6841] eff_attn_bias:[v0:0.4972 v1:0.9612 v2:0.6988 v3:0.7623 v4:0.8010 v5:0.8618 v6:0.7844 v7:0.7844 v8:0.7844 v9:0.7679 v10:0.6298 v11:0.5800 v12:0.6160 v13:0.9391] eff_mlp_bias:[v0:1.4087 v1:0.9226 v2:0.6408 v3:0.7458 v4:0.7043 v5:0.8784 v6:0.5800 v7:0.6187 v8:0.6961 v9:0.7513 v10:0.5110 v11:0.6215 v12:0.8342 v13:0.7513] unique_attn_gain_rms:[u0:1.7364 u1:1.6470 u2:1.5733 u3:1.3620 u4:1.6639 u5:1.4561 u6:1.3773 u7:1.2999 u8:1.4339 u9:1.2647 u10:1.1560 u11:1.2756] unique_mlp_gain_rms:[u0:3.8261 u1:3.6667 u2:3.4787 u3:3.3780 u4:3.2950 u5:3.1773 u6:3.1398 u7:3.2127 u8:3.2987 u9:3.2069 u10:3.1517 u11:3.1949] +step:8400/20000 val_loss:2.1717 val_bpb:1.2862 train_time:519758ms step_avg:61.88ms +step:8600/20000 train_loss:2.1799 train_time:532083ms step_avg:61.87ms +step:8600 shared0_alpha:mean=0.467,std=0.105 shared1_alpha:mean=0.509,std=0.085 shared2_alpha:mean=0.581,std=0.079 shared3_alpha:mean=0.641,std=0.082 eff_mlp_scale:[v0:213.7842 v1:99.1448 v2:96.8693 v3:102.1643 v4:109.5035 v5:120.7668 v6:105.5876 v7:107.2219 v8:115.8767 v9:108.1100 v10:86.6981 v11:90.0260 v12:110.0829 v13:280.2598] eff_attn_scale:[v0:0.2972 v1:7.0551 v2:14.2450 v3:15.1629 v4:16.9693 v5:13.0359 v6:13.9285 v7:14.9981 v8:16.8785 v9:13.8238 v10:12.3457 v11:13.1027 v12:11.8876 v13:1.6699] eff_attn_bias:[v0:0.5027 v1:0.9723 v2:0.7043 v3:0.7734 v4:0.8065 v5:0.8728 v6:0.7900 v7:0.7900 v8:0.7900 v9:0.7734 v10:0.6325 v11:0.5828 v12:0.6215 v13:0.9447] eff_mlp_bias:[v0:1.4363 v1:0.9336 v2:0.6436 v3:0.7568 v4:0.7071 v5:0.8839 v6:0.5856 v7:0.6242 v8:0.6988 v9:0.7568 v10:0.5138 v11:0.6242 v12:0.8397 v13:0.7623] unique_attn_gain_rms:[u0:1.7511 u1:1.6527 u2:1.5882 u3:1.3700 u4:1.6801 u5:1.4715 u6:1.3872 u7:1.3116 u8:1.4430 u9:1.2765 u10:1.1703 u11:1.2919] unique_mlp_gain_rms:[u0:3.8749 u1:3.7170 u2:3.5225 u3:3.4271 u4:3.3362 u5:3.2167 u6:3.1806 u7:3.2540 u8:3.3394 u9:3.2481 u10:3.1937 u11:3.2389] +step:8600/20000 val_loss:2.1660 val_bpb:1.2828 train_time:532111ms step_avg:61.87ms +step:8800/20000 train_loss:2.1465 train_time:544430ms step_avg:61.87ms +step:8800 shared0_alpha:mean=0.466,std=0.105 shared1_alpha:mean=0.508,std=0.085 shared2_alpha:mean=0.581,std=0.079 shared3_alpha:mean=0.641,std=0.083 eff_mlp_scale:[v0:218.5015 v1:100.2315 v2:97.8654 v3:103.5712 v4:111.2234 v5:121.9748 v6:106.6295 v7:108.6732 v8:117.6627 v9:109.2470 v10:87.6406 v11:91.3263 v12:112.3942 v13:282.2262] eff_attn_scale:[v0:0.2898 v1:7.0479 v2:14.4904 v3:15.4139 v4:17.3578 v5:13.2284 v6:14.0901 v7:15.2472 v8:17.3578 v9:14.0236 v10:12.5690 v11:13.3309 v12:12.1874 v13:1.7146] eff_attn_bias:[v0:0.4972 v1:0.9778 v2:0.7126 v3:0.7789 v4:0.8121 v5:0.8784 v6:0.7955 v7:0.7955 v8:0.7955 v9:0.7789 v10:0.6408 v11:0.5883 v12:0.6270 v13:0.9557] eff_mlp_bias:[v0:1.4474 v1:0.9447 v2:0.6519 v3:0.7623 v4:0.7182 v5:0.8894 v6:0.5883 v7:0.6298 v8:0.7071 v9:0.7679 v10:0.5193 v11:0.6298 v12:0.8507 v13:0.7734] unique_attn_gain_rms:[u0:1.7622 u1:1.6593 u2:1.5995 u3:1.3829 u4:1.6976 u5:1.4740 u6:1.3976 u7:1.3237 u8:1.4531 u9:1.2800 u10:1.1787 u11:1.3060] unique_mlp_gain_rms:[u0:3.9091 u1:3.7562 u2:3.5547 u3:3.4607 u4:3.3689 u5:3.2478 u6:3.2131 u7:3.2863 u8:3.3722 u9:3.2798 u10:3.2248 u11:3.2692] +step:8800/20000 val_loss:2.1579 val_bpb:1.2780 train_time:544456ms step_avg:61.87ms +step:9000/20000 train_loss:2.0555 train_time:556791ms step_avg:61.87ms +step:9000 shared0_alpha:mean=0.466,std=0.104 shared1_alpha:mean=0.509,std=0.084 shared2_alpha:mean=0.582,std=0.079 shared3_alpha:mean=0.642,std=0.084 eff_mlp_scale:[v0:221.2928 v1:101.2735 v2:98.9843 v3:104.9388 v4:113.1252 v5:123.1273 v6:107.8047 v7:109.5684 v8:119.0480 v9:110.3349 v10:88.6939 v11:92.5930 v12:113.7175 v13:286.6594] eff_attn_scale:[v0:0.2904 v1:7.0265 v2:14.7348 v3:15.5765 v4:17.6180 v5:13.3977 v6:14.3300 v7:15.4923 v8:17.7117 v9:14.1986 v10:12.7918 v11:13.5558 v12:12.4638 v13:1.7337] eff_attn_bias:[v0:0.4972 v1:0.9833 v2:0.7126 v3:0.7844 v4:0.8121 v5:0.8839 v6:0.8010 v7:0.8010 v8:0.7955 v9:0.7789 v10:0.6436 v11:0.5939 v12:0.6298 v13:0.9612] eff_mlp_bias:[v0:1.4474 v1:0.9502 v2:0.6546 v3:0.7679 v4:0.7237 v5:0.8949 v6:0.5911 v7:0.6353 v8:0.7126 v9:0.7679 v10:0.5220 v11:0.6381 v12:0.8563 v13:0.7789] unique_attn_gain_rms:[u0:1.7678 u1:1.6594 u2:1.6085 u3:1.3857 u4:1.7035 u5:1.4802 u6:1.4055 u7:1.3284 u8:1.4574 u9:1.2878 u10:1.1864 u11:1.3180] unique_mlp_gain_rms:[u0:3.9347 u1:3.7792 u2:3.5799 u3:3.4799 u4:3.3955 u5:3.2696 u6:3.2356 u7:3.3103 u8:3.3994 u9:3.3038 u10:3.2485 u11:3.2922] +step:9000/20000 val_loss:2.1502 val_bpb:1.2735 train_time:556817ms step_avg:61.87ms +step:9200/20000 train_loss:2.1079 train_time:569139ms step_avg:61.86ms +step:9200 shared0_alpha:mean=0.466,std=0.104 shared1_alpha:mean=0.509,std=0.085 shared2_alpha:mean=0.582,std=0.079 shared3_alpha:mean=0.642,std=0.084 eff_mlp_scale:[v0:222.6093 v1:102.1077 v2:99.5407 v3:106.4178 v4:114.2671 v5:124.6789 v6:108.4107 v7:111.0898 v8:120.8479 v9:111.2437 v10:89.1924 v11:93.4400 v12:115.4636 v13:288.8534] eff_attn_scale:[v0:0.2901 v1:7.1061 v2:14.8157 v3:15.7078 v4:17.8935 v5:13.4796 v6:14.4901 v7:15.6229 v8:17.9882 v9:14.3587 v10:12.8620 v11:13.7550 v12:12.5917 v13:1.7671] eff_attn_bias:[v0:0.4972 v1:0.9833 v2:0.7182 v3:0.7844 v4:0.8176 v5:0.8894 v6:0.8065 v7:0.8010 v8:0.8010 v9:0.7844 v10:0.6436 v11:0.5966 v12:0.6298 v13:0.9612] eff_mlp_bias:[v0:1.4474 v1:0.9557 v2:0.6602 v3:0.7734 v4:0.7237 v5:0.9005 v6:0.5966 v7:0.6381 v8:0.7182 v9:0.7734 v10:0.5248 v11:0.6381 v12:0.8563 v13:0.7844] unique_attn_gain_rms:[u0:1.7707 u1:1.6634 u2:1.6115 u3:1.3922 u4:1.7037 u5:1.4839 u6:1.4061 u7:1.3326 u8:1.4630 u9:1.2865 u10:1.1879 u11:1.3235] unique_mlp_gain_rms:[u0:3.9502 u1:3.7919 u2:3.5952 u3:3.4972 u4:3.4104 u5:3.2866 u6:3.2495 u7:3.3245 u8:3.4162 u9:3.3178 u10:3.2630 u11:3.3066] +step:9200/20000 val_loss:2.1402 val_bpb:1.2676 train_time:569165ms step_avg:61.87ms +step:9400/20000 train_loss:2.1505 train_time:581492ms step_avg:61.86ms +step:9400 shared0_alpha:mean=0.467,std=0.104 shared1_alpha:mean=0.509,std=0.084 shared2_alpha:mean=0.582,std=0.079 shared3_alpha:mean=0.641,std=0.084 eff_mlp_scale:[v0:223.0637 v1:102.5551 v2:100.1531 v3:107.1373 v4:115.2038 v5:125.2251 v6:109.5734 v7:111.8409 v8:121.8386 v9:111.7310 v10:89.7411 v11:94.0718 v12:116.4101 v13:290.5843] eff_attn_scale:[v0:0.2871 v1:7.0568 v2:14.7478 v3:15.7777 v4:17.9724 v5:13.4554 v6:14.4219 v7:15.6924 v8:18.0675 v9:14.2598 v10:12.8737 v11:13.8162 v12:12.6472 v13:1.7714] eff_attn_bias:[v0:0.4944 v1:0.9833 v2:0.7182 v3:0.7844 v4:0.8231 v5:0.8894 v6:0.8065 v7:0.8010 v8:0.8010 v9:0.7844 v10:0.6463 v11:0.5966 v12:0.6325 v13:0.9667] eff_mlp_bias:[v0:1.4474 v1:0.9557 v2:0.6629 v3:0.7734 v4:0.7292 v5:0.9005 v6:0.5966 v7:0.6408 v8:0.7182 v9:0.7734 v10:0.5248 v11:0.6381 v12:0.8618 v13:0.7900] unique_attn_gain_rms:[u0:1.7682 u1:1.6661 u2:1.6120 u3:1.3908 u4:1.7038 u5:1.4812 u6:1.4088 u7:1.3334 u8:1.4590 u9:1.2860 u10:1.1880 u11:1.3229] unique_mlp_gain_rms:[u0:3.9580 u1:3.7982 u2:3.6017 u3:3.5009 u4:3.4170 u5:3.2937 u6:3.2562 u7:3.3334 u8:3.4238 u9:3.3244 u10:3.2703 u11:3.3143] +step:9400/20000 val_loss:2.1313 val_bpb:1.2623 train_time:581519ms step_avg:61.86ms +step:9600/20000 train_loss:2.1575 train_time:593944ms step_avg:61.87ms +step:9600 shared0_alpha:mean=0.467,std=0.104 shared1_alpha:mean=0.509,std=0.084 shared2_alpha:mean=0.582,std=0.079 shared3_alpha:mean=0.641,std=0.084 eff_mlp_scale:[v0:223.2951 v1:102.9151 v2:100.4668 v3:107.6653 v4:115.7006 v5:125.6647 v6:109.9167 v7:112.3921 v8:122.3640 v9:112.1233 v10:90.0223 v11:94.5354 v12:116.9122 v13:291.7530] eff_attn_scale:[v0:0.2903 v1:7.0599 v2:14.6930 v3:15.7568 v4:17.9826 v5:13.3883 v6:14.4495 v7:15.6717 v8:18.0778 v9:14.2662 v10:12.8260 v11:13.7127 v12:12.6544 v13:1.7743] eff_attn_bias:[v0:0.4944 v1:0.9833 v2:0.7237 v3:0.7844 v4:0.8176 v5:0.8894 v6:0.8065 v7:0.8010 v8:0.8010 v9:0.7844 v10:0.6463 v11:0.5966 v12:0.6325 v13:0.9667] eff_mlp_bias:[v0:1.4474 v1:0.9557 v2:0.6629 v3:0.7734 v4:0.7292 v5:0.9005 v6:0.5966 v7:0.6408 v8:0.7182 v9:0.7734 v10:0.5248 v11:0.6408 v12:0.8618 v13:0.7900] unique_attn_gain_rms:[u0:1.7658 u1:1.6631 u2:1.6109 u3:1.3939 u4:1.7024 u5:1.4798 u6:1.4090 u7:1.3330 u8:1.4604 u9:1.2835 u10:1.1884 u11:1.3219] unique_mlp_gain_rms:[u0:3.9603 u1:3.8007 u2:3.6025 u3:3.5026 u4:3.4203 u5:3.2961 u6:3.2579 u7:3.3354 u8:3.4261 u9:3.3264 u10:3.2734 u11:3.3164] +step:9600/20000 val_loss:2.1234 val_bpb:1.2576 train_time:593971ms step_avg:61.87ms +step:9699/20000 val_loss:2.1208 val_bpb:1.2560 train_time:600048ms step_avg:61.87ms +stopping_early: wallclock_cap train_time:600048ms step:9699/20000 +peak memory allocated: 14443 MiB reserved: 14588 MiB +Serialized model: 45258046 bytes +Code size: 63793 bytes +Total submission size: 45321839 bytes +Serialized model int8+zlib: 10780366 bytes (payload:11716800 raw_torch:11749167 payload_ratio:3.86x) +Total submission size int8+zlib: 10844159 bytes +final_int8_zlib_roundtrip val_loss:2.1324 val_bpb:1.2629 eval_time:1887ms +final_int8_zlib_roundtrip_exact val_loss:2.13243085 val_bpb:1.26294566 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md index 215c3fce5f..4cd3b47ebb 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md @@ -185,3 +185,54 @@ This leaves ~4.8MB headroom (from an estimated ~11.2MB artifact) for SOTA featur > Csordás, R., Irie, K., Schmidhuber, J., Potts, C. & Manning, C. (2024). "MoEUT: Mixture-of-Experts Universal Transformers." NeurIPS 2024. [arXiv:2405.16039](https://arxiv.org/abs/2405.16039) > Ben-Zaken, E., Goldberg, Y. & Ravfogel, S. (2022). "BitFit: Simple Parameter-efficient Fine-tuning for Transformer-based Masked Language-models." ACL 2022. [arXiv:2106.10199](https://arxiv.org/abs/2106.10199) > Bae, S., Ko, J., Song, H. & Yun, S.-Y. (2025). "Relaxed Recursive Transformers: Effective Parameter Sharing with Layer-wise LoRA." ICLR 2025. [arXiv:2410.20672](https://arxiv.org/abs/2410.20672) + +## 10. Series 4 Results: Learned Depth Embeddings and Unique Input Norms + +**Hypothesis.** Per-iteration MLP differentiation is critical (s3_L: −0.026 BPB), but unique MLPs are too expensive (~4MB). Cheap per-iteration input controls — learned depth embeddings and unique input norms — should recover much of the benefit at ~110KB total cost (§9). + +**Experiment.** Four runs tested combinations of learned depth embeddings (512-dim vectors added to input before attention and MLP, initialized to zeros) and unique input norms (per-iteration RMSNorm parameters for attention and MLP inputs), all with FiLM bias enabled: + +| Run | Config | Eff. Layers | Pre-Q BPB | Post-Q BPB | Q-Gap | +|-----|--------|-------------|-----------|------------|-------| +| P | 1+4×2+1 learned depth+norms+bias | 10 | 1.2579 | 1.2663 | +0.0084 | +| Q | 1+4×3+1 learned depth+norms+bias | 14 | 1.2574 | 1.2643 | +0.0069 | +| R | 1+4×3+1 learned depth only+bias | 14 | 1.2566 | 1.2639 | +0.0073 | +| S | 1+4×3+1 norms only+bias | 14 | 1.2560 | 1.2629 | +0.0069 | + +**Baseline comparison.** FiLM bias alone (s3_O: 1.2625, Q-gap +0.0078) outperforms all Series 4 runs on post-Q BPB. The best Series 4 result (Run S: 1.2629) is +0.0004 worse than s3_O despite having additional parameters. + +**Negative result: neither technique improves over FiLM bias alone.** This was unexpected given the §9 analysis. Three factors explain the failure: + +### 10a. Learned Depth Embeddings Remained Near Zero + +Depth embeddings were initialized to zeros and learned during training. After full-scale training (600s, 8×H100), embedding RMS values were 0.006–0.010 — essentially still near initialization. The embeddings did not learn meaningful per-iteration identity within the training budget. This contrasts with the theoretical expectation: Xu & Sato (2025, Theorem 4.2) prove that timestep encoding closes the expressivity gap, but their analysis assumes the encodings carry sufficient signal. Near-zero learned embeddings provide negligible signal. + +The root cause is likely the interaction of initialization (zeros), learning rate, and training duration. With ~10k optimization steps and a cosine-decayed learning rate, small gradients on the depth embeddings never accumulated enough to produce meaningful values. This is a fundamental limitation of learned embeddings in short-training regimes. + +### 10b. Throughput Penalty as Primary Harm Mechanism + +The additional per-iteration parameters introduced 6–15% throughput overhead, costing 600–1700 training steps compared to the FiLM-bias-only baseline. In a wallclock-capped competition (600s), fewer steps means less training. The marginal specialization benefit of near-zero depth embeddings and barely-differentiated norms was overwhelmed by the lost training steps. + +### 10c. Unique Input Norms Failed to Differentiate + +Unique per-iteration RMSNorm parameters were expected to control what the shared MLP sees at each iteration (§9, "control the input, not the weights"). In practice, the MLP gains barely moved from 1.0 across iterations, indicating that the norms did not learn meaningfully different scaling. The Output-LN architecture already provides magnitude differentiation by letting the MLP see unnormalized inputs (§3). Adding per-iteration input norms on top of this provided no additional differentiation — the mechanism was redundant with Output-LN. + +### 10d. Positive Finding: Q-Gap Improvement + +Despite hurting BPB, the additional per-iteration parameters did reduce quantization gap: Q-gap 0.0069–0.0073 across Series 4 runs, vs 0.0078 for FiLM bias alone (s3_O). The extra FP16 passthrough parameters provide more degrees of freedom that survive int8 quantization. However, this Q-gap benefit is insufficient to overcome the BPB regression from throughput loss. + +### 10e. Implication: The MLP Needs Different Weights, Not Different Inputs + +The 0.026 BPB gap between full sharing (s2_I: 1.2668) and unique MLPs (s3_L: 1.2406) cannot be closed by cheap per-iteration input controls. Runs P–S demonstrate that even with depth embeddings, unique norms, and FiLM conditioning combined, the shared MLP produces nearly identical outputs across iterations. The MLP genuinely needs different weights per iteration — not just different inputs — to achieve the specialization that s3_L demonstrated. + +### 10f. Next Test: Sinusoidal Depth Encodings + +The natural next experiment is replacing learned depth embeddings with **sinusoidal depth encodings** following the Universal Transformer (Dehghani et al., 2019, §2.1). Sinusoidal encodings address all three failure modes of learned embeddings: + +1. **Full-strength from step 0:** Fixed sinusoidal patterns provide immediate per-iteration identity without needing to be learned. No dependence on learning rate, initialization, or training duration. +2. **Zero parameter cost:** Computed analytically, not stored. Zero artifact overhead. +3. **Zero throughput overhead:** No additional parameters to backpropagate through. No training step penalty. + +The Universal Transformer adds sinusoidal timestep embeddings $T_t$ at each recurrence step, where $T_t$ uses the same sinusoidal formula as positional encodings but indexed by iteration count rather than sequence position. This provides orthogonal identity signals across iterations with bounded magnitude. + +> Dehghani, M., Gouws, S., Vinyals, O., Uszkoreit, J. & Kaiser, L. (2019). "Universal Transformers." ICLR 2019. [arXiv:1807.03819](https://arxiv.org/abs/1807.03819) diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale3.sh b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale3.sh new file mode 100755 index 0000000000..a09917de50 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale3.sh @@ -0,0 +1,61 @@ +#!/bin/bash +set -uo pipefail + +SCRIPT="train_gpt.py" +NGPU=${NGPU:-8} +COMMON="SEED=1337 MAX_WALLCLOCK_SECONDS=600 VAL_LOSS_EVERY=200 TRAIN_LOG_EVERY=200" +DATA="DATA_PATH=${DATA_PATH:-./data/datasets/fineweb10B_sp1024} TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model VOCAB_SIZE=1024" + +FAILS=0 +SUMMARY="" + +run_experiment() { + local name="$1"; shift + echo "" + echo "=== $name ===" + if "$@"; then + SUMMARY="${SUMMARY} PASS $name"$'\n' + else + SUMMARY="${SUMMARY} FAIL $name (exit $?)"$'\n' + FAILS=$((FAILS + 1)) + fi +} + +# --- P: 1+4×2+1, full passthrough stack (depth embed + unique norms), 2 loops (compare vs s3_N) --- + +run_experiment "Run P: 1+4x2+1 full passthrough stack (2 loops)" \ + env $COMMON $DATA RUN_ID=s4_P NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=2 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + USE_TIMESTEP_BIAS=1 USE_DEPTH_EMBED=1 USE_UNIQUE_NORMS=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +# --- Q: 1+4×3+1, full passthrough stack (depth embed + unique norms), 3 loops (compare vs s3_O) --- + +run_experiment "Run Q: 1+4x3+1 full passthrough stack (3 loops)" \ + env $COMMON $DATA RUN_ID=s4_Q NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=3 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + USE_TIMESTEP_BIAS=1 USE_DEPTH_EMBED=1 USE_UNIQUE_NORMS=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +# --- R: 1+4×3+1, depth embed only (no unique norms), 3 loops (isolate depth embed, compare vs s3_O) --- + +run_experiment "Run R: 1+4x3+1 depth embed only (3 loops)" \ + env $COMMON $DATA RUN_ID=s4_R NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=3 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + USE_TIMESTEP_BIAS=1 USE_DEPTH_EMBED=1 USE_UNIQUE_NORMS=0 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +# --- S: 1+4×3+1, unique norms only (no depth embed), 3 loops (isolate unique norms, compare vs s3_O) --- + +run_experiment "Run S: 1+4x3+1 unique norms only (3 loops)" \ + env $COMMON $DATA RUN_ID=s4_S NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=3 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + USE_TIMESTEP_BIAS=1 USE_DEPTH_EMBED=0 USE_UNIQUE_NORMS=1 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +echo "" +echo "===============================" +echo " FULL-SCALE 3 SUMMARY" +echo "===============================" +echo "$SUMMARY" +echo "$FAILS run(s) failed." diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py index 4036333b31..1597f32bd0 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py @@ -91,7 +91,8 @@ class Hyperparameters: leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) - share_attn_only = bool(int(os.environ.get("SHARE_ATTN_ONLY", "0"))) + use_depth_embed = bool(int(os.environ.get("USE_DEPTH_EMBED", "0"))) + use_unique_norms = bool(int(os.environ.get("USE_UNIQUE_NORMS", "0"))) disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) # Optimizer hyperparameters. @@ -313,7 +314,7 @@ def eval_val( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta,depth_embed,unique_attn_gain,unique_mlp_gain", ).split(",") if pattern ) @@ -674,77 +675,6 @@ def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: return ag, mg, ab, mb -class SharedAttnLayer(nn.Module): - """Shared attention layer (mixing + attention only, no MLP) for attn-only sharing mode.""" - def __init__( - self, - dim: int, - num_heads: int, - num_kv_heads: int, - rope_base: float, - qk_gain_init: float, - use_birkhoff_mix: bool = False, - ): - super().__init__() - self.use_birkhoff_mix = use_birkhoff_mix - self.attn_norm = RMSNorm() - self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) - self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - if use_birkhoff_mix: - self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) - else: - self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - - def forward(self, x: Tensor, x0: Tensor, - ts_attn_gamma: Tensor | None = None, - ts_attn_beta: Tensor | None = None) -> Tensor: - if self.use_birkhoff_mix: - alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] - x = alpha * x + (1 - alpha) * x0 - else: - mix = self.resid_mix.to(dtype=x.dtype) - x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) - attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] - if ts_attn_gamma is not None: - attn_s = attn_s * ts_attn_gamma[None, None, :] - x = x + attn_s * attn_out - if ts_attn_beta is not None: - x = x + ts_attn_beta[None, None, :] - return x - - -class UniqueMLP(nn.Module): - """Unique MLP per virtual shared position for attn-only sharing mode.""" - def __init__( - self, - dim: int, - mlp_mult: int, - use_peri_norm: bool = False, - leaky_relu_slope: float = 0.5, - ): - super().__init__() - self.use_peri_norm = use_peri_norm - if use_peri_norm: - self.mlp_out_norm = RMSNorm() - else: - self.mlp_norm = RMSNorm() - self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) - self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) - - def forward(self, x: Tensor, - ts_mlp_gamma: Tensor | None = None, - ts_mlp_beta: Tensor | None = None) -> Tensor: - mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) - mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] - if ts_mlp_gamma is not None: - mlp_s = mlp_s * ts_mlp_gamma[None, None, :] - out = mlp_s * mlp_out - if ts_mlp_beta is not None: - out = out + ts_mlp_beta[None, None, :] - return out - - class Block(nn.Module): def __init__( self, @@ -779,21 +709,39 @@ def forward(self, x: Tensor, x0: Tensor, ts_attn_gamma: Tensor | None = None, ts_mlp_gamma: Tensor | None = None, ts_attn_beta: Tensor | None = None, - ts_mlp_beta: Tensor | None = None) -> Tensor: + ts_mlp_beta: Tensor | None = None, + depth_emb: Tensor | None = None, + ext_attn_gain: Tensor | None = None, + ext_mlp_gain: Tensor | None = None) -> Tensor: if self.use_birkhoff_mix: alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] x = alpha * x + (1 - alpha) * x0 else: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) + if depth_emb is not None: + x = x + depth_emb + attn_normed = self.attn_norm(x) + if ext_attn_gain is not None: + attn_normed = attn_normed * ext_attn_gain[None, None, :] + attn_out = self.attn(attn_normed) attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] if ts_attn_gamma is not None: attn_s = attn_s * ts_attn_gamma[None, None, :] x = x + attn_s * attn_out if ts_attn_beta is not None: x = x + ts_attn_beta[None, None, :] - mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + if self.use_peri_norm: + if ext_mlp_gain is not None: + m_input = F.rms_norm(x, (x.size(-1),)) * ext_mlp_gain[None, None, :] + else: + m_input = x + mlp_out = self.mlp_out_norm(self.mlp(m_input)) + else: + mlp_normed = self.mlp_norm(x) + if ext_mlp_gain is not None: + mlp_normed = mlp_normed * ext_mlp_gain[None, None, :] + mlp_out = self.mlp(mlp_normed) mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] if ts_mlp_gamma is not None: mlp_s = mlp_s * ts_mlp_gamma[None, None, :] @@ -825,7 +773,8 @@ def __init__( use_birkhoff_mix: bool = False, use_timestep_scale: bool = False, use_timestep_bias: bool = False, - share_attn_only: bool = False, + use_depth_embed: bool = False, + use_unique_norms: bool = False, timestep_gamma_max: float = 0.0, leaky_relu_slope: float = 0.5, ): @@ -849,31 +798,20 @@ def __init__( leaky_relu_slope=leaky_relu_slope, ) - self.share_attn_only = share_attn_only if self.use_recurrence else False - if self.use_recurrence: self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) - if self.share_attn_only: - shared_attn_kwargs = dict( - dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, - rope_base=rope_base, qk_gain_init=qk_gain_init, - use_birkhoff_mix=use_birkhoff_mix, - ) - unique_mlp_kwargs = dict( - dim=model_dim, mlp_mult=mlp_mult, - use_peri_norm=use_peri_norm, - leaky_relu_slope=leaky_relu_slope, - ) - self.shared_attn_layers = nn.ModuleList([SharedAttnLayer(**shared_attn_kwargs) for _ in range(num_shared)]) - self.unique_mlps = nn.ModuleList([UniqueMLP(**unique_mlp_kwargs) for _ in range(num_shared * self.num_loops)]) - self.shared_blocks = nn.ModuleList() # empty — keeps diagnostics safe - else: - self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) - self.shared_attn_layers = nn.ModuleList() - self.unique_mlps = nn.ModuleList() + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None + self.use_depth_embed = use_depth_embed + if self.use_depth_embed: + self.depth_embeddings = nn.Parameter(torch.zeros(effective_layers, model_dim, dtype=torch.float32)) + self.use_unique_norms = use_unique_norms + if self.use_unique_norms: + num_unique = num_shared * self.num_loops + self.unique_attn_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + self.unique_mlp_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) else: # Standard U-Net path self.num_encoder_layers = num_layers // 2 @@ -882,6 +820,8 @@ def __init__( self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) self.timestep_scale = None + self.use_depth_embed = False + self.use_unique_norms = False self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) @@ -910,26 +850,26 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: v = 0 for block in self.prelude_blocks: ag, mg, ab, mb = self._get_ts(v) - x = block(x, x0, ag, mg, ab, mb) + de = self.depth_embeddings[v] if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) v += 1 - if self.share_attn_only: - vid = 0 - for _loop in range(self.num_loops): - for attn_layer in self.shared_attn_layers: - ag, mg, ab, mb = self._get_ts(v) - x = attn_layer(x, x0, ag, ab) - x = x + self.unique_mlps[vid](x, mg, mb) - vid += 1 - v += 1 - else: - for _loop in range(self.num_loops): - for block in self.shared_blocks: - ag, mg, ab, mb = self._get_ts(v) - x = block(x, x0, ag, mg, ab, mb) - v += 1 + uid = 0 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_embeddings[v] if self.use_depth_embed else None + if self.use_unique_norms: + ag_n = self.unique_attn_gains[uid].to(dtype=x.dtype) + mg_n = self.unique_mlp_gains[uid].to(dtype=x.dtype) + x = block(x, x0, ag, mg, ab, mb, de, ag_n, mg_n) + else: + x = block(x, x0, ag, mg, ab, mb, de) + uid += 1 + v += 1 for block in self.coda_blocks: ag, mg, ab, mb = self._get_ts(v) - x = block(x, x0, ag, mg, ab, mb) + de = self.depth_embeddings[v] if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) v += 1 else: skips: list[Tensor] = [] @@ -960,23 +900,17 @@ def recurrence_param_diagnostics(gpt: GPT) -> str: return "" parts: list[str] = [] - # Birkhoff alpha stats per shared block/layer - if gpt.share_attn_only: - for i, layer in enumerate(gpt.shared_attn_layers): - if hasattr(layer, "resid_mix_logit"): - a = torch.sigmoid(layer.resid_mix_logit) - parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") - else: - for i, block in enumerate(gpt.shared_blocks): - if hasattr(block, "resid_mix_logit"): - a = torch.sigmoid(block.resid_mix_logit) - parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") # Effective MLP/attn contribution scale per virtual layer position: # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. # We report RMS (norm / sqrt(numel)) to make it scale-independent. v = 0 - effective_count = gpt.num_prelude + len(gpt.shared_blocks if not gpt.share_attn_only else gpt.shared_attn_layers) * gpt.num_loops + gpt.num_coda + effective_count = gpt.num_prelude + len(gpt.shared_blocks) * gpt.num_loops + gpt.num_coda mlp_norms: list[str] = [] attn_norms: list[str] = [] @@ -996,38 +930,20 @@ def recurrence_param_diagnostics(gpt: GPT) -> str: v += 1 # Shared positions - if gpt.share_attn_only: - vid = 0 - for _loop in range(gpt.num_loops): - for layer in gpt.shared_attn_layers: - asc = layer.attn_scale.norm().item() - ms = gpt.unique_mlps[vid].mlp_scale.norm().item() - d = layer.attn_scale.numel() ** 0.5 - if gpt.timestep_scale is not None: - ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() - as_g = gpt.timestep_scale.attn_gamma[v].norm().item() - mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") - attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") - else: - mlp_norms.append(f"v{v}:{ms / d:.4f}") - attn_norms.append(f"v{v}:{asc / d:.4f}") - vid += 1 - v += 1 - else: - for _loop in range(gpt.num_loops): - for block in gpt.shared_blocks: - ms = block.mlp_scale.norm().item() - asc = block.attn_scale.norm().item() - d = block.mlp_scale.numel() ** 0.5 - if gpt.timestep_scale is not None: - ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() - as_g = gpt.timestep_scale.attn_gamma[v].norm().item() - mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") - attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") - else: - mlp_norms.append(f"v{v}:{ms / d:.4f}") - attn_norms.append(f"v{v}:{asc / d:.4f}") - v += 1 + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 # Coda blocks for block in gpt.coda_blocks: @@ -1056,6 +972,22 @@ def recurrence_param_diagnostics(gpt: GPT) -> str: mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") + if gpt.use_unique_norms: + un_attn: list[str] = [] + un_mlp: list[str] = [] + for ui in range(gpt.unique_attn_gains.size(0)): + an_rms = gpt.unique_attn_gains[ui].norm().item() / gpt.unique_attn_gains[ui].numel() ** 0.5 + un_attn.append(f"u{ui}:{an_rms:.4f}") + mn_rms = gpt.unique_mlp_gains[ui].norm().item() / gpt.unique_mlp_gains[ui].numel() ** 0.5 + un_mlp.append(f"u{ui}:{mn_rms:.4f}") + parts.append("unique_attn_gain_rms:[" + " ".join(un_attn) + "]") + parts.append("unique_mlp_gain_rms:[" + " ".join(un_mlp) + "]") + if gpt.use_depth_embed: + de_norms: list[str] = [] + for vi in range(effective_count): + de_rms = gpt.depth_embeddings[vi].norm().item() / gpt.depth_embeddings[vi].numel() ** 0.5 + de_norms.append(f"v{vi}:{de_rms:.4f}") + parts.append("depth_emb_rms:[" + " ".join(de_norms) + "]") return " ".join(parts) @@ -1178,7 +1110,8 @@ def log0(msg: str, console: bool = True) -> None: use_birkhoff_mix=args.use_birkhoff_mix, use_timestep_scale=args.use_timestep_scale, use_timestep_bias=args.use_timestep_bias, - share_attn_only=args.share_attn_only, + use_depth_embed=args.use_depth_embed, + use_unique_norms=args.use_unique_norms, timestep_gamma_max=args.timestep_gamma_max, leaky_relu_slope=args.leaky_relu_slope, ).to(device).bfloat16() @@ -1202,9 +1135,9 @@ def log0(msg: str, console: bool = True) -> None: block_named_params = [] for bl in all_block_lists: block_named_params.extend(bl.named_parameters()) - if base_model.share_attn_only: - block_named_params.extend(base_model.shared_attn_layers.named_parameters()) - block_named_params.extend(base_model.unique_mlps.named_parameters()) + if base_model.use_unique_norms: + block_named_params.extend([("unique_attn_gains", base_model.unique_attn_gains)]) + block_named_params.extend([("unique_mlp_gains", base_model.unique_mlp_gains)]) else: block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ @@ -1222,6 +1155,8 @@ def log0(msg: str, console: bool = True) -> None: if base_model.timestep_scale is not None: for p in base_model.timestep_scale.parameters(): scalar_params.append(p) + if base_model.use_depth_embed: + scalar_params.append(base_model.depth_embeddings) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr optimizer_tok = torch.optim.Adam( [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], @@ -1271,10 +1206,9 @@ def log0(msg: str, console: bool = True) -> None: ) log0(f"seed:{args.seed}") if base_model.use_recurrence: - num_shared = len(base_model.shared_attn_layers) if base_model.share_attn_only else len(base_model.shared_blocks) + num_shared = len(base_model.shared_blocks) eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda - shared_label = f"shared_attn:{num_shared}" if base_model.share_attn_only else f"shared:{num_shared}" - log0(f"recurrence:enabled prelude:{base_model.num_prelude} {shared_label} " + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{num_shared} " f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") if base_model.timestep_scale is not None: @@ -1282,6 +1216,8 @@ def log0(msg: str, console: bool = True) -> None: log0(f"timestep_scale:enabled params:{ts_params}") else: log0("timestep_scale:disabled") + log0(f"depth_embed:{'enabled' if base_model.use_depth_embed else 'disabled'}") + log0(f"unique_norms:{'enabled' if base_model.use_unique_norms else 'disabled'}") else: log0(f"recurrence:disabled num_layers:{args.num_layers}") compile_mode = "disabled" if args.disable_compile else "fullgraph=True" From 5e31104282abe6b33a237d02d1e62e63cb0589b9 Mon Sep 17 00:00:00 2001 From: Alexandr Azizyan Date: Thu, 2 Apr 2026 22:21:31 +0400 Subject: [PATCH 10/10] feat: add sinusoidal depth encoding, unique norms, complete ablation series 4-5 --- .../README.md | 19 +- .../logs/s5_T.txt | 1716 +++++++++++++++++ .../research_notes.md | 39 + .../scripts/run_fullscale4.sh | 37 + .../train_gpt.py | 34 +- .../train_log.txt | 577 +++--- 6 files changed, 2182 insertions(+), 240 deletions(-) create mode 100644 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s5_T.txt create mode 100755 records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale4.sh diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md index c62ee0c7cc..071c6e29af 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/README.md @@ -28,6 +28,7 @@ - **ALBERT (Lan et al., 2020) found attention sharing is nearly free while FFN sharing causes most degradation.** s3_L confirms: the model needs per-iteration MLP differentiation, not per-iteration attention differentiation. - **Learned depth embeddings and unique input norms both hurt BPB despite reducing Q-gap.** The throughput overhead (6–15%) costs training steps that outweigh any specialization benefit. Learned depth embeddings remained near zero even after full training (RMS 0.006–0.010), suggesting they need far more steps to become useful. FiLM bias alone (s3_O: 1.2625) remains the best full-sharing config. - **The 0.026 BPB gap between full sharing and unique MLPs (s3_L) cannot be closed by cheap per-iteration input controls.** The MLP genuinely needs different weights per iteration, not just different inputs. +- **Sinusoidal depth encoding (UT-style) is neutral on BPB (1.2624 vs 1.2625 without) but strictly better than learned depth embeddings (1.2624 vs 1.2639)** due to zero throughput overhead. Keep it on since it's truly free — zero params, zero artifact cost, zero throughput penalty, marginal Q-gap benefit (0.0073 vs 0.0078). ## Techniques Applicable to Non-Recurrent Submissions @@ -72,11 +73,17 @@ Run M (1+4×3+1 attn-only, 3 loops) crashed during torch.compile with 12 UniqueM | R | 1+4×3+1 learned depth only+bias | 14 | 1.2566 | 1.2639 | +0.0073 | | S | 1+4×3+1 norms only+bias | 14 | 1.2560 | 1.2629 | +0.0069 | +## Series 5: Sinusoidal Depth Encoding (600s, 8×H100) + +| Run | Config | Eff. Layers | Pre-Q BPB | Post-Q BPB | Q-Gap | +|-----|--------|-------------|-----------|------------|-------| +| T | 1+4×3+1 sinusoidal depth+bias | 14 | 1.2551 | 1.2624 | +0.0073 | + ## Next Direction -The cheap learned specialization approach (Series 4: learned depth embeddings + unique input norms) was tested and did not improve over FiLM bias alone. Learned depth embeddings remained near zero (RMS 0.006–0.010) after full training, and throughput overhead (6–15%) cost more training steps than the specialization recovered. FiLM bias alone (s3_O: 1.2625) remains the best full-sharing configuration. +The depth encoding question is resolved: sinusoidal depth encoding (UT-style) is free and marginally helpful — keep it on. The best full-sharing config is now s5_T (1.2624 post-Q) with the complete stack: Output-LN + Birkhoff mixing + FiLM scale+shift (gammas+betas) + sinusoidal depth encoding. -Two next steps: (1) Replace learned depth embeddings with **sinusoidal depth encodings** (Universal Transformer style, Dehghani et al., 2019) — zero parameter cost, zero artifact cost, zero throughput overhead, and full-strength iteration identity signal from step 0 instead of slowly-learned near-zero values. (2) **Graft existing techniques** (Output-LN, Birkhoff mixing, FiLM scale+shift) onto the SOTA stack for a competitive submission. +The remaining path is **grafting the validated technique stack onto the SOTA stack** for a competitive submission. ## How to Reproduce @@ -130,11 +137,15 @@ bash scripts/run_fullscale2.sh ├── logs/ # All run logs │ ├── s1_Ap–F.txt # Screening (7 runs) │ ├── s2_G–K.txt # Full-scale (5 runs) -│ └── s3_L–O.txt # Follow-up ablations (4 runs) +│ ├── s3_L–O.txt # Follow-up ablations (4 runs) +│ ├── s4_P–S.txt # Depth embeddings + unique norms (4 runs) +│ └── s5_T.txt # Sinusoidal depth encoding (1 run) └── scripts/ ├── run_screening.sh # Series 1 screening ├── run_fullscale.sh # Series 2 full-scale - └── run_fullscale2.sh # Series 3 follow-up ablations + ├── run_fullscale2.sh # Series 3 follow-up ablations + ├── run_fullscale3.sh # Series 4 depth embeddings + unique norms + └── run_fullscale4.sh # Series 5 sinusoidal depth encoding ``` ## Links diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s5_T.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s5_T.txt new file mode 100644 index 0000000000..c7c56b7b2d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/logs/s5_T.txt @@ -0,0 +1,1716 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Flash Attention 3 via HuggingFace kernels (Hopper-optimized) +try: + from kernels import get_kernel + cap = torch.cuda.get_device_capability() + repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3" + fa3 = get_kernel(repo).flash_attn_interface + FA3_AVAILABLE = True +except Exception: + FA3_AVAILABLE = False + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Recurrence + stability features. + num_prelude = int(os.environ.get("NUM_PRELUDE", 0)) + num_coda = int(os.environ.get("NUM_CODA", 0)) + num_shared = int(os.environ.get("NUM_SHARED", 0)) + num_loops = int(os.environ.get("NUM_LOOPS", 2)) + use_peri_norm = bool(int(os.environ.get("USE_PERI_NORM", "0"))) + use_birkhoff_mix = bool(int(os.environ.get("USE_BIRKHOFF_MIX", "0"))) + use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) + leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) + timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) + use_depth_embed = bool(int(os.environ.get("USE_DEPTH_EMBED", "0"))) + depth_enc_base = float(os.environ.get("DEPTH_ENC_BASE", 10000.0)) + use_unique_norms = bool(int(os.environ.get("USE_UNIQUE_NORMS", "0"))) + disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta,unique_attn_gain,unique_mlp_gain", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +class DepthEncoding(nn.Module): + """Deterministic sinusoidal depth encodings (Universal Transformer style).""" + def __init__(self, depth: int, dim: int, base: float = 10000.0): + super().__init__() + t = torch.arange(depth, dtype=torch.float32) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, dtype=torch.float32) / dim)) + self.register_buffer('encodings', torch.sin(torch.outer(t, inv_freq)), persistent=False) + self.encodings: Tensor + + def get(self, v: int, dtype: torch.dtype) -> Tensor: + return self.encodings[v].to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if FA3_AVAILABLE: + # FA3 expects (B, T, H, D) — transpose from (B, H, T, D) + y = fa3.flash_attn_func( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + causal=True, + ).transpose(1, 2) + else: + y = F.scaled_dot_product_attention( + q, k, v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu^2 MLP (slope=0.5 matches SOTA, PRs #493/#518) + def __init__(self, dim: int, mlp_mult: int, leaky_relu_slope: float = 0.5): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.leaky_relu_slope = leaky_relu_slope + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=self.leaky_relu_slope) + return self.proj(x.square()) + + +class TimestepScaling(nn.Module): + """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0, use_bias: bool = False): + super().__init__() + self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) + self.gamma_max = gamma_max # 0 = uncapped + if use_bias: + self.attn_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + self.mlp_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + else: + self.attn_beta = None + self.mlp_beta = None + + def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: + ag = self.attn_gamma[v] + mg = self.mlp_gamma[v] + if self.gamma_max > 0: + ag = ag.clamp(-self.gamma_max, self.gamma_max) + mg = mg.clamp(-self.gamma_max, self.gamma_max) + ab = self.attn_beta[v] if self.attn_beta is not None else None + mb = self.mlp_beta[v] if self.mlp_beta is not None else None + return ag, mg, ab, mb + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + self.use_peri_norm = use_peri_norm + self.use_birkhoff_mix = use_birkhoff_mix + self.attn_norm = RMSNorm() + if use_peri_norm: + self.mlp_out_norm = RMSNorm() # output norm: bounds MLP contribution, preserves input magnitude + else: + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, leaky_relu_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + if use_birkhoff_mix: + self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + else: + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, + ts_attn_gamma: Tensor | None = None, + ts_mlp_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None, + ts_mlp_beta: Tensor | None = None, + depth_emb: Tensor | None = None, + ext_attn_gain: Tensor | None = None, + ext_mlp_gain: Tensor | None = None) -> Tensor: + if self.use_birkhoff_mix: + alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] + x = alpha * x + (1 - alpha) * x0 + else: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + if depth_emb is not None: + x = x + depth_emb + attn_normed = self.attn_norm(x) + if ext_attn_gain is not None: + attn_normed = attn_normed * ext_attn_gain[None, None, :] + attn_out = self.attn(attn_normed) + attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] + if ts_attn_gamma is not None: + attn_s = attn_s * ts_attn_gamma[None, None, :] + x = x + attn_s * attn_out + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + if self.use_peri_norm: + if ext_mlp_gain is not None: + m_input = F.rms_norm(x, (x.size(-1),)) * ext_mlp_gain[None, None, :] + else: + m_input = x + mlp_out = self.mlp_out_norm(self.mlp(m_input)) + else: + mlp_normed = self.mlp_norm(x) + if ext_mlp_gain is not None: + mlp_normed = mlp_normed * ext_mlp_gain[None, None, :] + mlp_out = self.mlp(mlp_normed) + mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] + if ts_mlp_gamma is not None: + mlp_s = mlp_s * ts_mlp_gamma[None, None, :] + x = x + mlp_s * mlp_out + if ts_mlp_beta is not None: + x = x + ts_mlp_beta[None, None, :] + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_prelude: int = 0, + num_coda: int = 0, + num_shared: int = 0, + num_loops: int = 2, + use_peri_norm: bool = False, + use_birkhoff_mix: bool = False, + use_timestep_scale: bool = False, + use_timestep_bias: bool = False, + use_depth_embed: bool = False, + depth_enc_base: float = 10000.0, + use_unique_norms: bool = False, + timestep_gamma_max: float = 0.0, + leaky_relu_slope: float = 0.5, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.use_recurrence = num_shared > 0 + self.num_loops = num_loops if self.use_recurrence else 1 + self.num_prelude = num_prelude if self.use_recurrence else 0 + self.num_coda = num_coda if self.use_recurrence else 0 + + block_kwargs = dict( + dim=model_dim, num_heads=num_heads, num_kv_heads=num_kv_heads, + mlp_mult=mlp_mult, rope_base=rope_base, qk_gain_init=qk_gain_init, + use_peri_norm=use_peri_norm if self.use_recurrence else False, + use_birkhoff_mix=use_birkhoff_mix if self.use_recurrence else False, + leaky_relu_slope=leaky_relu_slope, + ) + + if self.use_recurrence: + self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) + self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) + effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None + self.use_depth_embed = use_depth_embed + if self.use_depth_embed: + self.depth_enc = DepthEncoding(effective_layers, model_dim, base=depth_enc_base) + self.use_unique_norms = use_unique_norms + if self.use_unique_norms: + num_unique = num_shared * self.num_loops + self.unique_attn_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + self.unique_mlp_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + else: + # Standard U-Net path + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) + self.timestep_scale = None + self.use_depth_embed = False + self.use_unique_norms = False + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: + if self.timestep_scale is None: + return None, None, None, None + return self.timestep_scale.get(v) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + if self.use_recurrence: + v = 0 + for block in self.prelude_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_enc.get(v, x.dtype) if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) + v += 1 + uid = 0 + for _loop in range(self.num_loops): + for block in self.shared_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_enc.get(v, x.dtype) if self.use_depth_embed else None + if self.use_unique_norms: + ag_n = self.unique_attn_gains[uid].to(dtype=x.dtype) + mg_n = self.unique_mlp_gains[uid].to(dtype=x.dtype) + x = block(x, x0, ag, mg, ab, mb, de, ag_n, mg_n) + else: + x = block(x, x0, ag, mg, ab, mb, de) + uid += 1 + v += 1 + for block in self.coda_blocks: + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_enc.get(v, x.dtype) if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) + v += 1 + else: + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +@torch.no_grad() +def recurrence_param_diagnostics(gpt: GPT) -> str: + """Log parameter norms for recurrence health monitoring. No forward pass.""" + if not gpt.use_recurrence: + return "" + parts: list[str] = [] + + # Birkhoff alpha stats per shared block + for i, block in enumerate(gpt.shared_blocks): + if hasattr(block, "resid_mix_logit"): + a = torch.sigmoid(block.resid_mix_logit) + parts.append(f"shared{i}_alpha:mean={a.mean().item():.3f},std={a.std().item():.3f}") + + # Effective MLP/attn contribution scale per virtual layer position: + # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. + # We report RMS (norm / sqrt(numel)) to make it scale-independent. + v = 0 + effective_count = gpt.num_prelude + len(gpt.shared_blocks) * gpt.num_loops + gpt.num_coda + mlp_norms: list[str] = [] + attn_norms: list[str] = [] + + # Prelude blocks + for block in gpt.prelude_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Shared positions + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Coda blocks + for block in gpt.coda_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") + parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + if gpt.timestep_scale is not None and gpt.timestep_scale.attn_beta is not None: + attn_bias_norms: list[str] = [] + mlp_bias_norms: list[str] = [] + for vi in range(effective_count): + ab_rms = gpt.timestep_scale.attn_beta[vi].norm().item() / gpt.timestep_scale.attn_beta[vi].numel() ** 0.5 + mb_rms = gpt.timestep_scale.mlp_beta[vi].norm().item() / gpt.timestep_scale.mlp_beta[vi].numel() ** 0.5 + attn_bias_norms.append(f"v{vi}:{ab_rms:.4f}") + mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") + parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") + parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") + if gpt.use_unique_norms: + un_attn: list[str] = [] + un_mlp: list[str] = [] + for ui in range(gpt.unique_attn_gains.size(0)): + an_rms = gpt.unique_attn_gains[ui].norm().item() / gpt.unique_attn_gains[ui].numel() ** 0.5 + un_attn.append(f"u{ui}:{an_rms:.4f}") + mn_rms = gpt.unique_mlp_gains[ui].norm().item() / gpt.unique_mlp_gains[ui].numel() ** 0.5 + un_mlp.append(f"u{ui}:{mn_rms:.4f}") + parts.append("unique_attn_gain_rms:[" + " ".join(un_attn) + "]") + parts.append("unique_mlp_gain_rms:[" + " ".join(un_mlp) + "]") + if gpt.use_depth_embed: + parts.append("depth_encoding:sinusoidal") + return " ".join(parts) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_prelude=args.num_prelude, + num_coda=args.num_coda, + num_shared=args.num_shared, + num_loops=args.num_loops, + use_peri_norm=args.use_peri_norm, + use_birkhoff_mix=args.use_birkhoff_mix, + use_timestep_scale=args.use_timestep_scale, + use_timestep_bias=args.use_timestep_bias, + use_depth_embed=args.use_depth_embed, + depth_enc_base=args.depth_enc_base, + use_unique_norms=args.use_unique_norms, + timestep_gamma_max=args.timestep_gamma_max, + leaky_relu_slope=args.leaky_relu_slope, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if args.disable_compile: + compiled_model = base_model + else: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + if base_model.use_recurrence: + all_block_lists = [base_model.prelude_blocks, base_model.shared_blocks, base_model.coda_blocks] + block_named_params = [] + for bl in all_block_lists: + block_named_params.extend(bl.named_parameters()) + if base_model.use_unique_norms: + block_named_params.extend([("unique_attn_gains", base_model.unique_attn_gains)]) + block_named_params.extend([("unique_mlp_gains", base_model.unique_mlp_gains)]) + else: + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if hasattr(base_model, "skip_weights") and base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.timestep_scale is not None: + for p in base_model.timestep_scale.parameters(): + scalar_params.append(p) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"flash_attention_3:{FA3_AVAILABLE}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + if base_model.use_recurrence: + num_shared = len(base_model.shared_blocks) + eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{num_shared} " + f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") + log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") + if base_model.timestep_scale is not None: + ts_params = sum(p.numel() for p in base_model.timestep_scale.parameters()) + log0(f"timestep_scale:enabled params:{ts_params}") + else: + log0("timestep_scale:disabled") + log0(f"depth_embed:{'enabled' if base_model.use_depth_embed else 'disabled'}") + log0(f"unique_norms:{'enabled' if base_model.use_unique_norms else 'disabled'}") + else: + log0(f"recurrence:disabled num_layers:{args.num_layers}") + compile_mode = "disabled" if args.disable_compile else "fullgraph=True" + log0(f"compile_mode:{compile_mode}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + if base_model.use_recurrence and step % args.train_log_every == 0: + diag = recurrence_param_diagnostics(base_model) + if diag: + log0(f"step:{step} {diag}") + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Apr 2 17:29:56 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 30C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 29C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 29C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:11572272 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +flash_attention_3:True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +recurrence:enabled prelude:1 shared:4 loops:3 coda:1 effective_layers:14 +peri_norm:True birkhoff_mix:True +timestep_scale:enabled params:28672 +depth_embed:enabled +unique_norms:disabled +compile_mode:fullgraph=True +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9332 val_bpb:4.1062 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9329 train_time:38ms step_avg:38.33ms +step:2/20000 train_loss:9.5771 train_time:96ms step_avg:47.94ms +step:3/20000 train_loss:12.0056 train_time:154ms step_avg:51.33ms +step:4/20000 train_loss:11.0951 train_time:211ms step_avg:52.78ms +step:5/20000 train_loss:9.6599 train_time:271ms step_avg:54.14ms +step:6/20000 train_loss:8.8458 train_time:328ms step_avg:54.59ms +step:7/20000 train_loss:7.6174 train_time:385ms step_avg:54.93ms +step:8/20000 train_loss:7.0656 train_time:442ms step_avg:55.30ms +step:9/20000 train_loss:6.6004 train_time:500ms step_avg:55.52ms +step:10/20000 train_loss:6.0300 train_time:557ms step_avg:55.74ms +step:200/20000 train_loss:2.8222 train_time:11696ms step_avg:58.48ms +step:200 shared0_alpha:mean=0.473,std=0.047 shared1_alpha:mean=0.489,std=0.042 shared2_alpha:mean=0.494,std=0.047 shared3_alpha:mean=0.516,std=0.047 eff_mlp_scale:[v0:42.1731 v1:27.6985 v2:28.2939 v3:30.5156 v4:32.0949 v5:30.6609 v6:28.7337 v7:30.6697 v8:33.5249 v9:31.6977 v10:31.3725 v11:34.2144 v12:35.4315 v13:50.2958] eff_attn_scale:[v0:14.9320 v1:10.1316 v2:10.8969 v3:9.6005 v4:9.3961 v5:10.0503 v6:9.6550 v7:8.9120 v8:9.2401 v9:10.2130 v10:9.9354 v11:9.2180 v12:9.8639 v13:16.0045] eff_attn_bias:[v0:0.1630 v1:0.1450 v2:0.1581 v3:0.1616 v4:0.1650 v5:0.1692 v6:0.1733 v7:0.1795 v8:0.1851 v9:0.1961 v10:0.1961 v11:0.1878 v12:0.1864 v13:0.1485] eff_mlp_bias:[v0:0.1609 v1:0.1595 v2:0.1540 v3:0.1512 v4:0.1581 v5:0.1643 v6:0.1630 v7:0.1733 v8:0.1906 v9:0.1906 v10:0.1754 v11:0.1733 v12:0.1547 v13:0.1864] depth_encoding:sinusoidal +step:200/20000 val_loss:2.8059 val_bpb:1.6618 train_time:11751ms step_avg:58.75ms +step:400/20000 train_loss:2.3856 train_time:23480ms step_avg:58.70ms +step:400 shared0_alpha:mean=0.483,std=0.049 shared1_alpha:mean=0.509,std=0.044 shared2_alpha:mean=0.517,std=0.049 shared3_alpha:mean=0.545,std=0.050 eff_mlp_scale:[v0:50.5521 v1:32.5033 v2:35.1304 v3:37.9512 v4:38.0848 v5:41.3063 v6:38.4290 v7:38.7945 v8:39.2541 v9:40.9677 v10:38.4290 v11:39.6379 v12:37.9178 v13:66.6226] eff_attn_scale:[v0:6.5055 v1:5.4401 v2:5.9484 v3:5.0682 v4:4.9102 v5:5.8673 v6:5.6164 v7:4.9169 v8:5.0114 v9:5.6679 v10:5.2014 v11:4.4378 v12:4.6318 v13:9.4584] eff_attn_bias:[v0:0.2168 v1:0.1947 v2:0.2085 v3:0.2058 v4:0.2113 v5:0.2141 v6:0.2141 v7:0.2003 v8:0.2058 v9:0.2182 v10:0.2182 v11:0.1989 v12:0.2003 v13:0.1595] eff_mlp_bias:[v0:0.2955 v1:0.2293 v2:0.1947 v3:0.1878 v4:0.1906 v5:0.1906 v6:0.1733 v7:0.1782 v8:0.1989 v9:0.2044 v10:0.1823 v11:0.1761 v12:0.1692 v13:0.2417] depth_encoding:sinusoidal +step:400/20000 val_loss:2.5853 val_bpb:1.5312 train_time:23505ms step_avg:58.76ms +step:600/20000 train_loss:2.5924 train_time:35261ms step_avg:58.77ms +step:600 shared0_alpha:mean=0.488,std=0.049 shared1_alpha:mean=0.521,std=0.047 shared2_alpha:mean=0.532,std=0.050 shared3_alpha:mean=0.566,std=0.052 eff_mlp_scale:[v0:54.6022 v1:36.1924 v2:39.3814 v3:42.9007 v4:41.9468 v5:48.4979 v6:44.9576 v7:44.3131 v8:43.1502 v9:46.6882 v10:42.3437 v11:42.3711 v12:39.5400 v13:83.2264] eff_attn_scale:[v0:3.0099 v1:3.1470 v2:3.4999 v3:3.0893 v4:3.0057 v5:3.4597 v6:3.4027 v7:3.1254 v8:3.1677 v9:3.2447 v10:3.0333 v11:2.6737 v12:2.7177 v13:6.2262] eff_attn_bias:[v0:0.2735 v1:0.2444 v2:0.2472 v3:0.2472 v4:0.2610 v5:0.2665 v6:0.2569 v7:0.2237 v8:0.2293 v9:0.2403 v10:0.2306 v11:0.2030 v12:0.2099 v13:0.1823] eff_mlp_bias:[v0:0.4392 v1:0.2886 v2:0.2417 v3:0.2334 v4:0.2500 v5:0.2320 v6:0.1892 v7:0.1906 v8:0.2168 v9:0.2182 v10:0.1892 v11:0.1837 v12:0.1864 v13:0.2500] depth_encoding:sinusoidal +step:600/20000 val_loss:2.4888 val_bpb:1.4740 train_time:35285ms step_avg:58.81ms +step:800/20000 train_loss:2.3501 train_time:47075ms step_avg:58.84ms +step:800 shared0_alpha:mean=0.491,std=0.048 shared1_alpha:mean=0.529,std=0.049 shared2_alpha:mean=0.540,std=0.052 shared3_alpha:mean=0.579,std=0.054 eff_mlp_scale:[v0:59.9201 v1:40.1580 v2:43.4057 v3:46.3101 v4:44.5735 v5:54.1754 v6:49.9165 v7:48.3078 v8:46.1591 v9:50.7658 v10:45.0334 v11:44.4940 v12:41.4022 v13:96.0354] eff_attn_scale:[v0:1.9268 v1:2.2402 v2:2.4450 v3:2.2894 v4:2.2408 v5:2.4269 v6:2.4606 v7:2.3638 v8:2.3872 v9:2.2869 v10:2.1335 v11:1.9326 v12:1.9186 v13:4.7429] eff_attn_bias:[v0:0.3135 v1:0.2817 v2:0.2831 v3:0.2817 v4:0.3094 v5:0.3066 v6:0.2942 v7:0.2514 v8:0.2569 v9:0.2610 v10:0.2444 v11:0.2085 v12:0.2210 v13:0.2154] eff_mlp_bias:[v0:0.5635 v1:0.3370 v2:0.2845 v3:0.2748 v4:0.2997 v5:0.2721 v6:0.2099 v7:0.2044 v8:0.2389 v9:0.2362 v10:0.2003 v11:0.1975 v12:0.2072 v13:0.2569] depth_encoding:sinusoidal +step:800/20000 val_loss:2.4288 val_bpb:1.4385 train_time:47100ms step_avg:58.88ms +step:1000/20000 train_loss:2.4207 train_time:58897ms step_avg:58.90ms +step:1000 shared0_alpha:mean=0.492,std=0.050 shared1_alpha:mean=0.534,std=0.051 shared2_alpha:mean=0.546,std=0.053 shared3_alpha:mean=0.588,std=0.054 eff_mlp_scale:[v0:65.1131 v1:43.7501 v2:46.8400 v3:49.6567 v4:47.2571 v5:58.8567 v6:54.4912 v7:51.8968 v8:49.4216 v9:54.5405 v10:47.7731 v11:47.0432 v12:43.6497 v13:107.1387] eff_attn_scale:[v0:1.4494 v1:1.7948 v2:1.9658 v3:1.8958 v4:1.8762 v5:1.9716 v6:1.9931 v7:2.0026 v8:1.9544 v9:1.8220 v10:1.7132 v11:1.6021 v12:1.5635 v13:3.8871] eff_attn_bias:[v0:0.3466 v1:0.3094 v2:0.3135 v3:0.3135 v4:0.3480 v5:0.3439 v6:0.3356 v7:0.2790 v8:0.2859 v9:0.2831 v10:0.2583 v11:0.2154 v12:0.2306 v13:0.2458] eff_mlp_bias:[v0:0.6712 v1:0.3784 v2:0.3246 v3:0.3107 v4:0.3384 v5:0.3149 v6:0.2306 v7:0.2196 v8:0.2652 v9:0.2555 v10:0.2141 v11:0.2127 v12:0.2251 v13:0.2638] depth_encoding:sinusoidal +step:1000/20000 val_loss:2.3864 val_bpb:1.4134 train_time:58922ms step_avg:58.92ms +step:1200/20000 train_loss:2.4329 train_time:70714ms step_avg:58.93ms +step:1200 shared0_alpha:mean=0.491,std=0.051 shared1_alpha:mean=0.540,std=0.052 shared2_alpha:mean=0.551,std=0.054 shared3_alpha:mean=0.594,std=0.056 eff_mlp_scale:[v0:70.1553 v1:47.0161 v2:49.9708 v3:52.2430 v4:49.7842 v5:62.9572 v6:57.9814 v7:54.9123 v8:51.6281 v9:57.7108 v10:49.5894 v11:48.8109 v12:45.9121 v13:116.7423] eff_attn_scale:[v0:1.1874 v1:1.5626 v2:1.6931 v3:1.6728 v4:1.5858 v5:1.6429 v6:1.7182 v7:1.7595 v8:1.7059 v9:1.5441 v10:1.4548 v11:1.3754 v12:1.3335 v13:3.3954] eff_attn_bias:[v0:0.3784 v1:0.3439 v2:0.3411 v3:0.3384 v4:0.3812 v5:0.3784 v6:0.3701 v7:0.3066 v8:0.3149 v9:0.3066 v10:0.2721 v11:0.2237 v12:0.2431 v13:0.2748] eff_mlp_bias:[v0:0.7679 v1:0.4198 v2:0.3591 v3:0.3397 v4:0.3757 v5:0.3522 v6:0.2541 v7:0.2362 v8:0.2914 v9:0.2748 v10:0.2293 v11:0.2279 v12:0.2444 v13:0.2721] depth_encoding:sinusoidal +step:1200/20000 val_loss:2.3567 val_bpb:1.3957 train_time:70739ms step_avg:58.95ms +step:1400/20000 train_loss:2.4819 train_time:82526ms step_avg:58.95ms +step:1400 shared0_alpha:mean=0.491,std=0.052 shared1_alpha:mean=0.543,std=0.054 shared2_alpha:mean=0.554,std=0.055 shared3_alpha:mean=0.598,std=0.057 eff_mlp_scale:[v0:74.8502 v1:50.2471 v2:53.1146 v3:55.3364 v4:51.9688 v5:66.5826 v6:60.9256 v7:57.6745 v8:53.8518 v9:60.7928 v10:51.9430 v11:51.0498 v12:48.2030 v13:125.4454] eff_attn_scale:[v0:0.9967 v1:1.3821 v2:1.5224 v3:1.4748 v4:1.4563 v5:1.4576 v6:1.5224 v7:1.5561 v8:1.5245 v9:1.3646 v10:1.2922 v11:1.2077 v12:1.1889 v13:3.1047] eff_attn_bias:[v0:0.3950 v1:0.3701 v2:0.3646 v3:0.3674 v4:0.4088 v5:0.4060 v6:0.4033 v7:0.3328 v8:0.3425 v9:0.3273 v10:0.2886 v11:0.2306 v12:0.2541 v13:0.3038] eff_mlp_bias:[v0:0.8507 v1:0.4530 v2:0.3895 v3:0.3674 v4:0.4033 v5:0.3867 v6:0.2790 v7:0.2555 v8:0.3163 v9:0.2969 v10:0.2444 v11:0.2444 v12:0.2624 v13:0.2831] depth_encoding:sinusoidal +step:1400/20000 val_loss:2.3334 val_bpb:1.3820 train_time:82551ms step_avg:58.96ms +step:1600/20000 train_loss:2.1555 train_time:94337ms step_avg:58.96ms +step:1600 shared0_alpha:mean=0.492,std=0.053 shared1_alpha:mean=0.547,std=0.055 shared2_alpha:mean=0.557,std=0.056 shared3_alpha:mean=0.599,std=0.059 eff_mlp_scale:[v0:79.5614 v1:53.7939 v2:55.9060 v3:58.2668 v4:54.0765 v5:69.6156 v6:63.4394 v7:60.2486 v8:55.9941 v9:62.8650 v10:53.1305 v11:52.7176 v12:50.2412 v13:132.0925] eff_attn_scale:[v0:0.8713 v1:1.2774 v2:1.4277 v3:1.3605 v4:1.3307 v5:1.3443 v6:1.4107 v7:1.4438 v8:1.3796 v9:1.2551 v10:1.1831 v11:1.0995 v12:1.0863 v13:2.9157] eff_attn_bias:[v0:0.4116 v1:0.4033 v2:0.3839 v3:0.3922 v4:0.4337 v5:0.4309 v6:0.4309 v7:0.3563 v8:0.3674 v9:0.3439 v10:0.3025 v11:0.2403 v12:0.2665 v13:0.3342] eff_mlp_bias:[v0:0.9226 v1:0.4889 v2:0.4171 v3:0.3922 v4:0.4337 v5:0.4171 v6:0.3025 v7:0.2748 v8:0.3397 v9:0.3176 v10:0.2610 v11:0.2610 v12:0.2804 v13:0.2969] depth_encoding:sinusoidal +step:1600/20000 val_loss:2.3219 val_bpb:1.3752 train_time:94362ms step_avg:58.98ms +step:1800/20000 train_loss:2.2543 train_time:106138ms step_avg:58.97ms +step:1800 shared0_alpha:mean=0.491,std=0.054 shared1_alpha:mean=0.551,std=0.055 shared2_alpha:mean=0.559,std=0.057 shared3_alpha:mean=0.601,std=0.060 eff_mlp_scale:[v0:84.1182 v1:56.3085 v2:58.5366 v3:60.7906 v4:56.2029 v5:72.6423 v6:66.2069 v7:62.8035 v8:58.1544 v9:65.7649 v10:55.3070 v11:54.7518 v12:52.3000 v13:139.4565] eff_attn_scale:[v0:0.7733 v1:1.1822 v2:1.3352 v3:1.2542 v4:1.2462 v5:1.2034 v6:1.3187 v7:1.3183 v8:1.3040 v9:1.1239 v10:1.0879 v11:1.0034 v12:0.9938 v13:2.7729] eff_attn_bias:[v0:0.4254 v1:0.4254 v2:0.4033 v3:0.4143 v4:0.4613 v5:0.4558 v6:0.4558 v7:0.3784 v8:0.3895 v9:0.3646 v10:0.3176 v11:0.2514 v12:0.2790 v13:0.3618] eff_mlp_bias:[v0:0.9888 v1:0.5220 v2:0.4475 v3:0.4198 v4:0.4585 v5:0.4530 v6:0.3273 v7:0.2914 v8:0.3618 v9:0.3397 v10:0.2817 v11:0.2776 v12:0.2983 v13:0.3135] depth_encoding:sinusoidal +step:1800/20000 val_loss:2.3035 val_bpb:1.3642 train_time:106163ms step_avg:58.98ms +step:2000/20000 train_loss:2.3066 train_time:117933ms step_avg:58.97ms +step:2000 shared0_alpha:mean=0.491,std=0.055 shared1_alpha:mean=0.554,std=0.056 shared2_alpha:mean=0.561,std=0.058 shared3_alpha:mean=0.602,std=0.061 eff_mlp_scale:[v0:88.9970 v1:59.4835 v2:61.5244 v3:63.3619 v4:57.7963 v5:75.6665 v6:68.9073 v7:65.4058 v8:60.1715 v9:68.2311 v10:57.0126 v11:56.4125 v12:54.2335 v13:146.0011] eff_attn_scale:[v0:0.7016 v1:1.1244 v2:1.2590 v3:1.1668 v4:1.1665 v5:1.1400 v6:1.2218 v7:1.2339 v8:1.2169 v9:1.0571 v10:1.0093 v11:0.9345 v12:0.9190 v13:2.5910] eff_attn_bias:[v0:0.4309 v1:0.4502 v2:0.4226 v3:0.4337 v4:0.4861 v5:0.4778 v6:0.4834 v7:0.3977 v8:0.4116 v9:0.3812 v10:0.3315 v11:0.2610 v12:0.2900 v13:0.3895] eff_mlp_bias:[v0:1.0551 v1:0.5524 v2:0.4806 v3:0.4447 v4:0.4834 v5:0.4861 v6:0.3522 v7:0.3121 v8:0.3839 v9:0.3618 v10:0.3025 v11:0.2942 v12:0.3149 v13:0.3328] depth_encoding:sinusoidal +step:2000/20000 val_loss:2.2881 val_bpb:1.3551 train_time:117958ms step_avg:58.98ms +step:2200/20000 train_loss:2.1276 train_time:129731ms step_avg:58.97ms +step:2200 shared0_alpha:mean=0.491,std=0.055 shared1_alpha:mean=0.557,std=0.056 shared2_alpha:mean=0.563,std=0.058 shared3_alpha:mean=0.604,std=0.062 eff_mlp_scale:[v0:93.2070 v1:61.6985 v2:64.0873 v3:65.4437 v4:60.1313 v5:78.1219 v6:71.1619 v7:67.1005 v8:62.5527 v9:70.1321 v10:58.6773 v11:58.4023 v12:56.4992 v13:153.3788] eff_attn_scale:[v0:0.6433 v1:1.0833 v2:1.2209 v3:1.1398 v4:1.1340 v5:1.0631 v6:1.1739 v7:1.1858 v8:1.1844 v9:0.9922 v10:0.9600 v11:0.8996 v12:0.8971 v13:2.4914] eff_attn_bias:[v0:0.4392 v1:0.4723 v2:0.4447 v3:0.4558 v4:0.5055 v5:0.4944 v6:0.5055 v7:0.4171 v8:0.4337 v9:0.3977 v10:0.3494 v11:0.2693 v12:0.3025 v13:0.4171] eff_mlp_bias:[v0:1.1159 v1:0.5800 v2:0.5110 v3:0.4668 v4:0.5082 v5:0.5138 v6:0.3757 v7:0.3301 v8:0.4060 v9:0.3839 v10:0.3204 v11:0.3107 v12:0.3301 v13:0.3508] depth_encoding:sinusoidal +step:2200/20000 val_loss:2.2811 val_bpb:1.3510 train_time:129756ms step_avg:58.98ms +step:2400/20000 train_loss:2.2517 train_time:141522ms step_avg:58.97ms +step:2400 shared0_alpha:mean=0.490,std=0.057 shared1_alpha:mean=0.559,std=0.058 shared2_alpha:mean=0.564,std=0.059 shared3_alpha:mean=0.604,std=0.063 eff_mlp_scale:[v0:97.3088 v1:64.4124 v2:66.1193 v3:68.0254 v4:61.7178 v5:81.0785 v6:73.2787 v7:69.7050 v8:63.7614 v9:72.5202 v10:60.2233 v11:60.0471 v12:58.4480 v13:159.8243] eff_attn_scale:[v0:0.6018 v1:1.0217 v2:1.1559 v3:1.0927 v4:1.0802 v5:1.0070 v6:1.1152 v7:1.1426 v8:1.1342 v9:0.9333 v10:0.9166 v11:0.8482 v12:0.8543 v13:2.4043] eff_attn_bias:[v0:0.4475 v1:0.4972 v2:0.4613 v3:0.4723 v4:0.5276 v5:0.5138 v6:0.5276 v7:0.4364 v8:0.4558 v9:0.4143 v10:0.3646 v11:0.2776 v12:0.3135 v13:0.4419] eff_mlp_bias:[v0:1.1822 v1:0.6049 v2:0.5386 v3:0.4889 v4:0.5303 v5:0.5441 v6:0.4005 v7:0.3466 v8:0.4254 v9:0.4060 v10:0.3397 v11:0.3246 v12:0.3425 v13:0.3701] depth_encoding:sinusoidal +step:2400/20000 val_loss:2.2693 val_bpb:1.3440 train_time:141547ms step_avg:58.98ms +step:2600/20000 train_loss:2.4645 train_time:153292ms step_avg:58.96ms +step:2600 shared0_alpha:mean=0.490,std=0.058 shared1_alpha:mean=0.562,std=0.058 shared2_alpha:mean=0.565,std=0.060 shared3_alpha:mean=0.604,std=0.063 eff_mlp_scale:[v0:101.6481 v1:67.0256 v2:69.3391 v3:70.7851 v4:63.5850 v5:83.4400 v6:75.7594 v7:71.6379 v8:66.4942 v9:74.3209 v10:62.4908 v11:62.2568 v12:60.6759 v13:165.2093] eff_attn_scale:[v0:0.5697 v1:0.9969 v2:1.1324 v3:1.0532 v4:1.0312 v5:1.0017 v6:1.0975 v7:1.0772 v8:1.0689 v9:0.9110 v10:0.8930 v11:0.8079 v12:0.8052 v13:2.3357] eff_attn_bias:[v0:0.4613 v1:0.5193 v2:0.4834 v3:0.4889 v4:0.5497 v5:0.5276 v6:0.5497 v7:0.4530 v8:0.4723 v9:0.4254 v10:0.3757 v11:0.2886 v12:0.3273 v13:0.4640] eff_mlp_bias:[v0:1.2374 v1:0.6381 v2:0.5690 v3:0.5138 v4:0.5497 v5:0.5745 v6:0.4198 v7:0.3646 v8:0.4475 v9:0.4226 v10:0.3591 v11:0.3397 v12:0.3563 v13:0.3867] depth_encoding:sinusoidal +step:2600/20000 val_loss:2.2779 val_bpb:1.3491 train_time:153317ms step_avg:58.97ms +step:2800/20000 train_loss:2.2945 train_time:165082ms step_avg:58.96ms +step:2800 shared0_alpha:mean=0.489,std=0.058 shared1_alpha:mean=0.564,std=0.058 shared2_alpha:mean=0.566,std=0.060 shared3_alpha:mean=0.604,std=0.064 eff_mlp_scale:[v0:105.7923 v1:69.8320 v2:71.9417 v3:72.7933 v4:65.5945 v5:86.0182 v6:77.5757 v7:73.6547 v8:67.6969 v9:76.7689 v10:63.7074 v11:63.3172 v12:62.2307 v13:170.3488] eff_attn_scale:[v0:0.5409 v1:0.9771 v2:1.0912 v3:1.0090 v4:1.0242 v5:0.9392 v6:1.0372 v7:1.0185 v8:1.0431 v9:0.8680 v10:0.8504 v11:0.7722 v12:0.7930 v13:2.2121] eff_attn_bias:[v0:0.4613 v1:0.5386 v2:0.4999 v3:0.5055 v4:0.5718 v5:0.5441 v6:0.5690 v7:0.4696 v8:0.4917 v9:0.4392 v10:0.3867 v11:0.2983 v12:0.3370 v13:0.4889] eff_mlp_bias:[v0:1.2982 v1:0.6602 v2:0.5966 v3:0.5359 v4:0.5718 v5:0.6021 v6:0.4419 v7:0.3812 v8:0.4640 v9:0.4447 v10:0.3757 v11:0.3563 v12:0.3701 v13:0.4033] depth_encoding:sinusoidal +step:2800/20000 val_loss:2.2549 val_bpb:1.3355 train_time:165107ms step_avg:58.97ms +step:3000/20000 train_loss:2.2781 train_time:176850ms step_avg:58.95ms +step:3000 shared0_alpha:mean=0.489,std=0.058 shared1_alpha:mean=0.567,std=0.059 shared2_alpha:mean=0.567,std=0.061 shared3_alpha:mean=0.604,std=0.065 eff_mlp_scale:[v0:109.7013 v1:71.9646 v2:74.0094 v3:74.9589 v4:67.7205 v5:87.8528 v6:79.7024 v7:75.3947 v8:69.8501 v9:78.0395 v10:65.2509 v11:64.9353 v12:64.3132 v13:176.7194] eff_attn_scale:[v0:0.5126 v1:0.9590 v2:1.0788 v3:0.9688 v4:0.9963 v5:0.9123 v6:1.0202 v7:1.0111 v8:1.0197 v9:0.8374 v10:0.8299 v11:0.7571 v12:0.7718 v13:2.1389] eff_attn_bias:[v0:0.4640 v1:0.5635 v2:0.5165 v3:0.5220 v4:0.5883 v5:0.5580 v6:0.5883 v7:0.4861 v8:0.5082 v9:0.4530 v10:0.3977 v11:0.3052 v12:0.3494 v13:0.5138] eff_mlp_bias:[v0:1.3590 v1:0.6878 v2:0.6270 v3:0.5524 v4:0.5911 v5:0.6270 v6:0.4640 v7:0.3950 v8:0.4834 v9:0.4640 v10:0.3950 v11:0.3701 v12:0.3839 v13:0.4226] depth_encoding:sinusoidal +step:3000/20000 val_loss:2.2483 val_bpb:1.3316 train_time:176875ms step_avg:58.96ms +step:3200/20000 train_loss:2.2395 train_time:188622ms step_avg:58.94ms +step:3200 shared0_alpha:mean=0.487,std=0.059 shared1_alpha:mean=0.569,std=0.059 shared2_alpha:mean=0.568,std=0.061 shared3_alpha:mean=0.604,std=0.065 eff_mlp_scale:[v0:114.2085 v1:74.3596 v2:76.6599 v3:77.5110 v4:69.5039 v5:90.4630 v6:81.9773 v7:77.0705 v8:71.6624 v9:80.5168 v10:66.9112 v11:66.9413 v12:66.0503 v13:181.8108] eff_attn_scale:[v0:0.4854 v1:0.9360 v2:1.0429 v3:0.9650 v4:0.9903 v5:0.8804 v6:0.9946 v7:0.9839 v8:1.0042 v9:0.8109 v10:0.8015 v11:0.7532 v12:0.7532 v13:2.0938] eff_attn_bias:[v0:0.4751 v1:0.5828 v2:0.5303 v3:0.5331 v4:0.6104 v5:0.5745 v6:0.6104 v7:0.5027 v8:0.5248 v9:0.4640 v10:0.4116 v11:0.3135 v12:0.3618 v13:0.5331] eff_mlp_bias:[v0:1.4087 v1:0.7182 v2:0.6491 v3:0.5718 v4:0.6077 v5:0.6491 v6:0.4861 v7:0.4116 v8:0.5027 v9:0.4861 v10:0.4116 v11:0.3867 v12:0.3977 v13:0.4392] depth_encoding:sinusoidal +step:3200/20000 val_loss:2.2434 val_bpb:1.3287 train_time:188646ms step_avg:58.95ms +step:3400/20000 train_loss:2.2143 train_time:200387ms step_avg:58.94ms +step:3400 shared0_alpha:mean=0.487,std=0.060 shared1_alpha:mean=0.572,std=0.060 shared2_alpha:mean=0.569,std=0.062 shared3_alpha:mean=0.604,std=0.066 eff_mlp_scale:[v0:118.7020 v1:77.2060 v2:78.9053 v3:79.2049 v4:71.1050 v5:92.5513 v6:83.8369 v7:79.2049 v8:72.8499 v9:82.4809 v10:68.5938 v11:68.5256 v12:68.0514 v13:186.2671] eff_attn_scale:[v0:0.4577 v1:0.9144 v2:1.0188 v3:0.9579 v4:0.9676 v5:0.8598 v6:0.9572 v7:0.9718 v8:0.9768 v9:0.7825 v10:0.7866 v11:0.7254 v12:0.7326 v13:2.0363] eff_attn_bias:[v0:0.4723 v1:0.6049 v2:0.5552 v3:0.5497 v4:0.6325 v5:0.5883 v6:0.6270 v7:0.5193 v8:0.5414 v9:0.4778 v10:0.4254 v11:0.3218 v12:0.3729 v13:0.5524] eff_mlp_bias:[v0:1.4584 v1:0.7403 v2:0.6767 v3:0.5939 v4:0.6242 v5:0.6740 v6:0.5082 v7:0.4281 v8:0.5165 v9:0.5027 v10:0.4309 v11:0.4033 v12:0.4116 v13:0.4558] depth_encoding:sinusoidal +step:3400/20000 val_loss:2.2412 val_bpb:1.3274 train_time:200411ms step_avg:58.94ms +step:3600/20000 train_loss:2.1808 train_time:212152ms step_avg:58.93ms +step:3600 shared0_alpha:mean=0.485,std=0.060 shared1_alpha:mean=0.574,std=0.060 shared2_alpha:mean=0.569,std=0.062 shared3_alpha:mean=0.603,std=0.066 eff_mlp_scale:[v0:122.4656 v1:79.4080 v2:81.2694 v3:81.4648 v4:72.8739 v5:94.4181 v6:85.8096 v7:81.0147 v8:75.0822 v9:84.2500 v10:70.3730 v11:70.2127 v12:69.7823 v13:191.2313] eff_attn_scale:[v0:0.4533 v1:0.9054 v2:1.0046 v3:0.9258 v4:0.9416 v5:0.8423 v6:0.9527 v7:0.9396 v8:0.9597 v9:0.7702 v10:0.7640 v11:0.7058 v12:0.7153 v13:2.0134] eff_attn_bias:[v0:0.4778 v1:0.6215 v2:0.5773 v3:0.5690 v4:0.6463 v5:0.6021 v6:0.6463 v7:0.5303 v8:0.5607 v9:0.4917 v10:0.4337 v11:0.3315 v12:0.3839 v13:0.5718] eff_mlp_bias:[v0:1.5137 v1:0.7679 v2:0.7043 v3:0.6160 v4:0.6436 v5:0.6988 v6:0.5303 v7:0.4447 v8:0.5359 v9:0.5248 v10:0.4502 v11:0.4171 v12:0.4254 v13:0.4723] depth_encoding:sinusoidal +step:3600/20000 val_loss:2.2334 val_bpb:1.3227 train_time:212177ms step_avg:58.94ms +step:3800/20000 train_loss:2.2765 train_time:223908ms step_avg:58.92ms +step:3800 shared0_alpha:mean=0.486,std=0.060 shared1_alpha:mean=0.576,std=0.060 shared2_alpha:mean=0.570,std=0.063 shared3_alpha:mean=0.603,std=0.066 eff_mlp_scale:[v0:126.0984 v1:81.8079 v2:83.4886 v3:84.1340 v4:74.6995 v5:96.9937 v6:88.0759 v7:82.7696 v8:76.9360 v9:86.2167 v10:71.5617 v11:71.8549 v12:72.0157 v13:195.6981] eff_attn_scale:[v0:0.4313 v1:0.8895 v2:0.9993 v3:0.9230 v4:0.9478 v5:0.8311 v6:0.9427 v7:0.9321 v8:0.9433 v9:0.7503 v10:0.7589 v11:0.7037 v12:0.7086 v13:1.9990] eff_attn_bias:[v0:0.4778 v1:0.6436 v2:0.5939 v3:0.5828 v4:0.6629 v5:0.6132 v6:0.6574 v7:0.5441 v8:0.5800 v9:0.5027 v10:0.4447 v11:0.3384 v12:0.3950 v13:0.5911] eff_mlp_bias:[v0:1.5578 v1:0.7955 v2:0.7292 v3:0.6325 v4:0.6574 v5:0.7182 v6:0.5497 v7:0.4613 v8:0.5552 v9:0.5414 v10:0.4668 v11:0.4309 v12:0.4392 v13:0.4889] depth_encoding:sinusoidal +step:3800/20000 val_loss:2.2287 val_bpb:1.3200 train_time:223939ms step_avg:58.93ms +step:4000/20000 train_loss:2.2149 train_time:235662ms step_avg:58.92ms +step:4000 shared0_alpha:mean=0.485,std=0.061 shared1_alpha:mean=0.579,std=0.061 shared2_alpha:mean=0.570,std=0.063 shared3_alpha:mean=0.603,std=0.067 eff_mlp_scale:[v0:129.9269 v1:84.2808 v2:85.7783 v3:86.0343 v4:76.3379 v5:99.1539 v6:90.4149 v7:85.1141 v8:78.5964 v9:88.2469 v10:73.7229 v11:73.6122 v12:73.6277 v13:200.7250] eff_attn_scale:[v0:0.4180 v1:0.8840 v2:0.9935 v3:0.9250 v4:0.9492 v5:0.8129 v6:0.9279 v7:0.9114 v8:0.9402 v9:0.7329 v10:0.7404 v11:0.6847 v12:0.7051 v13:1.9278] eff_attn_bias:[v0:0.4834 v1:0.6629 v2:0.6132 v3:0.5994 v4:0.6822 v5:0.6270 v6:0.6767 v7:0.5552 v8:0.5939 v9:0.5138 v10:0.4558 v11:0.3480 v12:0.4060 v13:0.6104] eff_mlp_bias:[v0:1.6020 v1:0.8176 v2:0.7568 v3:0.6519 v4:0.6795 v5:0.7403 v6:0.5690 v7:0.4751 v8:0.5690 v9:0.5580 v10:0.4834 v11:0.4447 v12:0.4502 v13:0.5055] depth_encoding:sinusoidal +step:4000/20000 val_loss:2.2243 val_bpb:1.3174 train_time:235693ms step_avg:58.92ms +step:4200/20000 train_loss:2.2312 train_time:247477ms step_avg:58.92ms +step:4200 shared0_alpha:mean=0.484,std=0.061 shared1_alpha:mean=0.581,std=0.061 shared2_alpha:mean=0.570,std=0.063 shared3_alpha:mean=0.603,std=0.067 eff_mlp_scale:[v0:133.8872 v1:86.5654 v2:88.4886 v3:88.3419 v4:78.0057 v5:101.0763 v6:91.7659 v7:86.9471 v8:79.8304 v9:89.5676 v10:74.9110 v11:74.8582 v12:75.2687 v13:204.7848] eff_attn_scale:[v0:0.4050 v1:0.8744 v2:0.9886 v3:0.8961 v4:0.9314 v5:0.8041 v6:0.9054 v7:0.9006 v8:0.9225 v9:0.7207 v10:0.7299 v11:0.6765 v12:0.6830 v13:1.9229] eff_attn_bias:[v0:0.4834 v1:0.6822 v2:0.6325 v3:0.6104 v4:0.7043 v5:0.6381 v6:0.6905 v7:0.5718 v8:0.6104 v9:0.5248 v10:0.4668 v11:0.3536 v12:0.4143 v13:0.6298] eff_mlp_bias:[v0:1.6352 v1:0.8397 v2:0.7844 v3:0.6740 v4:0.6961 v5:0.7623 v6:0.5883 v7:0.4917 v8:0.5856 v9:0.5773 v10:0.4999 v11:0.4585 v12:0.4613 v13:0.5220] depth_encoding:sinusoidal +step:4200/20000 val_loss:2.2204 val_bpb:1.3150 train_time:247508ms step_avg:58.93ms +step:4400/20000 train_loss:2.1674 train_time:259233ms step_avg:58.92ms +step:4400 shared0_alpha:mean=0.484,std=0.062 shared1_alpha:mean=0.583,std=0.062 shared2_alpha:mean=0.571,std=0.064 shared3_alpha:mean=0.602,std=0.067 eff_mlp_scale:[v0:138.3447 v1:88.9576 v2:91.3003 v3:90.6148 v4:80.3670 v5:103.1100 v6:93.6656 v7:88.2673 v8:81.7527 v9:91.4848 v10:76.6355 v11:76.5296 v12:77.1339 v13:209.9459] eff_attn_scale:[v0:0.3903 v1:0.8642 v2:0.9837 v3:0.8862 v4:0.9046 v5:0.7813 v6:0.9005 v7:0.8952 v8:0.9002 v9:0.7027 v10:0.7112 v11:0.6714 v12:0.6763 v13:1.8581] eff_attn_bias:[v0:0.4861 v1:0.7016 v2:0.6463 v3:0.6270 v4:0.7237 v5:0.6491 v6:0.7071 v7:0.5856 v8:0.6325 v9:0.5359 v10:0.4778 v11:0.3618 v12:0.4254 v13:0.6436] eff_mlp_bias:[v0:1.6794 v1:0.8618 v2:0.8065 v3:0.6961 v4:0.7126 v5:0.7844 v6:0.6104 v7:0.5055 v8:0.6049 v9:0.5911 v10:0.5138 v11:0.4723 v12:0.4723 v13:0.5331] depth_encoding:sinusoidal +step:4400/20000 val_loss:2.2216 val_bpb:1.3158 train_time:259263ms step_avg:58.92ms +step:4600/20000 train_loss:2.0304 train_time:271001ms step_avg:58.91ms +step:4600 shared0_alpha:mean=0.483,std=0.062 shared1_alpha:mean=0.586,std=0.062 shared2_alpha:mean=0.571,std=0.065 shared3_alpha:mean=0.601,std=0.068 eff_mlp_scale:[v0:142.5881 v1:91.9992 v2:93.7048 v3:92.9953 v4:82.0128 v5:105.7990 v6:95.6171 v7:90.6230 v8:83.4107 v9:93.5325 v10:77.9279 v11:78.2868 v12:78.7509 v13:214.2951] eff_attn_scale:[v0:0.3849 v1:0.8685 v2:0.9893 v3:0.8841 v4:0.9260 v5:0.7903 v6:0.8886 v7:0.8841 v8:0.9171 v9:0.7078 v10:0.7099 v11:0.6631 v12:0.6823 v13:1.8677] eff_attn_bias:[v0:0.4861 v1:0.7182 v2:0.6574 v3:0.6408 v4:0.7403 v5:0.6602 v6:0.7237 v7:0.5966 v8:0.6463 v9:0.5497 v10:0.4889 v11:0.3701 v12:0.4364 v13:0.6602] eff_mlp_bias:[v0:1.7125 v1:0.8894 v2:0.8286 v3:0.7126 v4:0.7292 v5:0.8065 v6:0.6325 v7:0.5193 v8:0.6215 v9:0.6077 v10:0.5303 v11:0.4861 v12:0.4861 v13:0.5524] depth_encoding:sinusoidal +step:4600/20000 val_loss:2.2170 val_bpb:1.3130 train_time:271032ms step_avg:58.92ms +step:4800/20000 train_loss:2.3192 train_time:282759ms step_avg:58.91ms +step:4800 shared0_alpha:mean=0.481,std=0.063 shared1_alpha:mean=0.588,std=0.063 shared2_alpha:mean=0.571,std=0.065 shared3_alpha:mean=0.601,std=0.069 eff_mlp_scale:[v0:146.2769 v1:93.7678 v2:96.1486 v3:94.8808 v4:83.7838 v5:107.6784 v6:98.0813 v7:92.4848 v8:84.7251 v9:95.3134 v10:79.7212 v11:80.0257 v12:80.4889 v13:218.0837] eff_attn_scale:[v0:0.3777 v1:0.8454 v2:0.9594 v3:0.8779 v4:0.9276 v5:0.7643 v6:0.8779 v7:0.8867 v8:0.9055 v9:0.6874 v10:0.6969 v11:0.6606 v12:0.6714 v13:1.8193] eff_attn_bias:[v0:0.4861 v1:0.7347 v2:0.6740 v3:0.6546 v4:0.7568 v5:0.6767 v6:0.7403 v7:0.6104 v8:0.6602 v9:0.5580 v10:0.5027 v11:0.3757 v12:0.4447 v13:0.6767] eff_mlp_bias:[v0:1.7567 v1:0.9115 v2:0.8563 v3:0.7292 v4:0.7458 v5:0.8231 v6:0.6491 v7:0.5331 v8:0.6353 v9:0.6270 v10:0.5497 v11:0.5027 v12:0.4999 v13:0.5635] depth_encoding:sinusoidal +step:4800/20000 val_loss:2.2128 val_bpb:1.3106 train_time:282790ms step_avg:58.91ms +step:5000/20000 train_loss:2.0898 train_time:294519ms step_avg:58.90ms +step:5000 shared0_alpha:mean=0.481,std=0.063 shared1_alpha:mean=0.590,std=0.063 shared2_alpha:mean=0.571,std=0.065 shared3_alpha:mean=0.600,std=0.069 eff_mlp_scale:[v0:151.4040 v1:95.6262 v2:98.5672 v3:96.7358 v4:85.5539 v5:109.6583 v6:100.0311 v7:94.3174 v8:86.5045 v9:96.6656 v10:81.0008 v11:81.2580 v12:81.7515 v13:221.9333] eff_attn_scale:[v0:0.3686 v1:0.8479 v2:0.9673 v3:0.8527 v4:0.9195 v5:0.7623 v6:0.8810 v7:0.8658 v8:0.9020 v9:0.6852 v10:0.6948 v11:0.6472 v12:0.6655 v13:1.7874] eff_attn_bias:[v0:0.4889 v1:0.7513 v2:0.6933 v3:0.6684 v4:0.7734 v5:0.6850 v6:0.7568 v7:0.6215 v8:0.6795 v9:0.5718 v10:0.5110 v11:0.3839 v12:0.4558 v13:0.6933] eff_mlp_bias:[v0:1.7899 v1:0.9336 v2:0.8784 v3:0.7458 v4:0.7623 v5:0.8452 v6:0.6657 v7:0.5469 v8:0.6546 v9:0.6408 v10:0.5662 v11:0.5138 v12:0.5110 v13:0.5745] depth_encoding:sinusoidal +step:5000/20000 val_loss:2.2072 val_bpb:1.3072 train_time:294549ms step_avg:58.91ms +step:5200/20000 train_loss:2.2253 train_time:306278ms step_avg:58.90ms +step:5200 shared0_alpha:mean=0.481,std=0.064 shared1_alpha:mean=0.593,std=0.063 shared2_alpha:mean=0.572,std=0.065 shared3_alpha:mean=0.599,std=0.069 eff_mlp_scale:[v0:155.1819 v1:98.6015 v2:101.2260 v3:99.5478 v4:87.4473 v5:111.7134 v6:102.2135 v7:96.1320 v8:88.4082 v9:98.6015 v10:82.9559 v11:82.9565 v12:84.0839 v13:226.0120] eff_attn_scale:[v0:0.3579 v1:0.8480 v2:0.9673 v3:0.8740 v4:0.9194 v5:0.7576 v6:0.8679 v7:0.8566 v8:0.8889 v9:0.6801 v10:0.6916 v11:0.6392 v12:0.6667 v13:1.7798] eff_attn_bias:[v0:0.4917 v1:0.7734 v2:0.7126 v3:0.6850 v4:0.7900 v5:0.6988 v6:0.7734 v7:0.6353 v8:0.6933 v9:0.5800 v10:0.5220 v11:0.3922 v12:0.4613 v13:0.7071] eff_mlp_bias:[v0:1.8230 v1:0.9557 v2:0.9005 v3:0.7679 v4:0.7844 v5:0.8618 v6:0.6878 v7:0.5635 v8:0.6684 v9:0.6574 v10:0.5828 v11:0.5276 v12:0.5220 v13:0.5883] depth_encoding:sinusoidal +step:5200/20000 val_loss:2.2086 val_bpb:1.3081 train_time:306309ms step_avg:58.91ms +step:5400/20000 train_loss:2.2355 train_time:318037ms step_avg:58.90ms +step:5400 shared0_alpha:mean=0.480,std=0.064 shared1_alpha:mean=0.595,std=0.063 shared2_alpha:mean=0.572,std=0.066 shared3_alpha:mean=0.598,std=0.069 eff_mlp_scale:[v0:159.4047 v1:100.5377 v2:103.8090 v3:102.0535 v4:89.2185 v5:113.7664 v6:104.3080 v7:98.1094 v8:90.1882 v9:100.0086 v10:84.3448 v11:84.3051 v12:85.8243 v13:228.3482] eff_attn_scale:[v0:0.3527 v1:0.8505 v2:0.9617 v3:0.8711 v4:0.9155 v5:0.7560 v6:0.8623 v7:0.8492 v8:0.8894 v9:0.6701 v10:0.6908 v11:0.6434 v12:0.6540 v13:1.7649] eff_attn_bias:[v0:0.4999 v1:0.7900 v2:0.7237 v3:0.7016 v4:0.8065 v5:0.7126 v6:0.7844 v7:0.6463 v8:0.7071 v9:0.5911 v10:0.5331 v11:0.4005 v12:0.4723 v13:0.7237] eff_mlp_bias:[v0:1.8562 v1:0.9778 v2:0.9226 v3:0.7844 v4:0.8010 v5:0.8784 v6:0.7043 v7:0.5745 v8:0.6850 v9:0.6740 v10:0.5994 v11:0.5414 v12:0.5331 v13:0.5994] depth_encoding:sinusoidal +step:5400/20000 val_loss:2.2039 val_bpb:1.3053 train_time:318068ms step_avg:58.90ms +step:5600/20000 train_loss:2.2378 train_time:329787ms step_avg:58.89ms +step:5600 shared0_alpha:mean=0.479,std=0.064 shared1_alpha:mean=0.597,std=0.064 shared2_alpha:mean=0.571,std=0.066 shared3_alpha:mean=0.598,std=0.070 eff_mlp_scale:[v0:163.0470 v1:102.9921 v2:105.7195 v3:104.0865 v4:90.8663 v5:115.7994 v6:106.2229 v7:100.1023 v8:91.8434 v9:101.3912 v10:86.0859 v11:86.1577 v12:86.9581 v13:232.2047] eff_attn_scale:[v0:0.3429 v1:0.8505 v2:0.9585 v3:0.8656 v4:0.9370 v5:0.7603 v6:0.8636 v7:0.8612 v8:0.8801 v9:0.6701 v10:0.6782 v11:0.6383 v12:0.6568 v13:1.7430] eff_attn_bias:[v0:0.4972 v1:0.8121 v2:0.7458 v3:0.7126 v4:0.8286 v5:0.7237 v6:0.8010 v7:0.6602 v8:0.7237 v9:0.6021 v10:0.5441 v11:0.4060 v12:0.4834 v13:0.7403] eff_mlp_bias:[v0:1.8893 v1:1.0054 v2:0.9502 v3:0.8010 v4:0.8231 v5:0.9005 v6:0.7182 v7:0.5856 v8:0.6988 v9:0.6878 v10:0.6132 v11:0.5524 v12:0.5441 v13:0.6132] depth_encoding:sinusoidal +step:5600/20000 val_loss:2.2037 val_bpb:1.3052 train_time:329817ms step_avg:58.90ms +step:5800/20000 train_loss:2.2066 train_time:341545ms step_avg:58.89ms +step:5800 shared0_alpha:mean=0.479,std=0.064 shared1_alpha:mean=0.599,std=0.064 shared2_alpha:mean=0.571,std=0.067 shared3_alpha:mean=0.597,std=0.070 eff_mlp_scale:[v0:168.1565 v1:105.4822 v2:108.8509 v3:106.4394 v4:92.7424 v5:117.8602 v6:108.3422 v7:101.4187 v8:93.2357 v9:103.3295 v10:87.4876 v11:87.3607 v12:89.2892 v13:236.3641] eff_attn_scale:[v0:0.3342 v1:0.8420 v2:0.9565 v3:0.8563 v4:0.9415 v5:0.7442 v6:0.8577 v7:0.8520 v8:0.8933 v9:0.6676 v10:0.6781 v11:0.6358 v12:0.6568 v13:1.6983] eff_attn_bias:[v0:0.4972 v1:0.8286 v2:0.7623 v3:0.7292 v4:0.8452 v5:0.7347 v6:0.8121 v7:0.6684 v8:0.7347 v9:0.6160 v10:0.5524 v11:0.4143 v12:0.4917 v13:0.7568] eff_mlp_bias:[v0:1.9224 v1:1.0220 v2:0.9723 v3:0.8176 v4:0.8397 v5:0.9226 v6:0.7347 v7:0.5994 v8:0.7182 v9:0.7043 v10:0.6270 v11:0.5662 v12:0.5552 v13:0.6242] depth_encoding:sinusoidal +step:5800/20000 val_loss:2.2023 val_bpb:1.3043 train_time:341575ms step_avg:58.89ms +step:6000/20000 train_loss:2.2704 train_time:353300ms step_avg:58.88ms +step:6000 shared0_alpha:mean=0.478,std=0.064 shared1_alpha:mean=0.602,std=0.065 shared2_alpha:mean=0.571,std=0.067 shared3_alpha:mean=0.596,std=0.070 eff_mlp_scale:[v0:171.7213 v1:108.0624 v2:111.3780 v3:108.5207 v4:94.4264 v5:120.0090 v6:110.3515 v7:103.4496 v8:94.9234 v9:104.8042 v10:88.7944 v11:89.2507 v12:90.4505 v13:239.8397] eff_attn_scale:[v0:0.3340 v1:0.8365 v2:0.9637 v3:0.8631 v4:0.9384 v5:0.7445 v6:0.8561 v7:0.8459 v8:0.8862 v9:0.6525 v10:0.6723 v11:0.6226 v12:0.6473 v13:1.7154] eff_attn_bias:[v0:0.4999 v1:0.8507 v2:0.7789 v3:0.7403 v4:0.8673 v5:0.7513 v6:0.8286 v7:0.6822 v8:0.7458 v9:0.6270 v10:0.5635 v11:0.4254 v12:0.4999 v13:0.7734] eff_mlp_bias:[v0:1.9445 v1:1.0441 v2:0.9944 v3:0.8342 v4:0.8563 v5:0.9447 v6:0.7568 v7:0.6160 v8:0.7347 v9:0.7237 v10:0.6436 v11:0.5800 v12:0.5662 v13:0.6353] depth_encoding:sinusoidal +step:6000/20000 val_loss:2.1985 val_bpb:1.3021 train_time:353330ms step_avg:58.89ms +step:6200/20000 train_loss:2.1470 train_time:365057ms step_avg:58.88ms +step:6200 shared0_alpha:mean=0.478,std=0.065 shared1_alpha:mean=0.604,std=0.065 shared2_alpha:mean=0.571,std=0.067 shared3_alpha:mean=0.595,std=0.070 eff_mlp_scale:[v0:175.7898 v1:109.9845 v2:113.7156 v3:111.2663 v4:96.1211 v5:122.0226 v6:112.6771 v7:105.6261 v8:96.6217 v9:106.7014 v10:90.8686 v11:91.2692 v12:92.1160 v13:244.0034] eff_attn_scale:[v0:0.3192 v1:0.8359 v2:0.9590 v3:0.8515 v4:0.9179 v5:0.7309 v6:0.8559 v7:0.8515 v8:0.8746 v9:0.6427 v10:0.6767 v11:0.6224 v12:0.6494 v13:1.6625] eff_attn_bias:[v0:0.5027 v1:0.8728 v2:0.8010 v3:0.7568 v4:0.8784 v5:0.7623 v6:0.8397 v7:0.6933 v8:0.7568 v9:0.6353 v10:0.5745 v11:0.4309 v12:0.5110 v13:0.7844] eff_mlp_bias:[v0:1.9777 v1:1.0717 v2:1.0165 v3:0.8507 v4:0.8728 v5:0.9557 v6:0.7679 v7:0.6270 v8:0.7458 v9:0.7347 v10:0.6574 v11:0.5911 v12:0.5773 v13:0.6463] depth_encoding:sinusoidal +step:6200/20000 val_loss:2.1973 val_bpb:1.3014 train_time:365088ms step_avg:58.89ms +step:6400/20000 train_loss:2.2205 train_time:376823ms step_avg:58.88ms +step:6400 shared0_alpha:mean=0.477,std=0.065 shared1_alpha:mean=0.606,std=0.065 shared2_alpha:mean=0.571,std=0.068 shared3_alpha:mean=0.595,std=0.071 eff_mlp_scale:[v0:180.9867 v1:112.6642 v2:116.4174 v3:113.8787 v4:98.5493 v5:124.2620 v6:114.8442 v7:107.6671 v8:98.0439 v9:108.2460 v10:92.8193 v11:92.6558 v12:94.0008 v13:247.6979] eff_attn_scale:[v0:0.3199 v1:0.8530 v2:0.9636 v3:0.8780 v4:0.9439 v5:0.7384 v6:0.8516 v7:0.8476 v8:0.8827 v9:0.6451 v10:0.6678 v11:0.6259 v12:0.6511 v13:1.6925] eff_attn_bias:[v0:0.5027 v1:0.8894 v2:0.8121 v3:0.7679 v4:0.8949 v5:0.7734 v6:0.8563 v7:0.7043 v8:0.7734 v9:0.6463 v10:0.5828 v11:0.4392 v12:0.5193 v13:0.8010] eff_mlp_bias:[v0:2.0108 v1:1.0883 v2:1.0330 v3:0.8673 v4:0.8839 v5:0.9778 v6:0.7844 v7:0.6381 v8:0.7679 v9:0.7513 v10:0.6712 v11:0.6049 v12:0.5883 v13:0.6546] depth_encoding:sinusoidal +step:6400/20000 val_loss:2.1936 val_bpb:1.2992 train_time:376848ms step_avg:58.88ms +step:6600/20000 train_loss:2.1762 train_time:388580ms step_avg:58.88ms +step:6600 shared0_alpha:mean=0.476,std=0.066 shared1_alpha:mean=0.609,std=0.066 shared2_alpha:mean=0.571,std=0.068 shared3_alpha:mean=0.593,std=0.071 eff_mlp_scale:[v0:184.6921 v1:115.1496 v2:118.9919 v3:115.8578 v4:100.3595 v5:125.7189 v6:116.8765 v7:109.5952 v8:99.8501 v9:109.5868 v10:94.1358 v11:94.4606 v12:95.2651 v13:251.7797] eff_attn_scale:[v0:0.3123 v1:0.8518 v2:0.9686 v3:0.8629 v4:0.9508 v5:0.7411 v6:0.8559 v7:0.8455 v8:0.8810 v9:0.6474 v10:0.6757 v11:0.6287 v12:0.6542 v13:1.6763] eff_attn_bias:[v0:0.5027 v1:0.9115 v2:0.8286 v3:0.7844 v4:0.9115 v5:0.7900 v6:0.8728 v7:0.7126 v8:0.7900 v9:0.6602 v10:0.5966 v11:0.4447 v12:0.5331 v13:0.8121] eff_mlp_bias:[v0:2.0329 v1:1.1049 v2:1.0551 v3:0.8894 v4:0.9060 v5:0.9944 v6:0.8010 v7:0.6491 v8:0.7789 v9:0.7679 v10:0.6850 v11:0.6187 v12:0.5994 v13:0.6684] depth_encoding:sinusoidal +step:6600/20000 val_loss:2.1906 val_bpb:1.2974 train_time:388605ms step_avg:58.88ms +step:6800/20000 train_loss:2.2449 train_time:400393ms step_avg:58.88ms +step:6800 shared0_alpha:mean=0.475,std=0.066 shared1_alpha:mean=0.611,std=0.066 shared2_alpha:mean=0.570,std=0.068 shared3_alpha:mean=0.592,std=0.071 eff_mlp_scale:[v0:188.5722 v1:117.1606 v2:121.4109 v3:118.2960 v4:102.3443 v5:127.8116 v6:118.7483 v7:111.4611 v8:101.3158 v9:111.5549 v10:95.8507 v11:95.6883 v12:97.2014 v13:253.8321] eff_attn_scale:[v0:0.3050 v1:0.8442 v2:0.9597 v3:0.8727 v4:0.9510 v5:0.7345 v6:0.8398 v7:0.8383 v8:0.8772 v9:0.6374 v10:0.6576 v11:0.6147 v12:0.6557 v13:1.6447] eff_attn_bias:[v0:0.5082 v1:0.9281 v2:0.8452 v3:0.8010 v4:0.9336 v5:0.8065 v6:0.8839 v7:0.7237 v8:0.8010 v9:0.6712 v10:0.6049 v11:0.4502 v12:0.5414 v13:0.8286] eff_mlp_bias:[v0:2.0661 v1:1.1270 v2:1.0772 v3:0.9005 v4:0.9226 v5:1.0054 v6:0.8176 v7:0.6629 v8:0.7955 v9:0.7844 v10:0.7016 v11:0.6298 v12:0.6104 v13:0.6795] depth_encoding:sinusoidal +step:6800/20000 val_loss:2.1889 val_bpb:1.2964 train_time:400432ms step_avg:58.89ms +step:7000/20000 train_loss:2.2807 train_time:412163ms step_avg:58.88ms +step:7000 shared0_alpha:mean=0.474,std=0.066 shared1_alpha:mean=0.613,std=0.067 shared2_alpha:mean=0.571,std=0.068 shared3_alpha:mean=0.591,std=0.071 eff_mlp_scale:[v0:193.8277 v1:119.3186 v2:124.8006 v3:120.4276 v4:103.6154 v5:130.0629 v6:121.0351 v7:113.5309 v8:103.0973 v9:113.0982 v10:97.3660 v11:97.6153 v12:98.9527 v13:257.8575] eff_attn_scale:[v0:0.3045 v1:0.8598 v2:0.9760 v3:0.8848 v4:0.9615 v5:0.7364 v6:0.8506 v7:0.8456 v8:0.8828 v9:0.6470 v10:0.6581 v11:0.6277 v12:0.6555 v13:1.6466] eff_attn_bias:[v0:0.5082 v1:0.9502 v2:0.8618 v3:0.8176 v4:0.9447 v5:0.8176 v6:0.9005 v7:0.7403 v8:0.8121 v9:0.6795 v10:0.6160 v11:0.4585 v12:0.5497 v13:0.8397] eff_mlp_bias:[v0:2.0992 v1:1.1490 v2:1.0938 v3:0.9170 v4:0.9336 v5:1.0275 v6:0.8342 v7:0.6767 v8:0.8065 v9:0.7955 v10:0.7126 v11:0.6408 v12:0.6187 v13:0.6905] depth_encoding:sinusoidal +step:7000/20000 val_loss:2.1884 val_bpb:1.2961 train_time:412186ms step_avg:58.88ms +step:7200/20000 train_loss:2.2563 train_time:423926ms step_avg:58.88ms +step:7200 shared0_alpha:mean=0.474,std=0.067 shared1_alpha:mean=0.616,std=0.067 shared2_alpha:mean=0.571,std=0.069 shared3_alpha:mean=0.590,std=0.071 eff_mlp_scale:[v0:197.6152 v1:121.8241 v2:126.7557 v3:122.5043 v4:106.0428 v5:131.5017 v6:122.9639 v7:115.5499 v8:104.4757 v9:114.4235 v10:99.1295 v11:98.9664 v12:100.2967 v13:261.8293] eff_attn_scale:[v0:0.2982 v1:0.8658 v2:0.9792 v3:0.8818 v4:0.9702 v5:0.7427 v6:0.8451 v7:0.8556 v8:0.8824 v9:0.6451 v10:0.6662 v11:0.6286 v12:0.6629 v13:1.6167] eff_attn_bias:[v0:0.5055 v1:0.9667 v2:0.8784 v3:0.8286 v4:0.9667 v5:0.8286 v6:0.9115 v7:0.7513 v8:0.8231 v9:0.6905 v10:0.6242 v11:0.4640 v12:0.5607 v13:0.8563] eff_mlp_bias:[v0:2.1213 v1:1.1656 v2:1.1159 v3:0.9336 v4:0.9502 v5:1.0386 v6:0.8452 v7:0.6878 v8:0.8231 v9:0.8065 v10:0.7237 v11:0.6519 v12:0.6298 v13:0.7016] depth_encoding:sinusoidal +step:7200/20000 val_loss:2.1875 val_bpb:1.2955 train_time:423951ms step_avg:58.88ms +step:7400/20000 train_loss:2.1726 train_time:435685ms step_avg:58.88ms +step:7400 shared0_alpha:mean=0.473,std=0.067 shared1_alpha:mean=0.618,std=0.067 shared2_alpha:mean=0.570,std=0.069 shared3_alpha:mean=0.590,std=0.071 eff_mlp_scale:[v0:203.0925 v1:123.9202 v2:129.5425 v3:125.1102 v4:107.3104 v5:133.6732 v6:125.1697 v7:117.5604 v8:105.7323 v9:115.8883 v10:100.5730 v11:100.8431 v12:102.0501 v13:263.7550] eff_attn_scale:[v0:0.2876 v1:0.8517 v2:0.9711 v3:0.8805 v4:0.9727 v5:0.7306 v6:0.8337 v7:0.8458 v8:0.8768 v9:0.6262 v10:0.6563 v11:0.6246 v12:0.6499 v13:1.6352] eff_attn_bias:[v0:0.5110 v1:0.9888 v2:0.8894 v3:0.8452 v4:0.9833 v5:0.8452 v6:0.9281 v7:0.7623 v8:0.8397 v9:0.7043 v10:0.6353 v11:0.4723 v12:0.5690 v13:0.8673] eff_mlp_bias:[v0:2.1434 v1:1.1877 v2:1.1270 v3:0.9502 v4:0.9667 v5:1.0551 v6:0.8618 v7:0.6988 v8:0.8397 v9:0.8231 v10:0.7403 v11:0.6629 v12:0.6381 v13:0.7126] depth_encoding:sinusoidal +step:7400/20000 val_loss:2.1845 val_bpb:1.2938 train_time:435709ms step_avg:58.88ms +step:7600/20000 train_loss:2.0546 train_time:447444ms step_avg:58.87ms +step:7600 shared0_alpha:mean=0.472,std=0.067 shared1_alpha:mean=0.621,std=0.068 shared2_alpha:mean=0.570,std=0.069 shared3_alpha:mean=0.589,std=0.071 eff_mlp_scale:[v0:207.0811 v1:126.0600 v2:131.8895 v3:127.7551 v4:109.1114 v5:135.8904 v6:126.9230 v7:119.0569 v8:107.5224 v9:117.3862 v10:102.0902 v11:102.2041 v12:103.8148 v13:267.7385] eff_attn_scale:[v0:0.2909 v1:0.8606 v2:0.9846 v3:0.8806 v4:0.9690 v5:0.7269 v6:0.8452 v7:0.8416 v8:0.8817 v9:0.6267 v10:0.6609 v11:0.6203 v12:0.6547 v13:1.6012] eff_attn_bias:[v0:0.5110 v1:1.0054 v2:0.9115 v3:0.8507 v4:0.9944 v5:0.8563 v6:0.9391 v7:0.7734 v8:0.8507 v9:0.7182 v10:0.6463 v11:0.4778 v12:0.5800 v13:0.8839] eff_mlp_bias:[v0:2.1766 v1:1.2043 v2:1.1490 v3:0.9667 v4:0.9778 v5:1.0717 v6:0.8784 v7:0.7071 v8:0.8507 v9:0.8342 v10:0.7513 v11:0.6767 v12:0.6491 v13:0.7182] depth_encoding:sinusoidal +step:7600/20000 val_loss:2.1839 val_bpb:1.2934 train_time:447469ms step_avg:58.88ms +step:7800/20000 train_loss:2.2072 train_time:459195ms step_avg:58.87ms +step:7800 shared0_alpha:mean=0.472,std=0.067 shared1_alpha:mean=0.623,std=0.068 shared2_alpha:mean=0.569,std=0.069 shared3_alpha:mean=0.588,std=0.071 eff_mlp_scale:[v0:210.8913 v1:129.2975 v2:134.7026 v3:130.3263 v4:110.8714 v5:138.0338 v6:129.1364 v7:120.4697 v8:109.2722 v9:119.3963 v10:104.0884 v11:104.0420 v12:105.0080 v13:271.7089] eff_attn_scale:[v0:0.2896 v1:0.8537 v2:0.9893 v3:0.8843 v4:0.9900 v5:0.7335 v6:0.8461 v7:0.8539 v8:0.8976 v9:0.6216 v10:0.6581 v11:0.6198 v12:0.6600 v13:1.6186] eff_attn_bias:[v0:0.5138 v1:1.0275 v2:0.9226 v3:0.8673 v4:1.0109 v5:0.8728 v6:0.9502 v7:0.7844 v8:0.8618 v9:0.7292 v10:0.6546 v11:0.4861 v12:0.5883 v13:0.8949] eff_mlp_bias:[v0:2.1987 v1:1.2264 v2:1.1656 v3:0.9778 v4:0.9944 v5:1.0883 v6:0.8894 v7:0.7182 v8:0.8673 v9:0.8452 v10:0.7679 v11:0.6822 v12:0.6574 v13:0.7292] depth_encoding:sinusoidal +step:7800/20000 val_loss:2.1800 val_bpb:1.2911 train_time:459220ms step_avg:58.87ms +step:8000/20000 train_loss:2.1668 train_time:470948ms step_avg:58.87ms +step:8000 shared0_alpha:mean=0.471,std=0.067 shared1_alpha:mean=0.625,std=0.068 shared2_alpha:mean=0.569,std=0.070 shared3_alpha:mean=0.587,std=0.072 eff_mlp_scale:[v0:216.5491 v1:131.1973 v2:137.8954 v3:132.7072 v4:113.2817 v5:139.9829 v6:131.7042 v7:123.3071 v8:110.5973 v9:120.6547 v10:105.8136 v11:105.6128 v12:107.3760 v13:275.4622] eff_attn_scale:[v0:0.2796 v1:0.8794 v2:0.9894 v3:0.8871 v4:0.9918 v5:0.7448 v6:0.8417 v7:0.8481 v8:0.8949 v9:0.6354 v10:0.6581 v11:0.6231 v12:0.6656 v13:1.6173] eff_attn_bias:[v0:0.5110 v1:1.0441 v2:0.9447 v3:0.8784 v4:1.0275 v5:0.8839 v6:0.9612 v7:0.7955 v8:0.8784 v9:0.7403 v10:0.6657 v11:0.4944 v12:0.5994 v13:0.9005] eff_mlp_bias:[v0:2.2208 v1:1.2540 v2:1.1877 v3:0.9944 v4:1.0109 v5:1.0993 v6:0.9005 v7:0.7292 v8:0.8784 v9:0.8563 v10:0.7789 v11:0.6961 v12:0.6740 v13:0.7403] depth_encoding:sinusoidal +step:8000/20000 val_loss:2.1791 val_bpb:1.2906 train_time:470973ms step_avg:58.87ms +step:8200/20000 train_loss:2.2366 train_time:482692ms step_avg:58.86ms +step:8200 shared0_alpha:mean=0.470,std=0.068 shared1_alpha:mean=0.627,std=0.069 shared2_alpha:mean=0.569,std=0.069 shared3_alpha:mean=0.586,std=0.072 eff_mlp_scale:[v0:220.5306 v1:133.9416 v2:140.5703 v3:135.3293 v4:115.5924 v5:141.6122 v6:133.7686 v7:125.3049 v8:112.3515 v9:122.1406 v10:107.1282 v11:107.4837 v12:108.5705 v13:279.6792] eff_attn_scale:[v0:0.2763 v1:0.8807 v2:0.9846 v3:0.8918 v4:0.9995 v5:0.7495 v6:0.8503 v7:0.8570 v8:0.9070 v9:0.6309 v10:0.6668 v11:0.6308 v12:0.6693 v13:1.6274] eff_attn_bias:[v0:0.5110 v1:1.0662 v2:0.9612 v3:0.8894 v4:1.0386 v5:0.9005 v6:0.9778 v7:0.8065 v8:0.8949 v9:0.7513 v10:0.6767 v11:0.4999 v12:0.6077 v13:0.9170] eff_mlp_bias:[v0:2.2429 v1:1.2761 v2:1.2043 v3:1.0109 v4:1.0275 v5:1.1159 v6:0.9170 v7:0.7403 v8:0.8949 v9:0.8728 v10:0.7900 v11:0.7071 v12:0.6822 v13:0.7458] depth_encoding:sinusoidal +step:8200/20000 val_loss:2.1763 val_bpb:1.2889 train_time:482722ms step_avg:58.87ms +step:8400/20000 train_loss:2.1837 train_time:494511ms step_avg:58.87ms +step:8400 shared0_alpha:mean=0.470,std=0.068 shared1_alpha:mean=0.630,std=0.069 shared2_alpha:mean=0.569,std=0.070 shared3_alpha:mean=0.585,std=0.072 eff_mlp_scale:[v0:226.2229 v1:136.6276 v2:144.1857 v3:138.1543 v4:116.8718 v5:144.3500 v6:136.1753 v7:126.9222 v8:113.6103 v9:124.1529 v10:108.7114 v11:108.9510 v12:110.3487 v13:281.5219] eff_attn_scale:[v0:0.2767 v1:0.8880 v2:1.0091 v3:0.9013 v4:1.0253 v5:0.7442 v6:0.8560 v7:0.8532 v8:0.9144 v9:0.6300 v10:0.6667 v11:0.6344 v12:0.6747 v13:1.6118] eff_attn_bias:[v0:0.5138 v1:1.0883 v2:0.9778 v3:0.9005 v4:1.0607 v5:0.9115 v6:0.9944 v7:0.8231 v8:0.9060 v9:0.7623 v10:0.6878 v11:0.5082 v12:0.6160 v13:0.9226] eff_mlp_bias:[v0:2.2650 v1:1.2982 v2:1.2209 v3:1.0275 v4:1.0441 v5:1.1270 v6:0.9336 v7:0.7513 v8:0.9060 v9:0.8839 v10:0.8065 v11:0.7182 v12:0.6905 v13:0.7568] depth_encoding:sinusoidal +step:8400/20000 val_loss:2.1760 val_bpb:1.2887 train_time:494542ms step_avg:58.87ms +step:8600/20000 train_loss:2.1928 train_time:506266ms step_avg:58.87ms +step:8600 shared0_alpha:mean=0.469,std=0.068 shared1_alpha:mean=0.633,std=0.069 shared2_alpha:mean=0.568,std=0.070 shared3_alpha:mean=0.584,std=0.072 eff_mlp_scale:[v0:230.2102 v1:138.7728 v2:147.1916 v3:140.4101 v4:119.1224 v5:146.5489 v6:138.5333 v7:129.6528 v8:115.8287 v9:125.6133 v10:110.8266 v11:110.4031 v12:112.5350 v13:285.2443] eff_attn_scale:[v0:0.2729 v1:0.8861 v2:1.0086 v3:0.9193 v4:1.0285 v5:0.7559 v6:0.8555 v7:0.8580 v8:0.9172 v9:0.6257 v10:0.6664 v11:0.6304 v12:0.6723 v13:1.5990] eff_attn_bias:[v0:0.5110 v1:1.1049 v2:0.9944 v3:0.9170 v4:1.0717 v5:0.9226 v6:1.0109 v7:0.8286 v8:0.9170 v9:0.7679 v10:0.6988 v11:0.5138 v12:0.6242 v13:0.9391] eff_mlp_bias:[v0:2.2870 v1:1.3148 v2:1.2374 v3:1.0441 v4:1.0551 v5:1.1435 v6:0.9502 v7:0.7623 v8:0.9170 v9:0.8949 v10:0.8176 v11:0.7292 v12:0.7016 v13:0.7623] depth_encoding:sinusoidal +step:8600/20000 val_loss:2.1735 val_bpb:1.2873 train_time:506297ms step_avg:58.87ms +step:8800/20000 train_loss:2.1619 train_time:518030ms step_avg:58.87ms +step:8800 shared0_alpha:mean=0.468,std=0.069 shared1_alpha:mean=0.635,std=0.070 shared2_alpha:mean=0.568,std=0.070 shared3_alpha:mean=0.582,std=0.072 eff_mlp_scale:[v0:235.7040 v1:141.3207 v2:150.1160 v3:143.1708 v4:121.6550 v5:148.5371 v6:140.8064 v7:131.1924 v8:117.2312 v9:126.8880 v10:112.2961 v11:112.3691 v12:113.9134 v13:287.0195] eff_attn_scale:[v0:0.2688 v1:0.8927 v2:1.0101 v3:0.9047 v4:1.0514 v5:0.7645 v6:0.8530 v7:0.8525 v8:0.9127 v9:0.6364 v10:0.6555 v11:0.6263 v12:0.6756 v13:1.5983] eff_attn_bias:[v0:0.5138 v1:1.1270 v2:1.0165 v3:0.9281 v4:1.0883 v5:0.9391 v6:1.0165 v7:0.8397 v8:0.9336 v9:0.7844 v10:0.7071 v11:0.5193 v12:0.6325 v13:0.9447] eff_mlp_bias:[v0:2.3091 v1:1.3369 v2:1.2540 v3:1.0551 v4:1.0717 v5:1.1601 v6:0.9667 v7:0.7734 v8:0.9336 v9:0.9060 v10:0.8286 v11:0.7403 v12:0.7126 v13:0.7679] depth_encoding:sinusoidal +step:8800/20000 val_loss:2.1720 val_bpb:1.2864 train_time:518054ms step_avg:58.87ms +step:9000/20000 train_loss:2.0805 train_time:529792ms step_avg:58.87ms +step:9000 shared0_alpha:mean=0.467,std=0.069 shared1_alpha:mean=0.638,std=0.070 shared2_alpha:mean=0.568,std=0.070 shared3_alpha:mean=0.581,std=0.072 eff_mlp_scale:[v0:240.5350 v1:144.2761 v2:152.5839 v3:146.1572 v4:122.9263 v5:150.9443 v6:143.1941 v7:133.4979 v8:118.4765 v9:128.5148 v10:113.8511 v11:113.9336 v12:115.6953 v13:291.2618] eff_attn_scale:[v0:0.2667 v1:0.9120 v2:1.0282 v3:0.9273 v4:1.0704 v5:0.7737 v6:0.8606 v7:0.8831 v8:0.9310 v9:0.6440 v10:0.6704 v11:0.6403 v12:0.6881 v13:1.6274] eff_attn_bias:[v0:0.5138 v1:1.1435 v2:1.0330 v3:0.9502 v4:1.1104 v5:0.9502 v6:1.0330 v7:0.8507 v8:0.9502 v9:0.7955 v10:0.7182 v11:0.5276 v12:0.6408 v13:0.9612] eff_mlp_bias:[v0:2.3312 v1:1.3534 v2:1.2706 v3:1.0717 v4:1.0883 v5:1.1711 v6:0.9778 v7:0.7844 v8:0.9447 v9:0.9170 v10:0.8397 v11:0.7513 v12:0.7182 v13:0.7734] depth_encoding:sinusoidal +step:9000/20000 val_loss:2.1731 val_bpb:1.2870 train_time:529816ms step_avg:58.87ms +step:9200/20000 train_loss:2.1320 train_time:541541ms step_avg:58.86ms +step:9200 shared0_alpha:mean=0.466,std=0.069 shared1_alpha:mean=0.640,std=0.071 shared2_alpha:mean=0.567,std=0.070 shared3_alpha:mean=0.580,std=0.072 eff_mlp_scale:[v0:243.9950 v1:146.3367 v2:155.0233 v3:148.2915 v4:124.9611 v5:152.4341 v6:145.5562 v7:134.9684 v8:120.4782 v9:129.8738 v10:115.3799 v11:115.8527 v12:117.1160 v13:295.3407] eff_attn_scale:[v0:0.2652 v1:0.9201 v2:1.0197 v3:0.9299 v4:1.0786 v5:0.7703 v6:0.8663 v7:0.8810 v8:0.9472 v9:0.6376 v10:0.6587 v11:0.6407 v12:0.6934 v13:1.6424] eff_attn_bias:[v0:0.5193 v1:1.1601 v2:1.0496 v3:0.9612 v4:1.1214 v5:0.9612 v6:1.0496 v7:0.8618 v8:0.9612 v9:0.8121 v10:0.7347 v11:0.5331 v12:0.6463 v13:0.9667] eff_mlp_bias:[v0:2.3533 v1:1.3700 v2:1.2872 v3:1.0828 v4:1.0993 v5:1.1822 v6:0.9944 v7:0.7955 v8:0.9612 v9:0.9281 v10:0.8507 v11:0.7623 v12:0.7292 v13:0.7844] depth_encoding:sinusoidal +step:9200/20000 val_loss:2.1636 val_bpb:1.2814 train_time:541571ms step_avg:58.87ms +step:9400/20000 train_loss:2.1763 train_time:553298ms step_avg:58.86ms +step:9400 shared0_alpha:mean=0.466,std=0.069 shared1_alpha:mean=0.641,std=0.071 shared2_alpha:mean=0.566,std=0.070 shared3_alpha:mean=0.579,std=0.072 eff_mlp_scale:[v0:247.1050 v1:147.7496 v2:157.3117 v3:149.2161 v4:126.6447 v5:153.8803 v6:147.1818 v7:136.3928 v8:122.1217 v9:131.1967 v10:116.7920 v11:116.5751 v12:118.7294 v13:297.4357] eff_attn_scale:[v0:0.2637 v1:0.9329 v2:1.0400 v3:0.9449 v4:1.0872 v5:0.7781 v6:0.8842 v7:0.8784 v8:0.9679 v9:0.6449 v10:0.6780 v11:0.6433 v12:0.7019 v13:1.6673] eff_attn_bias:[v0:0.5110 v1:1.1711 v2:1.0607 v3:0.9667 v4:1.1325 v5:0.9667 v6:1.0551 v7:0.8673 v8:0.9667 v9:0.8176 v10:0.7403 v11:0.5359 v12:0.6519 v13:0.9778] eff_mlp_bias:[v0:2.3533 v1:1.3811 v2:1.2927 v3:1.0883 v4:1.1104 v5:1.1932 v6:0.9999 v7:0.8010 v8:0.9667 v9:0.9391 v10:0.8618 v11:0.7679 v12:0.7347 v13:0.7955] depth_encoding:sinusoidal +step:9400/20000 val_loss:2.1536 val_bpb:1.2755 train_time:553329ms step_avg:58.86ms +step:9600/20000 train_loss:2.1820 train_time:565046ms step_avg:58.86ms +step:9600 shared0_alpha:mean=0.465,std=0.069 shared1_alpha:mean=0.642,std=0.071 shared2_alpha:mean=0.565,std=0.070 shared3_alpha:mean=0.578,std=0.072 eff_mlp_scale:[v0:247.7165 v1:148.3111 v2:158.0710 v3:151.3469 v4:128.3050 v5:155.0805 v6:147.8922 v7:137.8547 v8:123.1728 v9:131.6953 v10:117.3558 v11:117.9098 v12:119.7514 v13:301.5698] eff_attn_scale:[v0:0.2630 v1:0.9361 v2:1.0344 v3:0.9398 v4:1.0961 v5:0.7845 v6:0.8704 v7:0.8819 v8:0.9677 v9:0.6501 v10:0.6699 v11:0.6503 v12:0.7063 v13:1.6560] eff_attn_bias:[v0:0.5110 v1:1.1767 v2:1.0662 v3:0.9723 v4:1.1380 v5:0.9778 v6:1.0607 v7:0.8673 v8:0.9723 v9:0.8176 v10:0.7458 v11:0.5414 v12:0.6574 v13:0.9778] eff_mlp_bias:[v0:2.3533 v1:1.3811 v2:1.2982 v3:1.0938 v4:1.1159 v5:1.1988 v6:1.0054 v7:0.8065 v8:0.9723 v9:0.9447 v10:0.8673 v11:0.7679 v12:0.7403 v13:0.8010] depth_encoding:sinusoidal +step:9600/20000 val_loss:2.1442 val_bpb:1.2699 train_time:565077ms step_avg:58.86ms +step:9800/20000 train_loss:2.0977 train_time:576802ms step_avg:58.86ms +step:9800 shared0_alpha:mean=0.464,std=0.069 shared1_alpha:mean=0.642,std=0.071 shared2_alpha:mean=0.565,std=0.070 shared3_alpha:mean=0.578,std=0.072 eff_mlp_scale:[v0:247.9621 v1:148.8215 v2:158.8952 v3:152.0676 v4:129.1139 v5:155.6142 v6:148.6633 v7:138.5112 v8:123.9493 v9:132.1486 v10:117.9676 v11:118.4713 v12:120.5063 v13:303.2700] eff_attn_scale:[v0:0.2721 v1:0.9413 v2:1.0305 v3:0.9389 v4:1.1016 v5:0.7801 v6:0.8794 v7:0.8853 v8:0.9639 v9:0.6450 v10:0.6687 v11:0.6438 v12:0.7022 v13:1.6871] eff_attn_bias:[v0:0.5138 v1:1.1767 v2:1.0662 v3:0.9778 v4:1.1380 v5:0.9778 v6:1.0607 v7:0.8728 v8:0.9778 v9:0.8231 v10:0.7513 v11:0.5441 v12:0.6602 v13:0.9778] eff_mlp_bias:[v0:2.3533 v1:1.3811 v2:1.3037 v3:1.0938 v4:1.1214 v5:1.2043 v6:1.0109 v7:0.8121 v8:0.9778 v9:0.9447 v10:0.8673 v11:0.7734 v12:0.7458 v13:0.8065] depth_encoding:sinusoidal +step:9800/20000 val_loss:2.1359 val_bpb:1.2650 train_time:576833ms step_avg:58.86ms +step:10000/20000 train_loss:2.1303 train_time:588562ms step_avg:58.86ms +step:10000 shared0_alpha:mean=0.464,std=0.069 shared1_alpha:mean=0.642,std=0.071 shared2_alpha:mean=0.564,std=0.070 shared3_alpha:mean=0.577,std=0.072 eff_mlp_scale:[v0:247.8048 v1:149.0514 v2:160.5002 v3:152.4073 v4:129.8024 v5:155.8546 v6:149.0359 v7:138.8206 v8:124.6103 v9:132.3527 v10:118.2633 v11:118.7360 v12:121.1489 v13:304.5426] eff_attn_scale:[v0:0.2631 v1:0.9419 v2:1.0281 v3:0.9489 v4:1.1151 v5:0.7886 v6:0.8681 v7:0.8947 v8:0.9850 v9:0.6484 v10:0.6625 v11:0.6507 v12:0.7109 v13:1.6958] eff_attn_bias:[v0:0.5110 v1:1.1822 v2:1.0717 v3:0.9778 v4:1.1435 v5:0.9778 v6:1.0607 v7:0.8728 v8:0.9723 v9:0.8231 v10:0.7513 v11:0.5441 v12:0.6602 v13:0.9833] eff_mlp_bias:[v0:2.3533 v1:1.3866 v2:1.3037 v3:1.0938 v4:1.1214 v5:1.2043 v6:1.0109 v7:0.8121 v8:0.9778 v9:0.9447 v10:0.8673 v11:0.7734 v12:0.7458 v13:0.8065] depth_encoding:sinusoidal +step:10000/20000 val_loss:2.1263 val_bpb:1.2593 train_time:588586ms step_avg:58.86ms +step:10195/20000 val_loss:2.1192 val_bpb:1.2551 train_time:600044ms step_avg:58.86ms +stopping_early: wallclock_cap train_time:600044ms step:10195/20000 +peak memory allocated: 13736 MiB reserved: 14196 MiB +Serialized model: 45208316 bytes +Code size: 64182 bytes +Total submission size: 45272498 bytes +Serialized model int8+zlib: 10728555 bytes (payload:11667648 raw_torch:11699443 payload_ratio:3.87x) +Total submission size int8+zlib: 10792737 bytes +final_int8_zlib_roundtrip val_loss:2.1316 val_bpb:1.2624 eval_time:1871ms +final_int8_zlib_roundtrip_exact val_loss:2.13156221 val_bpb:1.26243121 diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md index 4cd3b47ebb..b6981aff27 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/research_notes.md @@ -236,3 +236,42 @@ The natural next experiment is replacing learned depth embeddings with **sinusoi The Universal Transformer adds sinusoidal timestep embeddings $T_t$ at each recurrence step, where $T_t$ uses the same sinusoidal formula as positional encodings but indexed by iteration count rather than sequence position. This provides orthogonal identity signals across iterations with bounded magnitude. > Dehghani, M., Gouws, S., Vinyals, O., Uszkoreit, J. & Kaiser, L. (2019). "Universal Transformers." ICLR 2019. [arXiv:1807.03819](https://arxiv.org/abs/1807.03819) + +## 11. Series 5: Sinusoidal Depth Encoding + +**Experiment.** Run T replaces learned depth embeddings (Series 4) with deterministic sinusoidal depth encodings following the Universal Transformer (Dehghani et al., 2019, §2.1). The encoding uses the same sinusoidal formula as positional encodings but indexed by iteration count rather than sequence position: $\text{enc}(v, i) = \sin(v / b^{i/d})$ where $v$ is the virtual layer index, $b = 10000$ is the base frequency, $i$ is the dimension index, and $d$ is the model dimension. These are computed analytically at initialization and stored as a non-persistent buffer — zero learnable parameters, zero artifact cost. + +| Run | Config | Eff. Layers | Pre-Q BPB | Post-Q BPB | Q-Gap | Steps | step_avg | Artifact | +|-----|--------|-------------|-----------|------------|-------|-------|----------|----------| +| T | 1+4×3+1 sinusoidal depth+bias | 14 | 1.2551 | 1.2624 | +0.0073 | 10195 | 58.86ms | 10.73MB | + +### 11a. Comparison vs FiLM Bias Alone (s3_O) + +| Metric | s3_O (no depth) | s5_T (sinusoidal) | Delta | +|--------|-----------------|-------------------|-------| +| Post-Q BPB | 1.2625 | 1.2624 | −0.0001 | +| Pre-Q BPB | 1.2547 | 1.2551 | +0.0004 | +| Q-gap | +0.0078 | +0.0073 | −0.0005 | +| step_avg | 58.18ms | 58.86ms | +0.68ms | + +BPB is essentially identical (−0.0001 post-Q), confirming that per-iteration identity signaling provides no meaningful BPB improvement when FiLM gammas and betas are already present — they already provide sufficient per-iteration differentiation. The Q-gap improves marginally (0.0078 → 0.0073), likely because the sinusoidal encoding adds a fixed per-iteration bias to the input that slightly reduces the variance seen by shared weights, making quantization more stable. + +### 11b. Comparison vs Learned Depth Embeddings (s4_R) + +| Metric | s4_R (learned depth) | s5_T (sinusoidal) | Delta | +|--------|---------------------|-------------------|-------| +| Post-Q BPB | 1.2639 | 1.2624 | −0.0015 | +| Pre-Q BPB | 1.2566 | 1.2551 | −0.0015 | +| Q-gap | +0.0073 | +0.0073 | 0.0000 | +| step_avg | 62.56ms | 58.86ms | −3.70ms | +| Steps | 9592 | 10195 | +603 | + +Sinusoidal beats learned by 0.0015 BPB. The mechanism is entirely throughput: sinusoidal encoding has zero backpropagation cost (non-persistent buffer, no gradients), saving 3.70ms per step and enabling 603 additional training steps within the 600s wallclock cap. Q-gap is identical (0.0073), confirming that both provide equivalent per-iteration identity signal for quantization purposes — the difference is purely in training efficiency. + +### 11c. Conclusion + +Sinusoidal depth encoding is free and should be kept on as default. It doesn't help BPB but provides marginal Q-gap benefit and costs nothing. The model already gets per-iteration identity from FiLM gammas/betas — depth encoding is redundant for differentiation but harmless. + +This resolves the depth encoding question from §10f. The validated technique stack for SOTA integration is: **Output-LN + Birkhoff mixing + FiLM scale+shift (gammas+betas) + sinusoidal depth encoding**. The best full-sharing configuration is s5_T (1.2624 post-Q BPB, Q-gap +0.0073, 10.73MB artifact). + +> Dehghani, M., Gouws, S., Vinyals, O., Uszkoreit, J. & Kaiser, L. (2019). "Universal Transformers." ICLR 2019. [arXiv:1807.03819](https://arxiv.org/abs/1807.03819) diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale4.sh b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale4.sh new file mode 100755 index 0000000000..d844ba5e2d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/scripts/run_fullscale4.sh @@ -0,0 +1,37 @@ +#!/bin/bash +set -uo pipefail + +SCRIPT="train_gpt.py" +NGPU=${NGPU:-8} +COMMON="SEED=1337 MAX_WALLCLOCK_SECONDS=600 VAL_LOSS_EVERY=200 TRAIN_LOG_EVERY=200" +DATA="DATA_PATH=${DATA_PATH:-./data/datasets/fineweb10B_sp1024} TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model VOCAB_SIZE=1024" + +FAILS=0 +SUMMARY="" + +run_experiment() { + local name="$1"; shift + echo "" + echo "=== $name ===" + if "$@"; then + SUMMARY="${SUMMARY} PASS $name"$'\n' + else + SUMMARY="${SUMMARY} FAIL $name (exit $?)"$'\n' + FAILS=$((FAILS + 1)) + fi +} + +# --- T: 1+4×3+1, sinusoidal depth encoding + FiLM bias, 3 loops (compare vs s3_O and s4_R) --- + +run_experiment "Run T: 1+4x3+1 sinusoidal depth enc + FiLM bias (3 loops)" \ + env $COMMON $DATA RUN_ID=s5_T NUM_LAYERS=14 NUM_PRELUDE=1 NUM_SHARED=4 NUM_LOOPS=3 NUM_CODA=1 \ + USE_PERI_NORM=1 USE_BIRKHOFF_MIX=1 USE_TIMESTEP_SCALE=1 TIMESTEP_GAMMA_MAX=4.0 \ + USE_TIMESTEP_BIAS=1 USE_DEPTH_EMBED=1 USE_UNIQUE_NORMS=0 \ + torchrun --standalone --nproc_per_node=$NGPU $SCRIPT + +echo "" +echo "===============================" +echo " FULL-SCALE 4 SUMMARY" +echo "===============================" +echo "$SUMMARY" +echo "$FAILS run(s) failed." diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py index 1597f32bd0..70fb524932 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_gpt.py @@ -92,6 +92,7 @@ class Hyperparameters: timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) use_depth_embed = bool(int(os.environ.get("USE_DEPTH_EMBED", "0"))) + depth_enc_base = float(os.environ.get("DEPTH_ENC_BASE", 10000.0)) use_unique_norms = bool(int(os.environ.get("USE_UNIQUE_NORMS", "0"))) disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) @@ -314,7 +315,7 @@ def eval_val( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta,depth_embed,unique_attn_gain,unique_mlp_gain", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta,unique_attn_gain,unique_mlp_gain", ).split(",") if pattern ) @@ -571,6 +572,19 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +class DepthEncoding(nn.Module): + """Deterministic sinusoidal depth encodings (Universal Transformer style).""" + def __init__(self, depth: int, dim: int, base: float = 10000.0): + super().__init__() + t = torch.arange(depth, dtype=torch.float32) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, dtype=torch.float32) / dim)) + self.register_buffer('encodings', torch.sin(torch.outer(t, inv_freq)), persistent=False) + self.encodings: Tensor + + def get(self, v: int, dtype: torch.dtype) -> Tensor: + return self.encodings[v].to(dtype=dtype) + + def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] @@ -774,6 +788,7 @@ def __init__( use_timestep_scale: bool = False, use_timestep_bias: bool = False, use_depth_embed: bool = False, + depth_enc_base: float = 10000.0, use_unique_norms: bool = False, timestep_gamma_max: float = 0.0, leaky_relu_slope: float = 0.5, @@ -806,7 +821,7 @@ def __init__( self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None self.use_depth_embed = use_depth_embed if self.use_depth_embed: - self.depth_embeddings = nn.Parameter(torch.zeros(effective_layers, model_dim, dtype=torch.float32)) + self.depth_enc = DepthEncoding(effective_layers, model_dim, base=depth_enc_base) self.use_unique_norms = use_unique_norms if self.use_unique_norms: num_unique = num_shared * self.num_loops @@ -850,14 +865,14 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: v = 0 for block in self.prelude_blocks: ag, mg, ab, mb = self._get_ts(v) - de = self.depth_embeddings[v] if self.use_depth_embed else None + de = self.depth_enc.get(v, x.dtype) if self.use_depth_embed else None x = block(x, x0, ag, mg, ab, mb, de) v += 1 uid = 0 for _loop in range(self.num_loops): for block in self.shared_blocks: ag, mg, ab, mb = self._get_ts(v) - de = self.depth_embeddings[v] if self.use_depth_embed else None + de = self.depth_enc.get(v, x.dtype) if self.use_depth_embed else None if self.use_unique_norms: ag_n = self.unique_attn_gains[uid].to(dtype=x.dtype) mg_n = self.unique_mlp_gains[uid].to(dtype=x.dtype) @@ -868,7 +883,7 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: v += 1 for block in self.coda_blocks: ag, mg, ab, mb = self._get_ts(v) - de = self.depth_embeddings[v] if self.use_depth_embed else None + de = self.depth_enc.get(v, x.dtype) if self.use_depth_embed else None x = block(x, x0, ag, mg, ab, mb, de) v += 1 else: @@ -983,11 +998,7 @@ def recurrence_param_diagnostics(gpt: GPT) -> str: parts.append("unique_attn_gain_rms:[" + " ".join(un_attn) + "]") parts.append("unique_mlp_gain_rms:[" + " ".join(un_mlp) + "]") if gpt.use_depth_embed: - de_norms: list[str] = [] - for vi in range(effective_count): - de_rms = gpt.depth_embeddings[vi].norm().item() / gpt.depth_embeddings[vi].numel() ** 0.5 - de_norms.append(f"v{vi}:{de_rms:.4f}") - parts.append("depth_emb_rms:[" + " ".join(de_norms) + "]") + parts.append("depth_encoding:sinusoidal") return " ".join(parts) @@ -1111,6 +1122,7 @@ def log0(msg: str, console: bool = True) -> None: use_timestep_scale=args.use_timestep_scale, use_timestep_bias=args.use_timestep_bias, use_depth_embed=args.use_depth_embed, + depth_enc_base=args.depth_enc_base, use_unique_norms=args.use_unique_norms, timestep_gamma_max=args.timestep_gamma_max, leaky_relu_slope=args.leaky_relu_slope, @@ -1155,8 +1167,6 @@ def log0(msg: str, console: bool = True) -> None: if base_model.timestep_scale is not None: for p in base_model.timestep_scale.parameters(): scalar_params.append(p) - if base_model.use_depth_embed: - scalar_params.append(base_model.depth_embeddings) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr optimizer_tok = torch.optim.Adam( [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], diff --git a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_log.txt b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_log.txt index d4830cec70..c7c56b7b2d 100644 --- a/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_log.txt +++ b/records/track_non_record_16mb/2026-03-26_RecurrenceFix_3Loop_Birkhoff_OutputLN_TimestepScale/train_log.txt @@ -90,6 +90,10 @@ class Hyperparameters: use_timestep_scale = bool(int(os.environ.get("USE_TIMESTEP_SCALE", "0"))) leaky_relu_slope = float(os.environ.get("LEAKY_RELU_SLOPE", "0.5")) timestep_gamma_max = float(os.environ.get("TIMESTEP_GAMMA_MAX", "0.0")) + use_timestep_bias = bool(int(os.environ.get("USE_TIMESTEP_BIAS", "0"))) + use_depth_embed = bool(int(os.environ.get("USE_DEPTH_EMBED", "0"))) + depth_enc_base = float(os.environ.get("DEPTH_ENC_BASE", 10000.0)) + use_unique_norms = bool(int(os.environ.get("USE_UNIQUE_NORMS", "0"))) disable_compile = bool(int(os.environ.get("DISABLE_COMPILE", "0"))) # Optimizer hyperparameters. @@ -311,7 +315,7 @@ CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern for pattern in os.environ.get( "CONTROL_TENSOR_NAME_PATTERNS", - "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,ts_attn,ts_mlp,resid_mix_logit,ts_attn_beta,ts_mlp_beta,unique_attn_gain,unique_mlp_gain", ).split(",") if pattern ) @@ -568,6 +572,19 @@ class Rotary(nn.Module): return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +class DepthEncoding(nn.Module): + """Deterministic sinusoidal depth encodings (Universal Transformer style).""" + def __init__(self, depth: int, dim: int, base: float = 10000.0): + super().__init__() + t = torch.arange(depth, dtype=torch.float32) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, dtype=torch.float32) / dim)) + self.register_buffer('encodings', torch.sin(torch.outer(t, inv_freq)), persistent=False) + self.encodings: Tensor + + def get(self, v: int, dtype: torch.dtype) -> Tensor: + return self.encodings[v].to(dtype=dtype) + + def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: half = x.size(-1) // 2 x1, x2 = x[..., :half], x[..., half:] @@ -649,19 +666,27 @@ class MLP(nn.Module): class TimestepScaling(nn.Module): """Learned per-iteration scale vectors (arXiv:2410.01405, ICML 2025).""" - def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0): + def __init__(self, effective_layers: int, dim: int, gamma_max: float = 0.0, use_bias: bool = False): super().__init__() self.attn_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) self.mlp_gamma = nn.Parameter(torch.ones(effective_layers, dim, dtype=torch.float32)) self.gamma_max = gamma_max # 0 = uncapped + if use_bias: + self.attn_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + self.mlp_beta = nn.Parameter(torch.zeros(effective_layers, dim, dtype=torch.float32)) + else: + self.attn_beta = None + self.mlp_beta = None - def get(self, v: int) -> tuple[Tensor, Tensor]: + def get(self, v: int) -> tuple[Tensor, Tensor, Tensor | None, Tensor | None]: ag = self.attn_gamma[v] mg = self.mlp_gamma[v] if self.gamma_max > 0: ag = ag.clamp(-self.gamma_max, self.gamma_max) mg = mg.clamp(-self.gamma_max, self.gamma_max) - return ag, mg + ab = self.attn_beta[v] if self.attn_beta is not None else None + mb = self.mlp_beta[v] if self.mlp_beta is not None else None + return ag, mg, ab, mb class Block(nn.Module): @@ -690,30 +715,53 @@ class Block(nn.Module): self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) if use_birkhoff_mix: - # Birkhoff B₂ parameterization: sigmoid(logit) ∈ (0,1) → doubly stochastic self.resid_mix_logit = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) else: self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) def forward(self, x: Tensor, x0: Tensor, ts_attn_gamma: Tensor | None = None, - ts_mlp_gamma: Tensor | None = None) -> Tensor: + ts_mlp_gamma: Tensor | None = None, + ts_attn_beta: Tensor | None = None, + ts_mlp_beta: Tensor | None = None, + depth_emb: Tensor | None = None, + ext_attn_gain: Tensor | None = None, + ext_mlp_gain: Tensor | None = None) -> Tensor: if self.use_birkhoff_mix: alpha = torch.sigmoid(self.resid_mix_logit.to(dtype=x.dtype))[None, None, :] x = alpha * x + (1 - alpha) * x0 else: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) + if depth_emb is not None: + x = x + depth_emb + attn_normed = self.attn_norm(x) + if ext_attn_gain is not None: + attn_normed = attn_normed * ext_attn_gain[None, None, :] + attn_out = self.attn(attn_normed) attn_s = self.attn_scale.to(dtype=x.dtype)[None, None, :] if ts_attn_gamma is not None: attn_s = attn_s * ts_attn_gamma[None, None, :] x = x + attn_s * attn_out - mlp_out = self.mlp_out_norm(self.mlp(x)) if self.use_peri_norm else self.mlp(self.mlp_norm(x)) + if ts_attn_beta is not None: + x = x + ts_attn_beta[None, None, :] + if self.use_peri_norm: + if ext_mlp_gain is not None: + m_input = F.rms_norm(x, (x.size(-1),)) * ext_mlp_gain[None, None, :] + else: + m_input = x + mlp_out = self.mlp_out_norm(self.mlp(m_input)) + else: + mlp_normed = self.mlp_norm(x) + if ext_mlp_gain is not None: + mlp_normed = mlp_normed * ext_mlp_gain[None, None, :] + mlp_out = self.mlp(mlp_normed) mlp_s = self.mlp_scale.to(dtype=x.dtype)[None, None, :] if ts_mlp_gamma is not None: mlp_s = mlp_s * ts_mlp_gamma[None, None, :] x = x + mlp_s * mlp_out + if ts_mlp_beta is not None: + x = x + ts_mlp_beta[None, None, :] return x @@ -738,6 +786,10 @@ class GPT(nn.Module): use_peri_norm: bool = False, use_birkhoff_mix: bool = False, use_timestep_scale: bool = False, + use_timestep_bias: bool = False, + use_depth_embed: bool = False, + depth_enc_base: float = 10000.0, + use_unique_norms: bool = False, timestep_gamma_max: float = 0.0, leaky_relu_slope: float = 0.5, ): @@ -763,10 +815,18 @@ class GPT(nn.Module): if self.use_recurrence: self.prelude_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_prelude)]) - self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) self.coda_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(self.num_coda)]) + self.shared_blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_shared)]) effective_layers = self.num_prelude + num_shared * self.num_loops + self.num_coda - self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max) if use_timestep_scale else None + self.timestep_scale = TimestepScaling(effective_layers, model_dim, gamma_max=timestep_gamma_max, use_bias=use_timestep_bias) if use_timestep_scale else None + self.use_depth_embed = use_depth_embed + if self.use_depth_embed: + self.depth_enc = DepthEncoding(effective_layers, model_dim, base=depth_enc_base) + self.use_unique_norms = use_unique_norms + if self.use_unique_norms: + num_unique = num_shared * self.num_loops + self.unique_attn_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) + self.unique_mlp_gains = nn.Parameter(torch.ones(num_unique, model_dim, dtype=torch.float32)) else: # Standard U-Net path self.num_encoder_layers = num_layers // 2 @@ -775,6 +835,8 @@ class GPT(nn.Module): self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_layers)]) self.timestep_scale = None + self.use_depth_embed = False + self.use_unique_norms = False self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) @@ -789,9 +851,9 @@ class GPT(nn.Module): if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) - def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None]: + def _get_ts(self, v: int) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]: if self.timestep_scale is None: - return None, None + return None, None, None, None return self.timestep_scale.get(v) def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: @@ -802,17 +864,27 @@ class GPT(nn.Module): if self.use_recurrence: v = 0 for block in self.prelude_blocks: - ag, mg = self._get_ts(v) - x = block(x, x0, ag, mg) + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_enc.get(v, x.dtype) if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) v += 1 + uid = 0 for _loop in range(self.num_loops): for block in self.shared_blocks: - ag, mg = self._get_ts(v) - x = block(x, x0, ag, mg) + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_enc.get(v, x.dtype) if self.use_depth_embed else None + if self.use_unique_norms: + ag_n = self.unique_attn_gains[uid].to(dtype=x.dtype) + mg_n = self.unique_mlp_gains[uid].to(dtype=x.dtype) + x = block(x, x0, ag, mg, ab, mb, de, ag_n, mg_n) + else: + x = block(x, x0, ag, mg, ab, mb, de) + uid += 1 v += 1 for block in self.coda_blocks: - ag, mg = self._get_ts(v) - x = block(x, x0, ag, mg) + ag, mg, ab, mb = self._get_ts(v) + de = self.depth_enc.get(v, x.dtype) if self.use_depth_embed else None + x = block(x, x0, ag, mg, ab, mb, de) v += 1 else: skips: list[Tensor] = [] @@ -853,14 +925,43 @@ def recurrence_param_diagnostics(gpt: GPT) -> str: # mlp_scale * ts_mlp_gamma[v] is the learned multiplier on MLP output. # We report RMS (norm / sqrt(numel)) to make it scale-independent. v = 0 - all_blocks = ( - list(gpt.prelude_blocks) - + list(gpt.shared_blocks) * gpt.num_loops - + list(gpt.coda_blocks) - ) + effective_count = gpt.num_prelude + len(gpt.shared_blocks) * gpt.num_loops + gpt.num_coda mlp_norms: list[str] = [] attn_norms: list[str] = [] - for block in all_blocks: + + # Prelude blocks + for block in gpt.prelude_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Shared positions + for _loop in range(gpt.num_loops): + for block in gpt.shared_blocks: + ms = block.mlp_scale.norm().item() + asc = block.attn_scale.norm().item() + d = block.mlp_scale.numel() ** 0.5 + if gpt.timestep_scale is not None: + ms_g = gpt.timestep_scale.mlp_gamma[v].norm().item() + as_g = gpt.timestep_scale.attn_gamma[v].norm().item() + mlp_norms.append(f"v{v}:{ms * ms_g / d:.4f}") + attn_norms.append(f"v{v}:{asc * as_g / d:.4f}") + else: + mlp_norms.append(f"v{v}:{ms / d:.4f}") + attn_norms.append(f"v{v}:{asc / d:.4f}") + v += 1 + + # Coda blocks + for block in gpt.coda_blocks: ms = block.mlp_scale.norm().item() asc = block.attn_scale.norm().item() d = block.mlp_scale.numel() ** 0.5 @@ -873,8 +974,31 @@ def recurrence_param_diagnostics(gpt: GPT) -> str: mlp_norms.append(f"v{v}:{ms / d:.4f}") attn_norms.append(f"v{v}:{asc / d:.4f}") v += 1 + parts.append("eff_mlp_scale:[" + " ".join(mlp_norms) + "]") parts.append("eff_attn_scale:[" + " ".join(attn_norms) + "]") + if gpt.timestep_scale is not None and gpt.timestep_scale.attn_beta is not None: + attn_bias_norms: list[str] = [] + mlp_bias_norms: list[str] = [] + for vi in range(effective_count): + ab_rms = gpt.timestep_scale.attn_beta[vi].norm().item() / gpt.timestep_scale.attn_beta[vi].numel() ** 0.5 + mb_rms = gpt.timestep_scale.mlp_beta[vi].norm().item() / gpt.timestep_scale.mlp_beta[vi].numel() ** 0.5 + attn_bias_norms.append(f"v{vi}:{ab_rms:.4f}") + mlp_bias_norms.append(f"v{vi}:{mb_rms:.4f}") + parts.append("eff_attn_bias:[" + " ".join(attn_bias_norms) + "]") + parts.append("eff_mlp_bias:[" + " ".join(mlp_bias_norms) + "]") + if gpt.use_unique_norms: + un_attn: list[str] = [] + un_mlp: list[str] = [] + for ui in range(gpt.unique_attn_gains.size(0)): + an_rms = gpt.unique_attn_gains[ui].norm().item() / gpt.unique_attn_gains[ui].numel() ** 0.5 + un_attn.append(f"u{ui}:{an_rms:.4f}") + mn_rms = gpt.unique_mlp_gains[ui].norm().item() / gpt.unique_mlp_gains[ui].numel() ** 0.5 + un_mlp.append(f"u{ui}:{mn_rms:.4f}") + parts.append("unique_attn_gain_rms:[" + " ".join(un_attn) + "]") + parts.append("unique_mlp_gain_rms:[" + " ".join(un_mlp) + "]") + if gpt.use_depth_embed: + parts.append("depth_encoding:sinusoidal") return " ".join(parts) @@ -996,6 +1120,10 @@ def main() -> None: use_peri_norm=args.use_peri_norm, use_birkhoff_mix=args.use_birkhoff_mix, use_timestep_scale=args.use_timestep_scale, + use_timestep_bias=args.use_timestep_bias, + use_depth_embed=args.use_depth_embed, + depth_enc_base=args.depth_enc_base, + use_unique_norms=args.use_unique_norms, timestep_gamma_max=args.timestep_gamma_max, leaky_relu_slope=args.leaky_relu_slope, ).to(device).bfloat16() @@ -1019,6 +1147,9 @@ def main() -> None: block_named_params = [] for bl in all_block_lists: block_named_params.extend(bl.named_parameters()) + if base_model.use_unique_norms: + block_named_params.extend([("unique_attn_gains", base_model.unique_attn_gains)]) + block_named_params.extend([("unique_mlp_gains", base_model.unique_mlp_gains)]) else: block_named_params = list(base_model.blocks.named_parameters()) matrix_params = [ @@ -1085,8 +1216,9 @@ def main() -> None: ) log0(f"seed:{args.seed}") if base_model.use_recurrence: - eff = base_model.num_prelude + len(base_model.shared_blocks) * base_model.num_loops + base_model.num_coda - log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{len(base_model.shared_blocks)} " + num_shared = len(base_model.shared_blocks) + eff = base_model.num_prelude + num_shared * base_model.num_loops + base_model.num_coda + log0(f"recurrence:enabled prelude:{base_model.num_prelude} shared:{num_shared} " f"loops:{base_model.num_loops} coda:{base_model.num_coda} effective_layers:{eff}") log0(f"peri_norm:{args.use_peri_norm} birkhoff_mix:{args.use_birkhoff_mix}") if base_model.timestep_scale is not None: @@ -1094,6 +1226,8 @@ def main() -> None: log0(f"timestep_scale:enabled params:{ts_params}") else: log0("timestep_scale:disabled") + log0(f"depth_embed:{'enabled' if base_model.use_depth_embed else 'disabled'}") + log0(f"unique_norms:{'enabled' if base_model.use_unique_norms else 'disabled'}") else: log0(f"recurrence:disabled num_layers:{args.num_layers}") compile_mode = "disabled" if args.disable_compile else "fullgraph=True" @@ -1322,7 +1456,7 @@ if __name__ == "__main__": ==================================================================================================== Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] Running PyTorch 2.9.1+cu128 -Wed Mar 25 21:36:33 2026 +Thu Apr 2 17:29:56 2026 +-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 | +-----------------------------------------+------------------------+----------------------+ @@ -1330,36 +1464,36 @@ Wed Mar 25 21:36:33 2026 | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| -| 0 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | -| N/A 35C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 | -| N/A 29C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 | +| N/A 30C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA H100 80GB HBM3 On | 00000000:2F:00.0 Off | 0 | -| N/A 29C P0 114W / 700W | 1521MiB / 81559MiB | 0% Default | +| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 | +| N/A 29C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA H100 80GB HBM3 On | 00000000:38:00.0 Off | 0 | -| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 | +| N/A 31C P0 113W / 700W | 1521MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ -| 4 NVIDIA H100 80GB HBM3 On | 00000000:87:00.0 Off | 0 | -| N/A 32C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | +| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ -| 5 NVIDIA H100 80GB HBM3 On | 00000000:90:00.0 Off | 0 | -| N/A 30C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 | +| N/A 31C P0 115W / 700W | 1521MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ -| 6 NVIDIA H100 80GB HBM3 On | 00000000:BE:00.0 Off | 0 | -| N/A 31C P0 121W / 700W | 1521MiB / 81559MiB | 0% Default | +| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1521MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ -| 7 NVIDIA H100 80GB HBM3 On | 00000000:C7:00.0 Off | 0 | -| N/A 32C P0 118W / 700W | 1521MiB / 81559MiB | 0% Default | +| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 | +| N/A 29C P0 117W / 700W | 1521MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+------------------------+----------------------+ @@ -1368,21 +1502,14 @@ Wed Mar 25 21:36:33 2026 | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| -| 0 N/A N/A 1046 C /usr/local/bin/python 1512MiB | -| 1 N/A N/A 1047 C /usr/local/bin/python 1512MiB | -| 2 N/A N/A 1048 C /usr/local/bin/python 1512MiB | -| 3 N/A N/A 1049 C /usr/local/bin/python 1512MiB | -| 4 N/A N/A 1050 C /usr/local/bin/python 1512MiB | -| 5 N/A N/A 1051 C /usr/local/bin/python 1512MiB | -| 6 N/A N/A 1052 C /usr/local/bin/python 1512MiB | -| 7 N/A N/A 1053 C /usr/local/bin/python 1512MiB | +| No running processes found | +-----------------------------------------------------------------------------------------+ ==================================================================================================== val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model train_loader:dataset:fineweb10B_sp1024 train_shards:80 val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 -model_params:11557936 +model_params:11572272 world_size:8 grad_accum_steps:1 sdp_backends:cudnn=False flash=True mem_efficient=False math=False flash_attention_3:True @@ -1392,7 +1519,9 @@ train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 ma seed:1337 recurrence:enabled prelude:1 shared:4 loops:3 coda:1 effective_layers:14 peri_norm:True birkhoff_mix:True -timestep_scale:enabled params:14336 +timestep_scale:enabled params:28672 +depth_embed:enabled +unique_norms:disabled compile_mode:fullgraph=True warmup_step:1/20 warmup_step:2/20 @@ -1414,174 +1543,174 @@ warmup_step:17/20 warmup_step:18/20 warmup_step:19/20 warmup_step:20/20 -step:0/20000 val_loss:6.9379 val_bpb:4.1090 train_time:0ms step_avg:0.01ms -step:1/20000 train_loss:6.9377 train_time:34ms step_avg:33.54ms -step:2/20000 train_loss:9.6217 train_time:88ms step_avg:44.00ms -step:3/20000 train_loss:7.3812 train_time:144ms step_avg:48.09ms -step:4/20000 train_loss:9.1484 train_time:201ms step_avg:50.14ms -step:5/20000 train_loss:8.6154 train_time:257ms step_avg:51.49ms -step:6/20000 train_loss:8.3173 train_time:314ms step_avg:52.32ms -step:7/20000 train_loss:6.9135 train_time:370ms step_avg:52.87ms -step:8/20000 train_loss:6.1894 train_time:427ms step_avg:53.36ms -step:9/20000 train_loss:5.6406 train_time:483ms step_avg:53.70ms -step:10/20000 train_loss:5.4038 train_time:540ms step_avg:54.00ms -step:200/20000 train_loss:2.7823 train_time:11733ms step_avg:58.67ms -step:200 shared0_alpha:mean=0.448,std=0.049 shared1_alpha:mean=0.476,std=0.042 shared2_alpha:mean=0.489,std=0.039 shared3_alpha:mean=0.516,std=0.040 eff_mlp_scale:[v0:27.5484 v1:26.7534 v2:27.3004 v3:28.6959 v4:29.6441 v5:31.4062 v6:28.1404 v7:28.1103 v8:29.0391 v9:30.2430 v10:28.0004 v11:30.8921 v12:32.8202 v13:58.7413] eff_attn_scale:[v0:12.9666 v1:13.2604 v2:10.8659 v3:10.9078 v4:9.5361 v5:11.1804 v6:9.5996 v7:9.4641 v8:8.8824 v9:11.2670 v10:10.2124 v11:9.9453 v12:9.4207 v13:16.9741] -step:200/20000 val_loss:2.7597 val_bpb:1.6344 train_time:11795ms step_avg:58.98ms -step:400/20000 train_loss:2.3575 train_time:23561ms step_avg:58.90ms -step:400 shared0_alpha:mean=0.454,std=0.050 shared1_alpha:mean=0.497,std=0.043 shared2_alpha:mean=0.515,std=0.039 shared3_alpha:mean=0.550,std=0.043 eff_mlp_scale:[v0:34.2854 v1:33.6046 v2:34.5982 v3:37.3141 v4:36.5859 v5:42.7999 v6:38.2653 v7:36.6623 v8:34.8208 v9:38.9546 v10:35.2360 v11:36.8252 v12:35.6232 v13:73.9149] eff_attn_scale:[v0:6.4260 v1:6.6149 v2:5.7415 v3:5.8219 v4:5.4229 v5:6.0857 v6:5.4070 v7:5.2637 v8:5.3181 v9:5.9387 v10:5.1841 v11:4.7586 v12:4.7941 v13:9.9713] -step:400/20000 val_loss:2.5650 val_bpb:1.5192 train_time:23593ms step_avg:58.98ms -step:600/20000 train_loss:2.5758 train_time:35380ms step_avg:58.97ms -step:600 shared0_alpha:mean=0.454,std=0.050 shared1_alpha:mean=0.509,std=0.043 shared2_alpha:mean=0.532,std=0.040 shared3_alpha:mean=0.574,std=0.045 eff_mlp_scale:[v0:40.3946 v1:37.9918 v2:38.8994 v3:42.4052 v4:40.5967 v5:49.2288 v6:44.6498 v7:42.9203 v8:39.4368 v9:44.2346 v10:39.0685 v11:39.8300 v12:37.1170 v13:88.3020] eff_attn_scale:[v0:3.0049 v1:3.5136 v2:3.1467 v3:3.4297 v4:3.4763 v5:3.3945 v6:3.2047 v7:3.1516 v8:3.5922 v9:3.2953 v10:2.8958 v11:2.6696 v12:2.9162 v13:6.4813] -step:600/20000 val_loss:2.4776 val_bpb:1.4674 train_time:35411ms step_avg:59.02ms -step:800/20000 train_loss:2.3367 train_time:47210ms step_avg:59.01ms -step:800 shared0_alpha:mean=0.453,std=0.050 shared1_alpha:mean=0.517,std=0.045 shared2_alpha:mean=0.544,std=0.042 shared3_alpha:mean=0.590,std=0.046 eff_mlp_scale:[v0:47.3770 v1:41.6239 v2:42.4979 v3:45.9607 v4:43.3238 v5:54.2598 v6:49.1711 v7:47.3749 v8:43.1539 v9:47.9419 v10:41.6198 v11:41.5414 v12:38.7366 v13:99.1780] eff_attn_scale:[v0:1.9183 v1:2.3710 v2:2.2032 v3:2.3913 v4:2.6813 v5:2.3556 v6:2.2797 v7:2.2455 v8:2.8579 v9:2.2633 v10:2.0349 v11:1.8226 v12:2.1354 v13:4.8487] -step:800/20000 val_loss:2.4226 val_bpb:1.4348 train_time:47241ms step_avg:59.05ms -step:1000/20000 train_loss:2.4170 train_time:59041ms step_avg:59.04ms -step:1000 shared0_alpha:mean=0.449,std=0.051 shared1_alpha:mean=0.523,std=0.047 shared2_alpha:mean=0.554,std=0.043 shared3_alpha:mean=0.603,std=0.048 eff_mlp_scale:[v0:54.0667 v1:45.0744 v2:45.5208 v3:49.3566 v4:46.1017 v5:57.9254 v6:52.7463 v7:50.8083 v8:46.4510 v9:51.4040 v10:43.8951 v11:43.5500 v12:40.8629 v13:108.3864] eff_attn_scale:[v0:1.4628 v1:1.8692 v2:1.7604 v3:1.9248 v4:2.2411 v5:1.8034 v6:1.7998 v7:1.8108 v8:2.3981 v9:1.7639 v10:1.5830 v11:1.4246 v12:1.7486 v13:3.9293] -step:1000/20000 val_loss:2.3828 val_bpb:1.4112 train_time:59072ms step_avg:59.07ms -step:1200/20000 train_loss:2.4374 train_time:70896ms step_avg:59.08ms -step:1200 shared0_alpha:mean=0.447,std=0.052 shared1_alpha:mean=0.527,std=0.047 shared2_alpha:mean=0.561,std=0.044 shared3_alpha:mean=0.612,std=0.049 eff_mlp_scale:[v0:59.9127 v1:48.1806 v2:48.0975 v3:52.2504 v4:48.5522 v5:61.3565 v6:56.2371 v7:54.1033 v8:49.6232 v9:53.8836 v10:46.0626 v11:45.0243 v12:42.8401 v13:117.2016] eff_attn_scale:[v0:1.2315 v1:1.5553 v2:1.4982 v3:1.6681 v4:1.9516 v5:1.5137 v6:1.5342 v7:1.5631 v8:2.0957 v9:1.4900 v10:1.3364 v11:1.2073 v12:1.4997 v13:3.4107] -step:1200/20000 val_loss:2.3552 val_bpb:1.3949 train_time:70929ms step_avg:59.11ms -step:1400/20000 train_loss:2.4865 train_time:82745ms step_avg:59.10ms -step:1400 shared0_alpha:mean=0.443,std=0.052 shared1_alpha:mean=0.531,std=0.048 shared2_alpha:mean=0.567,std=0.045 shared3_alpha:mean=0.620,std=0.050 eff_mlp_scale:[v0:65.7036 v1:51.0674 v2:51.0106 v3:55.1941 v4:50.6375 v5:64.3369 v6:58.9456 v7:57.0843 v8:52.4589 v9:56.6969 v10:47.7988 v11:46.6881 v12:44.6265 v13:124.5906] eff_attn_scale:[v0:1.0368 v1:1.3702 v2:1.3014 v3:1.4703 v4:1.7817 v5:1.3202 v6:1.3401 v7:1.3669 v8:1.8930 v9:1.2925 v10:1.1518 v11:1.0510 v12:1.3363 v13:3.0387] -step:1400/20000 val_loss:2.3352 val_bpb:1.3830 train_time:82776ms step_avg:59.13ms -step:1600/20000 train_loss:2.1583 train_time:94592ms step_avg:59.12ms -step:1600 shared0_alpha:mean=0.440,std=0.054 shared1_alpha:mean=0.534,std=0.049 shared2_alpha:mean=0.571,std=0.046 shared3_alpha:mean=0.626,std=0.050 eff_mlp_scale:[v0:71.6690 v1:53.9213 v2:53.3864 v3:57.5377 v4:52.9029 v5:66.5846 v6:61.4519 v7:59.4556 v8:54.7527 v9:58.4147 v10:49.5456 v11:47.9481 v12:46.0588 v13:131.3244] eff_attn_scale:[v0:0.8817 v1:1.2447 v2:1.1989 v3:1.3474 v4:1.6413 v5:1.1715 v6:1.2308 v7:1.2737 v8:1.7483 v9:1.1349 v10:1.0504 v11:0.9526 v12:1.2250 v13:2.7730] -step:1600/20000 val_loss:2.3234 val_bpb:1.3760 train_time:94624ms step_avg:59.14ms -step:1800/20000 train_loss:2.2560 train_time:106476ms step_avg:59.15ms -step:1800 shared0_alpha:mean=0.437,std=0.055 shared1_alpha:mean=0.536,std=0.049 shared2_alpha:mean=0.575,std=0.047 shared3_alpha:mean=0.631,std=0.051 eff_mlp_scale:[v0:77.0559 v1:56.5370 v2:55.8204 v3:59.9762 v4:55.0234 v5:69.4242 v6:64.0178 v7:61.9235 v8:57.2847 v9:60.6942 v10:51.1362 v11:49.4609 v12:48.0513 v13:137.3129] eff_attn_scale:[v0:0.8088 v1:1.1540 v2:1.0999 v3:1.2589 v4:1.5357 v5:1.0727 v6:1.1151 v7:1.1628 v8:1.6388 v9:1.0523 v10:0.9529 v11:0.8646 v12:1.1288 v13:2.5785] -step:1800/20000 val_loss:2.3070 val_bpb:1.3663 train_time:106508ms step_avg:59.17ms -step:2000/20000 train_loss:2.3081 train_time:118327ms step_avg:59.16ms -step:2000 shared0_alpha:mean=0.434,std=0.055 shared1_alpha:mean=0.538,std=0.050 shared2_alpha:mean=0.579,std=0.048 shared3_alpha:mean=0.636,std=0.052 eff_mlp_scale:[v0:82.7543 v1:59.0391 v2:58.7532 v3:62.3953 v4:57.0309 v5:71.6903 v6:66.2958 v7:63.9749 v8:59.7102 v9:62.8344 v10:52.7985 v11:50.5481 v12:49.3758 v13:142.4716] eff_attn_scale:[v0:0.7283 v1:1.0662 v2:1.0360 v3:1.1851 v4:1.4729 v5:0.9834 v6:1.0360 v7:1.1012 v8:1.5511 v9:0.9640 v10:0.8796 v11:0.8049 v12:1.0601 v13:2.4340] -step:2000/20000 val_loss:2.2943 val_bpb:1.3588 train_time:118359ms step_avg:59.18ms -step:2200/20000 train_loss:2.1339 train_time:130184ms step_avg:59.17ms -step:2200 shared0_alpha:mean=0.431,std=0.056 shared1_alpha:mean=0.540,std=0.051 shared2_alpha:mean=0.582,std=0.049 shared3_alpha:mean=0.640,std=0.053 eff_mlp_scale:[v0:87.8139 v1:61.0818 v2:60.7661 v3:64.6332 v4:59.1065 v5:73.8962 v6:68.4121 v7:66.2390 v8:61.8285 v9:64.4990 v10:54.3273 v11:52.1883 v12:51.3293 v13:148.2095] eff_attn_scale:[v0:0.6803 v1:1.0172 v2:0.9873 v3:1.1284 v4:1.4244 v5:0.9179 v6:0.9921 v7:1.0319 v8:1.5017 v9:0.8990 v10:0.8380 v11:0.7667 v12:1.0269 v13:2.3139] -step:2200/20000 val_loss:2.2862 val_bpb:1.3540 train_time:130215ms step_avg:59.19ms -step:2400/20000 train_loss:2.2593 train_time:142034ms step_avg:59.18ms -step:2400 shared0_alpha:mean=0.428,std=0.057 shared1_alpha:mean=0.541,std=0.052 shared2_alpha:mean=0.585,std=0.050 shared3_alpha:mean=0.644,std=0.054 eff_mlp_scale:[v0:93.0263 v1:63.7902 v2:62.8558 v3:66.7213 v4:61.5684 v5:76.3747 v6:70.6107 v7:68.3486 v8:64.3311 v9:66.8279 v10:55.9172 v11:53.7025 v12:52.8857 v13:154.3660] eff_attn_scale:[v0:0.6218 v1:0.9561 v2:0.9382 v3:1.0693 v4:1.3684 v5:0.8596 v6:0.9243 v7:0.9802 v8:1.4387 v9:0.8458 v10:0.7764 v11:0.7222 v12:0.9898 v13:2.1780] -step:2400/20000 val_loss:2.2751 val_bpb:1.3475 train_time:142066ms step_avg:59.19ms -step:2600/20000 train_loss:2.4727 train_time:153881ms step_avg:59.19ms -step:2600 shared0_alpha:mean=0.425,std=0.057 shared1_alpha:mean=0.543,std=0.052 shared2_alpha:mean=0.587,std=0.051 shared3_alpha:mean=0.647,std=0.054 eff_mlp_scale:[v0:97.4525 v1:65.8375 v2:64.8462 v3:69.2610 v4:63.6741 v5:78.1271 v6:72.6938 v7:70.4978 v8:66.4774 v9:68.4710 v10:57.4116 v11:54.8316 v12:54.4634 v13:158.4925] eff_attn_scale:[v0:0.5899 v1:0.9353 v2:0.9162 v3:1.0517 v4:1.3262 v5:0.8269 v6:0.8934 v7:0.9590 v8:1.3737 v9:0.8088 v10:0.7475 v11:0.6996 v12:0.9352 v13:2.1238] -step:2600/20000 val_loss:2.2874 val_bpb:1.3547 train_time:153913ms step_avg:59.20ms -step:2800/20000 train_loss:2.2975 train_time:165743ms step_avg:59.19ms -step:2800 shared0_alpha:mean=0.423,std=0.058 shared1_alpha:mean=0.545,std=0.053 shared2_alpha:mean=0.590,std=0.052 shared3_alpha:mean=0.649,std=0.055 eff_mlp_scale:[v0:102.8439 v1:68.1906 v2:67.2278 v3:71.7107 v4:65.6281 v5:79.7033 v6:74.3264 v7:72.5445 v8:68.4639 v9:69.9617 v10:58.8765 v11:56.2846 v12:56.3106 v13:163.4298] eff_attn_scale:[v0:0.5667 v1:0.9016 v2:0.8768 v3:1.0077 v4:1.2826 v5:0.7999 v6:0.8543 v7:0.9177 v8:1.3507 v9:0.7734 v10:0.7239 v11:0.6613 v12:0.9214 v13:2.0292] -step:2800/20000 val_loss:2.2621 val_bpb:1.3398 train_time:165774ms step_avg:59.21ms -step:3000/20000 train_loss:2.2850 train_time:177586ms step_avg:59.20ms -step:3000 shared0_alpha:mean=0.420,std=0.058 shared1_alpha:mean=0.545,std=0.054 shared2_alpha:mean=0.592,std=0.052 shared3_alpha:mean=0.652,std=0.056 eff_mlp_scale:[v0:107.1784 v1:70.3300 v2:69.3097 v3:73.8286 v4:67.6902 v5:81.9770 v6:76.0716 v7:74.6724 v8:70.1517 v9:71.6739 v10:60.0120 v11:57.7973 v12:57.8444 v13:168.6794] eff_attn_scale:[v0:0.5273 v1:0.8750 v2:0.8527 v3:0.9793 v4:1.2626 v5:0.7700 v6:0.8307 v7:0.8778 v8:1.3090 v9:0.7394 v10:0.6989 v11:0.6352 v12:0.8916 v13:1.9865] -step:3000/20000 val_loss:2.2550 val_bpb:1.3355 train_time:177618ms step_avg:59.21ms -step:3200/20000 train_loss:2.2509 train_time:189434ms step_avg:59.20ms -step:3200 shared0_alpha:mean=0.417,std=0.059 shared1_alpha:mean=0.546,std=0.054 shared2_alpha:mean=0.595,std=0.053 shared3_alpha:mean=0.654,std=0.056 eff_mlp_scale:[v0:111.8766 v1:72.4473 v2:71.3405 v3:75.8253 v4:69.6745 v5:84.2200 v6:78.1756 v7:76.6773 v8:71.7481 v9:72.9001 v10:61.5152 v11:58.7859 v12:59.3062 v13:172.0224] eff_attn_scale:[v0:0.5131 v1:0.8464 v2:0.8337 v3:0.9520 v4:1.2488 v5:0.7401 v6:0.8033 v7:0.8647 v8:1.2951 v9:0.7103 v10:0.6774 v11:0.6201 v12:0.8891 v13:1.9186] -step:3200/20000 val_loss:2.2506 val_bpb:1.3329 train_time:189467ms step_avg:59.21ms -step:3400/20000 train_loss:2.2156 train_time:201279ms step_avg:59.20ms -step:3400 shared0_alpha:mean=0.415,std=0.059 shared1_alpha:mean=0.547,std=0.054 shared2_alpha:mean=0.597,std=0.054 shared3_alpha:mean=0.657,std=0.057 eff_mlp_scale:[v0:116.7305 v1:74.6774 v2:73.4077 v3:77.6875 v4:71.9578 v5:86.1310 v6:79.8848 v7:78.5507 v8:74.0618 v9:74.6774 v10:63.0442 v11:60.4236 v12:61.0169 v13:176.9400] eff_attn_scale:[v0:0.4849 v1:0.8381 v2:0.8049 v3:0.9454 v4:1.2184 v5:0.7281 v6:0.7836 v7:0.8452 v8:1.2641 v9:0.6942 v10:0.6558 v11:0.6143 v12:0.8631 v13:1.8646] -step:3400/20000 val_loss:2.2473 val_bpb:1.3310 train_time:201311ms step_avg:59.21ms -step:3600/20000 train_loss:2.1850 train_time:213120ms step_avg:59.20ms -step:3600 shared0_alpha:mean=0.413,std=0.060 shared1_alpha:mean=0.548,std=0.055 shared2_alpha:mean=0.599,std=0.054 shared3_alpha:mean=0.659,std=0.057 eff_mlp_scale:[v0:121.4435 v1:77.4289 v2:75.4982 v3:80.2307 v4:73.4860 v5:88.0927 v6:82.0443 v7:80.2307 v8:76.0347 v9:76.5016 v10:64.1516 v11:61.4811 v12:62.4419 v13:180.8942] eff_attn_scale:[v0:0.4713 v1:0.8233 v2:0.7938 v3:0.9230 v4:1.1999 v5:0.7141 v6:0.7726 v7:0.8281 v8:1.2351 v9:0.6847 v10:0.6410 v11:0.5995 v12:0.8535 v13:1.8283] -step:3600/20000 val_loss:2.2399 val_bpb:1.3266 train_time:213152ms step_avg:59.21ms -step:3800/20000 train_loss:2.2832 train_time:224971ms step_avg:59.20ms -step:3800 shared0_alpha:mean=0.411,std=0.061 shared1_alpha:mean=0.549,std=0.055 shared2_alpha:mean=0.600,std=0.055 shared3_alpha:mean=0.661,std=0.057 eff_mlp_scale:[v0:126.1614 v1:79.1049 v2:77.7751 v3:82.4223 v4:75.6845 v5:89.8707 v6:83.9617 v7:81.9816 v8:77.8346 v9:77.7007 v10:65.8437 v11:63.0289 v12:64.0738 v13:184.9535] eff_attn_scale:[v0:0.4496 v1:0.8116 v2:0.7936 v3:0.8995 v4:1.2013 v5:0.7034 v6:0.7516 v7:0.8181 v8:1.2162 v9:0.6660 v10:0.6257 v11:0.5868 v12:0.8374 v13:1.7982] -step:3800/20000 val_loss:2.2354 val_bpb:1.3239 train_time:225002ms step_avg:59.21ms -step:4000/20000 train_loss:2.2207 train_time:236822ms step_avg:59.21ms -step:4000 shared0_alpha:mean=0.409,std=0.061 shared1_alpha:mean=0.550,std=0.055 shared2_alpha:mean=0.602,std=0.056 shared3_alpha:mean=0.663,std=0.057 eff_mlp_scale:[v0:130.4283 v1:81.3678 v2:79.3461 v3:84.2017 v4:77.4413 v5:91.7753 v6:85.5869 v7:83.7562 v8:79.6166 v9:79.4755 v10:66.8647 v11:64.1537 v12:65.6946 v13:189.9913] eff_attn_scale:[v0:0.4358 v1:0.7930 v2:0.7767 v3:0.8925 v4:1.1616 v5:0.6856 v6:0.7350 v7:0.7910 v8:1.2061 v9:0.6526 v10:0.6139 v11:0.5795 v12:0.8205 v13:1.7570] -step:4000/20000 val_loss:2.2301 val_bpb:1.3208 train_time:236854ms step_avg:59.21ms -step:4200/20000 train_loss:2.2356 train_time:248914ms step_avg:59.27ms -step:4200 shared0_alpha:mean=0.407,std=0.061 shared1_alpha:mean=0.551,std=0.056 shared2_alpha:mean=0.604,std=0.056 shared3_alpha:mean=0.665,std=0.058 eff_mlp_scale:[v0:134.3793 v1:83.5345 v2:81.4617 v3:86.8770 v4:79.6370 v5:93.5586 v6:87.3126 v7:85.5265 v8:81.3969 v9:80.6705 v10:67.9598 v11:65.7204 v12:67.3175 v13:193.9883] eff_attn_scale:[v0:0.4229 v1:0.7838 v2:0.7523 v3:0.8786 v4:1.1727 v5:0.6701 v6:0.7157 v7:0.7908 v8:1.1975 v9:0.6295 v10:0.5896 v11:0.5648 v12:0.8115 v13:1.7432] -step:4200/20000 val_loss:2.2262 val_bpb:1.3185 train_time:248945ms step_avg:59.27ms -step:4400/20000 train_loss:2.1765 train_time:260750ms step_avg:59.26ms -step:4400 shared0_alpha:mean=0.405,std=0.061 shared1_alpha:mean=0.552,std=0.056 shared2_alpha:mean=0.606,std=0.057 shared3_alpha:mean=0.667,std=0.058 eff_mlp_scale:[v0:140.1213 v1:85.2745 v2:83.1819 v3:88.7375 v4:81.4193 v5:95.3918 v6:89.0910 v7:87.3723 v8:83.1990 v9:82.3838 v10:69.5455 v11:66.8944 v12:68.9617 v13:197.1806] eff_attn_scale:[v0:0.4059 v1:0.7894 v2:0.7553 v3:0.8802 v4:1.1561 v5:0.6626 v6:0.7063 v7:0.7801 v8:1.1757 v9:0.6339 v10:0.5838 v11:0.5590 v12:0.8117 v13:1.7199] -step:4400/20000 val_loss:2.2270 val_bpb:1.3190 train_time:260781ms step_avg:59.27ms -step:4600/20000 train_loss:2.0368 train_time:272605ms step_avg:59.26ms -step:4600 shared0_alpha:mean=0.403,std=0.062 shared1_alpha:mean=0.552,std=0.056 shared2_alpha:mean=0.607,std=0.057 shared3_alpha:mean=0.668,std=0.059 eff_mlp_scale:[v0:145.3811 v1:87.5979 v2:85.2514 v3:90.8275 v4:83.6477 v5:96.8443 v6:90.7515 v7:88.9926 v8:85.4466 v9:83.7046 v10:70.5845 v11:67.8912 v12:70.6059 v13:202.0904] eff_attn_scale:[v0:0.3928 v1:0.7792 v2:0.7549 v3:0.8647 v4:1.1528 v5:0.6507 v6:0.7025 v7:0.7658 v8:1.1725 v9:0.6105 v10:0.5814 v11:0.5476 v12:0.8094 v13:1.6740] -step:4600/20000 val_loss:2.2220 val_bpb:1.3160 train_time:272637ms step_avg:59.27ms -step:4800/20000 train_loss:2.3243 train_time:284445ms step_avg:59.26ms -step:4800 shared0_alpha:mean=0.401,std=0.062 shared1_alpha:mean=0.553,std=0.056 shared2_alpha:mean=0.609,std=0.058 shared3_alpha:mean=0.670,std=0.059 eff_mlp_scale:[v0:149.6388 v1:89.8112 v2:86.9723 v3:93.2148 v4:85.5791 v5:98.6451 v6:92.5237 v7:90.8961 v8:86.9447 v9:85.3943 v10:72.1685 v11:69.5633 v12:72.3780 v13:205.0534] eff_attn_scale:[v0:0.3859 v1:0.7676 v2:0.7423 v3:0.8653 v4:1.1359 v5:0.6436 v6:0.6898 v7:0.7623 v8:1.1603 v9:0.6076 v10:0.5728 v11:0.5398 v12:0.7946 v13:1.6561] -step:4800/20000 val_loss:2.2187 val_bpb:1.3140 train_time:284477ms step_avg:59.27ms -step:5000/20000 train_loss:2.0935 train_time:296302ms step_avg:59.26ms -step:5000 shared0_alpha:mean=0.399,std=0.062 shared1_alpha:mean=0.553,std=0.057 shared2_alpha:mean=0.610,std=0.058 shared3_alpha:mean=0.672,std=0.059 eff_mlp_scale:[v0:153.8149 v1:91.5847 v2:89.0435 v3:95.4824 v4:87.7745 v5:100.4957 v6:93.7055 v7:92.2061 v8:88.6936 v9:86.6342 v10:73.1929 v11:70.6757 v12:73.9879 v13:208.6951] eff_attn_scale:[v0:0.3806 v1:0.7597 v2:0.7427 v3:0.8650 v4:1.1437 v5:0.6259 v6:0.6865 v7:0.7615 v8:1.1633 v9:0.5904 v10:0.5621 v11:0.5422 v12:0.8001 v13:1.6242] -step:5000/20000 val_loss:2.2133 val_bpb:1.3108 train_time:296334ms step_avg:59.27ms -step:5200/20000 train_loss:2.2323 train_time:308141ms step_avg:59.26ms -step:5200 shared0_alpha:mean=0.397,std=0.062 shared1_alpha:mean=0.553,std=0.058 shared2_alpha:mean=0.612,std=0.059 shared3_alpha:mean=0.674,std=0.060 eff_mlp_scale:[v0:159.4933 v1:93.7957 v2:90.9036 v3:97.4001 v4:89.7327 v5:101.7783 v6:95.6136 v7:94.5632 v8:91.1275 v9:87.8087 v10:74.4185 v11:71.8680 v12:75.7846 v13:212.6315] eff_attn_scale:[v0:0.3687 v1:0.7652 v2:0.7401 v3:0.8644 v4:1.1541 v5:0.6304 v6:0.6878 v7:0.7615 v8:1.1688 v9:0.5947 v10:0.5631 v11:0.5392 v12:0.8054 v13:1.6234] -step:5200/20000 val_loss:2.2144 val_bpb:1.3115 train_time:308173ms step_avg:59.26ms -step:5400/20000 train_loss:2.2443 train_time:320001ms step_avg:59.26ms -step:5400 shared0_alpha:mean=0.395,std=0.062 shared1_alpha:mean=0.554,std=0.058 shared2_alpha:mean=0.613,std=0.059 shared3_alpha:mean=0.675,std=0.060 eff_mlp_scale:[v0:163.9131 v1:95.6766 v2:93.0526 v3:99.6529 v4:91.3985 v5:103.7336 v6:97.3254 v7:95.8384 v8:92.8046 v9:89.6339 v10:75.4865 v11:73.4285 v12:76.8684 v13:215.8871] eff_attn_scale:[v0:0.3644 v1:0.7672 v2:0.7476 v3:0.8665 v4:1.1519 v5:0.6249 v6:0.6761 v7:0.7551 v8:1.1715 v9:0.5893 v10:0.5528 v11:0.5364 v12:0.7990 v13:1.6067] -step:5400/20000 val_loss:2.2088 val_bpb:1.3082 train_time:320032ms step_avg:59.27ms -step:5600/20000 train_loss:2.2457 train_time:331844ms step_avg:59.26ms -step:5600 shared0_alpha:mean=0.393,std=0.063 shared1_alpha:mean=0.555,std=0.058 shared2_alpha:mean=0.615,std=0.059 shared3_alpha:mean=0.677,std=0.061 eff_mlp_scale:[v0:168.9937 v1:97.9664 v2:94.8258 v3:101.4387 v4:93.6504 v5:105.5803 v6:99.1361 v7:98.0734 v8:94.5964 v9:90.8600 v10:76.6269 v11:74.5165 v12:78.5150 v13:219.7498] eff_attn_scale:[v0:0.3544 v1:0.7693 v2:0.7419 v3:0.8706 v4:1.1515 v5:0.6273 v6:0.6704 v7:0.7515 v8:1.1515 v9:0.5839 v10:0.5475 v11:0.5338 v12:0.7904 v13:1.6007] -step:5600/20000 val_loss:2.2100 val_bpb:1.3089 train_time:331876ms step_avg:59.26ms -step:5800/20000 train_loss:2.2100 train_time:343701ms step_avg:59.26ms -step:5800 shared0_alpha:mean=0.392,std=0.063 shared1_alpha:mean=0.556,std=0.058 shared2_alpha:mean=0.617,std=0.060 shared3_alpha:mean=0.678,std=0.061 eff_mlp_scale:[v0:173.9413 v1:100.1436 v2:97.1412 v3:103.6861 v4:95.6212 v5:106.7858 v6:101.0075 v7:99.3255 v8:96.5774 v9:91.9686 v10:78.2929 v11:75.5843 v12:79.8437 v13:222.7330] eff_attn_scale:[v0:0.3424 v1:0.7614 v2:0.7388 v3:0.8595 v4:1.1536 v5:0.6169 v6:0.6673 v7:0.7339 v8:1.1439 v9:0.5779 v10:0.5481 v11:0.5271 v12:0.7901 v13:1.5714] -step:5800/20000 val_loss:2.2085 val_bpb:1.3080 train_time:343733ms step_avg:59.26ms -step:6000/20000 train_loss:2.2787 train_time:355541ms step_avg:59.26ms -step:6000 shared0_alpha:mean=0.390,std=0.063 shared1_alpha:mean=0.556,std=0.057 shared2_alpha:mean=0.618,std=0.060 shared3_alpha:mean=0.679,std=0.061 eff_mlp_scale:[v0:179.7115 v1:102.5137 v2:98.8222 v3:106.1559 v4:97.2881 v5:109.2105 v6:102.7167 v7:101.2639 v8:98.2513 v9:93.2411 v10:79.3499 v11:76.8040 v12:81.3945 v13:226.4459] eff_attn_scale:[v0:0.3415 v1:0.7450 v2:0.7313 v3:0.8596 v4:1.1518 v5:0.6067 v6:0.6684 v7:0.7385 v8:1.1567 v9:0.5683 v10:0.5465 v11:0.5246 v12:0.7922 v13:1.5915] -step:6000/20000 val_loss:2.2042 val_bpb:1.3054 train_time:355572ms step_avg:59.26ms -step:6200/20000 train_loss:2.1520 train_time:367387ms step_avg:59.26ms -step:6200 shared0_alpha:mean=0.388,std=0.063 shared1_alpha:mean=0.557,std=0.058 shared2_alpha:mean=0.620,std=0.060 shared3_alpha:mean=0.681,std=0.062 eff_mlp_scale:[v0:184.1409 v1:104.4329 v2:101.0519 v3:108.1300 v4:99.9027 v5:110.6677 v6:103.9952 v7:102.6988 v8:100.3901 v9:95.0807 v10:80.4491 v11:78.0116 v12:83.3335 v13:228.4138] eff_attn_scale:[v0:0.3336 v1:0.7624 v2:0.7430 v3:0.8663 v4:1.1573 v5:0.6139 v6:0.6565 v7:0.7402 v8:1.1622 v9:0.5630 v10:0.5347 v11:0.5287 v12:0.7878 v13:1.5432] -step:6200/20000 val_loss:2.2034 val_bpb:1.3050 train_time:367419ms step_avg:59.26ms -step:6400/20000 train_loss:2.2230 train_time:379237ms step_avg:59.26ms -step:6400 shared0_alpha:mean=0.386,std=0.064 shared1_alpha:mean=0.558,std=0.058 shared2_alpha:mean=0.621,std=0.061 shared3_alpha:mean=0.682,std=0.062 eff_mlp_scale:[v0:189.8302 v1:106.8324 v2:102.8722 v3:110.1204 v4:101.9474 v5:112.5930 v6:105.8396 v7:104.6393 v8:102.4399 v9:96.3587 v10:81.6053 v11:79.2269 v12:85.2024 v13:232.3741] eff_attn_scale:[v0:0.3267 v1:0.7589 v2:0.7307 v3:0.8651 v4:1.1648 v5:0.5993 v6:0.6447 v7:0.7392 v8:1.1502 v9:0.5604 v10:0.5353 v11:0.5199 v12:0.7944 v13:1.5767] -step:6400/20000 val_loss:2.2000 val_bpb:1.3029 train_time:379269ms step_avg:59.26ms -step:6600/20000 train_loss:2.1894 train_time:391086ms step_avg:59.26ms -step:6600 shared0_alpha:mean=0.384,std=0.064 shared1_alpha:mean=0.558,std=0.059 shared2_alpha:mean=0.622,std=0.061 shared3_alpha:mean=0.684,std=0.063 eff_mlp_scale:[v0:195.4746 v1:108.5459 v2:104.6949 v3:112.5847 v4:103.8120 v5:114.3420 v6:107.1876 v7:106.5534 v8:104.3087 v9:97.4805 v10:82.7588 v11:80.9202 v12:86.9239 v13:236.5316] eff_attn_scale:[v0:0.3259 v1:0.7674 v2:0.7435 v3:0.8732 v4:1.1760 v5:0.6077 v6:0.6530 v7:0.7420 v8:1.1562 v9:0.5648 v10:0.5311 v11:0.5247 v12:0.8054 v13:1.5359] -step:6600/20000 val_loss:2.1963 val_bpb:1.3007 train_time:391118ms step_avg:59.26ms -step:6800/20000 train_loss:2.2551 train_time:402928ms step_avg:59.25ms -step:6800 shared0_alpha:mean=0.382,std=0.063 shared1_alpha:mean=0.558,std=0.059 shared2_alpha:mean=0.624,std=0.062 shared3_alpha:mean=0.685,std=0.063 eff_mlp_scale:[v0:199.3570 v1:110.9249 v2:106.4690 v3:114.4221 v4:106.2555 v5:115.7016 v6:108.4778 v7:108.3466 v8:105.7543 v9:98.7179 v10:83.8694 v11:82.0194 v12:88.2121 v13:238.4881] eff_attn_scale:[v0:0.3224 v1:0.7681 v2:0.7286 v3:0.8746 v4:1.1817 v5:0.6052 v6:0.6468 v7:0.7323 v8:1.1669 v9:0.5470 v10:0.5182 v11:0.5187 v12:0.7977 v13:1.5251] -step:6800/20000 val_loss:2.1956 val_bpb:1.3004 train_time:402961ms step_avg:59.26ms -step:7000/20000 train_loss:2.2832 train_time:414774ms step_avg:59.25ms -step:7000 shared0_alpha:mean=0.380,std=0.064 shared1_alpha:mean=0.559,std=0.059 shared2_alpha:mean=0.625,std=0.062 shared3_alpha:mean=0.686,std=0.063 eff_mlp_scale:[v0:205.1638 v1:112.9238 v2:108.7964 v3:116.8961 v4:108.2008 v5:117.2053 v6:110.3145 v7:109.7496 v8:108.2008 v9:100.0794 v10:85.0130 v11:83.2055 v12:89.9988 v13:242.5826] eff_attn_scale:[v0:0.3126 v1:0.7662 v2:0.7415 v3:0.8720 v4:1.1966 v5:0.6037 v6:0.6586 v7:0.7369 v8:1.1671 v9:0.5495 v10:0.5245 v11:0.5240 v12:0.8076 v13:1.5226] -step:7000/20000 val_loss:2.1936 val_bpb:1.2992 train_time:414805ms step_avg:59.26ms -step:7200/20000 train_loss:2.2635 train_time:426617ms step_avg:59.25ms -step:7200 shared0_alpha:mean=0.379,std=0.063 shared1_alpha:mean=0.560,std=0.059 shared2_alpha:mean=0.627,std=0.062 shared3_alpha:mean=0.687,std=0.063 eff_mlp_scale:[v0:209.4189 v1:114.7228 v2:110.1191 v3:118.8122 v4:110.1611 v5:119.0317 v6:112.1584 v7:111.6115 v8:109.6511 v9:101.2577 v10:86.1580 v11:84.3516 v12:91.8009 v13:246.3847] eff_attn_scale:[v0:0.3114 v1:0.7758 v2:0.7512 v3:0.8862 v4:1.1983 v5:0.6082 v6:0.6597 v7:0.7536 v8:1.1687 v9:0.5536 v10:0.5326 v11:0.5342 v12:0.8087 v13:1.5231] -step:7200/20000 val_loss:2.1947 val_bpb:1.2998 train_time:426648ms step_avg:59.26ms -step:7400/20000 train_loss:2.1789 train_time:438467ms step_avg:59.25ms -step:7400 shared0_alpha:mean=0.376,std=0.063 shared1_alpha:mean=0.560,std=0.060 shared2_alpha:mean=0.628,std=0.063 shared3_alpha:mean=0.688,std=0.064 eff_mlp_scale:[v0:214.8811 v1:116.5179 v2:112.4040 v3:120.8861 v4:112.1007 v5:120.8534 v6:113.4305 v7:113.6226 v8:111.5865 v9:102.4273 v10:87.2542 v11:85.6061 v12:93.0745 v13:250.0835] eff_attn_scale:[v0:0.3001 v1:0.7651 v2:0.7407 v3:0.8777 v4:1.2095 v5:0.6036 v6:0.6472 v7:0.7436 v8:1.1749 v9:0.5459 v10:0.5146 v11:0.5140 v12:0.7997 v13:1.5224] -step:7400/20000 val_loss:2.1898 val_bpb:1.2969 train_time:438499ms step_avg:59.26ms -step:7600/20000 train_loss:2.0665 train_time:450304ms step_avg:59.25ms -step:7600 shared0_alpha:mean=0.375,std=0.064 shared1_alpha:mean=0.561,std=0.060 shared2_alpha:mean=0.630,std=0.063 shared3_alpha:mean=0.690,std=0.064 eff_mlp_scale:[v0:219.5673 v1:119.0424 v2:114.3795 v3:123.2943 v4:114.5205 v5:122.3188 v6:114.8971 v7:114.9354 v8:113.4841 v9:103.7525 v10:88.5018 v11:86.7240 v12:94.8292 v13:252.0920] eff_attn_scale:[v0:0.3055 v1:0.7820 v2:0.7397 v3:0.8893 v4:1.2225 v5:0.5991 v6:0.6463 v7:0.7418 v8:1.1827 v9:0.5446 v10:0.5178 v11:0.5184 v12:0.8100 v13:1.5209] -step:7600/20000 val_loss:2.1886 val_bpb:1.2962 train_time:450336ms step_avg:59.25ms -step:7800/20000 train_loss:2.2108 train_time:462150ms step_avg:59.25ms -step:7800 shared0_alpha:mean=0.373,std=0.063 shared1_alpha:mean=0.562,std=0.060 shared2_alpha:mean=0.631,std=0.063 shared3_alpha:mean=0.691,std=0.065 eff_mlp_scale:[v0:225.3324 v1:120.9175 v2:116.2944 v3:125.4547 v4:116.5135 v5:124.2153 v6:116.8159 v7:116.4937 v8:115.4685 v9:104.9784 v10:89.1764 v11:88.0292 v12:96.1367 v13:256.1435] eff_attn_scale:[v0:0.3010 v1:0.7857 v2:0.7362 v3:0.9016 v4:1.2284 v5:0.6009 v6:0.6388 v7:0.7418 v8:1.1887 v9:0.5392 v10:0.5181 v11:0.5287 v12:0.8107 v13:1.5223] -step:7800/20000 val_loss:2.1854 val_bpb:1.2943 train_time:462181ms step_avg:59.25ms -step:8000/20000 train_loss:2.1762 train_time:473998ms step_avg:59.25ms -step:8000 shared0_alpha:mean=0.371,std=0.063 shared1_alpha:mean=0.562,std=0.060 shared2_alpha:mean=0.632,std=0.063 shared3_alpha:mean=0.692,std=0.065 eff_mlp_scale:[v0:230.1253 v1:122.9614 v2:118.8503 v3:126.9168 v4:119.1279 v5:126.2847 v6:118.3244 v7:118.4203 v8:117.0194 v9:106.3450 v10:90.9783 v11:89.2135 v12:98.0433 v13:258.0185] eff_attn_scale:[v0:0.2941 v1:0.7905 v2:0.7514 v3:0.9012 v4:1.2311 v5:0.6075 v6:0.6457 v7:0.7407 v8:1.1814 v9:0.5490 v10:0.5126 v11:0.5246 v12:0.8191 v13:1.5115] -step:8000/20000 val_loss:2.1836 val_bpb:1.2933 train_time:474029ms step_avg:59.25ms -step:8200/20000 train_loss:2.2414 train_time:485841ms step_avg:59.25ms -step:8200 shared0_alpha:mean=0.370,std=0.063 shared1_alpha:mean=0.563,std=0.060 shared2_alpha:mean=0.634,std=0.064 shared3_alpha:mean=0.693,std=0.066 eff_mlp_scale:[v0:236.0981 v1:125.4997 v2:120.2003 v3:130.1774 v4:121.1870 v5:127.1730 v6:119.6708 v7:119.9989 v8:119.0609 v9:107.6508 v10:91.6064 v11:90.5349 v12:99.9261 v13:262.0005] eff_attn_scale:[v0:0.2834 v1:0.8007 v2:0.7520 v3:0.9107 v4:1.2411 v5:0.6064 v6:0.6423 v7:0.7493 v8:1.2061 v9:0.5442 v10:0.5209 v11:0.5340 v12:0.8257 v13:1.5124] -step:8200/20000 val_loss:2.1824 val_bpb:1.2925 train_time:485873ms step_avg:59.25ms -step:8400/20000 train_loss:2.1960 train_time:497938ms step_avg:59.28ms -step:8400 shared0_alpha:mean=0.368,std=0.064 shared1_alpha:mean=0.564,std=0.060 shared2_alpha:mean=0.635,std=0.064 shared3_alpha:mean=0.695,std=0.066 eff_mlp_scale:[v0:240.2487 v1:127.4610 v2:122.7811 v3:132.1472 v4:123.2464 v5:129.1455 v6:121.7134 v7:121.8991 v8:121.1030 v9:108.9314 v10:92.8866 v11:91.6940 v12:101.2764 v13:263.9616] eff_attn_scale:[v0:0.2900 v1:0.8182 v2:0.7521 v3:0.9300 v4:1.2661 v5:0.6078 v6:0.6391 v7:0.7523 v8:1.2058 v9:0.5455 v10:0.5105 v11:0.5332 v12:0.8290 v13:1.5126] -step:8400/20000 val_loss:2.1815 val_bpb:1.2920 train_time:497969ms step_avg:59.28ms -step:8600/20000 train_loss:2.1955 train_time:509769ms step_avg:59.28ms -step:8600 shared0_alpha:mean=0.366,std=0.064 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.636,std=0.064 shared3_alpha:mean=0.696,std=0.066 eff_mlp_scale:[v0:246.5324 v1:129.4005 v2:124.7777 v3:134.7354 v4:125.2409 v5:130.5306 v6:123.1642 v7:123.8696 v8:123.0816 v9:110.1882 v10:94.1211 v11:92.9022 v12:103.1078 v13:267.7187] eff_attn_scale:[v0:0.2869 v1:0.8100 v2:0.7461 v3:0.9212 v4:1.2635 v5:0.6153 v6:0.6445 v7:0.7526 v8:1.2086 v9:0.5413 v10:0.5156 v11:0.5346 v12:0.8290 v13:1.5093] -step:8600/20000 val_loss:2.1791 val_bpb:1.2906 train_time:509804ms step_avg:59.28ms -step:8800/20000 train_loss:2.1664 train_time:521614ms step_avg:59.27ms -step:8800 shared0_alpha:mean=0.364,std=0.064 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.638,std=0.064 shared3_alpha:mean=0.697,std=0.066 eff_mlp_scale:[v0:251.0507 v1:131.9416 v2:126.6586 v3:136.1598 v4:127.8389 v5:132.5104 v6:124.4935 v7:125.2233 v8:124.5749 v9:111.4679 v10:95.2646 v11:94.0542 v12:104.4471 v13:271.2946] eff_attn_scale:[v0:0.2766 v1:0.8323 v2:0.7588 v3:0.9297 v4:1.2828 v5:0.6154 v6:0.6414 v7:0.7637 v8:1.2121 v9:0.5522 v10:0.5085 v11:0.5354 v12:0.8384 v13:1.5304] -step:8800/20000 val_loss:2.1778 val_bpb:1.2898 train_time:521645ms step_avg:59.28ms -step:9000/20000 train_loss:2.0805 train_time:533470ms step_avg:59.27ms -step:9000 shared0_alpha:mean=0.363,std=0.064 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.698,std=0.066 eff_mlp_scale:[v0:257.4152 v1:133.7777 v2:128.6963 v3:138.8611 v4:130.0868 v5:134.3494 v6:126.5150 v7:127.2894 v8:126.7935 v9:112.6248 v10:96.5222 v11:95.3293 v12:106.4845 v13:273.4311] eff_attn_scale:[v0:0.2766 v1:0.8403 v2:0.7611 v3:0.9508 v4:1.3217 v5:0.6273 v6:0.6512 v7:0.7699 v8:1.2397 v9:0.5484 v10:0.5139 v11:0.5427 v12:0.8555 v13:1.5236] -step:9000/20000 val_loss:2.1757 val_bpb:1.2886 train_time:533502ms step_avg:59.28ms -step:9200/20000 train_loss:2.1355 train_time:545326ms step_avg:59.27ms -step:9200 shared0_alpha:mean=0.361,std=0.063 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.698,std=0.066 eff_mlp_scale:[v0:262.0444 v1:135.2486 v2:130.5472 v3:140.4741 v4:132.1480 v5:135.8241 v6:127.8047 v7:128.8142 v8:128.2776 v9:113.9541 v10:97.0877 v11:96.6107 v12:107.8195 v13:277.5444] eff_attn_scale:[v0:0.2734 v1:0.8502 v2:0.7689 v3:0.9572 v4:1.3303 v5:0.6336 v6:0.6512 v7:0.7759 v8:1.2633 v9:0.5614 v10:0.5139 v11:0.5482 v12:0.8663 v13:1.5439] -step:9200/20000 val_loss:2.1655 val_bpb:1.2825 train_time:545357ms step_avg:59.28ms -step:9400/20000 train_loss:2.1757 train_time:557172ms step_avg:59.27ms -step:9400 shared0_alpha:mean=0.360,std=0.063 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.640,std=0.065 shared3_alpha:mean=0.699,std=0.066 eff_mlp_scale:[v0:263.2965 v1:136.5867 v2:131.9275 v3:142.5522 v4:134.5448 v5:137.1655 v6:129.1675 v7:130.2536 v8:130.0786 v9:114.5939 v10:97.7036 v11:97.2709 v12:109.4223 v13:279.7477] eff_attn_scale:[v0:0.2731 v1:0.8550 v2:0.7636 v3:0.9639 v4:1.3521 v5:0.6342 v6:0.6539 v7:0.7821 v8:1.2793 v9:0.5580 v10:0.5090 v11:0.5538 v12:0.8841 v13:1.5618] -step:9400/20000 val_loss:2.1559 val_bpb:1.2768 train_time:557203ms step_avg:59.28ms -step:9600/20000 train_loss:2.1789 train_time:569015ms step_avg:59.27ms -step:9600 shared0_alpha:mean=0.359,std=0.063 shared1_alpha:mean=0.565,std=0.061 shared2_alpha:mean=0.640,std=0.065 shared3_alpha:mean=0.699,std=0.067 eff_mlp_scale:[v0:264.5626 v1:137.2842 v2:132.6420 v3:143.3359 v4:135.3772 v5:137.8659 v6:129.8671 v7:130.9697 v8:130.8833 v9:115.1791 v10:98.7878 v11:97.8057 v12:110.0993 v13:281.6810] eff_attn_scale:[v0:0.2760 v1:0.8545 v2:0.7619 v3:0.9659 v4:1.3683 v5:0.6338 v6:0.6598 v7:0.7880 v8:1.2795 v9:0.5536 v10:0.5106 v11:0.5507 v12:0.8930 v13:1.5710] -step:9600/20000 val_loss:2.1464 val_bpb:1.2712 train_time:569047ms step_avg:59.28ms -step:9800/20000 train_loss:2.0988 train_time:580854ms step_avg:59.27ms -step:9800 shared0_alpha:mean=0.359,std=0.062 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.699,std=0.067 eff_mlp_scale:[v0:265.2037 v1:138.5365 v2:133.1247 v3:143.8686 v4:136.1328 v5:138.5365 v6:130.8967 v7:131.4564 v8:131.6138 v9:115.7394 v10:99.1473 v11:98.1692 v12:110.7138 v13:285.5064] eff_attn_scale:[v0:0.2792 v1:0.8503 v2:0.7623 v3:0.9771 v4:1.3634 v5:0.6438 v6:0.6556 v7:0.7894 v8:1.2953 v9:0.5588 v10:0.5095 v11:0.5547 v12:0.8915 v13:1.5897] -step:9800/20000 val_loss:2.1375 val_bpb:1.2660 train_time:580886ms step_avg:59.27ms -step:10000/20000 train_loss:2.1327 train_time:592699ms step_avg:59.27ms -step:10000 shared0_alpha:mean=0.359,std=0.062 shared1_alpha:mean=0.564,std=0.061 shared2_alpha:mean=0.639,std=0.065 shared3_alpha:mean=0.698,std=0.067 eff_mlp_scale:[v0:265.1537 v1:138.7435 v2:133.3693 v3:144.2558 v4:136.6628 v5:138.7435 v6:131.1372 v7:131.8102 v8:132.1263 v9:115.9123 v10:99.3294 v11:98.4333 v12:111.1449 v13:286.6695] eff_attn_scale:[v0:0.2733 v1:0.8454 v2:0.7637 v3:0.9686 v4:1.3718 v5:0.6391 v6:0.6563 v7:0.7894 v8:1.3032 v9:0.5582 v10:0.5131 v11:0.5504 v12:0.8970 v13:1.5789] -step:10000/20000 val_loss:2.1286 val_bpb:1.2606 train_time:592731ms step_avg:59.27ms -step:10122/20000 val_loss:2.1246 val_bpb:1.2583 train_time:600039ms step_avg:59.28ms -stopping_early: wallclock_cap train_time:600039ms step:10122/20000 -peak memory allocated: 13736 MiB reserved: 14068 MiB -Serialized model: 45178938 bytes -Code size: 57024 bytes -Total submission size: 45235962 bytes -Serialized model int8+zlib: 10717286 bytes (payload:11638976 raw_torch:11670071 payload_ratio:3.88x) -Total submission size int8+zlib: 10774310 bytes -final_int8_zlib_roundtrip val_loss:2.1374 val_bpb:1.2659 eval_time:1873ms -final_int8_zlib_roundtrip_exact val_loss:2.13735865 val_bpb:1.26586418 +step:0/20000 val_loss:6.9332 val_bpb:4.1062 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9329 train_time:38ms step_avg:38.33ms +step:2/20000 train_loss:9.5771 train_time:96ms step_avg:47.94ms +step:3/20000 train_loss:12.0056 train_time:154ms step_avg:51.33ms +step:4/20000 train_loss:11.0951 train_time:211ms step_avg:52.78ms +step:5/20000 train_loss:9.6599 train_time:271ms step_avg:54.14ms +step:6/20000 train_loss:8.8458 train_time:328ms step_avg:54.59ms +step:7/20000 train_loss:7.6174 train_time:385ms step_avg:54.93ms +step:8/20000 train_loss:7.0656 train_time:442ms step_avg:55.30ms +step:9/20000 train_loss:6.6004 train_time:500ms step_avg:55.52ms +step:10/20000 train_loss:6.0300 train_time:557ms step_avg:55.74ms +step:200/20000 train_loss:2.8222 train_time:11696ms step_avg:58.48ms +step:200 shared0_alpha:mean=0.473,std=0.047 shared1_alpha:mean=0.489,std=0.042 shared2_alpha:mean=0.494,std=0.047 shared3_alpha:mean=0.516,std=0.047 eff_mlp_scale:[v0:42.1731 v1:27.6985 v2:28.2939 v3:30.5156 v4:32.0949 v5:30.6609 v6:28.7337 v7:30.6697 v8:33.5249 v9:31.6977 v10:31.3725 v11:34.2144 v12:35.4315 v13:50.2958] eff_attn_scale:[v0:14.9320 v1:10.1316 v2:10.8969 v3:9.6005 v4:9.3961 v5:10.0503 v6:9.6550 v7:8.9120 v8:9.2401 v9:10.2130 v10:9.9354 v11:9.2180 v12:9.8639 v13:16.0045] eff_attn_bias:[v0:0.1630 v1:0.1450 v2:0.1581 v3:0.1616 v4:0.1650 v5:0.1692 v6:0.1733 v7:0.1795 v8:0.1851 v9:0.1961 v10:0.1961 v11:0.1878 v12:0.1864 v13:0.1485] eff_mlp_bias:[v0:0.1609 v1:0.1595 v2:0.1540 v3:0.1512 v4:0.1581 v5:0.1643 v6:0.1630 v7:0.1733 v8:0.1906 v9:0.1906 v10:0.1754 v11:0.1733 v12:0.1547 v13:0.1864] depth_encoding:sinusoidal +step:200/20000 val_loss:2.8059 val_bpb:1.6618 train_time:11751ms step_avg:58.75ms +step:400/20000 train_loss:2.3856 train_time:23480ms step_avg:58.70ms +step:400 shared0_alpha:mean=0.483,std=0.049 shared1_alpha:mean=0.509,std=0.044 shared2_alpha:mean=0.517,std=0.049 shared3_alpha:mean=0.545,std=0.050 eff_mlp_scale:[v0:50.5521 v1:32.5033 v2:35.1304 v3:37.9512 v4:38.0848 v5:41.3063 v6:38.4290 v7:38.7945 v8:39.2541 v9:40.9677 v10:38.4290 v11:39.6379 v12:37.9178 v13:66.6226] eff_attn_scale:[v0:6.5055 v1:5.4401 v2:5.9484 v3:5.0682 v4:4.9102 v5:5.8673 v6:5.6164 v7:4.9169 v8:5.0114 v9:5.6679 v10:5.2014 v11:4.4378 v12:4.6318 v13:9.4584] eff_attn_bias:[v0:0.2168 v1:0.1947 v2:0.2085 v3:0.2058 v4:0.2113 v5:0.2141 v6:0.2141 v7:0.2003 v8:0.2058 v9:0.2182 v10:0.2182 v11:0.1989 v12:0.2003 v13:0.1595] eff_mlp_bias:[v0:0.2955 v1:0.2293 v2:0.1947 v3:0.1878 v4:0.1906 v5:0.1906 v6:0.1733 v7:0.1782 v8:0.1989 v9:0.2044 v10:0.1823 v11:0.1761 v12:0.1692 v13:0.2417] depth_encoding:sinusoidal +step:400/20000 val_loss:2.5853 val_bpb:1.5312 train_time:23505ms step_avg:58.76ms +step:600/20000 train_loss:2.5924 train_time:35261ms step_avg:58.77ms +step:600 shared0_alpha:mean=0.488,std=0.049 shared1_alpha:mean=0.521,std=0.047 shared2_alpha:mean=0.532,std=0.050 shared3_alpha:mean=0.566,std=0.052 eff_mlp_scale:[v0:54.6022 v1:36.1924 v2:39.3814 v3:42.9007 v4:41.9468 v5:48.4979 v6:44.9576 v7:44.3131 v8:43.1502 v9:46.6882 v10:42.3437 v11:42.3711 v12:39.5400 v13:83.2264] eff_attn_scale:[v0:3.0099 v1:3.1470 v2:3.4999 v3:3.0893 v4:3.0057 v5:3.4597 v6:3.4027 v7:3.1254 v8:3.1677 v9:3.2447 v10:3.0333 v11:2.6737 v12:2.7177 v13:6.2262] eff_attn_bias:[v0:0.2735 v1:0.2444 v2:0.2472 v3:0.2472 v4:0.2610 v5:0.2665 v6:0.2569 v7:0.2237 v8:0.2293 v9:0.2403 v10:0.2306 v11:0.2030 v12:0.2099 v13:0.1823] eff_mlp_bias:[v0:0.4392 v1:0.2886 v2:0.2417 v3:0.2334 v4:0.2500 v5:0.2320 v6:0.1892 v7:0.1906 v8:0.2168 v9:0.2182 v10:0.1892 v11:0.1837 v12:0.1864 v13:0.2500] depth_encoding:sinusoidal +step:600/20000 val_loss:2.4888 val_bpb:1.4740 train_time:35285ms step_avg:58.81ms +step:800/20000 train_loss:2.3501 train_time:47075ms step_avg:58.84ms +step:800 shared0_alpha:mean=0.491,std=0.048 shared1_alpha:mean=0.529,std=0.049 shared2_alpha:mean=0.540,std=0.052 shared3_alpha:mean=0.579,std=0.054 eff_mlp_scale:[v0:59.9201 v1:40.1580 v2:43.4057 v3:46.3101 v4:44.5735 v5:54.1754 v6:49.9165 v7:48.3078 v8:46.1591 v9:50.7658 v10:45.0334 v11:44.4940 v12:41.4022 v13:96.0354] eff_attn_scale:[v0:1.9268 v1:2.2402 v2:2.4450 v3:2.2894 v4:2.2408 v5:2.4269 v6:2.4606 v7:2.3638 v8:2.3872 v9:2.2869 v10:2.1335 v11:1.9326 v12:1.9186 v13:4.7429] eff_attn_bias:[v0:0.3135 v1:0.2817 v2:0.2831 v3:0.2817 v4:0.3094 v5:0.3066 v6:0.2942 v7:0.2514 v8:0.2569 v9:0.2610 v10:0.2444 v11:0.2085 v12:0.2210 v13:0.2154] eff_mlp_bias:[v0:0.5635 v1:0.3370 v2:0.2845 v3:0.2748 v4:0.2997 v5:0.2721 v6:0.2099 v7:0.2044 v8:0.2389 v9:0.2362 v10:0.2003 v11:0.1975 v12:0.2072 v13:0.2569] depth_encoding:sinusoidal +step:800/20000 val_loss:2.4288 val_bpb:1.4385 train_time:47100ms step_avg:58.88ms +step:1000/20000 train_loss:2.4207 train_time:58897ms step_avg:58.90ms +step:1000 shared0_alpha:mean=0.492,std=0.050 shared1_alpha:mean=0.534,std=0.051 shared2_alpha:mean=0.546,std=0.053 shared3_alpha:mean=0.588,std=0.054 eff_mlp_scale:[v0:65.1131 v1:43.7501 v2:46.8400 v3:49.6567 v4:47.2571 v5:58.8567 v6:54.4912 v7:51.8968 v8:49.4216 v9:54.5405 v10:47.7731 v11:47.0432 v12:43.6497 v13:107.1387] eff_attn_scale:[v0:1.4494 v1:1.7948 v2:1.9658 v3:1.8958 v4:1.8762 v5:1.9716 v6:1.9931 v7:2.0026 v8:1.9544 v9:1.8220 v10:1.7132 v11:1.6021 v12:1.5635 v13:3.8871] eff_attn_bias:[v0:0.3466 v1:0.3094 v2:0.3135 v3:0.3135 v4:0.3480 v5:0.3439 v6:0.3356 v7:0.2790 v8:0.2859 v9:0.2831 v10:0.2583 v11:0.2154 v12:0.2306 v13:0.2458] eff_mlp_bias:[v0:0.6712 v1:0.3784 v2:0.3246 v3:0.3107 v4:0.3384 v5:0.3149 v6:0.2306 v7:0.2196 v8:0.2652 v9:0.2555 v10:0.2141 v11:0.2127 v12:0.2251 v13:0.2638] depth_encoding:sinusoidal +step:1000/20000 val_loss:2.3864 val_bpb:1.4134 train_time:58922ms step_avg:58.92ms +step:1200/20000 train_loss:2.4329 train_time:70714ms step_avg:58.93ms +step:1200 shared0_alpha:mean=0.491,std=0.051 shared1_alpha:mean=0.540,std=0.052 shared2_alpha:mean=0.551,std=0.054 shared3_alpha:mean=0.594,std=0.056 eff_mlp_scale:[v0:70.1553 v1:47.0161 v2:49.9708 v3:52.2430 v4:49.7842 v5:62.9572 v6:57.9814 v7:54.9123 v8:51.6281 v9:57.7108 v10:49.5894 v11:48.8109 v12:45.9121 v13:116.7423] eff_attn_scale:[v0:1.1874 v1:1.5626 v2:1.6931 v3:1.6728 v4:1.5858 v5:1.6429 v6:1.7182 v7:1.7595 v8:1.7059 v9:1.5441 v10:1.4548 v11:1.3754 v12:1.3335 v13:3.3954] eff_attn_bias:[v0:0.3784 v1:0.3439 v2:0.3411 v3:0.3384 v4:0.3812 v5:0.3784 v6:0.3701 v7:0.3066 v8:0.3149 v9:0.3066 v10:0.2721 v11:0.2237 v12:0.2431 v13:0.2748] eff_mlp_bias:[v0:0.7679 v1:0.4198 v2:0.3591 v3:0.3397 v4:0.3757 v5:0.3522 v6:0.2541 v7:0.2362 v8:0.2914 v9:0.2748 v10:0.2293 v11:0.2279 v12:0.2444 v13:0.2721] depth_encoding:sinusoidal +step:1200/20000 val_loss:2.3567 val_bpb:1.3957 train_time:70739ms step_avg:58.95ms +step:1400/20000 train_loss:2.4819 train_time:82526ms step_avg:58.95ms +step:1400 shared0_alpha:mean=0.491,std=0.052 shared1_alpha:mean=0.543,std=0.054 shared2_alpha:mean=0.554,std=0.055 shared3_alpha:mean=0.598,std=0.057 eff_mlp_scale:[v0:74.8502 v1:50.2471 v2:53.1146 v3:55.3364 v4:51.9688 v5:66.5826 v6:60.9256 v7:57.6745 v8:53.8518 v9:60.7928 v10:51.9430 v11:51.0498 v12:48.2030 v13:125.4454] eff_attn_scale:[v0:0.9967 v1:1.3821 v2:1.5224 v3:1.4748 v4:1.4563 v5:1.4576 v6:1.5224 v7:1.5561 v8:1.5245 v9:1.3646 v10:1.2922 v11:1.2077 v12:1.1889 v13:3.1047] eff_attn_bias:[v0:0.3950 v1:0.3701 v2:0.3646 v3:0.3674 v4:0.4088 v5:0.4060 v6:0.4033 v7:0.3328 v8:0.3425 v9:0.3273 v10:0.2886 v11:0.2306 v12:0.2541 v13:0.3038] eff_mlp_bias:[v0:0.8507 v1:0.4530 v2:0.3895 v3:0.3674 v4:0.4033 v5:0.3867 v6:0.2790 v7:0.2555 v8:0.3163 v9:0.2969 v10:0.2444 v11:0.2444 v12:0.2624 v13:0.2831] depth_encoding:sinusoidal +step:1400/20000 val_loss:2.3334 val_bpb:1.3820 train_time:82551ms step_avg:58.96ms +step:1600/20000 train_loss:2.1555 train_time:94337ms step_avg:58.96ms +step:1600 shared0_alpha:mean=0.492,std=0.053 shared1_alpha:mean=0.547,std=0.055 shared2_alpha:mean=0.557,std=0.056 shared3_alpha:mean=0.599,std=0.059 eff_mlp_scale:[v0:79.5614 v1:53.7939 v2:55.9060 v3:58.2668 v4:54.0765 v5:69.6156 v6:63.4394 v7:60.2486 v8:55.9941 v9:62.8650 v10:53.1305 v11:52.7176 v12:50.2412 v13:132.0925] eff_attn_scale:[v0:0.8713 v1:1.2774 v2:1.4277 v3:1.3605 v4:1.3307 v5:1.3443 v6:1.4107 v7:1.4438 v8:1.3796 v9:1.2551 v10:1.1831 v11:1.0995 v12:1.0863 v13:2.9157] eff_attn_bias:[v0:0.4116 v1:0.4033 v2:0.3839 v3:0.3922 v4:0.4337 v5:0.4309 v6:0.4309 v7:0.3563 v8:0.3674 v9:0.3439 v10:0.3025 v11:0.2403 v12:0.2665 v13:0.3342] eff_mlp_bias:[v0:0.9226 v1:0.4889 v2:0.4171 v3:0.3922 v4:0.4337 v5:0.4171 v6:0.3025 v7:0.2748 v8:0.3397 v9:0.3176 v10:0.2610 v11:0.2610 v12:0.2804 v13:0.2969] depth_encoding:sinusoidal +step:1600/20000 val_loss:2.3219 val_bpb:1.3752 train_time:94362ms step_avg:58.98ms +step:1800/20000 train_loss:2.2543 train_time:106138ms step_avg:58.97ms +step:1800 shared0_alpha:mean=0.491,std=0.054 shared1_alpha:mean=0.551,std=0.055 shared2_alpha:mean=0.559,std=0.057 shared3_alpha:mean=0.601,std=0.060 eff_mlp_scale:[v0:84.1182 v1:56.3085 v2:58.5366 v3:60.7906 v4:56.2029 v5:72.6423 v6:66.2069 v7:62.8035 v8:58.1544 v9:65.7649 v10:55.3070 v11:54.7518 v12:52.3000 v13:139.4565] eff_attn_scale:[v0:0.7733 v1:1.1822 v2:1.3352 v3:1.2542 v4:1.2462 v5:1.2034 v6:1.3187 v7:1.3183 v8:1.3040 v9:1.1239 v10:1.0879 v11:1.0034 v12:0.9938 v13:2.7729] eff_attn_bias:[v0:0.4254 v1:0.4254 v2:0.4033 v3:0.4143 v4:0.4613 v5:0.4558 v6:0.4558 v7:0.3784 v8:0.3895 v9:0.3646 v10:0.3176 v11:0.2514 v12:0.2790 v13:0.3618] eff_mlp_bias:[v0:0.9888 v1:0.5220 v2:0.4475 v3:0.4198 v4:0.4585 v5:0.4530 v6:0.3273 v7:0.2914 v8:0.3618 v9:0.3397 v10:0.2817 v11:0.2776 v12:0.2983 v13:0.3135] depth_encoding:sinusoidal +step:1800/20000 val_loss:2.3035 val_bpb:1.3642 train_time:106163ms step_avg:58.98ms +step:2000/20000 train_loss:2.3066 train_time:117933ms step_avg:58.97ms +step:2000 shared0_alpha:mean=0.491,std=0.055 shared1_alpha:mean=0.554,std=0.056 shared2_alpha:mean=0.561,std=0.058 shared3_alpha:mean=0.602,std=0.061 eff_mlp_scale:[v0:88.9970 v1:59.4835 v2:61.5244 v3:63.3619 v4:57.7963 v5:75.6665 v6:68.9073 v7:65.4058 v8:60.1715 v9:68.2311 v10:57.0126 v11:56.4125 v12:54.2335 v13:146.0011] eff_attn_scale:[v0:0.7016 v1:1.1244 v2:1.2590 v3:1.1668 v4:1.1665 v5:1.1400 v6:1.2218 v7:1.2339 v8:1.2169 v9:1.0571 v10:1.0093 v11:0.9345 v12:0.9190 v13:2.5910] eff_attn_bias:[v0:0.4309 v1:0.4502 v2:0.4226 v3:0.4337 v4:0.4861 v5:0.4778 v6:0.4834 v7:0.3977 v8:0.4116 v9:0.3812 v10:0.3315 v11:0.2610 v12:0.2900 v13:0.3895] eff_mlp_bias:[v0:1.0551 v1:0.5524 v2:0.4806 v3:0.4447 v4:0.4834 v5:0.4861 v6:0.3522 v7:0.3121 v8:0.3839 v9:0.3618 v10:0.3025 v11:0.2942 v12:0.3149 v13:0.3328] depth_encoding:sinusoidal +step:2000/20000 val_loss:2.2881 val_bpb:1.3551 train_time:117958ms step_avg:58.98ms +step:2200/20000 train_loss:2.1276 train_time:129731ms step_avg:58.97ms +step:2200 shared0_alpha:mean=0.491,std=0.055 shared1_alpha:mean=0.557,std=0.056 shared2_alpha:mean=0.563,std=0.058 shared3_alpha:mean=0.604,std=0.062 eff_mlp_scale:[v0:93.2070 v1:61.6985 v2:64.0873 v3:65.4437 v4:60.1313 v5:78.1219 v6:71.1619 v7:67.1005 v8:62.5527 v9:70.1321 v10:58.6773 v11:58.4023 v12:56.4992 v13:153.3788] eff_attn_scale:[v0:0.6433 v1:1.0833 v2:1.2209 v3:1.1398 v4:1.1340 v5:1.0631 v6:1.1739 v7:1.1858 v8:1.1844 v9:0.9922 v10:0.9600 v11:0.8996 v12:0.8971 v13:2.4914] eff_attn_bias:[v0:0.4392 v1:0.4723 v2:0.4447 v3:0.4558 v4:0.5055 v5:0.4944 v6:0.5055 v7:0.4171 v8:0.4337 v9:0.3977 v10:0.3494 v11:0.2693 v12:0.3025 v13:0.4171] eff_mlp_bias:[v0:1.1159 v1:0.5800 v2:0.5110 v3:0.4668 v4:0.5082 v5:0.5138 v6:0.3757 v7:0.3301 v8:0.4060 v9:0.3839 v10:0.3204 v11:0.3107 v12:0.3301 v13:0.3508] depth_encoding:sinusoidal +step:2200/20000 val_loss:2.2811 val_bpb:1.3510 train_time:129756ms step_avg:58.98ms +step:2400/20000 train_loss:2.2517 train_time:141522ms step_avg:58.97ms +step:2400 shared0_alpha:mean=0.490,std=0.057 shared1_alpha:mean=0.559,std=0.058 shared2_alpha:mean=0.564,std=0.059 shared3_alpha:mean=0.604,std=0.063 eff_mlp_scale:[v0:97.3088 v1:64.4124 v2:66.1193 v3:68.0254 v4:61.7178 v5:81.0785 v6:73.2787 v7:69.7050 v8:63.7614 v9:72.5202 v10:60.2233 v11:60.0471 v12:58.4480 v13:159.8243] eff_attn_scale:[v0:0.6018 v1:1.0217 v2:1.1559 v3:1.0927 v4:1.0802 v5:1.0070 v6:1.1152 v7:1.1426 v8:1.1342 v9:0.9333 v10:0.9166 v11:0.8482 v12:0.8543 v13:2.4043] eff_attn_bias:[v0:0.4475 v1:0.4972 v2:0.4613 v3:0.4723 v4:0.5276 v5:0.5138 v6:0.5276 v7:0.4364 v8:0.4558 v9:0.4143 v10:0.3646 v11:0.2776 v12:0.3135 v13:0.4419] eff_mlp_bias:[v0:1.1822 v1:0.6049 v2:0.5386 v3:0.4889 v4:0.5303 v5:0.5441 v6:0.4005 v7:0.3466 v8:0.4254 v9:0.4060 v10:0.3397 v11:0.3246 v12:0.3425 v13:0.3701] depth_encoding:sinusoidal +step:2400/20000 val_loss:2.2693 val_bpb:1.3440 train_time:141547ms step_avg:58.98ms +step:2600/20000 train_loss:2.4645 train_time:153292ms step_avg:58.96ms +step:2600 shared0_alpha:mean=0.490,std=0.058 shared1_alpha:mean=0.562,std=0.058 shared2_alpha:mean=0.565,std=0.060 shared3_alpha:mean=0.604,std=0.063 eff_mlp_scale:[v0:101.6481 v1:67.0256 v2:69.3391 v3:70.7851 v4:63.5850 v5:83.4400 v6:75.7594 v7:71.6379 v8:66.4942 v9:74.3209 v10:62.4908 v11:62.2568 v12:60.6759 v13:165.2093] eff_attn_scale:[v0:0.5697 v1:0.9969 v2:1.1324 v3:1.0532 v4:1.0312 v5:1.0017 v6:1.0975 v7:1.0772 v8:1.0689 v9:0.9110 v10:0.8930 v11:0.8079 v12:0.8052 v13:2.3357] eff_attn_bias:[v0:0.4613 v1:0.5193 v2:0.4834 v3:0.4889 v4:0.5497 v5:0.5276 v6:0.5497 v7:0.4530 v8:0.4723 v9:0.4254 v10:0.3757 v11:0.2886 v12:0.3273 v13:0.4640] eff_mlp_bias:[v0:1.2374 v1:0.6381 v2:0.5690 v3:0.5138 v4:0.5497 v5:0.5745 v6:0.4198 v7:0.3646 v8:0.4475 v9:0.4226 v10:0.3591 v11:0.3397 v12:0.3563 v13:0.3867] depth_encoding:sinusoidal +step:2600/20000 val_loss:2.2779 val_bpb:1.3491 train_time:153317ms step_avg:58.97ms +step:2800/20000 train_loss:2.2945 train_time:165082ms step_avg:58.96ms +step:2800 shared0_alpha:mean=0.489,std=0.058 shared1_alpha:mean=0.564,std=0.058 shared2_alpha:mean=0.566,std=0.060 shared3_alpha:mean=0.604,std=0.064 eff_mlp_scale:[v0:105.7923 v1:69.8320 v2:71.9417 v3:72.7933 v4:65.5945 v5:86.0182 v6:77.5757 v7:73.6547 v8:67.6969 v9:76.7689 v10:63.7074 v11:63.3172 v12:62.2307 v13:170.3488] eff_attn_scale:[v0:0.5409 v1:0.9771 v2:1.0912 v3:1.0090 v4:1.0242 v5:0.9392 v6:1.0372 v7:1.0185 v8:1.0431 v9:0.8680 v10:0.8504 v11:0.7722 v12:0.7930 v13:2.2121] eff_attn_bias:[v0:0.4613 v1:0.5386 v2:0.4999 v3:0.5055 v4:0.5718 v5:0.5441 v6:0.5690 v7:0.4696 v8:0.4917 v9:0.4392 v10:0.3867 v11:0.2983 v12:0.3370 v13:0.4889] eff_mlp_bias:[v0:1.2982 v1:0.6602 v2:0.5966 v3:0.5359 v4:0.5718 v5:0.6021 v6:0.4419 v7:0.3812 v8:0.4640 v9:0.4447 v10:0.3757 v11:0.3563 v12:0.3701 v13:0.4033] depth_encoding:sinusoidal +step:2800/20000 val_loss:2.2549 val_bpb:1.3355 train_time:165107ms step_avg:58.97ms +step:3000/20000 train_loss:2.2781 train_time:176850ms step_avg:58.95ms +step:3000 shared0_alpha:mean=0.489,std=0.058 shared1_alpha:mean=0.567,std=0.059 shared2_alpha:mean=0.567,std=0.061 shared3_alpha:mean=0.604,std=0.065 eff_mlp_scale:[v0:109.7013 v1:71.9646 v2:74.0094 v3:74.9589 v4:67.7205 v5:87.8528 v6:79.7024 v7:75.3947 v8:69.8501 v9:78.0395 v10:65.2509 v11:64.9353 v12:64.3132 v13:176.7194] eff_attn_scale:[v0:0.5126 v1:0.9590 v2:1.0788 v3:0.9688 v4:0.9963 v5:0.9123 v6:1.0202 v7:1.0111 v8:1.0197 v9:0.8374 v10:0.8299 v11:0.7571 v12:0.7718 v13:2.1389] eff_attn_bias:[v0:0.4640 v1:0.5635 v2:0.5165 v3:0.5220 v4:0.5883 v5:0.5580 v6:0.5883 v7:0.4861 v8:0.5082 v9:0.4530 v10:0.3977 v11:0.3052 v12:0.3494 v13:0.5138] eff_mlp_bias:[v0:1.3590 v1:0.6878 v2:0.6270 v3:0.5524 v4:0.5911 v5:0.6270 v6:0.4640 v7:0.3950 v8:0.4834 v9:0.4640 v10:0.3950 v11:0.3701 v12:0.3839 v13:0.4226] depth_encoding:sinusoidal +step:3000/20000 val_loss:2.2483 val_bpb:1.3316 train_time:176875ms step_avg:58.96ms +step:3200/20000 train_loss:2.2395 train_time:188622ms step_avg:58.94ms +step:3200 shared0_alpha:mean=0.487,std=0.059 shared1_alpha:mean=0.569,std=0.059 shared2_alpha:mean=0.568,std=0.061 shared3_alpha:mean=0.604,std=0.065 eff_mlp_scale:[v0:114.2085 v1:74.3596 v2:76.6599 v3:77.5110 v4:69.5039 v5:90.4630 v6:81.9773 v7:77.0705 v8:71.6624 v9:80.5168 v10:66.9112 v11:66.9413 v12:66.0503 v13:181.8108] eff_attn_scale:[v0:0.4854 v1:0.9360 v2:1.0429 v3:0.9650 v4:0.9903 v5:0.8804 v6:0.9946 v7:0.9839 v8:1.0042 v9:0.8109 v10:0.8015 v11:0.7532 v12:0.7532 v13:2.0938] eff_attn_bias:[v0:0.4751 v1:0.5828 v2:0.5303 v3:0.5331 v4:0.6104 v5:0.5745 v6:0.6104 v7:0.5027 v8:0.5248 v9:0.4640 v10:0.4116 v11:0.3135 v12:0.3618 v13:0.5331] eff_mlp_bias:[v0:1.4087 v1:0.7182 v2:0.6491 v3:0.5718 v4:0.6077 v5:0.6491 v6:0.4861 v7:0.4116 v8:0.5027 v9:0.4861 v10:0.4116 v11:0.3867 v12:0.3977 v13:0.4392] depth_encoding:sinusoidal +step:3200/20000 val_loss:2.2434 val_bpb:1.3287 train_time:188646ms step_avg:58.95ms +step:3400/20000 train_loss:2.2143 train_time:200387ms step_avg:58.94ms +step:3400 shared0_alpha:mean=0.487,std=0.060 shared1_alpha:mean=0.572,std=0.060 shared2_alpha:mean=0.569,std=0.062 shared3_alpha:mean=0.604,std=0.066 eff_mlp_scale:[v0:118.7020 v1:77.2060 v2:78.9053 v3:79.2049 v4:71.1050 v5:92.5513 v6:83.8369 v7:79.2049 v8:72.8499 v9:82.4809 v10:68.5938 v11:68.5256 v12:68.0514 v13:186.2671] eff_attn_scale:[v0:0.4577 v1:0.9144 v2:1.0188 v3:0.9579 v4:0.9676 v5:0.8598 v6:0.9572 v7:0.9718 v8:0.9768 v9:0.7825 v10:0.7866 v11:0.7254 v12:0.7326 v13:2.0363] eff_attn_bias:[v0:0.4723 v1:0.6049 v2:0.5552 v3:0.5497 v4:0.6325 v5:0.5883 v6:0.6270 v7:0.5193 v8:0.5414 v9:0.4778 v10:0.4254 v11:0.3218 v12:0.3729 v13:0.5524] eff_mlp_bias:[v0:1.4584 v1:0.7403 v2:0.6767 v3:0.5939 v4:0.6242 v5:0.6740 v6:0.5082 v7:0.4281 v8:0.5165 v9:0.5027 v10:0.4309 v11:0.4033 v12:0.4116 v13:0.4558] depth_encoding:sinusoidal +step:3400/20000 val_loss:2.2412 val_bpb:1.3274 train_time:200411ms step_avg:58.94ms +step:3600/20000 train_loss:2.1808 train_time:212152ms step_avg:58.93ms +step:3600 shared0_alpha:mean=0.485,std=0.060 shared1_alpha:mean=0.574,std=0.060 shared2_alpha:mean=0.569,std=0.062 shared3_alpha:mean=0.603,std=0.066 eff_mlp_scale:[v0:122.4656 v1:79.4080 v2:81.2694 v3:81.4648 v4:72.8739 v5:94.4181 v6:85.8096 v7:81.0147 v8:75.0822 v9:84.2500 v10:70.3730 v11:70.2127 v12:69.7823 v13:191.2313] eff_attn_scale:[v0:0.4533 v1:0.9054 v2:1.0046 v3:0.9258 v4:0.9416 v5:0.8423 v6:0.9527 v7:0.9396 v8:0.9597 v9:0.7702 v10:0.7640 v11:0.7058 v12:0.7153 v13:2.0134] eff_attn_bias:[v0:0.4778 v1:0.6215 v2:0.5773 v3:0.5690 v4:0.6463 v5:0.6021 v6:0.6463 v7:0.5303 v8:0.5607 v9:0.4917 v10:0.4337 v11:0.3315 v12:0.3839 v13:0.5718] eff_mlp_bias:[v0:1.5137 v1:0.7679 v2:0.7043 v3:0.6160 v4:0.6436 v5:0.6988 v6:0.5303 v7:0.4447 v8:0.5359 v9:0.5248 v10:0.4502 v11:0.4171 v12:0.4254 v13:0.4723] depth_encoding:sinusoidal +step:3600/20000 val_loss:2.2334 val_bpb:1.3227 train_time:212177ms step_avg:58.94ms +step:3800/20000 train_loss:2.2765 train_time:223908ms step_avg:58.92ms +step:3800 shared0_alpha:mean=0.486,std=0.060 shared1_alpha:mean=0.576,std=0.060 shared2_alpha:mean=0.570,std=0.063 shared3_alpha:mean=0.603,std=0.066 eff_mlp_scale:[v0:126.0984 v1:81.8079 v2:83.4886 v3:84.1340 v4:74.6995 v5:96.9937 v6:88.0759 v7:82.7696 v8:76.9360 v9:86.2167 v10:71.5617 v11:71.8549 v12:72.0157 v13:195.6981] eff_attn_scale:[v0:0.4313 v1:0.8895 v2:0.9993 v3:0.9230 v4:0.9478 v5:0.8311 v6:0.9427 v7:0.9321 v8:0.9433 v9:0.7503 v10:0.7589 v11:0.7037 v12:0.7086 v13:1.9990] eff_attn_bias:[v0:0.4778 v1:0.6436 v2:0.5939 v3:0.5828 v4:0.6629 v5:0.6132 v6:0.6574 v7:0.5441 v8:0.5800 v9:0.5027 v10:0.4447 v11:0.3384 v12:0.3950 v13:0.5911] eff_mlp_bias:[v0:1.5578 v1:0.7955 v2:0.7292 v3:0.6325 v4:0.6574 v5:0.7182 v6:0.5497 v7:0.4613 v8:0.5552 v9:0.5414 v10:0.4668 v11:0.4309 v12:0.4392 v13:0.4889] depth_encoding:sinusoidal +step:3800/20000 val_loss:2.2287 val_bpb:1.3200 train_time:223939ms step_avg:58.93ms +step:4000/20000 train_loss:2.2149 train_time:235662ms step_avg:58.92ms +step:4000 shared0_alpha:mean=0.485,std=0.061 shared1_alpha:mean=0.579,std=0.061 shared2_alpha:mean=0.570,std=0.063 shared3_alpha:mean=0.603,std=0.067 eff_mlp_scale:[v0:129.9269 v1:84.2808 v2:85.7783 v3:86.0343 v4:76.3379 v5:99.1539 v6:90.4149 v7:85.1141 v8:78.5964 v9:88.2469 v10:73.7229 v11:73.6122 v12:73.6277 v13:200.7250] eff_attn_scale:[v0:0.4180 v1:0.8840 v2:0.9935 v3:0.9250 v4:0.9492 v5:0.8129 v6:0.9279 v7:0.9114 v8:0.9402 v9:0.7329 v10:0.7404 v11:0.6847 v12:0.7051 v13:1.9278] eff_attn_bias:[v0:0.4834 v1:0.6629 v2:0.6132 v3:0.5994 v4:0.6822 v5:0.6270 v6:0.6767 v7:0.5552 v8:0.5939 v9:0.5138 v10:0.4558 v11:0.3480 v12:0.4060 v13:0.6104] eff_mlp_bias:[v0:1.6020 v1:0.8176 v2:0.7568 v3:0.6519 v4:0.6795 v5:0.7403 v6:0.5690 v7:0.4751 v8:0.5690 v9:0.5580 v10:0.4834 v11:0.4447 v12:0.4502 v13:0.5055] depth_encoding:sinusoidal +step:4000/20000 val_loss:2.2243 val_bpb:1.3174 train_time:235693ms step_avg:58.92ms +step:4200/20000 train_loss:2.2312 train_time:247477ms step_avg:58.92ms +step:4200 shared0_alpha:mean=0.484,std=0.061 shared1_alpha:mean=0.581,std=0.061 shared2_alpha:mean=0.570,std=0.063 shared3_alpha:mean=0.603,std=0.067 eff_mlp_scale:[v0:133.8872 v1:86.5654 v2:88.4886 v3:88.3419 v4:78.0057 v5:101.0763 v6:91.7659 v7:86.9471 v8:79.8304 v9:89.5676 v10:74.9110 v11:74.8582 v12:75.2687 v13:204.7848] eff_attn_scale:[v0:0.4050 v1:0.8744 v2:0.9886 v3:0.8961 v4:0.9314 v5:0.8041 v6:0.9054 v7:0.9006 v8:0.9225 v9:0.7207 v10:0.7299 v11:0.6765 v12:0.6830 v13:1.9229] eff_attn_bias:[v0:0.4834 v1:0.6822 v2:0.6325 v3:0.6104 v4:0.7043 v5:0.6381 v6:0.6905 v7:0.5718 v8:0.6104 v9:0.5248 v10:0.4668 v11:0.3536 v12:0.4143 v13:0.6298] eff_mlp_bias:[v0:1.6352 v1:0.8397 v2:0.7844 v3:0.6740 v4:0.6961 v5:0.7623 v6:0.5883 v7:0.4917 v8:0.5856 v9:0.5773 v10:0.4999 v11:0.4585 v12:0.4613 v13:0.5220] depth_encoding:sinusoidal +step:4200/20000 val_loss:2.2204 val_bpb:1.3150 train_time:247508ms step_avg:58.93ms +step:4400/20000 train_loss:2.1674 train_time:259233ms step_avg:58.92ms +step:4400 shared0_alpha:mean=0.484,std=0.062 shared1_alpha:mean=0.583,std=0.062 shared2_alpha:mean=0.571,std=0.064 shared3_alpha:mean=0.602,std=0.067 eff_mlp_scale:[v0:138.3447 v1:88.9576 v2:91.3003 v3:90.6148 v4:80.3670 v5:103.1100 v6:93.6656 v7:88.2673 v8:81.7527 v9:91.4848 v10:76.6355 v11:76.5296 v12:77.1339 v13:209.9459] eff_attn_scale:[v0:0.3903 v1:0.8642 v2:0.9837 v3:0.8862 v4:0.9046 v5:0.7813 v6:0.9005 v7:0.8952 v8:0.9002 v9:0.7027 v10:0.7112 v11:0.6714 v12:0.6763 v13:1.8581] eff_attn_bias:[v0:0.4861 v1:0.7016 v2:0.6463 v3:0.6270 v4:0.7237 v5:0.6491 v6:0.7071 v7:0.5856 v8:0.6325 v9:0.5359 v10:0.4778 v11:0.3618 v12:0.4254 v13:0.6436] eff_mlp_bias:[v0:1.6794 v1:0.8618 v2:0.8065 v3:0.6961 v4:0.7126 v5:0.7844 v6:0.6104 v7:0.5055 v8:0.6049 v9:0.5911 v10:0.5138 v11:0.4723 v12:0.4723 v13:0.5331] depth_encoding:sinusoidal +step:4400/20000 val_loss:2.2216 val_bpb:1.3158 train_time:259263ms step_avg:58.92ms +step:4600/20000 train_loss:2.0304 train_time:271001ms step_avg:58.91ms +step:4600 shared0_alpha:mean=0.483,std=0.062 shared1_alpha:mean=0.586,std=0.062 shared2_alpha:mean=0.571,std=0.065 shared3_alpha:mean=0.601,std=0.068 eff_mlp_scale:[v0:142.5881 v1:91.9992 v2:93.7048 v3:92.9953 v4:82.0128 v5:105.7990 v6:95.6171 v7:90.6230 v8:83.4107 v9:93.5325 v10:77.9279 v11:78.2868 v12:78.7509 v13:214.2951] eff_attn_scale:[v0:0.3849 v1:0.8685 v2:0.9893 v3:0.8841 v4:0.9260 v5:0.7903 v6:0.8886 v7:0.8841 v8:0.9171 v9:0.7078 v10:0.7099 v11:0.6631 v12:0.6823 v13:1.8677] eff_attn_bias:[v0:0.4861 v1:0.7182 v2:0.6574 v3:0.6408 v4:0.7403 v5:0.6602 v6:0.7237 v7:0.5966 v8:0.6463 v9:0.5497 v10:0.4889 v11:0.3701 v12:0.4364 v13:0.6602] eff_mlp_bias:[v0:1.7125 v1:0.8894 v2:0.8286 v3:0.7126 v4:0.7292 v5:0.8065 v6:0.6325 v7:0.5193 v8:0.6215 v9:0.6077 v10:0.5303 v11:0.4861 v12:0.4861 v13:0.5524] depth_encoding:sinusoidal +step:4600/20000 val_loss:2.2170 val_bpb:1.3130 train_time:271032ms step_avg:58.92ms +step:4800/20000 train_loss:2.3192 train_time:282759ms step_avg:58.91ms +step:4800 shared0_alpha:mean=0.481,std=0.063 shared1_alpha:mean=0.588,std=0.063 shared2_alpha:mean=0.571,std=0.065 shared3_alpha:mean=0.601,std=0.069 eff_mlp_scale:[v0:146.2769 v1:93.7678 v2:96.1486 v3:94.8808 v4:83.7838 v5:107.6784 v6:98.0813 v7:92.4848 v8:84.7251 v9:95.3134 v10:79.7212 v11:80.0257 v12:80.4889 v13:218.0837] eff_attn_scale:[v0:0.3777 v1:0.8454 v2:0.9594 v3:0.8779 v4:0.9276 v5:0.7643 v6:0.8779 v7:0.8867 v8:0.9055 v9:0.6874 v10:0.6969 v11:0.6606 v12:0.6714 v13:1.8193] eff_attn_bias:[v0:0.4861 v1:0.7347 v2:0.6740 v3:0.6546 v4:0.7568 v5:0.6767 v6:0.7403 v7:0.6104 v8:0.6602 v9:0.5580 v10:0.5027 v11:0.3757 v12:0.4447 v13:0.6767] eff_mlp_bias:[v0:1.7567 v1:0.9115 v2:0.8563 v3:0.7292 v4:0.7458 v5:0.8231 v6:0.6491 v7:0.5331 v8:0.6353 v9:0.6270 v10:0.5497 v11:0.5027 v12:0.4999 v13:0.5635] depth_encoding:sinusoidal +step:4800/20000 val_loss:2.2128 val_bpb:1.3106 train_time:282790ms step_avg:58.91ms +step:5000/20000 train_loss:2.0898 train_time:294519ms step_avg:58.90ms +step:5000 shared0_alpha:mean=0.481,std=0.063 shared1_alpha:mean=0.590,std=0.063 shared2_alpha:mean=0.571,std=0.065 shared3_alpha:mean=0.600,std=0.069 eff_mlp_scale:[v0:151.4040 v1:95.6262 v2:98.5672 v3:96.7358 v4:85.5539 v5:109.6583 v6:100.0311 v7:94.3174 v8:86.5045 v9:96.6656 v10:81.0008 v11:81.2580 v12:81.7515 v13:221.9333] eff_attn_scale:[v0:0.3686 v1:0.8479 v2:0.9673 v3:0.8527 v4:0.9195 v5:0.7623 v6:0.8810 v7:0.8658 v8:0.9020 v9:0.6852 v10:0.6948 v11:0.6472 v12:0.6655 v13:1.7874] eff_attn_bias:[v0:0.4889 v1:0.7513 v2:0.6933 v3:0.6684 v4:0.7734 v5:0.6850 v6:0.7568 v7:0.6215 v8:0.6795 v9:0.5718 v10:0.5110 v11:0.3839 v12:0.4558 v13:0.6933] eff_mlp_bias:[v0:1.7899 v1:0.9336 v2:0.8784 v3:0.7458 v4:0.7623 v5:0.8452 v6:0.6657 v7:0.5469 v8:0.6546 v9:0.6408 v10:0.5662 v11:0.5138 v12:0.5110 v13:0.5745] depth_encoding:sinusoidal +step:5000/20000 val_loss:2.2072 val_bpb:1.3072 train_time:294549ms step_avg:58.91ms +step:5200/20000 train_loss:2.2253 train_time:306278ms step_avg:58.90ms +step:5200 shared0_alpha:mean=0.481,std=0.064 shared1_alpha:mean=0.593,std=0.063 shared2_alpha:mean=0.572,std=0.065 shared3_alpha:mean=0.599,std=0.069 eff_mlp_scale:[v0:155.1819 v1:98.6015 v2:101.2260 v3:99.5478 v4:87.4473 v5:111.7134 v6:102.2135 v7:96.1320 v8:88.4082 v9:98.6015 v10:82.9559 v11:82.9565 v12:84.0839 v13:226.0120] eff_attn_scale:[v0:0.3579 v1:0.8480 v2:0.9673 v3:0.8740 v4:0.9194 v5:0.7576 v6:0.8679 v7:0.8566 v8:0.8889 v9:0.6801 v10:0.6916 v11:0.6392 v12:0.6667 v13:1.7798] eff_attn_bias:[v0:0.4917 v1:0.7734 v2:0.7126 v3:0.6850 v4:0.7900 v5:0.6988 v6:0.7734 v7:0.6353 v8:0.6933 v9:0.5800 v10:0.5220 v11:0.3922 v12:0.4613 v13:0.7071] eff_mlp_bias:[v0:1.8230 v1:0.9557 v2:0.9005 v3:0.7679 v4:0.7844 v5:0.8618 v6:0.6878 v7:0.5635 v8:0.6684 v9:0.6574 v10:0.5828 v11:0.5276 v12:0.5220 v13:0.5883] depth_encoding:sinusoidal +step:5200/20000 val_loss:2.2086 val_bpb:1.3081 train_time:306309ms step_avg:58.91ms +step:5400/20000 train_loss:2.2355 train_time:318037ms step_avg:58.90ms +step:5400 shared0_alpha:mean=0.480,std=0.064 shared1_alpha:mean=0.595,std=0.063 shared2_alpha:mean=0.572,std=0.066 shared3_alpha:mean=0.598,std=0.069 eff_mlp_scale:[v0:159.4047 v1:100.5377 v2:103.8090 v3:102.0535 v4:89.2185 v5:113.7664 v6:104.3080 v7:98.1094 v8:90.1882 v9:100.0086 v10:84.3448 v11:84.3051 v12:85.8243 v13:228.3482] eff_attn_scale:[v0:0.3527 v1:0.8505 v2:0.9617 v3:0.8711 v4:0.9155 v5:0.7560 v6:0.8623 v7:0.8492 v8:0.8894 v9:0.6701 v10:0.6908 v11:0.6434 v12:0.6540 v13:1.7649] eff_attn_bias:[v0:0.4999 v1:0.7900 v2:0.7237 v3:0.7016 v4:0.8065 v5:0.7126 v6:0.7844 v7:0.6463 v8:0.7071 v9:0.5911 v10:0.5331 v11:0.4005 v12:0.4723 v13:0.7237] eff_mlp_bias:[v0:1.8562 v1:0.9778 v2:0.9226 v3:0.7844 v4:0.8010 v5:0.8784 v6:0.7043 v7:0.5745 v8:0.6850 v9:0.6740 v10:0.5994 v11:0.5414 v12:0.5331 v13:0.5994] depth_encoding:sinusoidal +step:5400/20000 val_loss:2.2039 val_bpb:1.3053 train_time:318068ms step_avg:58.90ms +step:5600/20000 train_loss:2.2378 train_time:329787ms step_avg:58.89ms +step:5600 shared0_alpha:mean=0.479,std=0.064 shared1_alpha:mean=0.597,std=0.064 shared2_alpha:mean=0.571,std=0.066 shared3_alpha:mean=0.598,std=0.070 eff_mlp_scale:[v0:163.0470 v1:102.9921 v2:105.7195 v3:104.0865 v4:90.8663 v5:115.7994 v6:106.2229 v7:100.1023 v8:91.8434 v9:101.3912 v10:86.0859 v11:86.1577 v12:86.9581 v13:232.2047] eff_attn_scale:[v0:0.3429 v1:0.8505 v2:0.9585 v3:0.8656 v4:0.9370 v5:0.7603 v6:0.8636 v7:0.8612 v8:0.8801 v9:0.6701 v10:0.6782 v11:0.6383 v12:0.6568 v13:1.7430] eff_attn_bias:[v0:0.4972 v1:0.8121 v2:0.7458 v3:0.7126 v4:0.8286 v5:0.7237 v6:0.8010 v7:0.6602 v8:0.7237 v9:0.6021 v10:0.5441 v11:0.4060 v12:0.4834 v13:0.7403] eff_mlp_bias:[v0:1.8893 v1:1.0054 v2:0.9502 v3:0.8010 v4:0.8231 v5:0.9005 v6:0.7182 v7:0.5856 v8:0.6988 v9:0.6878 v10:0.6132 v11:0.5524 v12:0.5441 v13:0.6132] depth_encoding:sinusoidal +step:5600/20000 val_loss:2.2037 val_bpb:1.3052 train_time:329817ms step_avg:58.90ms +step:5800/20000 train_loss:2.2066 train_time:341545ms step_avg:58.89ms +step:5800 shared0_alpha:mean=0.479,std=0.064 shared1_alpha:mean=0.599,std=0.064 shared2_alpha:mean=0.571,std=0.067 shared3_alpha:mean=0.597,std=0.070 eff_mlp_scale:[v0:168.1565 v1:105.4822 v2:108.8509 v3:106.4394 v4:92.7424 v5:117.8602 v6:108.3422 v7:101.4187 v8:93.2357 v9:103.3295 v10:87.4876 v11:87.3607 v12:89.2892 v13:236.3641] eff_attn_scale:[v0:0.3342 v1:0.8420 v2:0.9565 v3:0.8563 v4:0.9415 v5:0.7442 v6:0.8577 v7:0.8520 v8:0.8933 v9:0.6676 v10:0.6781 v11:0.6358 v12:0.6568 v13:1.6983] eff_attn_bias:[v0:0.4972 v1:0.8286 v2:0.7623 v3:0.7292 v4:0.8452 v5:0.7347 v6:0.8121 v7:0.6684 v8:0.7347 v9:0.6160 v10:0.5524 v11:0.4143 v12:0.4917 v13:0.7568] eff_mlp_bias:[v0:1.9224 v1:1.0220 v2:0.9723 v3:0.8176 v4:0.8397 v5:0.9226 v6:0.7347 v7:0.5994 v8:0.7182 v9:0.7043 v10:0.6270 v11:0.5662 v12:0.5552 v13:0.6242] depth_encoding:sinusoidal +step:5800/20000 val_loss:2.2023 val_bpb:1.3043 train_time:341575ms step_avg:58.89ms +step:6000/20000 train_loss:2.2704 train_time:353300ms step_avg:58.88ms +step:6000 shared0_alpha:mean=0.478,std=0.064 shared1_alpha:mean=0.602,std=0.065 shared2_alpha:mean=0.571,std=0.067 shared3_alpha:mean=0.596,std=0.070 eff_mlp_scale:[v0:171.7213 v1:108.0624 v2:111.3780 v3:108.5207 v4:94.4264 v5:120.0090 v6:110.3515 v7:103.4496 v8:94.9234 v9:104.8042 v10:88.7944 v11:89.2507 v12:90.4505 v13:239.8397] eff_attn_scale:[v0:0.3340 v1:0.8365 v2:0.9637 v3:0.8631 v4:0.9384 v5:0.7445 v6:0.8561 v7:0.8459 v8:0.8862 v9:0.6525 v10:0.6723 v11:0.6226 v12:0.6473 v13:1.7154] eff_attn_bias:[v0:0.4999 v1:0.8507 v2:0.7789 v3:0.7403 v4:0.8673 v5:0.7513 v6:0.8286 v7:0.6822 v8:0.7458 v9:0.6270 v10:0.5635 v11:0.4254 v12:0.4999 v13:0.7734] eff_mlp_bias:[v0:1.9445 v1:1.0441 v2:0.9944 v3:0.8342 v4:0.8563 v5:0.9447 v6:0.7568 v7:0.6160 v8:0.7347 v9:0.7237 v10:0.6436 v11:0.5800 v12:0.5662 v13:0.6353] depth_encoding:sinusoidal +step:6000/20000 val_loss:2.1985 val_bpb:1.3021 train_time:353330ms step_avg:58.89ms +step:6200/20000 train_loss:2.1470 train_time:365057ms step_avg:58.88ms +step:6200 shared0_alpha:mean=0.478,std=0.065 shared1_alpha:mean=0.604,std=0.065 shared2_alpha:mean=0.571,std=0.067 shared3_alpha:mean=0.595,std=0.070 eff_mlp_scale:[v0:175.7898 v1:109.9845 v2:113.7156 v3:111.2663 v4:96.1211 v5:122.0226 v6:112.6771 v7:105.6261 v8:96.6217 v9:106.7014 v10:90.8686 v11:91.2692 v12:92.1160 v13:244.0034] eff_attn_scale:[v0:0.3192 v1:0.8359 v2:0.9590 v3:0.8515 v4:0.9179 v5:0.7309 v6:0.8559 v7:0.8515 v8:0.8746 v9:0.6427 v10:0.6767 v11:0.6224 v12:0.6494 v13:1.6625] eff_attn_bias:[v0:0.5027 v1:0.8728 v2:0.8010 v3:0.7568 v4:0.8784 v5:0.7623 v6:0.8397 v7:0.6933 v8:0.7568 v9:0.6353 v10:0.5745 v11:0.4309 v12:0.5110 v13:0.7844] eff_mlp_bias:[v0:1.9777 v1:1.0717 v2:1.0165 v3:0.8507 v4:0.8728 v5:0.9557 v6:0.7679 v7:0.6270 v8:0.7458 v9:0.7347 v10:0.6574 v11:0.5911 v12:0.5773 v13:0.6463] depth_encoding:sinusoidal +step:6200/20000 val_loss:2.1973 val_bpb:1.3014 train_time:365088ms step_avg:58.89ms +step:6400/20000 train_loss:2.2205 train_time:376823ms step_avg:58.88ms +step:6400 shared0_alpha:mean=0.477,std=0.065 shared1_alpha:mean=0.606,std=0.065 shared2_alpha:mean=0.571,std=0.068 shared3_alpha:mean=0.595,std=0.071 eff_mlp_scale:[v0:180.9867 v1:112.6642 v2:116.4174 v3:113.8787 v4:98.5493 v5:124.2620 v6:114.8442 v7:107.6671 v8:98.0439 v9:108.2460 v10:92.8193 v11:92.6558 v12:94.0008 v13:247.6979] eff_attn_scale:[v0:0.3199 v1:0.8530 v2:0.9636 v3:0.8780 v4:0.9439 v5:0.7384 v6:0.8516 v7:0.8476 v8:0.8827 v9:0.6451 v10:0.6678 v11:0.6259 v12:0.6511 v13:1.6925] eff_attn_bias:[v0:0.5027 v1:0.8894 v2:0.8121 v3:0.7679 v4:0.8949 v5:0.7734 v6:0.8563 v7:0.7043 v8:0.7734 v9:0.6463 v10:0.5828 v11:0.4392 v12:0.5193 v13:0.8010] eff_mlp_bias:[v0:2.0108 v1:1.0883 v2:1.0330 v3:0.8673 v4:0.8839 v5:0.9778 v6:0.7844 v7:0.6381 v8:0.7679 v9:0.7513 v10:0.6712 v11:0.6049 v12:0.5883 v13:0.6546] depth_encoding:sinusoidal +step:6400/20000 val_loss:2.1936 val_bpb:1.2992 train_time:376848ms step_avg:58.88ms +step:6600/20000 train_loss:2.1762 train_time:388580ms step_avg:58.88ms +step:6600 shared0_alpha:mean=0.476,std=0.066 shared1_alpha:mean=0.609,std=0.066 shared2_alpha:mean=0.571,std=0.068 shared3_alpha:mean=0.593,std=0.071 eff_mlp_scale:[v0:184.6921 v1:115.1496 v2:118.9919 v3:115.8578 v4:100.3595 v5:125.7189 v6:116.8765 v7:109.5952 v8:99.8501 v9:109.5868 v10:94.1358 v11:94.4606 v12:95.2651 v13:251.7797] eff_attn_scale:[v0:0.3123 v1:0.8518 v2:0.9686 v3:0.8629 v4:0.9508 v5:0.7411 v6:0.8559 v7:0.8455 v8:0.8810 v9:0.6474 v10:0.6757 v11:0.6287 v12:0.6542 v13:1.6763] eff_attn_bias:[v0:0.5027 v1:0.9115 v2:0.8286 v3:0.7844 v4:0.9115 v5:0.7900 v6:0.8728 v7:0.7126 v8:0.7900 v9:0.6602 v10:0.5966 v11:0.4447 v12:0.5331 v13:0.8121] eff_mlp_bias:[v0:2.0329 v1:1.1049 v2:1.0551 v3:0.8894 v4:0.9060 v5:0.9944 v6:0.8010 v7:0.6491 v8:0.7789 v9:0.7679 v10:0.6850 v11:0.6187 v12:0.5994 v13:0.6684] depth_encoding:sinusoidal +step:6600/20000 val_loss:2.1906 val_bpb:1.2974 train_time:388605ms step_avg:58.88ms +step:6800/20000 train_loss:2.2449 train_time:400393ms step_avg:58.88ms +step:6800 shared0_alpha:mean=0.475,std=0.066 shared1_alpha:mean=0.611,std=0.066 shared2_alpha:mean=0.570,std=0.068 shared3_alpha:mean=0.592,std=0.071 eff_mlp_scale:[v0:188.5722 v1:117.1606 v2:121.4109 v3:118.2960 v4:102.3443 v5:127.8116 v6:118.7483 v7:111.4611 v8:101.3158 v9:111.5549 v10:95.8507 v11:95.6883 v12:97.2014 v13:253.8321] eff_attn_scale:[v0:0.3050 v1:0.8442 v2:0.9597 v3:0.8727 v4:0.9510 v5:0.7345 v6:0.8398 v7:0.8383 v8:0.8772 v9:0.6374 v10:0.6576 v11:0.6147 v12:0.6557 v13:1.6447] eff_attn_bias:[v0:0.5082 v1:0.9281 v2:0.8452 v3:0.8010 v4:0.9336 v5:0.8065 v6:0.8839 v7:0.7237 v8:0.8010 v9:0.6712 v10:0.6049 v11:0.4502 v12:0.5414 v13:0.8286] eff_mlp_bias:[v0:2.0661 v1:1.1270 v2:1.0772 v3:0.9005 v4:0.9226 v5:1.0054 v6:0.8176 v7:0.6629 v8:0.7955 v9:0.7844 v10:0.7016 v11:0.6298 v12:0.6104 v13:0.6795] depth_encoding:sinusoidal +step:6800/20000 val_loss:2.1889 val_bpb:1.2964 train_time:400432ms step_avg:58.89ms +step:7000/20000 train_loss:2.2807 train_time:412163ms step_avg:58.88ms +step:7000 shared0_alpha:mean=0.474,std=0.066 shared1_alpha:mean=0.613,std=0.067 shared2_alpha:mean=0.571,std=0.068 shared3_alpha:mean=0.591,std=0.071 eff_mlp_scale:[v0:193.8277 v1:119.3186 v2:124.8006 v3:120.4276 v4:103.6154 v5:130.0629 v6:121.0351 v7:113.5309 v8:103.0973 v9:113.0982 v10:97.3660 v11:97.6153 v12:98.9527 v13:257.8575] eff_attn_scale:[v0:0.3045 v1:0.8598 v2:0.9760 v3:0.8848 v4:0.9615 v5:0.7364 v6:0.8506 v7:0.8456 v8:0.8828 v9:0.6470 v10:0.6581 v11:0.6277 v12:0.6555 v13:1.6466] eff_attn_bias:[v0:0.5082 v1:0.9502 v2:0.8618 v3:0.8176 v4:0.9447 v5:0.8176 v6:0.9005 v7:0.7403 v8:0.8121 v9:0.6795 v10:0.6160 v11:0.4585 v12:0.5497 v13:0.8397] eff_mlp_bias:[v0:2.0992 v1:1.1490 v2:1.0938 v3:0.9170 v4:0.9336 v5:1.0275 v6:0.8342 v7:0.6767 v8:0.8065 v9:0.7955 v10:0.7126 v11:0.6408 v12:0.6187 v13:0.6905] depth_encoding:sinusoidal +step:7000/20000 val_loss:2.1884 val_bpb:1.2961 train_time:412186ms step_avg:58.88ms +step:7200/20000 train_loss:2.2563 train_time:423926ms step_avg:58.88ms +step:7200 shared0_alpha:mean=0.474,std=0.067 shared1_alpha:mean=0.616,std=0.067 shared2_alpha:mean=0.571,std=0.069 shared3_alpha:mean=0.590,std=0.071 eff_mlp_scale:[v0:197.6152 v1:121.8241 v2:126.7557 v3:122.5043 v4:106.0428 v5:131.5017 v6:122.9639 v7:115.5499 v8:104.4757 v9:114.4235 v10:99.1295 v11:98.9664 v12:100.2967 v13:261.8293] eff_attn_scale:[v0:0.2982 v1:0.8658 v2:0.9792 v3:0.8818 v4:0.9702 v5:0.7427 v6:0.8451 v7:0.8556 v8:0.8824 v9:0.6451 v10:0.6662 v11:0.6286 v12:0.6629 v13:1.6167] eff_attn_bias:[v0:0.5055 v1:0.9667 v2:0.8784 v3:0.8286 v4:0.9667 v5:0.8286 v6:0.9115 v7:0.7513 v8:0.8231 v9:0.6905 v10:0.6242 v11:0.4640 v12:0.5607 v13:0.8563] eff_mlp_bias:[v0:2.1213 v1:1.1656 v2:1.1159 v3:0.9336 v4:0.9502 v5:1.0386 v6:0.8452 v7:0.6878 v8:0.8231 v9:0.8065 v10:0.7237 v11:0.6519 v12:0.6298 v13:0.7016] depth_encoding:sinusoidal +step:7200/20000 val_loss:2.1875 val_bpb:1.2955 train_time:423951ms step_avg:58.88ms +step:7400/20000 train_loss:2.1726 train_time:435685ms step_avg:58.88ms +step:7400 shared0_alpha:mean=0.473,std=0.067 shared1_alpha:mean=0.618,std=0.067 shared2_alpha:mean=0.570,std=0.069 shared3_alpha:mean=0.590,std=0.071 eff_mlp_scale:[v0:203.0925 v1:123.9202 v2:129.5425 v3:125.1102 v4:107.3104 v5:133.6732 v6:125.1697 v7:117.5604 v8:105.7323 v9:115.8883 v10:100.5730 v11:100.8431 v12:102.0501 v13:263.7550] eff_attn_scale:[v0:0.2876 v1:0.8517 v2:0.9711 v3:0.8805 v4:0.9727 v5:0.7306 v6:0.8337 v7:0.8458 v8:0.8768 v9:0.6262 v10:0.6563 v11:0.6246 v12:0.6499 v13:1.6352] eff_attn_bias:[v0:0.5110 v1:0.9888 v2:0.8894 v3:0.8452 v4:0.9833 v5:0.8452 v6:0.9281 v7:0.7623 v8:0.8397 v9:0.7043 v10:0.6353 v11:0.4723 v12:0.5690 v13:0.8673] eff_mlp_bias:[v0:2.1434 v1:1.1877 v2:1.1270 v3:0.9502 v4:0.9667 v5:1.0551 v6:0.8618 v7:0.6988 v8:0.8397 v9:0.8231 v10:0.7403 v11:0.6629 v12:0.6381 v13:0.7126] depth_encoding:sinusoidal +step:7400/20000 val_loss:2.1845 val_bpb:1.2938 train_time:435709ms step_avg:58.88ms +step:7600/20000 train_loss:2.0546 train_time:447444ms step_avg:58.87ms +step:7600 shared0_alpha:mean=0.472,std=0.067 shared1_alpha:mean=0.621,std=0.068 shared2_alpha:mean=0.570,std=0.069 shared3_alpha:mean=0.589,std=0.071 eff_mlp_scale:[v0:207.0811 v1:126.0600 v2:131.8895 v3:127.7551 v4:109.1114 v5:135.8904 v6:126.9230 v7:119.0569 v8:107.5224 v9:117.3862 v10:102.0902 v11:102.2041 v12:103.8148 v13:267.7385] eff_attn_scale:[v0:0.2909 v1:0.8606 v2:0.9846 v3:0.8806 v4:0.9690 v5:0.7269 v6:0.8452 v7:0.8416 v8:0.8817 v9:0.6267 v10:0.6609 v11:0.6203 v12:0.6547 v13:1.6012] eff_attn_bias:[v0:0.5110 v1:1.0054 v2:0.9115 v3:0.8507 v4:0.9944 v5:0.8563 v6:0.9391 v7:0.7734 v8:0.8507 v9:0.7182 v10:0.6463 v11:0.4778 v12:0.5800 v13:0.8839] eff_mlp_bias:[v0:2.1766 v1:1.2043 v2:1.1490 v3:0.9667 v4:0.9778 v5:1.0717 v6:0.8784 v7:0.7071 v8:0.8507 v9:0.8342 v10:0.7513 v11:0.6767 v12:0.6491 v13:0.7182] depth_encoding:sinusoidal +step:7600/20000 val_loss:2.1839 val_bpb:1.2934 train_time:447469ms step_avg:58.88ms +step:7800/20000 train_loss:2.2072 train_time:459195ms step_avg:58.87ms +step:7800 shared0_alpha:mean=0.472,std=0.067 shared1_alpha:mean=0.623,std=0.068 shared2_alpha:mean=0.569,std=0.069 shared3_alpha:mean=0.588,std=0.071 eff_mlp_scale:[v0:210.8913 v1:129.2975 v2:134.7026 v3:130.3263 v4:110.8714 v5:138.0338 v6:129.1364 v7:120.4697 v8:109.2722 v9:119.3963 v10:104.0884 v11:104.0420 v12:105.0080 v13:271.7089] eff_attn_scale:[v0:0.2896 v1:0.8537 v2:0.9893 v3:0.8843 v4:0.9900 v5:0.7335 v6:0.8461 v7:0.8539 v8:0.8976 v9:0.6216 v10:0.6581 v11:0.6198 v12:0.6600 v13:1.6186] eff_attn_bias:[v0:0.5138 v1:1.0275 v2:0.9226 v3:0.8673 v4:1.0109 v5:0.8728 v6:0.9502 v7:0.7844 v8:0.8618 v9:0.7292 v10:0.6546 v11:0.4861 v12:0.5883 v13:0.8949] eff_mlp_bias:[v0:2.1987 v1:1.2264 v2:1.1656 v3:0.9778 v4:0.9944 v5:1.0883 v6:0.8894 v7:0.7182 v8:0.8673 v9:0.8452 v10:0.7679 v11:0.6822 v12:0.6574 v13:0.7292] depth_encoding:sinusoidal +step:7800/20000 val_loss:2.1800 val_bpb:1.2911 train_time:459220ms step_avg:58.87ms +step:8000/20000 train_loss:2.1668 train_time:470948ms step_avg:58.87ms +step:8000 shared0_alpha:mean=0.471,std=0.067 shared1_alpha:mean=0.625,std=0.068 shared2_alpha:mean=0.569,std=0.070 shared3_alpha:mean=0.587,std=0.072 eff_mlp_scale:[v0:216.5491 v1:131.1973 v2:137.8954 v3:132.7072 v4:113.2817 v5:139.9829 v6:131.7042 v7:123.3071 v8:110.5973 v9:120.6547 v10:105.8136 v11:105.6128 v12:107.3760 v13:275.4622] eff_attn_scale:[v0:0.2796 v1:0.8794 v2:0.9894 v3:0.8871 v4:0.9918 v5:0.7448 v6:0.8417 v7:0.8481 v8:0.8949 v9:0.6354 v10:0.6581 v11:0.6231 v12:0.6656 v13:1.6173] eff_attn_bias:[v0:0.5110 v1:1.0441 v2:0.9447 v3:0.8784 v4:1.0275 v5:0.8839 v6:0.9612 v7:0.7955 v8:0.8784 v9:0.7403 v10:0.6657 v11:0.4944 v12:0.5994 v13:0.9005] eff_mlp_bias:[v0:2.2208 v1:1.2540 v2:1.1877 v3:0.9944 v4:1.0109 v5:1.0993 v6:0.9005 v7:0.7292 v8:0.8784 v9:0.8563 v10:0.7789 v11:0.6961 v12:0.6740 v13:0.7403] depth_encoding:sinusoidal +step:8000/20000 val_loss:2.1791 val_bpb:1.2906 train_time:470973ms step_avg:58.87ms +step:8200/20000 train_loss:2.2366 train_time:482692ms step_avg:58.86ms +step:8200 shared0_alpha:mean=0.470,std=0.068 shared1_alpha:mean=0.627,std=0.069 shared2_alpha:mean=0.569,std=0.069 shared3_alpha:mean=0.586,std=0.072 eff_mlp_scale:[v0:220.5306 v1:133.9416 v2:140.5703 v3:135.3293 v4:115.5924 v5:141.6122 v6:133.7686 v7:125.3049 v8:112.3515 v9:122.1406 v10:107.1282 v11:107.4837 v12:108.5705 v13:279.6792] eff_attn_scale:[v0:0.2763 v1:0.8807 v2:0.9846 v3:0.8918 v4:0.9995 v5:0.7495 v6:0.8503 v7:0.8570 v8:0.9070 v9:0.6309 v10:0.6668 v11:0.6308 v12:0.6693 v13:1.6274] eff_attn_bias:[v0:0.5110 v1:1.0662 v2:0.9612 v3:0.8894 v4:1.0386 v5:0.9005 v6:0.9778 v7:0.8065 v8:0.8949 v9:0.7513 v10:0.6767 v11:0.4999 v12:0.6077 v13:0.9170] eff_mlp_bias:[v0:2.2429 v1:1.2761 v2:1.2043 v3:1.0109 v4:1.0275 v5:1.1159 v6:0.9170 v7:0.7403 v8:0.8949 v9:0.8728 v10:0.7900 v11:0.7071 v12:0.6822 v13:0.7458] depth_encoding:sinusoidal +step:8200/20000 val_loss:2.1763 val_bpb:1.2889 train_time:482722ms step_avg:58.87ms +step:8400/20000 train_loss:2.1837 train_time:494511ms step_avg:58.87ms +step:8400 shared0_alpha:mean=0.470,std=0.068 shared1_alpha:mean=0.630,std=0.069 shared2_alpha:mean=0.569,std=0.070 shared3_alpha:mean=0.585,std=0.072 eff_mlp_scale:[v0:226.2229 v1:136.6276 v2:144.1857 v3:138.1543 v4:116.8718 v5:144.3500 v6:136.1753 v7:126.9222 v8:113.6103 v9:124.1529 v10:108.7114 v11:108.9510 v12:110.3487 v13:281.5219] eff_attn_scale:[v0:0.2767 v1:0.8880 v2:1.0091 v3:0.9013 v4:1.0253 v5:0.7442 v6:0.8560 v7:0.8532 v8:0.9144 v9:0.6300 v10:0.6667 v11:0.6344 v12:0.6747 v13:1.6118] eff_attn_bias:[v0:0.5138 v1:1.0883 v2:0.9778 v3:0.9005 v4:1.0607 v5:0.9115 v6:0.9944 v7:0.8231 v8:0.9060 v9:0.7623 v10:0.6878 v11:0.5082 v12:0.6160 v13:0.9226] eff_mlp_bias:[v0:2.2650 v1:1.2982 v2:1.2209 v3:1.0275 v4:1.0441 v5:1.1270 v6:0.9336 v7:0.7513 v8:0.9060 v9:0.8839 v10:0.8065 v11:0.7182 v12:0.6905 v13:0.7568] depth_encoding:sinusoidal +step:8400/20000 val_loss:2.1760 val_bpb:1.2887 train_time:494542ms step_avg:58.87ms +step:8600/20000 train_loss:2.1928 train_time:506266ms step_avg:58.87ms +step:8600 shared0_alpha:mean=0.469,std=0.068 shared1_alpha:mean=0.633,std=0.069 shared2_alpha:mean=0.568,std=0.070 shared3_alpha:mean=0.584,std=0.072 eff_mlp_scale:[v0:230.2102 v1:138.7728 v2:147.1916 v3:140.4101 v4:119.1224 v5:146.5489 v6:138.5333 v7:129.6528 v8:115.8287 v9:125.6133 v10:110.8266 v11:110.4031 v12:112.5350 v13:285.2443] eff_attn_scale:[v0:0.2729 v1:0.8861 v2:1.0086 v3:0.9193 v4:1.0285 v5:0.7559 v6:0.8555 v7:0.8580 v8:0.9172 v9:0.6257 v10:0.6664 v11:0.6304 v12:0.6723 v13:1.5990] eff_attn_bias:[v0:0.5110 v1:1.1049 v2:0.9944 v3:0.9170 v4:1.0717 v5:0.9226 v6:1.0109 v7:0.8286 v8:0.9170 v9:0.7679 v10:0.6988 v11:0.5138 v12:0.6242 v13:0.9391] eff_mlp_bias:[v0:2.2870 v1:1.3148 v2:1.2374 v3:1.0441 v4:1.0551 v5:1.1435 v6:0.9502 v7:0.7623 v8:0.9170 v9:0.8949 v10:0.8176 v11:0.7292 v12:0.7016 v13:0.7623] depth_encoding:sinusoidal +step:8600/20000 val_loss:2.1735 val_bpb:1.2873 train_time:506297ms step_avg:58.87ms +step:8800/20000 train_loss:2.1619 train_time:518030ms step_avg:58.87ms +step:8800 shared0_alpha:mean=0.468,std=0.069 shared1_alpha:mean=0.635,std=0.070 shared2_alpha:mean=0.568,std=0.070 shared3_alpha:mean=0.582,std=0.072 eff_mlp_scale:[v0:235.7040 v1:141.3207 v2:150.1160 v3:143.1708 v4:121.6550 v5:148.5371 v6:140.8064 v7:131.1924 v8:117.2312 v9:126.8880 v10:112.2961 v11:112.3691 v12:113.9134 v13:287.0195] eff_attn_scale:[v0:0.2688 v1:0.8927 v2:1.0101 v3:0.9047 v4:1.0514 v5:0.7645 v6:0.8530 v7:0.8525 v8:0.9127 v9:0.6364 v10:0.6555 v11:0.6263 v12:0.6756 v13:1.5983] eff_attn_bias:[v0:0.5138 v1:1.1270 v2:1.0165 v3:0.9281 v4:1.0883 v5:0.9391 v6:1.0165 v7:0.8397 v8:0.9336 v9:0.7844 v10:0.7071 v11:0.5193 v12:0.6325 v13:0.9447] eff_mlp_bias:[v0:2.3091 v1:1.3369 v2:1.2540 v3:1.0551 v4:1.0717 v5:1.1601 v6:0.9667 v7:0.7734 v8:0.9336 v9:0.9060 v10:0.8286 v11:0.7403 v12:0.7126 v13:0.7679] depth_encoding:sinusoidal +step:8800/20000 val_loss:2.1720 val_bpb:1.2864 train_time:518054ms step_avg:58.87ms +step:9000/20000 train_loss:2.0805 train_time:529792ms step_avg:58.87ms +step:9000 shared0_alpha:mean=0.467,std=0.069 shared1_alpha:mean=0.638,std=0.070 shared2_alpha:mean=0.568,std=0.070 shared3_alpha:mean=0.581,std=0.072 eff_mlp_scale:[v0:240.5350 v1:144.2761 v2:152.5839 v3:146.1572 v4:122.9263 v5:150.9443 v6:143.1941 v7:133.4979 v8:118.4765 v9:128.5148 v10:113.8511 v11:113.9336 v12:115.6953 v13:291.2618] eff_attn_scale:[v0:0.2667 v1:0.9120 v2:1.0282 v3:0.9273 v4:1.0704 v5:0.7737 v6:0.8606 v7:0.8831 v8:0.9310 v9:0.6440 v10:0.6704 v11:0.6403 v12:0.6881 v13:1.6274] eff_attn_bias:[v0:0.5138 v1:1.1435 v2:1.0330 v3:0.9502 v4:1.1104 v5:0.9502 v6:1.0330 v7:0.8507 v8:0.9502 v9:0.7955 v10:0.7182 v11:0.5276 v12:0.6408 v13:0.9612] eff_mlp_bias:[v0:2.3312 v1:1.3534 v2:1.2706 v3:1.0717 v4:1.0883 v5:1.1711 v6:0.9778 v7:0.7844 v8:0.9447 v9:0.9170 v10:0.8397 v11:0.7513 v12:0.7182 v13:0.7734] depth_encoding:sinusoidal +step:9000/20000 val_loss:2.1731 val_bpb:1.2870 train_time:529816ms step_avg:58.87ms +step:9200/20000 train_loss:2.1320 train_time:541541ms step_avg:58.86ms +step:9200 shared0_alpha:mean=0.466,std=0.069 shared1_alpha:mean=0.640,std=0.071 shared2_alpha:mean=0.567,std=0.070 shared3_alpha:mean=0.580,std=0.072 eff_mlp_scale:[v0:243.9950 v1:146.3367 v2:155.0233 v3:148.2915 v4:124.9611 v5:152.4341 v6:145.5562 v7:134.9684 v8:120.4782 v9:129.8738 v10:115.3799 v11:115.8527 v12:117.1160 v13:295.3407] eff_attn_scale:[v0:0.2652 v1:0.9201 v2:1.0197 v3:0.9299 v4:1.0786 v5:0.7703 v6:0.8663 v7:0.8810 v8:0.9472 v9:0.6376 v10:0.6587 v11:0.6407 v12:0.6934 v13:1.6424] eff_attn_bias:[v0:0.5193 v1:1.1601 v2:1.0496 v3:0.9612 v4:1.1214 v5:0.9612 v6:1.0496 v7:0.8618 v8:0.9612 v9:0.8121 v10:0.7347 v11:0.5331 v12:0.6463 v13:0.9667] eff_mlp_bias:[v0:2.3533 v1:1.3700 v2:1.2872 v3:1.0828 v4:1.0993 v5:1.1822 v6:0.9944 v7:0.7955 v8:0.9612 v9:0.9281 v10:0.8507 v11:0.7623 v12:0.7292 v13:0.7844] depth_encoding:sinusoidal +step:9200/20000 val_loss:2.1636 val_bpb:1.2814 train_time:541571ms step_avg:58.87ms +step:9400/20000 train_loss:2.1763 train_time:553298ms step_avg:58.86ms +step:9400 shared0_alpha:mean=0.466,std=0.069 shared1_alpha:mean=0.641,std=0.071 shared2_alpha:mean=0.566,std=0.070 shared3_alpha:mean=0.579,std=0.072 eff_mlp_scale:[v0:247.1050 v1:147.7496 v2:157.3117 v3:149.2161 v4:126.6447 v5:153.8803 v6:147.1818 v7:136.3928 v8:122.1217 v9:131.1967 v10:116.7920 v11:116.5751 v12:118.7294 v13:297.4357] eff_attn_scale:[v0:0.2637 v1:0.9329 v2:1.0400 v3:0.9449 v4:1.0872 v5:0.7781 v6:0.8842 v7:0.8784 v8:0.9679 v9:0.6449 v10:0.6780 v11:0.6433 v12:0.7019 v13:1.6673] eff_attn_bias:[v0:0.5110 v1:1.1711 v2:1.0607 v3:0.9667 v4:1.1325 v5:0.9667 v6:1.0551 v7:0.8673 v8:0.9667 v9:0.8176 v10:0.7403 v11:0.5359 v12:0.6519 v13:0.9778] eff_mlp_bias:[v0:2.3533 v1:1.3811 v2:1.2927 v3:1.0883 v4:1.1104 v5:1.1932 v6:0.9999 v7:0.8010 v8:0.9667 v9:0.9391 v10:0.8618 v11:0.7679 v12:0.7347 v13:0.7955] depth_encoding:sinusoidal +step:9400/20000 val_loss:2.1536 val_bpb:1.2755 train_time:553329ms step_avg:58.86ms +step:9600/20000 train_loss:2.1820 train_time:565046ms step_avg:58.86ms +step:9600 shared0_alpha:mean=0.465,std=0.069 shared1_alpha:mean=0.642,std=0.071 shared2_alpha:mean=0.565,std=0.070 shared3_alpha:mean=0.578,std=0.072 eff_mlp_scale:[v0:247.7165 v1:148.3111 v2:158.0710 v3:151.3469 v4:128.3050 v5:155.0805 v6:147.8922 v7:137.8547 v8:123.1728 v9:131.6953 v10:117.3558 v11:117.9098 v12:119.7514 v13:301.5698] eff_attn_scale:[v0:0.2630 v1:0.9361 v2:1.0344 v3:0.9398 v4:1.0961 v5:0.7845 v6:0.8704 v7:0.8819 v8:0.9677 v9:0.6501 v10:0.6699 v11:0.6503 v12:0.7063 v13:1.6560] eff_attn_bias:[v0:0.5110 v1:1.1767 v2:1.0662 v3:0.9723 v4:1.1380 v5:0.9778 v6:1.0607 v7:0.8673 v8:0.9723 v9:0.8176 v10:0.7458 v11:0.5414 v12:0.6574 v13:0.9778] eff_mlp_bias:[v0:2.3533 v1:1.3811 v2:1.2982 v3:1.0938 v4:1.1159 v5:1.1988 v6:1.0054 v7:0.8065 v8:0.9723 v9:0.9447 v10:0.8673 v11:0.7679 v12:0.7403 v13:0.8010] depth_encoding:sinusoidal +step:9600/20000 val_loss:2.1442 val_bpb:1.2699 train_time:565077ms step_avg:58.86ms +step:9800/20000 train_loss:2.0977 train_time:576802ms step_avg:58.86ms +step:9800 shared0_alpha:mean=0.464,std=0.069 shared1_alpha:mean=0.642,std=0.071 shared2_alpha:mean=0.565,std=0.070 shared3_alpha:mean=0.578,std=0.072 eff_mlp_scale:[v0:247.9621 v1:148.8215 v2:158.8952 v3:152.0676 v4:129.1139 v5:155.6142 v6:148.6633 v7:138.5112 v8:123.9493 v9:132.1486 v10:117.9676 v11:118.4713 v12:120.5063 v13:303.2700] eff_attn_scale:[v0:0.2721 v1:0.9413 v2:1.0305 v3:0.9389 v4:1.1016 v5:0.7801 v6:0.8794 v7:0.8853 v8:0.9639 v9:0.6450 v10:0.6687 v11:0.6438 v12:0.7022 v13:1.6871] eff_attn_bias:[v0:0.5138 v1:1.1767 v2:1.0662 v3:0.9778 v4:1.1380 v5:0.9778 v6:1.0607 v7:0.8728 v8:0.9778 v9:0.8231 v10:0.7513 v11:0.5441 v12:0.6602 v13:0.9778] eff_mlp_bias:[v0:2.3533 v1:1.3811 v2:1.3037 v3:1.0938 v4:1.1214 v5:1.2043 v6:1.0109 v7:0.8121 v8:0.9778 v9:0.9447 v10:0.8673 v11:0.7734 v12:0.7458 v13:0.8065] depth_encoding:sinusoidal +step:9800/20000 val_loss:2.1359 val_bpb:1.2650 train_time:576833ms step_avg:58.86ms +step:10000/20000 train_loss:2.1303 train_time:588562ms step_avg:58.86ms +step:10000 shared0_alpha:mean=0.464,std=0.069 shared1_alpha:mean=0.642,std=0.071 shared2_alpha:mean=0.564,std=0.070 shared3_alpha:mean=0.577,std=0.072 eff_mlp_scale:[v0:247.8048 v1:149.0514 v2:160.5002 v3:152.4073 v4:129.8024 v5:155.8546 v6:149.0359 v7:138.8206 v8:124.6103 v9:132.3527 v10:118.2633 v11:118.7360 v12:121.1489 v13:304.5426] eff_attn_scale:[v0:0.2631 v1:0.9419 v2:1.0281 v3:0.9489 v4:1.1151 v5:0.7886 v6:0.8681 v7:0.8947 v8:0.9850 v9:0.6484 v10:0.6625 v11:0.6507 v12:0.7109 v13:1.6958] eff_attn_bias:[v0:0.5110 v1:1.1822 v2:1.0717 v3:0.9778 v4:1.1435 v5:0.9778 v6:1.0607 v7:0.8728 v8:0.9723 v9:0.8231 v10:0.7513 v11:0.5441 v12:0.6602 v13:0.9833] eff_mlp_bias:[v0:2.3533 v1:1.3866 v2:1.3037 v3:1.0938 v4:1.1214 v5:1.2043 v6:1.0109 v7:0.8121 v8:0.9778 v9:0.9447 v10:0.8673 v11:0.7734 v12:0.7458 v13:0.8065] depth_encoding:sinusoidal +step:10000/20000 val_loss:2.1263 val_bpb:1.2593 train_time:588586ms step_avg:58.86ms +step:10195/20000 val_loss:2.1192 val_bpb:1.2551 train_time:600044ms step_avg:58.86ms +stopping_early: wallclock_cap train_time:600044ms step:10195/20000 +peak memory allocated: 13736 MiB reserved: 14196 MiB +Serialized model: 45208316 bytes +Code size: 64182 bytes +Total submission size: 45272498 bytes +Serialized model int8+zlib: 10728555 bytes (payload:11667648 raw_torch:11699443 payload_ratio:3.87x) +Total submission size int8+zlib: 10792737 bytes +final_int8_zlib_roundtrip val_loss:2.1316 val_bpb:1.2624 eval_time:1871ms +final_int8_zlib_roundtrip_exact val_loss:2.13156221 val_bpb:1.26243121