diff --git a/train_gpt_mlx_kl.py b/train_gpt_mlx_kl.py index 981c38c564..b362f7fa75 100644 --- a/train_gpt_mlx_kl.py +++ b/train_gpt_mlx_kl.py @@ -1,30 +1,6 @@ #!/usr/bin/env python3 -""" -KaiLean's Parameter Golf script — based on baseline train_gpt_mlx.py - -Innovations integrated: - 1. BigramHash logit bias (vocab lookup shortcut for common bigrams) - 2. 11 layers + 3× MLP (env-configurable) - 3. Int6 QAT with STE (late-start, activates at configurable fraction) - 4. EMA weight averaging (replaces SWA) - 5. OrthoInit for weight matrices - 6. Warmdown-aware LR schedule - 7. SmearGate learnable prev-token gate on attention (USE_SMEARGATE=1) - 8. Partial RoPE (ROPE_DIMS=16, applied to first N head dims only) - 9. Depth-aware LN scale 1/√(layer+1) (LN_SCALE_ENABLED=1) - 10. Cross-layer Shared Attention on last N layers (XSA_LAST_N=4) - 11. GPTQ-lite per-row scale quantization (USE_GPTQ_LITE=1) - 12. EngramLite: gated multi-head bigram+trigram hash (ENGRAM_LITE_ENABLED=1) - 13. SkipGramHash: non-adjacent token logit bias (SKIPGRAM_HASH_SIZE>0) - 14. Complementary Training: down-weights bigram-easy tokens (COMPLEMENT_ALPHA=0.5) - 15. BackoffNgramMixer: eval-time causal n-gram mixing, zero artifact cost (NGRAM_MIXER_ENABLED=1) - 16. Sliding-window eval with configurable stride (EVAL_MODE=sliding) - 17. LoRA Test-Time Training at eval (TTT_ENABLED=1) - -Moonshot invocation (8×H100): - ENGRAM_LITE_ENABLED=1 COMPLEMENT_ALPHA=0.5 NGRAM_MIXER_ENABLED=1 \\ - NGRAM_ALPHA=0.25 NGRAM_MAX_ORDER=4 python3 train_gpt_mlx_kl.py -""" +"""KaiLean's Parameter Golf script — GPT training with int6 QAT, EMA, BigramHash, +EngramLite, SmearGate, XSA, complementary training, BackoffNgramMixer, and LoRA TTT.""" from __future__ import annotations import glob, json, math, os, pickle, sys, time, uuid, copy import zstandard @@ -40,9 +16,6 @@ COMPUTE_DTYPE = mx.bfloat16 -# ============================================================================ -# HYPERPARAMETERS — defaults tuned for competition -# ============================================================================ class Hyperparameters: data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") @@ -62,7 +35,6 @@ class Hyperparameters: warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 3500)) max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) - # Model — CHANGED DEFAULTS: 11 layers, 3x MLP vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) num_layers: int = int(os.environ.get("NUM_LAYERS", 11)) model_dim: int = int(os.environ.get("MODEL_DIM", 512)) @@ -75,59 +47,35 @@ class Hyperparameters: logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) - - # Sliding-window eval (eval only — training stays at train_seq_len) eval_seq_len: int = int(os.environ.get("EVAL_SEQ_LEN", 2048)) eval_stride: int = int(os.environ.get("EVAL_STRIDE", 64)) eval_batch_seqs: int = int(os.environ.get("EVAL_BATCH_SEQS", 32)) - - # KL innovations - bigram_hash_size: int = int(os.environ.get("BIGRAM_HASH_SIZE", 16384)) # Task 2: fill budget (was 10240) - qat_start_frac: float = float(os.environ.get("QAT_START_FRAC", 0.15)) # UNUSED — late_qat_threshold controls QAT + bigram_hash_size: int = int(os.environ.get("BIGRAM_HASH_SIZE", 16384)) + qat_start_frac: float = float(os.environ.get("QAT_START_FRAC", 0.15)) ema_decay: float = float(os.environ.get("EMA_DECAY", 0.995)) ema_start_frac: float = float(os.environ.get("EMA_START_FRAC", 0.5)) use_ortho_init: bool = bool(int(os.environ.get("USE_ORTHO_INIT", "1"))) - # SWA — optional complement/replacement for EMA; starts at 60% of iterations use_swa: bool = bool(int(os.environ.get("USE_SWA", "0"))) swa_decay: float = float(os.environ.get("SWA_DECAY", "0.4")) - - # SOTA techniques (Tasks 3-4) - smear_enabled: bool = bool(int(os.environ.get("USE_SMEARGATE", os.environ.get("SMEAR_ENABLED", "1")))) # 3b: learnable prev-token gate (USE_SMEARGATE=alias) - rope_dims: int = int(os.environ.get("ROPE_DIMS", 16)) # 3f: partial RoPE dims; 0 = no RoPE - ln_scale_enabled: bool = bool(int(os.environ.get( # 3g: 1/sqrt(layer+1) depth scale + smear_enabled: bool = bool(int(os.environ.get("USE_SMEARGATE", os.environ.get("SMEAR_ENABLED", "1")))) + rope_dims: int = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale_enabled: bool = bool(int(os.environ.get( "LN_SCALE_ENABLED", os.environ.get("USE_LN_SCALE", "1")))) - xsa_last_n: int = int(os.environ.get("XSA_LAST_N", 4)) # 4: XSA on last N decoder layers - - # EngramLite — gated multi-head bigram+trigram hash (replaces BigramHash when enabled) + xsa_last_n: int = int(os.environ.get("XSA_LAST_N", 4)) engram_lite_enabled: bool = bool(int(os.environ.get("ENGRAM_LITE_ENABLED", "0"))) engram_hash_size: int = int(os.environ.get("ENGRAM_HASH_SIZE", "2048")) engram_embed_dim: int = int(os.environ.get("ENGRAM_EMBED_DIM", "128")) engram_n_heads: int = int(os.environ.get("ENGRAM_N_HEADS", "2")) - - # SkipGram — non-adjacent token hash logit bias (disabled by default) skipgram_hash_size: int = int(os.environ.get("SKIPGRAM_HASH_SIZE", "0")) - - # Complementary Training — down-weights tokens easily predicted by bigrams complement_alpha: float = float(os.environ.get("COMPLEMENT_ALPHA", "0.0")) - - # BackoffNgramMixer — eval-time causal n-gram mixing (zero artifact cost) ngram_mixer_enabled: bool = bool(int(os.environ.get("NGRAM_MIXER_ENABLED", "0"))) ngram_alpha: float = float(os.environ.get("NGRAM_ALPHA", "0.25")) ngram_max_order: int = int(os.environ.get("NGRAM_MAX_ORDER", "4")) - - # Eval mode — controls final roundtrip evaluation strategy - # "standard" = chunked eval only (fast, ~36 min on M1) - # "sliding" = sliding-window eval (accurate, ~3× slower) [default] - # "both" = run standard first, then sliding, log both eval_mode: str = os.environ.get("EVAL_MODE", "sliding") - - # LoRA Test-Time Training at eval (Task 5); only applies when eval_mode includes sliding ttt_enabled: bool = bool(int(os.environ.get("TTT_ENABLED", "0"))) ttt_rank: int = int(os.environ.get("TTT_RANK", 4)) ttt_lr: float = float(os.environ.get("TTT_LR", 0.001)) ttt_steps: int = int(os.environ.get("TTT_STEPS", 2)) - - # Optimizer beta1: float = float(os.environ.get("BETA1", 0.9)) beta2: float = float(os.environ.get("BETA2", 0.95)) adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) @@ -169,16 +117,12 @@ def lr_mul(self, step: int, elapsed_ms: float) -> float: remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 - CONTROL_TENSOR_NAME_PATTERNS = ( "attn_scale", "attn_scales", "mlp_scale", "mlp_scales", "resid_mix", "resid_mixes", "q_gain", "skip_weight", "skip_weights", - "smear", # SmearGate gate parameter → scalar optimizer + "smear", ) -# ============================================================================ -# MATH HELPERS -# ============================================================================ def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) @@ -197,9 +141,6 @@ def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.ar x = x.T return x.astype(g.dtype) -# ============================================================================ -# DATA LOADING (unchanged from baseline) -# ============================================================================ def load_data_shard(path: Path) -> np.ndarray: header_bytes = 256 * np.dtype(" tuple[mx.array, mx.arra y = chunk[1:].reshape(-1, seq_len) return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) -# ============================================================================ -# INNOVATION 1: BigramHash — hash-based bigram logit features -# Adds bigram context to logits at near-zero parameter cost. -# ============================================================================ class BigramHashEmbedding(nn.Module): def __init__(self, hash_size: int, dim: int): super().__init__() self.hash_size = hash_size self.table = nn.Embedding(hash_size, dim) - # Small init — these are additive corrections, not primary features self.table.weight = self.table.weight * 0.02 def __call__(self, tokens: mx.array) -> mx.array: """tokens: (B, T) int32 → bigram embeddings: (B, T, dim)""" - # For position t, hash(token[t-1], token[t]) to get a bigram feature - t_prev = tokens[:, :-1] # (B, T-1) - t_curr = tokens[:, 1:] # (B, T-1) + t_prev = tokens[:, :-1] + t_curr = tokens[:, 1:] idx = mx.remainder(t_prev * 31337 + t_curr, self.hash_size) - bigram_emb = self.table(idx) # (B, T-1, dim) - # Pad position 0 with zeros (no previous token) + bigram_emb = self.table(idx) pad = mx.zeros((tokens.shape[0], 1, bigram_emb.shape[-1]), dtype=bigram_emb.dtype) - return mx.concatenate([pad, bigram_emb], axis=1) # (B, T, dim) + return mx.concatenate([pad, bigram_emb], axis=1) -# ============================================================================ -# EngramLiteEmbedding — gated multi-head bigram+trigram hash logit features. -# Replaces BigramHash when ENGRAM_LITE_ENABLED=1. -# Key improvement: gating suppresses noisy hash collisions (raw trigrams hurt). -# ============================================================================ class EngramLiteEmbedding(nn.Module): def __init__(self, hash_size: int = 2048, embed_dim: int = 128, output_dim: int = 1024, n_heads: int = 2, @@ -296,25 +225,18 @@ def __init__(self, hash_size: int = 2048, embed_dim: int = 128, self.output_dim = output_dim self.n_heads = n_heads self.orders = list(orders) - # Different prime per hash head to reduce collision rate (at most 4 heads supported) _all_primes = [31337, 59999, 73721, 97531] if n_heads > len(_all_primes): raise ValueError(f"EngramLiteEmbedding: n_heads={n_heads} exceeds max supported ({len(_all_primes)})") self._primes = _all_primes[:n_heads] - - # Separate embedding table per n-gram order (small dim, projected later) self.tables = { f"order_{o}": nn.Embedding(hash_size, embed_dim) for o in orders } for tbl in self.tables.values(): tbl.weight = tbl.weight * 0.01 - - # Project from embed_dim → output_dim (vocab_size) self.proj = nn.Linear(embed_dim, output_dim, bias=False) self.proj.weight = self.proj.weight * 0.01 - - # Learned gate per n-gram order — starts mostly suppressed (sigmoid(-2) ≈ 0.12) self.gate_proj = nn.Linear(embed_dim, len(orders), bias=True) self.gate_proj.bias = mx.full((len(orders),), -2.0) self.gate_proj.weight = self.gate_proj.weight * 0.01 @@ -344,39 +266,26 @@ def __call__(self, tokens: mx.array) -> mx.array: """tokens: (B, T) → (B, T, output_dim) additive logit bias""" B, T = tokens.shape combined = mx.zeros((B, T, self.embed_dim), dtype=mx.float32) - for order in self.orders: tbl = self.tables[f"order_{order}"] - # Multi-head average to reduce collision noise head_sum = None for hi in range(self.n_heads): idx, valid_start = self._hash_ngram(tokens, order, hi) - emb = tbl(idx).astype(mx.float32) # (B, T-valid_start, embed_dim) + emb = tbl(idx).astype(mx.float32) pad = mx.zeros((B, valid_start, self.embed_dim), dtype=mx.float32) - emb = mx.concatenate([pad, emb], axis=1) # (B, T, embed_dim) + emb = mx.concatenate([pad, emb], axis=1) head_sum = emb if head_sum is None else head_sum + emb combined = combined + head_sum / self.n_heads + gate = mx.sigmoid(self.gate_proj(combined)) + gate_scalar = gate.mean(axis=-1, keepdims=True) + return self.proj(combined) * gate_scalar - # Sigmoid gate: suppress noisy lookups, let model learn when to trust them. - # Gate is averaged across all n-gram orders into a single scalar per position — - # empirically simpler and stable; the combined embedding already encodes order info. - gate = mx.sigmoid(self.gate_proj(combined)) # (B, T, n_orders) - gate_scalar = gate.mean(axis=-1, keepdims=True) # (B, T, 1) - return self.proj(combined) * gate_scalar # (B, T, output_dim) - - -# ============================================================================ -# SkipGramHashEmbedding — non-adjacent token hash logit bias. -# Captures structured patterns (e.g. token[-1] × token[-3]). -# Enabled when SKIPGRAM_HASH_SIZE > 0. -# ============================================================================ class SkipGramHashEmbedding(nn.Module): def __init__(self, hash_size: int = 4096, dim: int = 1024, skip_patterns: list = None): super().__init__() self.hash_size = hash_size self.dim = dim - # Each pattern is a list of negative offsets, e.g. [-1, -3] self.skip_patterns = skip_patterns if skip_patterns is not None else [[-1, -3], [-1, -5], [-2, -4]] self.tables = { f"skip_{i}": nn.Embedding(hash_size, dim) @@ -389,14 +298,12 @@ def __call__(self, tokens: mx.array) -> mx.array: """tokens: (B, T) → (B, T, dim) additive logit bias""" B, T = tokens.shape output = mx.zeros((B, T, self.dim), dtype=mx.float32) - for i, pattern in enumerate(self.skip_patterns): tbl = self.tables[f"skip_{i}"] - min_offset = min(pattern) # most negative offset + min_offset = min(pattern) valid_start = abs(min_offset) if valid_start >= T: continue - # Accumulate hash over all offsets hash_val = mx.zeros((B, T - valid_start), dtype=mx.int32) prime = 31337 for offset in pattern: @@ -408,14 +315,8 @@ def __call__(self, tokens: mx.array) -> mx.array: pad = mx.zeros((B, valid_start, self.dim), dtype=mx.float32) emb = mx.concatenate([pad, emb], axis=1) output = output + emb - return output - -# ============================================================================ -# INNOVATION: SmearGate — blend each token embedding with previous token's -# Technique: @unnir (PR #102/#135). Gate initialized to 3.0 → sigmoid≈0.95 pass-through. -# ============================================================================ class SmearGate(nn.Module): def __init__(self, dim: int): super().__init__() @@ -426,18 +327,12 @@ def __call__(self, x: mx.array) -> mx.array: x_prev = mx.concatenate([mx.zeros_like(x[:, :1]), x[:, :-1]], axis=1) return g * x + (1.0 - g) * x_prev -# ============================================================================ -# INNOVATION 2: Int6 QAT — Fake quantization with STE -# ============================================================================ def fake_quant_int6(w: mx.array) -> mx.array: - """Simulate int6 quantization during training. Gradients pass through via STE.""" + """Simulate int6 quantization during training (STE for gradients).""" scale = mx.max(mx.abs(w), keepdims=True) / 31.0 + 1e-8 w_q = mx.clip(mx.round(w / scale), -32, 31) * scale return w + mx.stop_gradient(w_q - w) -# ============================================================================ -# INNOVATION 3: EMA Weight Averaging -# ============================================================================ class EMABuffer: def __init__(self, model, decay: float = 0.995): self.decay = decay @@ -452,8 +347,6 @@ def update(self, model): key = ".".join(str(p) for p in k) if isinstance(k, (list, tuple)) else k if key in self.shadow: self.shadow[key] = d * self.shadow[key] + (1.0 - d) * v - # Force-evaluate all shadow tensors so MLX doesn't accumulate an - # ever-deepening lazy computation graph (causes OOM after ~100 EMA steps). mx.eval(list(self.shadow.values())) def apply(self, model): @@ -463,9 +356,6 @@ def apply(self, model): def state_dict(self): return dict(self.shadow) -# ============================================================================ -# MODEL BLOCKS -# ============================================================================ class CastedLinear(nn.Module): def __init__(self, in_dim: int, out_dim: int): super().__init__() @@ -495,26 +385,21 @@ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, self.c_v = CastedLinear(dim, kv_dim) self.proj = CastedLinear(dim, dim) self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init - # Partial RoPE: rotate only first rope_dims of each head (Task 3f) - # rope_dims=0 → no positional encoding (self.rope = None, skip in forward) - # rope_dims=16 → partial RoPE on first 16 dims (current default) - # rope_dims=64 → full RoPE on all head dims if rope_dims > 0: self.rope = nn.RoPE(rope_dims, traditional=False, base=rope_base) else: self.rope = None self.scale = self.head_dim ** -0.5 - self.use_xsa = use_xsa # Task 4: XSA enabled for last N layers + self.use_xsa = use_xsa def _xsa(self, y: mx.array, v: mx.array) -> mx.array: - """Subtract self-value component — forces attention to other tokens (XSA, PR #198). - y: (B, H, T, D), v: (B, Hkv, T, D) → GQA-aware subtraction.""" + """XSA: subtract self-value component (PR #198).""" B, H, T, D = y.shape Hkv = v.shape[1] group = H // Hkv y_g = y.reshape(B, Hkv, group, T, D) v_norm = v / (mx.sqrt((v * v).sum(-1, keepdims=True)) + 1e-6) - vn = v_norm[:, :, None, :, :] # (B, Hkv, 1, T, D) + vn = v_norm[:, :, None, :, :] proj = (y_g * vn).sum(-1, keepdims=True) * vn return (y_g - proj).reshape(B, H, T, D) @@ -543,7 +428,7 @@ def __init__(self, dim: int, mlp_mult: int): def __call__(self, x: mx.array, use_qat: bool = False) -> mx.array: x = nn.relu(self.fc(x, use_qat)) - return self.proj(x * x, use_qat) # relu² + return self.proj(x * x, use_qat) class Block(nn.Module): def __init__(self, dim: int, num_heads: int, num_kv_heads: int, @@ -562,7 +447,6 @@ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32), ))) - # Task 3g: layerwise LN scale — dampen deep layer activations self.ln_scale_factor = float(1.0 / math.sqrt(layer_idx + 1)) if use_ln_scale else 1.0 def __call__(self, x: mx.array, x0: mx.array, use_qat: bool = False) -> mx.array: @@ -586,18 +470,13 @@ def __init__(self, vocab_size, num_layers, dim, num_heads, num_kv_heads, super().__init__() self.logit_chunk_tokens = logit_chunk_tokens self.logit_softcap = logit_softcap - self.use_qat = False # Toggled on during training - + self.use_qat = False self.tok_emb = nn.Embedding(vocab_size, dim) self.num_encoder_layers = num_layers // 2 self.num_decoder_layers = num_layers - self.num_encoder_layers self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) - - # Task 3b: SmearGate on embedded tokens self.smear = SmearGate(dim) if smear_enabled else None - - # Task 4: XSA on last xsa_last_n decoder layers xsa_decoder_start = max(0, self.num_decoder_layers - xsa_last_n) if xsa_last_n > 0 else self.num_decoder_layers self.blocks = [ Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, @@ -607,8 +486,6 @@ def __init__(self, vocab_size, num_layers, dim, num_heads, num_kv_heads, for i in range(num_layers) ] self.final_norm = RMSNormNoWeight() - - # Logit bias modules: EngramLite replaces BigramHash when enabled if engram_lite_enabled: self.engram_lite = EngramLiteEmbedding( hash_size=engram_hash_size, embed_dim=engram_embed_dim, @@ -617,21 +494,14 @@ def __init__(self, vocab_size, num_layers, dim, num_heads, num_kv_heads, else: self.engram_lite = None self.bigram_hash = BigramHashEmbedding(bigram_hash_size, vocab_size) if bigram_hash_size > 0 else None - - # SkipGram logit bias (additive, independent of BigramHash/EngramLite) self.skipgram_hash = SkipGramHashEmbedding(hash_size=skipgram_hash_size, dim=vocab_size) if skipgram_hash_size > 0 else None - - # Zero-init output projections for b in self.blocks: b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) - - # OrthoInit for weight matrices if use_ortho_init: for b in self.blocks: for linear in [b.attn.c_q, b.attn.c_k, b.attn.c_v, b.mlp.fc]: w = linear.weight - # SVD-based orthogonal init m, n = w.shape flat = mx.random.normal((m, n)).astype(mx.float32) u, s, vt = mx.linalg.svd(flat, stream=mx.cpu) @@ -639,8 +509,6 @@ def __init__(self, vocab_size, num_layers, dim, num_heads, num_kv_heads, linear.weight = (u[:, :n] * 0.5).astype(w.dtype) else: linear.weight = (vt[:m, :] * 0.5).astype(w.dtype) - - # Tied embedding init self.tok_emb.weight = ( mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std @@ -657,7 +525,6 @@ def __call__(self, input_ids: mx.array) -> mx.array: x0 = x skips = [] qat = self.use_qat - for i in range(self.num_encoder_layers): x = self.blocks[i](x, x0, qat) skips.append(x) @@ -691,32 +558,18 @@ def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: def complementary_loss(self, input_ids: mx.array, target_ids: mx.array, bigram_probs: mx.array, alpha: float) -> mx.array: - """Cross-entropy loss that down-weights tokens easily predicted by bigrams. - - For token at position t with predecessor prev: - weight[t] = 1 - alpha * P_bigram(target[t] | prev[t]) - Weights are clipped to [0.1, 1.0] and normalized so the effective - learning rate is preserved. - - Args: - bigram_probs: (V, V) float32 pre-computed P(next|prev) matrix. - alpha: strength of complementary weighting (0 = standard CE). - """ + """CE loss that down-weights tokens easily predicted by bigrams.""" x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) y = target_ids.reshape(-1) logits = x @ self.tok_emb.weight.astype(x.dtype).T logits = self.softcap(logits) logits = self._add_logit_biases(logits, input_ids) - ce_per_token = nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="none") - - # prev_tokens: the input token at each position (= predecessor of target) prev_tokens = input_ids.reshape(-1) - p_bigram = bigram_probs[prev_tokens, y] # (B*T,) + p_bigram = bigram_probs[prev_tokens, y] weights = 1.0 - alpha * p_bigram.astype(mx.float32) weights = mx.clip(weights, 0.1, 1.0) - weights = weights / (weights.mean() + 1e-8) # normalize: E[weight]=1 preserves effective LR - + weights = weights / (weights.mean() + 1e-8) return (ce_per_token * weights).mean() def token_losses(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: @@ -739,9 +592,6 @@ def token_logits(self, input_ids: mx.array) -> mx.array: logits = self._add_logit_biases(logits, input_ids) return logits.reshape(B, T, -1) -# ============================================================================ -# OPTIMIZERS (same structure as baseline) -# ============================================================================ class Muon: def __init__(self, keys, params, args): self.keys = keys @@ -805,7 +655,6 @@ def step(self, model, grads_tree, step, lr_mul): {self.embed_key: grads[self.embed_key]}, {self.embed_key: params[self.embed_key]}, ) - # Apply weight decay to embeddings if self.embed_key in embed_updated: embed_updated[self.embed_key] = embed_updated[self.embed_key] * (1.0 - self.args.adam_weight_decay * lr_mul) updated.update(embed_updated) @@ -814,15 +663,11 @@ def step(self, model, grads_tree, step, lr_mul): scalar_p = {k: params[k] for k in self.scalar_keys if k in params} if scalar_g: scalar_updated = self.adam_scalar.apply_gradients(scalar_g, scalar_p) - # Apply weight decay to scalars for k in scalar_updated: scalar_updated[k] = scalar_updated[k] * (1.0 - self.args.adam_weight_decay * lr_mul) updated.update(scalar_updated) model.update(tree_unflatten(list(updated.items()))) -# ============================================================================ -# QUANTIZATION (int6 packed + zstandard — smaller artifact than int8+zlib) -# ============================================================================ INT6_KEEP_FLOAT_MAX_NUMEL = 65_536 INT6_KEEP_FLOAT_STORE_DTYPE = np.float16 INT6_PER_ROW_SCALE_DTYPE = np.float16 @@ -870,9 +715,7 @@ def unpack_int6(packed: np.ndarray, orig_len: int) -> np.ndarray: _GPTQ_PERCENTILES = np.array([99.0, 99.5, 99.9, 99.99, 99.999]) def quantize_float_array_gptq_lite(arr): - """GPTQ-lite: search 5 percentiles per row to minimize MSE in int6 quantization. - Returns (packed, scale, shape, orig_len, gptq_stats) where gptq_stats is a dict - with per-percentile row counts and the chosen percentile indices.""" + """GPTQ-lite: search 5 percentiles per row to minimize MSE in int6 quantization.""" f32 = _np_float32(arr) if f32.ndim == 2: n_rows = f32.shape[0] @@ -893,7 +736,6 @@ def quantize_float_array_gptq_lite(arr): best_mse, best_clip, best_idx = mse, c, j clip_abs[i] = best_clip chosen_pct_idx[i] = best_idx - # Tally how many rows landed at each percentile pct_counts = {float(_GPTQ_PERCENTILES[j]): int(np.sum(chosen_pct_idx == j)) for j in range(len(_GPTQ_PERCENTILES))} clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) @@ -902,7 +744,6 @@ def quantize_float_array_gptq_lite(arr): packed, orig_len = pack_int6(q) gptq_stats = {"n_rows": n_rows, "pct_counts": pct_counts} return packed, np.ascontiguousarray(scale.astype(INT6_PER_ROW_SCALE_DTYPE)), f32.shape, orig_len, gptq_stats - # Scalar case — no per-row sweep, fall back to fixed clip clip_abs_s = float(np.quantile(np.abs(f32).reshape(-1), INT6_CLIP_Q)) if f32.size else 0.0 scale = np.array(clip_abs_s / 31.0 if clip_abs_s > 0 else 1.0, dtype=np.float32) q = np.clip(np.round(np.clip(f32, -clip_abs_s, clip_abs_s) / scale), -32, 31).astype(np.int8) @@ -929,14 +770,11 @@ def quantize_state_dict_int6(flat_state, args=None): """Quantize state dict to int6 with optional GPTQ-lite clip search.""" use_gptq = args.use_gptq_lite if args else False quant_fn = quantize_float_array_gptq_lite if use_gptq else quantize_float_array - quantized, scales, shapes, dtypes, passthrough = {}, {}, {}, {}, {} passthrough_orig_dtypes, qmeta = {}, {} stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors","num_nonfloat_tensors","baseline_tensor_bytes","int6_payload_bytes"), 0) - # Accumulated GPTQ-lite stats across all quantized tensors gptq_total_rows = 0 gptq_pct_counts: dict[float, int] = {} - for name, arr in flat_state.items(): stats["param_count"] += int(arr.size) stats["num_tensors"] += 1 @@ -968,11 +806,9 @@ def quantize_state_dict_int6(flat_state, args=None): shapes[name] = orig_shape dtypes[name] = str(arr.dtype).split(".")[-1] stats["int6_payload_bytes"] += int(packed.nbytes + s.nbytes) - if use_gptq and gptq_total_rows > 0: stats["gptq_total_rows"] = gptq_total_rows stats["gptq_pct_counts"] = gptq_pct_counts - obj = {"__quant_format__": "int6_packed_per_row_v1", "quantized": quantized, "scales": scales, "shapes": shapes, "dtypes": dtypes, "passthrough": passthrough} if qmeta: obj["qmeta"] = qmeta @@ -1001,22 +837,9 @@ def dequantize_state_dict_int6(quant_obj): out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig]) if isinstance(orig, str) else mx.array(out_arr) return out -# ============================================================================ -# COMPLEMENTARY TRAINING HELPER -# ============================================================================ def build_bigram_stats(data_path: str, vocab_size: int = 1024) -> np.ndarray: - """Pre-compute P(next_token | prev_token) from all training shards. - - Returns a (vocab_size, vocab_size) float32 array where entry [i, j] is - the smoothed probability of token j following token i. Uses Laplace - smoothing so every entry is > 0. - - The result is used for complementary training (down-weighting tokens that - bigram statistics can already predict well) and is NOT stored in the - artifact — it is recomputed from training data at the start of each run. - """ + """Pre-compute P(next_token | prev_token) from all training shards with Laplace smoothing.""" counts = np.zeros((vocab_size, vocab_size), dtype=np.float64) - # Shard pattern is the same as args.train_files (fineweb_train_*.bin) shard_paths = sorted(glob.glob(f"{data_path}/fineweb_train_*.bin")) for shard_path in shard_paths: tokens = load_data_shard(Path(shard_path)) @@ -1024,37 +847,12 @@ def build_bigram_stats(data_path: str, vocab_size: int = 1024) -> np.ndarray: curr = tokens[1:].astype(np.int32) mask = (prev < vocab_size) & (curr < vocab_size) np.add.at(counts, (prev[mask], curr[mask]), 1.0) - # Laplace smoothing: add 1 to every cell, normalize per row counts += 1.0 row_sums = counts.sum(axis=1, keepdims=True) return (counts / row_sums).astype(np.float32) - -# ============================================================================ -# BACKOFF N-GRAM MIXER -# Causal, zero-artifact-cost eval-time n-gram language model. -# -# Competition compliance notes: -# - Produces a full normalized probability distribution over the vocabulary -# at each step (sums to 1 by construction). -# - Strictly causal: only tokens at positions < current position are used. -# - No artifact cost: the cache is built from scratch during each evaluation -# pass using the tokens already scored — it is never saved to disk. -# ============================================================================ class BackoffNgramMixer: - """Causal n-gram LM with linear-interpolation backoff. - - For each evaluation position t the mixer: - 1. Queries count tables built from tokens at positions 0 .. t-1. - 2. Produces P_ngram(· | context) via linear interpolation from order 1 - up to max_order — a valid probability distribution over all vocab_size - tokens. - 3. Mixes with the neural model's distribution: - P_mix = (1 - α) · P_neural + α · P_ngram - 4. Scores the true next token under P_mix. - 5. Updates the count tables with the token just scored (so it can be used - as context for future positions). - """ + """Causal n-gram LM with linear-interpolation backoff for eval-time mixing.""" def __init__(self, vocab_size: int = 1024, max_order: int = 4, hash_buckets: int = 2_000_000, # ~2M buckets ≈ 16 collisions at 32M tokens @@ -1069,7 +867,6 @@ def __init__(self, vocab_size: int = 1024, max_order: int = 4, def _reset(self): """Clear all count tables — call before each new eval pass.""" - # counts[order][ctx_hash] -> float32 array of shape (vocab_size,) self._counts = [ defaultdict(lambda: np.zeros(self.vocab_size, dtype=np.float32)) for _ in range(self.max_order + 1) @@ -1093,13 +890,11 @@ def _ngram_probs(self, context_tokens) -> np.ndarray: total = self._total[order][ctx_hash] if total <= 0.0: continue - # Confidence-weighted interpolation: λ→1 when total>>5 (counts well established) - lam = total / (total + 5.0) # 5.0: discount factor; reaches λ=0.5 at 5 observations + lam = total / (total + 5.0) c = self._counts[order][ctx_hash].astype(np.float64) order_probs = (c + 1e-10) / (total + 1e-10 * V) order_probs /= order_probs.sum() probs = (1.0 - lam) * probs + lam * order_probs - # Final normalization (guard against floating-point drift) s = probs.sum() if s > 0: probs /= s @@ -1109,7 +904,6 @@ def _mixing_alpha(self, neural_logits: np.ndarray) -> float: """Entropy-adaptive mixing weight α ∈ [0.15, 0.60].""" if self.alpha_mode == "fixed": return self.fixed_alpha - # High neural entropy → trust n-grams more logits = neural_logits.astype(np.float64) logits -= logits.max() probs = np.exp(logits) @@ -1121,47 +915,26 @@ def _mixing_alpha(self, neural_logits: np.ndarray) -> float: def score_and_update(self, context_tokens, target_token: int, neural_logits: np.ndarray) -> float: - """Score target_token and update the cache. Must be called in order. - - Args: - context_tokens: sequence of token IDs before the current position. - target_token: the true next token to score. - neural_logits: (V,) float32/64 raw logits from the neural model - at the current position. - - Returns: - log_prob: natural-log probability of target_token under P_mix. - """ + """Score target_token under mixed neural+ngram distribution and update cache.""" ngram_probs = self._ngram_probs(context_tokens) alpha = self._mixing_alpha(neural_logits) - - # Normalize neural distribution nl = neural_logits.astype(np.float64) nl -= nl.max() neural_probs = np.exp(nl) neural_probs /= neural_probs.sum() - mixed = (1.0 - alpha) * neural_probs + alpha * ngram_probs.astype(np.float64) s = mixed.sum() if s > 0: mixed /= s - log_prob = float(np.log(mixed[target_token] + 1e-40)) - - # Update cache AFTER scoring (causal: this token becomes future context) tok = int(target_token) for order in range(1, self.max_order + 1): if len(context_tokens) >= order: ctx_hash = self._hash_ctx(context_tokens[-order:]) self._counts[order][ctx_hash][tok] += 1.0 self._total[order][ctx_hash] += 1.0 - return log_prob - -# ============================================================================ -# VALIDATION HELPERS -# ============================================================================ def build_sentencepiece_luts(sp, vocab_size): sp_vocab_size = int(sp.vocab_size()) table_size = max(sp_vocab_size, vocab_size) @@ -1266,33 +1039,22 @@ def eval_val(args, compiled_loss, val_tokens, base_bytes_lut, has_leading_space_ val_bpb = (val_loss / math.log(2.0)) * (total_tokens_f / total_bytes) return val_loss, val_bpb - def eval_val_sliding(args, model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=None): - """Sliding-window eval: each token scored with up to eval_seq_len context. - - Windows of eval_seq_len advance by eval_stride. Only the last eval_stride - tokens per window contribute to the metric (first window scores all). - With eval_seq_len=2048 and eval_stride=64, every token sees up to 2048 - tokens of context — twice the training context length. - """ + """Sliding-window eval: each token scored with up to eval_seq_len context.""" seq_len = args.eval_seq_len stride = args.eval_stride batch_seqs = args.eval_batch_seqs total_tokens = val_tokens.size - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) if min(ws + seq_len, total_tokens) - ws >= 1] total_windows = len(window_starts) - loss_sum = 0.0 token_count = 0.0 byte_count = 0.0 - model.use_qat = False for bi in range(0, total_windows, batch_seqs): batch_ws = window_starts[bi:bi + batch_seqs] bsz = len(batch_ws) - x_np = np.zeros((bsz, seq_len), dtype=np.int32) y_np = np.zeros((bsz, seq_len), dtype=np.int32) wlens: list[int] = [] @@ -1302,13 +1064,11 @@ def eval_val_sliding(args, model, val_tokens, base_bytes_lut, has_leading_space_ wlens.append(wlen) x_np[i, :wlen] = val_tokens[ws:end] y_np[i, :wlen] = val_tokens[ws + 1:end + 1] - x = mx.array(x_np) y = mx.array(y_np) nll = model.token_losses(x, y) # (B, T) mx.eval(nll) nll_np = np.array(nll) - for i, ws in enumerate(batch_ws): wlen = wlens[i] s = 0 if ws == 0 else max(wlen - stride, 0) @@ -1319,7 +1079,6 @@ def eval_val_sliding(args, model, val_tokens, base_bytes_lut, has_leading_space_ tb = base_bytes_lut[tgt].astype(np.float64) tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).astype(np.float64) byte_count += float(tb.sum()) - if log_fn and (bi // batch_seqs) % 50 == 0: done = min(bi + batch_seqs, total_windows) pct = done / total_windows * 100 @@ -1327,43 +1086,25 @@ def eval_val_sliding(args, model, val_tokens, base_bytes_lut, has_leading_space_ if token_count > 0: rbpb = (loss_sum / token_count) / math.log(2.0) * (token_count / byte_count) log_fn(f"sliding_eval [{pct:5.1f}%] {done}/{total_windows} windows running_bpb={rbpb:.6f}") - val_loss = loss_sum / token_count val_bpb = (val_loss / math.log(2.0)) * (token_count / byte_count) return val_loss, val_bpb - def eval_val_sliding_ngram(args, model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=None): - """Sliding-window eval with BackoffNgramMixer post-processing. - - Identical windowing strategy to eval_val_sliding, but each "new" token - in each window is scored under a mixture of the neural distribution and - a causal n-gram distribution built incrementally from all previously - scored tokens. - - Competition compliance: - - The n-gram cache only sees tokens at positions strictly before the - current position (causal). - - The mixed distribution sums to 1 at every position. - - The n-gram cache is built from scratch during the eval pass and is - never saved to disk (zero artifact cost). - """ + """Sliding-window eval with BackoffNgramMixer post-processing.""" seq_len = args.eval_seq_len stride = args.eval_stride batch_seqs = args.eval_batch_seqs max_order = args.ngram_max_order total_tokens = val_tokens.size - 1 - window_starts = [ws for ws in range(0, total_tokens, stride) if min(ws + seq_len, total_tokens) - ws >= 1] total_windows = len(window_starts) - loss_sum = 0.0 token_count = 0.0 byte_count = 0.0 - model.use_qat = False mixer = BackoffNgramMixer( vocab_size=args.vocab_size, @@ -1371,11 +1112,9 @@ def eval_val_sliding_ngram(args, model, val_tokens, alpha_mode="entropy_adaptive", fixed_alpha=args.ngram_alpha, ) - for bi in range(0, total_windows, batch_seqs): batch_ws = window_starts[bi:bi + batch_seqs] bsz = len(batch_ws) - x_np = np.zeros((bsz, seq_len), dtype=np.int32) y_np = np.zeros((bsz, seq_len), dtype=np.int32) wlens: list[int] = [] @@ -1385,26 +1124,19 @@ def eval_val_sliding_ngram(args, model, val_tokens, wlens.append(wlen) x_np[i, :wlen] = val_tokens[ws:end] y_np[i, :wlen] = val_tokens[ws + 1:end + 1] - - # Get full (B, T, V) logits for the batch x = mx.array(x_np) - logits_all = model.token_logits(x) # (B, T, V) + logits_all = model.token_logits(x) mx.eval(logits_all) - logits_np = np.array(logits_all) # (B, T, V) float32 - + logits_np = np.array(logits_all) for i, ws in enumerate(batch_ws): wlen = wlens[i] s = 0 if ws == 0 else max(wlen - stride, 0) - for j in range(s, wlen): global_pos = ws + j # global index of the input token - target = int(y_np[i, j]) # = val_tokens[global_pos + 1] - neural_logits = logits_np[i, j] # (V,) - - # N-gram context: all tokens before global_pos + 1 + target = int(y_np[i, j]) + neural_logits = logits_np[i, j] ctx_start = max(0, global_pos + 1 - max_order) context = val_tokens[ctx_start:global_pos + 1].tolist() - log_prob = mixer.score_and_update(context, target, neural_logits) loss_sum += -log_prob token_count += 1.0 @@ -1413,7 +1145,6 @@ def eval_val_sliding_ngram(args, model, val_tokens, tb = float(base_bytes_lut[tgt_arr[0]]) tb += float(has_leading_space_lut[tgt_arr[0]] and not is_boundary_token_lut[prev_arr[0]]) byte_count += tb - if log_fn and (bi // batch_seqs) % 50 == 0: done = min(bi + batch_seqs, total_windows) pct = done / total_windows * 100 @@ -1421,13 +1152,12 @@ def eval_val_sliding_ngram(args, model, val_tokens, if token_count > 0 and byte_count > 0: rbpb = (loss_sum / token_count) / math.log(2.0) * (token_count / byte_count) log_fn(f"ngram_sliding_eval [{pct:5.1f}%] {done}/{total_windows} windows running_bpb={rbpb:.6f}") - val_loss = loss_sum / max(token_count, 1.0) val_bpb = (val_loss / math.log(2.0)) * (token_count / max(byte_count, 1.0)) return val_loss, val_bpb - - +def clip_grad_tree(grads_tree, max_norm): + """Clip gradient tree by global norm.""" if max_norm <= 0: return grads_tree flat = dict(tree_flatten(grads_tree)) @@ -1437,11 +1167,6 @@ def eval_val_sliding_ngram(args, model, val_tokens, scale = max_norm / (math.sqrt(total_sq) + 1e-12) return tree_unflatten([(k, g * scale) for k, g in flat.items()]) -# ============================================================================ -# TASK 5: LoRA Test-Time Training (TTT) -# At eval time, for each document: adapt rank-r LoRA on Q/V, then re-score. -# LoRA params not counted toward artifact — created and destroyed per-doc. -# ============================================================================ def eval_val_sliding_ttt(args, model, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=None): @@ -1454,11 +1179,8 @@ def eval_val_sliding_ttt(args, model, val_tokens, total_tokens = val_tokens.size - 1 window_starts = [ws for ws in range(0, total_tokens, stride) if min(ws + seq_len, total_tokens) - ws >= 1] - loss_sum, token_count, byte_count = 0.0, 0.0, 0.0 model.use_qat = False - - # Collect Q/V projection weights once qv_keys = [(li, proj) for li, blk in enumerate(model.blocks) for proj in ("attn.c_q", "attn.c_v")] @@ -1473,7 +1195,6 @@ def _set_w(li, proj_name, w): blk.attn.c_q.weight = w else: blk.attn.c_v.weight = w - for wi, ws in enumerate(window_starts): end = min(ws + seq_len, total_tokens) wlen = end - ws @@ -1483,8 +1204,6 @@ def _set_w(li, proj_name, w): y_np[0, :wlen] = val_tokens[ws + 1:end + 1] x = mx.array(x_np) y = mx.array(y_np) - - # Save original weights and init LoRA (A*B added to weight) saved, lora_A, lora_B = {}, {}, {} for li, proj in qv_keys: w = _get_w(li, proj) @@ -1492,14 +1211,11 @@ def _set_w(li, proj_name, w): out_d, in_d = w.shape lora_A[(li, proj)] = mx.random.normal((rank, in_d)).astype(mx.float32) * 0.01 lora_B[(li, proj)] = mx.zeros((out_d, rank), dtype=mx.float32) - - # TTT: gradient steps on LoRA params using context tokens (s=0..wlen-stride) s = 0 if ws == 0 else max(wlen - stride, 0) - if s > 0: # Only train if there are context tokens before the eval window + if s > 0: ctx_x = x_np[:, :s] ctx_y = y_np[:, :s] for _ in range(ttt_steps): - # Apply LoRA deltas to weights for li, proj in qv_keys: w_base = saved[(li, proj)] delta = lora_B[(li, proj)] @ lora_A[(li, proj)] @@ -1509,28 +1225,18 @@ def lora_loss(): cx = mx.array(ctx_x) cy = mx.array(ctx_y) return model.loss(cx, cy) - - # Compute grads wrt LoRA params via current model weights - # (simplified: treat weight delta as a single step) loss_val = lora_loss() mx.eval(loss_val) - # Finite-difference update on lora_B (simple gradient-free step for stability) - # Full autograd TTT would require threading lora params through model. - # This version just applies one SGD step using the base loss signal. for li, proj in qv_keys: w_base = saved[(li, proj)] _set_w(li, proj, w_base) # restore for clean grad - - # Apply final LoRA and score new tokens for li, proj in qv_keys: w_base = saved[(li, proj)] delta = lora_B[(li, proj)] @ lora_A[(li, proj)] _set_w(li, proj, w_base + delta.astype(w_base.dtype)) - nll = model.token_losses(x, y) mx.eval(nll) nll_np = np.array(nll) - loss_sum += float(nll_np[0, s:wlen].sum()) token_count += float(wlen - s) tgt = y_np[0, s:wlen] @@ -1538,24 +1244,16 @@ def lora_loss(): tb = base_bytes_lut[tgt].astype(np.float64) tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).astype(np.float64) byte_count += float(tb.sum()) - - # Restore original weights for li, proj in qv_keys: _set_w(li, proj, saved[(li, proj)]) - if log_fn and wi % 500 == 0: pct = wi / len(window_starts) * 100 rbpb = (loss_sum / max(token_count, 1)) / math.log(2.0) * (token_count / max(byte_count, 1)) log_fn(f"ttt_eval [{pct:.1f}%] {wi}/{len(window_starts)} bpb={rbpb:.4f}") - val_loss = loss_sum / token_count val_bpb = (val_loss / math.log(2.0)) * (token_count / byte_count) return val_loss, val_bpb - -# ============================================================================ -# MAIN -# ============================================================================ def main(): args = Hyperparameters() out_dir = Path(args.out_dir) @@ -1565,21 +1263,17 @@ def main(): def log(msg, console=True): if console: print(msg) with logfile.open("a") as f: print(msg, file=f) - code = Path(__file__).read_text() log(code, console=False) log(f"Running MLX {mx.__version__}", console=False) - sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) if int(sp.vocab_size()) != args.vocab_size: raise ValueError(f"VOCAB_SIZE mismatch: {args.vocab_size} vs {int(sp.vocab_size())}") dataset_name, actual_files, expected_files = validate_dataset_tokenizer_pair(args.data_path, args.tokenizer_path) val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size) - mx.random.seed(args.seed) train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, dim=args.model_dim, num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, @@ -1596,13 +1290,8 @@ def log(msg, console=True): skipgram_hash_size=args.skipgram_hash_size, ) opt = SplitOptimizers(model, args) - - # EMA buffer — starts collecting after ema_start_frac of training ema = None - # SWA buffer — starts at 60% of training when USE_SWA=1 swa = None - - # Complementary training: precompute bigram stats once if needed bigram_probs_mx = None if args.complement_alpha > 0.0: log("complement_training: building bigram stats from training shards...") @@ -1611,8 +1300,6 @@ def log(msg, console=True): mx.eval(bigram_probs_mx) log(f"complement_training: bigram stats ready (alpha={args.complement_alpha})") del _bp_np - - # Wire up loss functions — use complementary loss when alpha > 0 if bigram_probs_mx is not None: _alpha = args.complement_alpha _bp = bigram_probs_mx @@ -1621,13 +1308,11 @@ def _loss_fn(x, y): else: def _loss_fn(x, y): return model.loss(x, y) - compiled_loss = mx.compile(_loss_fn, inputs=model.state, outputs=model.state) compiled_loss_and_grad = mx.compile( nn.value_and_grad(model, _loss_fn), inputs=model.state, outputs=model.state, ) - n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) log(f"run_id:{args.run_id}") xsa_layers = [i for i, b in enumerate(model.blocks) if b.attn.use_xsa] @@ -1646,8 +1331,6 @@ def _loss_fn(x, y): f"grad_accum:{args.grad_accum_steps} seq_len:{args.train_seq_len}") log(f"optimizer: muon_keys:{len(opt.matrix_keys)} scalar_keys:{len(opt.scalar_keys)}") log(f"val_tokens:{val_tokens.size - 1} train_shards:{actual_files}") - - # Warmup if args.warmup_steps > 0: for ws in range(args.warmup_steps): wl, wg = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) @@ -1655,7 +1338,6 @@ def _loss_fn(x, y): mx.synchronize() if ws + 1 == args.warmup_steps: log(f"warmup_done:{args.warmup_steps} steps") - # Prime eval graph vbs = args.val_batch_size // args.grad_accum_steps vs = min(vbs // args.train_seq_len, (val_tokens.size - 1) // args.train_seq_len) wc = val_tokens[:vs * args.train_seq_len + 1] @@ -1664,51 +1346,38 @@ def _loss_fn(x, y): mx.eval(compiled_loss(xv, yv)) mx.synchronize() train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) - - # Training loop train_time_ms = 0.0 max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None stop_after = None t0 = time.perf_counter() step = 0 _prev_use_qat = False # track QAT state to detect transition and recompile - while True: last_step = step == args.iterations or (stop_after is not None and step >= stop_after) - if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): train_time_ms += 1000.0 * (time.perf_counter() - t0) - # Apply SWA (preferred) or EMA weights for eval if available _avg = swa if swa is not None else ema if _avg is not None: saved_state = {k: mx.array(v) for k, v in tree_flatten(model.parameters())} _avg.apply(model) compiled_loss = mx.compile(_loss_fn, inputs=model.state, outputs=model.state) - model.use_qat = False # No QAT during eval val_loss, val_bpb = eval_val(args, compiled_loss, val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, log_fn=log) log(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms/max(step,1):.2f}ms") - if _avg is not None: model.update(tree_unflatten(list(saved_state.items()))) compiled_loss = mx.compile(_loss_fn, inputs=model.state, outputs=model.state) compiled_loss_and_grad = mx.compile( nn.value_and_grad(model, _loss_fn), inputs=model.state, outputs=model.state) - t0 = time.perf_counter() - if last_step: if stop_after is not None and step < args.iterations: log(f"stopping_early: wallclock train_time:{train_time_ms:.0f}ms step:{step}") break - - # Compute lr_mul early for late_qat_threshold lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) - - # Toggle QAT when lr_mul drops below late_qat_threshold (warmdown-triggered) _new_use_qat = lr_mul < args.late_qat_threshold if _new_use_qat != _prev_use_qat: model.use_qat = _new_use_qat @@ -1720,8 +1389,6 @@ def _loss_fn(x, y): compiled_loss_and_grad = mx.compile( nn.value_and_grad(model, _loss_fn), inputs=model.state, outputs=model.state) - - # Initialize EMA after ema_start_frac; initialize SWA at 60% of iterations est_total = args.iterations if max_wc_ms and step > 0: est_total = min(args.iterations, int(max_wc_ms / (train_time_ms / step + 0.001))) @@ -1732,7 +1399,6 @@ def _loss_fn(x, y): swa = EMABuffer(model, decay=args.swa_decay) log(f"swa_started:step={step} decay={args.swa_decay}") step_t0 = time.perf_counter() - accum = None train_loss = mx.array(0.0, dtype=mx.float32) gs = 1.0 / args.grad_accum_steps @@ -1742,23 +1408,18 @@ def _loss_fn(x, y): train_loss = train_loss + loss.astype(mx.float32) * gs if args.mlx_eager_eval: mx.eval(train_loss, accum) - grads = tree_unflatten(list(accum.items())) grads = clip_grad_tree(grads, args.grad_clip_norm) tl = float(train_loss.item()) opt.step(model, grads, step=step, lr_mul=lr_mul) mx.synchronize() - - # Update EMA and SWA if ema is not None: ema.update(model) if swa is not None: swa.update(model) - step_ms = 1000.0 * (time.perf_counter() - step_t0) approx_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) step += 1 - if args.train_log_every > 0 and (step <= 5 or step % args.train_log_every == 0): tok_s = args.train_batch_tokens / (step_ms / 1000.0) qat_tag = " [QAT]" if model.use_qat else "" @@ -1766,27 +1427,19 @@ def _loss_fn(x, y): swa_tag = " [SWA]" if swa is not None else "" log(f"step:{step}/{args.iterations} train_loss:{tl:.4f} " f"step_ms:{step_ms:.0f} tok_s:{tok_s:.0f}{qat_tag}{ema_tag}{swa_tag}") - if max_wc_ms and stop_after is None and approx_ms >= max_wc_ms: stop_after = step - - # ======================================================================== - # SERIALIZE + ROUNDTRIP EVAL - # ======================================================================== - # Apply SWA (preferred) or EMA for final save if swa is not None: swa.apply(model) log("swa_applied_for_save") elif ema is not None: ema.apply(model) log("ema_applied_for_save") - model.use_qat = False flat_state = {k: v for k, v in tree_flatten(model.state)} out_path = out_dir / f"{args.run_id}_mlx_model.npz" mx.savez(str(out_path), **flat_state) log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") - quant_obj, quant_stats = quantize_state_dict_int6(flat_state, args) quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) @@ -1795,8 +1448,6 @@ def _loss_fn(x, y): f.write(quant_blob) log(f"serialized_int6_zstd:{quant_path.stat().st_size} bytes " f"(payload:{quant_stats['int6_payload_bytes']} ratio:{quant_stats['baseline_tensor_bytes']/max(quant_stats['int6_payload_bytes'],1):.2f}x)") - - # Roundtrip eval with quant_path.open("rb") as f: quant_blob_disk = f.read() quant_flat = dequantize_state_dict_int6(pickle.loads(zstandard.ZstdDecompressor().decompress(quant_blob_disk))) @@ -1805,8 +1456,6 @@ def _loss_fn(x, y): eval_mode = args.eval_mode.lower().strip() if eval_mode not in ("standard", "sliding", "both"): raise ValueError(f"EVAL_MODE must be standard/sliding/both, got: {eval_mode!r}") - - # standard path (always run for "standard" or "both") if eval_mode in ("standard", "both"): qt0 = time.perf_counter() log("final_eval_mode:standard") @@ -1816,8 +1465,6 @@ def _loss_fn(x, y): log(f"final_int6_zstd_roundtrip_standard val_loss:{s_val_loss:.4f} val_bpb:{s_val_bpb:.4f} eval_time:{sms:.0f}ms") log(f"final_int6_zstd_roundtrip_standard_exact val_loss:{s_val_loss:.8f} val_bpb:{s_val_bpb:.8f}") q_val_loss, q_val_bpb = s_val_loss, s_val_bpb # used as fallback for the summary lines below - - # sliding path (run for "sliding" or "both") if eval_mode in ("sliding", "both"): qt0 = time.perf_counter() if args.ngram_mixer_enabled: @@ -1838,7 +1485,6 @@ def _loss_fn(x, y): log(f"final_int6_zstd_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{qms:.0f}ms") log(f"final_int6_zstd_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") elif eval_mode == "standard": - # Re-emit the canonical lines using standard result so log parsers see them regardless of mode qms = sms log(f"final_int6_zstd_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{qms:.0f}ms") log(f"final_int6_zstd_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}")