From fe368b2b3581b3621b746ff75f76167c1cd78edb Mon Sep 17 00:00:00 2001 From: Noah Cylich Date: Sun, 8 Mar 2026 19:35:55 -0700 Subject: [PATCH 1/3] Learned matryoshka (topk) implementation Adds learned matryoshka FFN masking with saliency-initialized topk method,including warmup/learning/freeze phases, tau annealing, mask logits optimizer, and topk export support. Makes saliency topk the default matryoshka method. --- docs/learned_matryoshka_plan.md | 306 +++++++++++++++++ src/cli.py | 101 +++++- src/evaluate.py | 56 ++-- src/export.py | 59 +++- src/model.py | 95 +++--- src/run.py | 56 ++-- src/test.py | 105 ++---- src/train.py | 574 +++++++++++++++++++++++++++----- 8 files changed, 1102 insertions(+), 250 deletions(-) create mode 100644 docs/learned_matryoshka_plan.md diff --git a/docs/learned_matryoshka_plan.md b/docs/learned_matryoshka_plan.md new file mode 100644 index 0000000..4522be8 --- /dev/null +++ b/docs/learned_matryoshka_plan.md @@ -0,0 +1,306 @@ +# Learned Matryoshka (TopK FFN Masking) + +> **Note:** This document was ported from the `learned-mat-old` branch. The CLI args +> use the old `--mrl-*` naming — the current codebase uses `--mat-*` instead (e.g. +> `--mat-method topk`, `--mat-tau-start`, `--mat-mask-lr`). The architecture has also +> changed: masks are now **per-layer** (one mask per encoder/decoder block) and +> training uses the single-forward `forward_masked()` approach with heterogeneous +> `(n_blocks, batch, d_ff)` masks, rather than the N+1 `forward_with_aux()` approach +> described below. **The experimental results and lessons learned remain valid.** + +## Overview + +Extension of the FFN interior matryoshka (see `matryoshka_results.md`). Instead of fixed prefix masks (first k FFN neurons), we learn WHICH neurons to keep via differentiable top-k selection. The model discovers the optimal neuron selection rather than being constrained to a prefix. + +## How It Works + +The base FFN matryoshka uses `_make_ffn_mask(batch_size, d_ff, mrl_ff_slices)` to create a `(batch, d_ff)` mask where different batch items get different widths, each using a **prefix mask**: `arange < k`. + +Learned matryoshka replaces this with `_make_ffn_mask_topk(mask_logits, ...)` which uses **differentiable top-k** to select neurons: + +```python +def topk_mask(logits, k, tau, hard): + topk_vals = jax.lax.top_k(logits, k)[0] + threshold = topk_vals[-1] + y_soft = sigmoid((logits - threshold) / tau) + y_hard = (y_soft >= 0.5).astype(y_soft.dtype) + ste = y_hard - stop_gradient(y_soft) + y_soft + return where(hard, ste, y_soft) +``` + +- `mask_logits`: `(n_mrl, d_ff)` — learned per-width logit vectors +- `tau`: sigmoid temperature — controls softness of selection +- `hard`: boolean — use STE for exact binary masks +- Gradients flow through the soft sigmoid back to the logits + +The same heterogeneous-batch structure is used: batch is split into N sections (full + one per MRL width), each section gets its mask. The only difference is WHERE the mask comes from. + +## Training Schedule + +Two-phase (no warmup — MRL active from step 0): + +1. **Learning phase** (first 40% of steps): soft masks with tau annealing (1.0 → 0.2). Mask logits updated by their own Adam optimizer after each step. +2. **Freeze phase** (last 60% of steps): hard binary masks via STE. Mask logit optimizer frozen — the model adapts to the locked neuron selection. + +> **Current implementation note:** The codebase now supports a three-phase schedule +> with an explicit warmup phase (`--mat-warmup-frac`) before mask learning begins. + +## CLI Args + +> **Outdated naming.** See table below for current equivalents. +> +> | Old (`learned-mat-old`) | Current (`fixed-mrl`) | +> |---|---| +> | `--mrl-method topk` | `--mat-method topk` | +> | `--mrl-tau-start` | `--mat-tau-start` | +> | `--mrl-tau-end` | `--mat-tau-end` | +> | `--mrl-freeze-frac` | `--mat-freeze-frac` | +> | `--mrl-mask-lr` | `--mat-mask-lr` | +> | `--mrl-spread-lambda` | `--mat-spread-lambda` | +> | `--mrl-init-mode` | `--mat-init-mode` | +> | `--mrl-warmup-frac` | `--mat-warmup-frac` | +> | `--mrl-saliency-scale` | `--mat-saliency-scale` | +> | *(N/A)* | `--mat-gumbel` (new: per-item Gumbel noise) | + +``` +--mrl-method topk # Enable learned masks (default: prefix) +--mrl-tau-start 1.0 # Initial sigmoid temperature +--mrl-tau-end 0.2 # Final temperature before freeze +--mrl-freeze-frac 0.6 # Fraction of training with hard frozen masks +--mrl-mask-lr 0.003 # Mask logit optimizer learning rate +--mrl-spread-lambda 0.01 # Spread penalty weight +--mrl-init-mode shuffled_prefix # Mask logit initialization (default) +``` + +## Key Design Decisions + +1. **FFN interior masking**: Masks applied inside FeedForward on d_ff intermediate. d_model is constant. Only FFN params are reduced per sub-model. +2. **Shuffled init**: Randomly permuted ramp (same value distribution as prefix, different neuron order). FFN neurons are symmetric, so init order shouldn't matter. +3. **tau/hard as JAX array args**: Passed through pmap as replicated scalars, not Python globals (globals are captured at compile time and never update). +4. **Mask optimizer host-side**: `mask_opt_state` stored as plain numpy, not replicated across devices. +5. **Eval uses learned masks**: `forward_with_aux` accepts `mrl_ffn_masks` to evaluate with the learned neuron selection, not hardcoded prefix. +6. **Spread penalty**: `−λ · mean(var(logits))` encourages clear on/off decisions. λ=0.01 matches matformer-olmo R6. + +## Results + +All experiments use identical config matching the prefix baseline in `matryoshka_results.md`: +- 87.7M params: d=512, d_ff=2048, heads=8 (kv=8), enc=8, dec=4, memory=64 +- batch=32×8=256, 11038 total steps, seed=42, sparsity=50%, speech every 3 +- TopK defaults: shuffled_prefix init, tau 1.0→0.2, freeze=0.6, mask_lr=3e-3, spread λ=0.01 + +> **Important architectural difference:** All experiments below were run on the old +> `learned-mat` branch, which applied matryoshka masking **only to normal FeedForward +> layers** (encoder local_ffn and decoder FFN). The MLPMixer FeedForwards (token_mix and +> channel_mix) were **unmasked**. The current codebase (`fixed-mrl`) applies ffn_mask to +> ALL FeedForward instances including the mixer, so these results are not directly +> comparable. The mixer masking may change optimal hyperparameters. + +### E1: TopK vs Prefix Baseline (1 epoch) + +| Width | d_ff | Prefix Baseline | **TopK (E1)** | +|-------|------|-----------------|---------------| +| Full | 2048 | **4.66** | 4.93 | +| 256 | 1024 | **4.66** | 4.95 | +| 128 | 512 | **4.68** | 5.03 | +| 64 | 256 | **4.80** | 5.19 | + +TopK is +0.27 worse at full, +0.39 worse at d=64. The learned mask has extra overhead vs fixed prefix. + +### E2: Freeze Fraction Ablation + +| Width | d_ff | freeze=0.4 | freeze=0.5 | **freeze=0.6** | freeze=0.7 | +|-------|------|------------|------------|----------------|------------| +| Full | 2048 | 5.08 | 4.98 | **4.93** | 4.94 | +| 256 | 1024 | 5.14 | 5.01 | **4.95** | 4.96 | +| 128 | 512 | 5.25 | 5.12 | **5.03** | 5.07 | +| 64 | 256 | 5.43 | 5.29 | **5.19** | 5.24 | + +freeze=0.6 is best overall. Too little freeze (0.4) hurts most. 0.7 is slightly worse than 0.6 (less learning time). + +### E3: Tau Schedule Ablation + +| Width | d_ff | tau 0.5→0.1 | **tau 1.0→0.2 (default)** | tau 2.0→0.3 | +|-------|------|-------------|---------------------------|-------------| +| Full | 2048 | 4.79 | 4.93 | **4.76** | +| 256 | 1024 | 4.82 | 4.95 | **4.78** | +| 128 | 512 | 4.93 | 5.03 | **4.88** | +| 64 | 256 | 5.09 | 5.19 | **5.02** | + +Both alternatives beat the default! tau 2.0→0.3 (softer start, softer end) is best. tau 0.5→0.1 (sharper) is also strong. The default 1.0→0.2 is the worst of the three. + +### E4: Mask LR Ablation + +| Width | d_ff | lr=1e-3 | **lr=3e-3 (default)** | lr=1e-2 | +|-------|------|---------|-----------------------|---------| +| Full | 2048 | **4.80** | 4.93 | 5.10 | +| 256 | 1024 | **4.85** | 4.95 | 5.13 | +| 128 | 512 | **4.97** | 5.03 | 5.26 | +| 64 | 256 | **5.12** | 5.19 | 5.43 | + +Lower LR (1e-3) is best. Higher LR (1e-2) hurts significantly. Default 3e-3 is middle. + +### E5: Multi-Epoch (2 epochs) + +| Width | d_ff | TopK 1 epoch | **TopK 2 epochs** | +|-------|------|--------------|--------------------| +| Full | 2048 | 4.93 | **4.36** | +| 256 | 1024 | 4.95 | **4.37** | +| 128 | 512 | 5.03 | **4.44** | +| 64 | 256 | 5.19 | **4.55** | + +2 epochs dramatically improves everything. The 2-epoch d=64 (4.55) beats the 1-epoch prefix baseline d=64 (4.80). + +### Summary: Best Single-Epoch Config + +Based on the ablations, the optimal single-variable changes from default are: +- **tau 2.0→0.3** (E3): biggest win, −0.17 at d=64 +- **lr 1e-3** (E4): −0.07 at d=64 +- **freeze 0.6** (E2): already the default, confirmed best + +A combined run with tau 2.0→0.3 + lr 1e-3 could potentially close the gap to prefix baseline further. + +### C1-C5: Hyperparameter Combinations + +All use combined best base: tau 2.0→0.3 + lr 1e-3 + freeze 0.6 + shuffled_prefix init. + +| Width | d_ff | C1 (base) | C2 (f=0.7) | C3 (f=0.5) | **C4 (λ=0)** | C5 (λ=0.005) | +|-------|------|-----------|------------|------------|--------------|--------------| +| Full | 2048 | 5.03 | 5.10 | 4.99 | **4.82** | 5.03 | +| 256 | 1024 | 5.10 | 5.13 | 5.03 | **4.84** | 5.06 | +| 128 | 512 | 5.23 | 5.24 | 5.14 | **4.96** | 5.19 | +| 64 | 256 | 5.40 | 5.42 | 5.32 | **5.12** | 5.35 | + +**Key finding: Removing the spread penalty (C4, λ=0) is the single biggest improvement.** C4 achieves 4.82 full / 5.12 d=64, beating all previous single-epoch topk runs and approaching the prefix baseline (4.66 / 4.80). The spread penalty was actively hurting — it pushes logits apart but interferes with the tau annealing schedule. + +C1 (combined best with default spread) underperforms the individual E3/E4 ablations — the combination doesn't stack. This is because the spread penalty (λ=0.01) interacts poorly with tau 2.0→0.3. + +Freeze=0.5 (C3) beats 0.6 (C1) and 0.7 (C2) when combined with slow tau, suggesting the softer tau schedule benefits from more learning time. + +### Summary: Best Overall Single-Epoch TopK Config + +| Width | d_ff | Prefix Baseline | Best TopK (C4) | Gap | +|-------|------|-----------------|----------------|-----| +| Full | 2048 | **4.66** | 4.82 | +0.16 | +| 256 | 1024 | **4.66** | 4.84 | +0.18 | +| 128 | 512 | **4.68** | 4.96 | +0.28 | +| 64 | 256 | **4.80** | 5.12 | +0.32 | + +Best topk config: `--mrl-tau-start 2.0 --mrl-tau-end 0.3 --mrl-freeze-frac 0.6 --mrl-mask-lr 0.001 --mrl-spread-lambda 0.0` + +> **Current equivalent:** `--mat-method topk --mat-tau-start 2.0 --mat-tau-end 0.3 --mat-freeze-frac 0.6 --mat-mask-lr 0.001 --mat-spread-lambda 0.0` + +TopK is still ~0.16-0.32 behind prefix at 1 epoch, but 2-epoch topk (E5: 4.55 d=64) already beats 1-epoch prefix (4.80 d=64). + +--- + +## Full-Matryoshka Experiments (with mixer masking) + +All experiments below apply `ffn_mask` to ALL FeedForward layers including MLPMixer +token_mix and channel_mix. Baselines from `matryoshka_results.md`: +- **Prefix baseline** (full mat): 4.82 / 4.82 / 4.85 / 4.97 (full/2x/4x/8x) +- **No-matryoshka**: 4.52 (full only, no sub-models) + +### R1–R5: Re-validation of Prior Findings + +All use seed=42, 1 epoch, 88.5M params. Base topk defaults: tau 0.5→0.1, lr 3e-3, +λ=0.01, warmup 15%, freeze 20%, normal init. + +| Width | d_ff | Prefix | R1 (default) | R2 (C4) | R3 (tau 0.5→0.1) | R4 (lr 3e-3) | R5 (λ=0.01) | +|-------|------|--------|-------------|---------|-------------------|-------------|-------------| +| Full | 2048 | 4.82 | 4.69 | 4.72 | 4.70 | 4.72 | 4.74 | +| 2x | 1024 | 4.82 | 4.94 | 4.76 | 4.75 | 4.74 | 4.78 | +| 4x | 512 | 4.85 | 5.40 | 4.92 | 4.88 | 4.85 | 4.96 | +| 8x | 256 | 4.97 | 5.81 | 5.17 | 5.05 | **5.01** | 5.14 | + +**R1** = default topk (tau 0.5→0.1, lr 3e-3, λ=0.01, f=0.2, normal init). +**R2** = old C4 config (tau 2.0→0.3, lr 1e-3, λ=0, f=0.6). +**R3** = R2 but tau 0.5→0.1 (sharper). **R4** = R3 but lr 3e-3. **R5** = R4 but λ=0.01. + +**Key finding shifts from old experiments (without mixer masking):** +- **Tau 0.5→0.1 (sharper) is now better** than 2.0→0.3 (softer). Reversed from before. +- **lr 3e-3 is now slightly better** than 1e-3 at sub-models. Reversed from before. +- **Spread penalty (λ=0.01) still hurts.** Confirmed. +- **Best non-saliency 1-epoch config**: tau 0.5→0.1, lr 3e-3, λ=0, f=0.6 (R4). + +### N1–N3: Saliency Scale Sweep + +Saliency init with best R-series config (tau 0.5→0.1, lr 3e-3, λ=0, f=0.6, 10% warmup). + +| Width | d_ff | R4 (normal init) | N1 (scale=1.0) | N2 (scale=0.5) | N3 (scale=2.0) | +|-------|------|-----------------|----------------|----------------|----------------| +| Full | 2048 | 4.72 | **4.62** | 4.72 | 4.64 | +| 2x | 1024 | 4.74 | **4.63** | 4.74 | 4.64 | +| 4x | 512 | 4.85 | **4.72** | 4.78 | 4.73 | +| 8x | 256 | 5.01 | **4.88** | 4.94 | **4.88** | + +**Saliency init is transformative.** Scale=1.0 and 2.0 are essentially tied; scale=0.5 +is too weak. All saliency runs beat prefix at every width. + +### N7: Multi-Epoch (2 epochs, non-saliency) + +| Width | d_ff | R4 (1 epoch) | N7 (2 epochs) | +|-------|------|-------------|--------------| +| Full | 2048 | 4.72 | **4.49** | +| 2x | 1024 | 4.74 | **4.50** | +| 4x | 512 | 4.85 | **4.63** | +| 8x | 256 | 5.01 | **4.76** | + +### S1–S10: Saliency Deep Dive + +All use: tau 0.5→0.1, lr 3e-3, λ=0, f=0.6, saliency scale=1.0, seed=42. + +| ID | Experiment | Full | 2x | 4x | 8x | +|----|-----------|------|-----|-----|-----| +| S1 | Saliency-only (f=1.0, 10% warmup) | 4.67 | 4.67 | 4.73 | 4.88 | +| S2 | Saliency + freeze=0.5, 10% warmup | 4.61 | 4.63 | 4.73 | 4.93 | +| S3 | Saliency + Gumbel, 10% warmup | 4.66 | 4.66 | 4.72 | 4.86 | +| S4 | Saliency, **20% warmup** | 4.60 | 4.61 | 4.69 | 4.85 | +| S5 | Saliency, **30% warmup** | 4.60 | 4.61 | 4.68 | 4.84 | +| **S6** | **Saliency, 40% warmup** | **4.59** | **4.60** | **4.67** | **4.84** | +| S7 | Saliency, 50% warmup | 4.61 | 4.62 | 4.71 | 4.89 | +| S8 | Saliency + Gumbel, 40% warmup | 4.63 | 4.64 | 4.71 | 4.88 | +| S9 | Saliency-only (f=1.0), 40% warmup | **4.59** | **4.60** | **4.67** | **4.84** | +| **S10** | **Saliency, 40% warmup, 2 epochs** | **4.38** | **4.39** | **4.45** | **4.60** | + +### Summary: Best Configs + +| Config | Full | 2x | 4x | 8x | Notes | +|--------|------|-----|-----|-----|-------| +| No matryoshka | 4.52 | - | - | - | No sub-models | +| Prefix (static) | 4.82 | 4.82 | 4.85 | 4.97 | Simple, no learning | +| **Best topk 1ep (S9)** | **4.59** | **4.60** | **4.67** | **4.84** | Saliency-only, 40% warmup, no mask learning | +| Best topk 2ep (S10) | 4.38 | 4.39 | 4.45 | 4.60 | S9 config at 2 epochs | + +**Optimal config (S9 — saliency-only)**: `--mat-method topk --mat-init-mode saliency --mat-saliency-scale 1.0 --mat-warmup-frac 0.4 --mat-freeze-frac 1.0` + +S9 is the recommended default: saliency picks the neurons during 40% warmup, then hard masks are frozen for the remaining 60%. No mask optimizer, no tau annealing — simplest and best. Mask learning (S6, freeze=0.6) achieves identical results with more complexity. + +### Experiments Still Needed + +#### Export Verification + +| ID | Experiment | Details | +|----|-----------|---------| +| N9 | **Export correctness** | Extract learned hard masks, slice FFN weights, verify: loads, PPL matches masked eval, size matches `_estimate_mrl_params`, works for all 3 widths (256, 128, 64) | +| N10 | **Export + quantization** | INT4 quantization on exported sub-models (smaller d_ff may interact with group_size=32 alignment) | + +## Key Lessons Learned + +### From old experiments (without mixer masking) +1. **Spread penalty hurts.** λ=0 was the single biggest improvement (C4). The penalty fights the tau schedule. +2. **Multi-epoch is transformative.** 2 epochs closes the gap to prefix baseline entirely. +3. **Combinations don't always stack.** C1 (tau+lr combined) was worse than individual ablations due to spread penalty interaction. + +### New findings (with mixer masking) +4. **Saliency init is the most important hyperparameter.** It alone provides +0.1–0.13 PPL improvement over normal init at every width. +5. **Warmup fraction matters more than expected.** 40% warmup is optimal — gives the saliency estimator enough data to rank neurons accurately. Below 20% is noticeably worse; above 50% wastes too much training time on full-model-only. +6. **Mask learning adds almost nothing with good saliency.** S9 (saliency-only, freeze=1.0) matches S6 (saliency + learning). The gradient-importance ranking is already near-optimal. +7. **Gumbel noise hurts.** Per-item noise during soft phase adds variance without benefit. Saliency provides enough exploration. +8. **Tau and LR priors reversed with mixer masking.** Sharper tau (0.5→0.1) and higher LR (3e-3) now beat their softer/lower counterparts. More masked parameters may need faster convergence. +9. **Spread penalty still hurts.** Confirmed across all configs. +10. **2-epoch saliency (S10) is the best model overall.** 4.38 full PPL beats the no-matryoshka baseline (4.52) by 0.14, while providing 3 additional sub-models down to 8x compression. +11. **Saliency is text-only.** Speech warmup gradients are currently discarded — saliency ranks FFN neurons based on text steps only. This is acceptable since text is the primary task, but for speech-heavy deployments the saliency ranking may under-represent speech-relevant neurons. A future experiment could accumulate speech gradients into the same saliency buffer and compare sub-model speech PPL. + +## Historical Notes + +Earlier iterations (on the `learned-mat-old` branch) applied masks to d_model at the output logit projection, not inside FFN. This was fundamentally wrong — sub-models shared all internal computation. Evidence: shuffled_prefix gave different results than prefix with output-only masking (d=128: 4.63 vs 4.55 PPL), proving the model exploited contiguous prefix structure rather than learning genuine sub-networks. Moving to FFN interior masking fixed this. diff --git a/src/cli.py b/src/cli.py index 928ae95..6d1f19f 100644 --- a/src/cli.py +++ b/src/cli.py @@ -1,7 +1,90 @@ import argparse import sys -HELP = """Check the readme""" +HELP = """ + ┌───────────────────────────────────────────────────────────────────┐ + │ │ + │ ┌─┐┌─┐┌─┐┌┬┐┬ ┬┌─┐ ┌┐┌┌─┐┌─┐┌┬┐┬ ┌─┐ │ + │ │ ├─┤│ │ │ │└─┐ │││├┤ ├┤ │││ ├┤ │ + │ └─┘┴ ┴└─┘ ┴ └─┘└─┘ ┘└┘└─┘└─┘─┴┘┴─┘└─┘ │ + │ ...the tiny model to rule them all... │ + │ │ + │ train │ + │ --full Use full 1B config (~1.17B params) │ + │ --epochs INT Training epochs (default: 1) │ + │ --batch-size INT Batch size (default: 32) │ + │ --lr FLOAT AdamW learning rate (default: 3e-4) │ + │ --muon-lr FLOAT Muon learning rate (default: 0.02) │ + │ --d-model INT Model dim (default: 512) │ + │ --num-heads INT Attention heads (default: 8) │ + │ --num-kv-heads INT KV heads for GQA (default: num-heads)│ + │ --num-layers INT Encoder layers (default: 8) │ + │ --num-dec-layers INT Decoder layers (default: 4) │ + │ --max-enc-len INT Max encoder seq length (default: 256)│ + │ --max-dec-len INT Max decoder seq length (default: 256)│ + │ --max-samples INT Training samples (default: all) │ + │ --mat-factors INT [...] FFN shrink factors (default: 2 4 8) │ + │ --mat-method STR static-prefix|topk (default: topk) │ + │ --mat-init-mode STR saliency|prefix|normal (def: sal.) │ + │ --mat-warmup-frac FL Saliency warmup fraction (def: 0.4) │ + │ --mat-freeze-frac FL Mask freeze fraction (default: 1.0) │ + │ --mat-tau-start FLOAT TopK tau start (default: 0.5) │ + │ --mat-tau-end FLOAT TopK tau end (default: 0.1) │ + │ --mat-mask-lr FLOAT Mask logit LR (default: 3e-3) │ + │ --sparsity-ratio FLOAT Block prune ratio (default: 0.5) │ + │ --group-size INT Quant/prune group size (default: 32) │ + │ --prune-interval INT Steps between mask updates (def: 100)│ + │ --prune-start-frac FL Start pruning at frac (def: 0.33) │ + │ --prune-end-frac FL Lock mask at this frac (def: 0.67) │ + │ --activation STR drelu|swiglu|geglu (default: drelu) │ + │ --warmup-ratio FLOAT LR warmup ratio (default: 0.05) │ + │ --eval-every INT Val eval interval (default: 1000) │ + │ --wandb Enable W&B logging │ + │ --checkpoint PATH Resume from checkpoint │ + │ --checkpoint-dir DIR Checkpoint directory │ + │ --seed INT Random seed (default: 42) │ + │ --no-speech Disable speech (text-only training) │ + │ --speech-every INT Speech step every N text (default: 3) │ + │ --max-mel-len INT Max mel frames (default: 1024) │ + │ --n-mels INT Mel frequency bins (default: 80) │ + │ --max-speech-samples INT Max LibriSpeech samples │ + │ │ + │ run │ + │ --checkpoint PATH Path to model checkpoint (required) │ + │ --query STR Query text for tool-call generation │ + │ --tools STR Tools JSON for tool-call generation │ + │ --audio PATH [...] Audio files for voice-to-tool-call │ + │ --max-len INT Max tokens to generate (default: 512) │ + │ --seed INT Random seed (default: 0) │ + │ │ + │ test │ + │ --checkpoint PATH Path to model checkpoint (required) │ + │ --batch-size INT Batch size (default: 32) │ + │ --max-eval-samples INT Evaluation samples (default: 1000) │ + │ --max-gen-len INT Max generation length (default: 512) │ + │ --tool-call-samples INT Tool-call accuracy samples (def: 200) │ + │ --voice-tc-samples INT Voice-tool-call samples (default: 50) │ + │ --throughput-runs INT Throughput runs (default: 10) │ + │ │ + │ evaluate │ + │ --checkpoint PATH Path to model checkpoint (required) │ + │ --benchmarks [...] wikitext2 lambada hellaswag arc_easy │ + │ --max-samples INT Samples per benchmark (default: 500) │ + │ │ + │ tpu │ + │ create NAME Create TPU (auto-finds zone) │ + │ --type STR Accelerator (default: v6e-8) │ + │ --version STR TPU OS (auto-detected from --type) │ + │ connect NAME SSH config + connect (auto-zone) │ + │ claude NAME Install Claude Code on instance │ + │ stop NAME Stop instance (auto-zone) │ + │ start NAME Start stopped instance (auto-zone) │ + │ delete NAME Delete instance (auto-zone) │ + │ list List all TPU instances │ + │ --zone ZONE Override auto-detected zone │ + │ │ + └───────────────────────────────────────────────────────────────────┘ +""" def main(): if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help", "help"): @@ -47,6 +130,22 @@ def main(): help="Matryoshka FFN shrink factors, e.g. 2=half width (default: 2 4 8)") p.add_argument("--mat-shared-input", action="store_true", help="Each unique input is repeated across all mat widths (default: unique input per width)") + p.add_argument("--mat-method", choices=["static-prefix", "topk"], default="topk", + help="Matryoshka method: 'static-prefix' (fixed first-N masks), 'topk' (saliency-based masks, default)") + p.add_argument("--mat-tau-start", type=float, default=0.5) + p.add_argument("--mat-tau-end", type=float, default=0.1) + p.add_argument("--mat-init-mode", choices=["prefix", "shuffled_prefix", "saliency", "normal", "zeros"], default="saliency") + p.add_argument("--mat-init-value", type=float, default=0.5) + p.add_argument("--mat-spread-lambda", type=float, default=0.0) + p.add_argument("--mat-warmup-frac", type=float, default=0.4, + help="Fraction of total steps for vanilla warmup (no masks)") + p.add_argument("--mat-freeze-frac", type=float, default=1.0, + help="Fraction of total steps at end with frozen hard masks") + p.add_argument("--mat-mask-lr", type=float, default=3e-3, + help="Mask logit optimizer learning rate (default: 3e-3)") + p.add_argument("--mat-saliency-scale", type=float, default=1.0) + p.add_argument("--mat-gumbel", action="store_true", + help="Use Gumbel noise for per-item mask diversity during topk learning") p.add_argument("--no-speech", action="store_true", help="Disable speech training (text-only)") p.add_argument("--max-mel-len", type=int, default=1024, help="Max mel spectrogram frames (default: 1024)") diff --git a/src/evaluate.py b/src/evaluate.py index 911b853..a30dbc1 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -10,44 +10,46 @@ from .data import get_tokenizer from .model import ( EncoderDecoderTransformer, - TransformerConfig, make_causal_mask, make_padding_mask, ) from .run import load_checkpoint -def _make_p_encode(model): +def _make_p_encode(model, enc_ffn=None): """Create a pmap'd encode function.""" def _encode(params, src, src_mask): return model.apply( - {"params": params}, src, src_mask=src_mask, method="encode", + {"params": params}, src, src_mask=src_mask, ffn_mask=enc_ffn, method="encode", ) return jax.pmap(_encode, axis_name="batch") -def _make_p_decode(model): +def _make_p_decode(model, dec_ffn=None): """Create a pmap'd decode function.""" def _decode(params, dec_input, encoder_out, tgt_mask, _unused_cross_mask): return model.apply( {"params": params}, dec_input, encoder_out, - self_mask=tgt_mask, method="decode", + self_mask=tgt_mask, ffn_mask=dec_ffn, method="decode", ) return jax.pmap(_decode, axis_name="batch") def _shard_single(x, num_devices): """Replicate a single-sample batch across all devices for pmap.""" - return jnp.broadcast_to(x, (num_devices, *x.shape[1:])) + return jnp.broadcast_to(x[None], (num_devices, *x.shape)) def score_sequence(model, params, enc_tokens, dec_tokens, pad_id, sos_id=None, - p_encode=None, p_decode=None, num_devices=1): + p_encode=None, p_decode=None, num_devices=1, ffn_mask=None): """Compute average negative log-likelihood of dec_tokens given enc_tokens.""" sos = sos_id if sos_id is not None else pad_id enc_input = jnp.array([enc_tokens]) src_mask = make_padding_mask(enc_input, pad_id) + enc_ffn = ffn_mask["encoder"] if ffn_mask else None + dec_ffn = ffn_mask["decoder"] if ffn_mask else None + if p_encode is not None and num_devices > 1: enc_s = _shard_single(enc_input, num_devices) src_mask_s = _shard_single(src_mask, num_devices) @@ -55,7 +57,7 @@ def score_sequence(model, params, enc_tokens, dec_tokens, pad_id, sos_id=None, else: p = params if num_devices <= 1 else jax_utils.unreplicate(params) encoder_out = model.apply( - {"params": p}, enc_input, src_mask=src_mask, method="encode", + {"params": p}, enc_input, src_mask=src_mask, ffn_mask=enc_ffn, method="encode", ) dec_in = [sos] + list(dec_tokens[:-1]) @@ -71,7 +73,7 @@ def score_sequence(model, params, enc_tokens, dec_tokens, pad_id, sos_id=None, p = params if num_devices <= 1 else jax_utils.unreplicate(params) logits = model.apply( {"params": p}, dec_input, encoder_out, - self_mask=tgt_mask, method="decode", + self_mask=tgt_mask, ffn_mask=dec_ffn, method="decode", )[0] log_probs = jax.nn.log_softmax(logits if logits.ndim == 2 else logits[0]) @@ -81,7 +83,7 @@ def score_sequence(model, params, enc_tokens, dec_tokens, pad_id, sos_id=None, def eval_wikitext2(model, params, tokenizer, max_samples=500, max_len=256, - num_devices=1, p_encode=None, p_decode=None): + num_devices=1, p_encode=None, p_decode=None, ffn_mask=None): """Perplexity on WikiText-2 test split.""" ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") @@ -91,6 +93,8 @@ def eval_wikitext2(model, params, tokenizer, max_samples=500, max_len=256, total_tokens = 0 evaluated = 0 + enc_ffn = ffn_mask["encoder"] if ffn_mask else None + dec_ffn = ffn_mask["decoder"] if ffn_mask else None single_params = jax_utils.unreplicate(params) if num_devices > 1 else params for example in ds: @@ -116,7 +120,7 @@ def eval_wikitext2(model, params, tokenizer, max_samples=500, max_len=256, encoder_out = p_encode(params, enc_s, src_mask_s)[0:1] else: encoder_out = model.apply( - {"params": single_params}, enc_input, src_mask=src_mask, method="encode", + {"params": single_params}, enc_input, src_mask=src_mask, ffn_mask=enc_ffn, method="encode", ) dec_in = [sos_id] + list(dec_tokens[:-1]) @@ -131,7 +135,7 @@ def eval_wikitext2(model, params, tokenizer, max_samples=500, max_len=256, else: logits = model.apply( {"params": single_params}, dec_input, encoder_out, - self_mask=tgt_mask, method="decode", + self_mask=tgt_mask, ffn_mask=dec_ffn, method="decode", ) log_probs = jax.nn.log_softmax(logits[0] if logits.ndim == 3 else logits[0]) @@ -150,7 +154,7 @@ def eval_wikitext2(model, params, tokenizer, max_samples=500, max_len=256, def eval_lambada(model, params, tokenizer, max_samples=500, - num_devices=1, p_encode=None, p_decode=None): + num_devices=1, p_encode=None, p_decode=None, ffn_mask=None): """Accuracy of predicting the final word on LAMBADA.""" ds = load_dataset("EleutherAI/lambada_openai", "default", split="test") @@ -159,6 +163,8 @@ def eval_lambada(model, params, tokenizer, max_samples=500, correct = 0 total = 0 + enc_ffn = ffn_mask["encoder"] if ffn_mask else None + dec_ffn = ffn_mask["decoder"] if ffn_mask else None single_params = jax_utils.unreplicate(params) if num_devices > 1 else params for example in ds: @@ -183,7 +189,7 @@ def eval_lambada(model, params, tokenizer, max_samples=500, encoder_out = p_encode(params, enc_s, src_mask_s)[0:1] else: encoder_out = model.apply( - {"params": single_params}, enc_input, src_mask=src_mask, method="encode", + {"params": single_params}, enc_input, src_mask=src_mask, ffn_mask=enc_ffn, method="encode", ) dec_in = jnp.array([[sos_id]]) @@ -197,7 +203,7 @@ def eval_lambada(model, params, tokenizer, max_samples=500, else: logits = model.apply( {"params": single_params}, dec_in, encoder_out, - self_mask=tgt_mask, method="decode", + self_mask=tgt_mask, ffn_mask=dec_ffn, method="decode", ) predicted = int(jnp.argmax(logits[0, 0] if logits.ndim == 3 else logits[0, 0])) @@ -213,7 +219,7 @@ def eval_lambada(model, params, tokenizer, max_samples=500, def eval_hellaswag(model, params, tokenizer, max_samples=500, - num_devices=1, p_encode=None, p_decode=None): + num_devices=1, p_encode=None, p_decode=None, ffn_mask=None): """Accuracy on HellaSwag by scoring each candidate ending.""" ds = load_dataset("Rowan/hellaswag", split="validation") @@ -242,7 +248,7 @@ def eval_hellaswag(model, params, tokenizer, max_samples=500, dec_tokens = dec_tokens[:64] score = score_sequence(model, params, enc_tokens, dec_tokens, pad_id, sos_id=sos_id, p_encode=p_encode, p_decode=p_decode, - num_devices=num_devices) + num_devices=num_devices, ffn_mask=ffn_mask) scores.append(score) predicted = int(np.argmax(scores)) @@ -258,7 +264,7 @@ def eval_hellaswag(model, params, tokenizer, max_samples=500, def eval_arc_easy(model, params, tokenizer, max_samples=500, - num_devices=1, p_encode=None, p_decode=None): + num_devices=1, p_encode=None, p_decode=None, ffn_mask=None): """Accuracy on ARC-Easy by scoring each answer choice.""" ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test") @@ -288,7 +294,7 @@ def eval_arc_easy(model, params, tokenizer, max_samples=500, dec_tokens = dec_tokens[:64] score = score_sequence(model, params, enc_tokens, dec_tokens, pad_id, sos_id=sos_id, p_encode=p_encode, p_decode=p_decode, - num_devices=num_devices) + num_devices=num_devices, ffn_mask=ffn_mask) scores.append(score) predicted_idx = int(np.argmax(scores)) @@ -317,7 +323,7 @@ def main(args): print(f"Detected {num_devices} device(s) for data-parallel evaluation") print(f"Loading checkpoint: {args.checkpoint}") - params, config = load_checkpoint(args.checkpoint) + params, config, ffn_mask = load_checkpoint(args.checkpoint) model = EncoderDecoderTransformer(config) tokenizer = get_tokenizer() @@ -325,10 +331,15 @@ def main(args): print(f"Model parameters: {param_count:,}") # Replicate params across devices for pmap + enc_ffn = ffn_mask["encoder"] if ffn_mask else None + dec_ffn = ffn_mask["decoder"] if ffn_mask else None + if ffn_mask: + print(f"sub-model: topk FFN masking active") + if num_devices > 1: params = jax_utils.replicate(params) - p_encode = _make_p_encode(model) - p_decode = _make_p_decode(model) + p_encode = _make_p_encode(model, enc_ffn=enc_ffn) + p_decode = _make_p_decode(model, dec_ffn=dec_ffn) print(f"Params replicated across {num_devices} devices") else: p_encode = None @@ -349,6 +360,7 @@ def main(args): result = BENCHMARKS[name]( model, params, tokenizer, max_samples=args.max_samples, num_devices=num_devices, p_encode=p_encode, p_decode=p_decode, + ffn_mask=ffn_mask, ) results[name] = result diff --git a/src/export.py b/src/export.py index 8b5d00e..60c6f36 100644 --- a/src/export.py +++ b/src/export.py @@ -2,6 +2,8 @@ With FFN interior matryoshka, d_model stays constant — only FFN intermediate dimensions (gate_proj, up_proj, down_proj) are sliced. + +For topk-trained models, exports full model + mask indices per factor. """ import os @@ -19,10 +21,10 @@ def export_submodel(checkpoint_path, factor, output_path): - """Slice a full matryoshka checkpoint to a sub-model at given shrink factor. + """Export a matryoshka sub-model from a full checkpoint. - factor: how many times smaller the FFN width (e.g. 2 = half, 4 = quarter). - Attention, embeddings, and norms are unchanged. + For prefix-trained models: slices FFN weights to create a smaller d_ff. + For topk-trained models: saves full model + binary mask indices per factor. """ with open(checkpoint_path, "rb") as f: @@ -30,6 +32,57 @@ def export_submodel(checkpoint_path, factor, output_path): params = data["params"] config = TransformerConfig(**data["config"]) + mat_method = data.get("mat_method", "static-prefix") + if mat_method == "topk" and "mask_logits" in data: + return _export_topk(data, params, config, factor, output_path) + else: + return _export_prefix(params, config, factor, output_path) + + +def _export_topk(data, params, config, factor, output_path): + """TopK export: full model + per-layer binary mask indices for FFN masking.""" + mask_logits = np.asarray(data["mask_logits"]) # (n_mat, n_blocks, d_ff) + mat_factors = data.get("mat_factors", []) + if factor not in mat_factors: + raise ValueError(f"factor={factor} not found in mat_factors={mat_factors}") + + factor_logits = mask_logits[mat_factors.index(factor)] # (n_blocks, d_ff) + ff_w = config.d_ff // factor + per_layer_indices = [np.sort(np.argsort(-factor_logits[b])[:ff_w]) for b in range(factor_logits.shape[0])] + + params_np = jax.tree.map( + lambda x: np.asarray(x) if isinstance(x, jnp.ndarray) else x, params + ) + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + with open(output_path, "wb") as f: + pickle.dump({ + "params": params_np, + "config": config.__dict__, + "mat_mask_indices": per_layer_indices, + "mat_factor": factor, + "mat_ff_width": ff_w, + }, f) + + n_blocks = len(per_layer_indices) + orig_count = sum(x.size for x in jax.tree.leaves(params)) + orig_bytes = sum(x.nbytes for x in jax.tree.leaves(params)) + + print(f"\n TopK export: {output_path}") + print(f" ─────────────────────────────────────") + print(f" d_ff (full) {config.d_ff:>12d}") + print(f" d_ff (masked) {ff_w:>12d}") + print(f" factor {str(factor)+'x':>12s}") + print(f" blocks {n_blocks:>12d} (per-layer masks)") + print(f" neurons/layer {ff_w:>12d}") + print(f" params (full) {orig_count:>12,d}") + print(f" size (MB) {orig_bytes / 1e6:>12.1f}") + print(f" Note: full weights kept; per-layer mask applied at FFN level") + print() + + +def _export_prefix(params, config, factor, output_path): + """Prefix export: slice FFN weights to a smaller d_ff.""" d_ff_new = config.d_ff // factor if d_ff_new == 0: raise ValueError(f"factor={factor} too large: would give d_ff=0") diff --git a/src/model.py b/src/model.py index 9a16661..688f190 100644 --- a/src/model.py +++ b/src/model.py @@ -231,10 +231,11 @@ def __call__(self, x, mask=None, rope=None, ffn_mask=None): mask = mask[..., :T_new] for i in range(cfg.num_encoder_layers): + block_mask = ffn_mask[i] if (ffn_mask is not None and ffn_mask.ndim == 3) else ffn_mask x, s = nn.remat(MemoryMixerBlock)( cfg.num_heads, cfg.num_kv_heads, cfg.d_model, cfg.d_ff, cfg.num_memory_slots, cfg.total_layers, dt, cfg.activation, name=f"block_{i}" - )(x, s, mask=mask, rope=rope, ffn_mask=ffn_mask) + )(x, s, mask=mask, rope=rope, ffn_mask=block_mask) s = ZCRMSNorm(dtype=dt, name="final_norm")(s) return s @@ -284,9 +285,10 @@ def __call__(self, x, encoder_out, self_mask=None, cross_mask=None, rope=None, f x = x.astype(dt) for i in range(cfg.num_decoder_layers): + block_mask = ffn_mask[i] if (ffn_mask is not None and ffn_mask.ndim == 3) else ffn_mask x = nn.remat(DecoderBlock)( cfg.num_heads, cfg.num_kv_heads, cfg.d_model, cfg.d_ff, cfg.total_layers, dt, cfg.activation, name=f"block_{i}" - )(x, encoder_out, self_mask=self_mask, cross_mask=cross_mask, rope=rope, ffn_mask=ffn_mask) + )(x, encoder_out, self_mask=self_mask, cross_mask=cross_mask, rope=rope, ffn_mask=block_mask) x = ZCRMSNorm(dtype=dt)(x) return x @@ -349,9 +351,8 @@ def encode_text(self, src, src_mask=None, ffn_mask=None): rope = self._rope(src.shape[1]) return self.encoder(x, mask=src_mask, rope=rope, ffn_mask=ffn_mask) - def encode(self, src, src_mask=None): - """Backward-compatible alias for encode_text.""" - return self.encode_text(src, src_mask=src_mask) + def encode(self, src, src_mask=None, ffn_mask=None): + return self.encode_text(src, src_mask=src_mask, ffn_mask=ffn_mask) def encode_speech(self, mel, src_mask=None, ffn_mask=None, deterministic=True): mel = self.spec_augment(mel, deterministic=deterministic) @@ -359,11 +360,13 @@ def encode_speech(self, mel, src_mask=None, ffn_mask=None, deterministic=True): rope = self._rope(x.shape[1]) return self.encoder(x, mask=src_mask, rope=rope, ffn_mask=ffn_mask) - def decode(self, tgt, encoder_out, self_mask=None): + def decode(self, tgt, encoder_out, self_mask=None, cross_mask=None, ffn_mask=None): """Decode from encoder memory slots. No cross_mask needed (fixed-size slots).""" x = self.embedding(tgt) * self.embed_scale rope = self._rope(tgt.shape[1]) - x = self.decoder(x, encoder_out, self_mask=self_mask, cross_mask=None, rope=rope) + x = self.decoder( + x, encoder_out, self_mask=self_mask, cross_mask=None, rope=rope, ffn_mask=ffn_mask + ) logits = x.astype(jnp.float32) @ self.embedding.embedding.T return logits @@ -385,18 +388,27 @@ def _slot_diversity(self, encoder_out): diag_sq = jnp.sum(jnp.diagonal(gram, axis1=1, axis2=2) ** 2) return (jnp.sum(gram ** 2) - diag_sq) / s.shape[0] + def _split_ffn_mask(self, ffn_mask): + """Split a (n_blocks, batch, d_ff) mask into encoder and decoder portions.""" + if ffn_mask is not None and ffn_mask.ndim == 3: + n_enc = self.config.num_encoder_layers + return ffn_mask[:n_enc], ffn_mask[n_enc:] + return ffn_mask, ffn_mask + def forward_masked(self, src, tgt, src_mask=None, tgt_mask=None, ffn_mask=None): """Single forward with per-batch-item FFN masking. Returns (logits, slot_div).""" - encoder_out = self.encode_text(src, src_mask=src_mask, ffn_mask=ffn_mask) - x_f32 = self._run_decoder(encoder_out, tgt, tgt_mask=tgt_mask, ffn_mask=ffn_mask) + enc_mask, dec_mask = self._split_ffn_mask(ffn_mask) + encoder_out = self.encode_text(src, src_mask=src_mask, ffn_mask=enc_mask) + x_f32 = self._run_decoder(encoder_out, tgt, tgt_mask=tgt_mask, ffn_mask=dec_mask) logits = x_f32 @ self.embedding.embedding.T slot_div = self._slot_diversity(encoder_out) return logits, slot_div def forward_speech_masked(self, mel, tgt, src_mask=None, tgt_mask=None, ffn_mask=None, deterministic=True): """Single speech forward with per-batch-item FFN masking. Returns (logits, slot_div).""" - encoder_out = self.encode_speech(mel, src_mask=src_mask, ffn_mask=ffn_mask, deterministic=deterministic) - x_f32 = self._run_decoder(encoder_out, tgt, tgt_mask=tgt_mask, ffn_mask=ffn_mask) + enc_mask, dec_mask = self._split_ffn_mask(ffn_mask) + encoder_out = self.encode_speech(mel, src_mask=src_mask, ffn_mask=enc_mask, deterministic=deterministic) + x_f32 = self._run_decoder(encoder_out, tgt, tgt_mask=tgt_mask, ffn_mask=dec_mask) logits = x_f32 @ self.embedding.embedding.T slot_div = self._slot_diversity(encoder_out) return logits, slot_div @@ -406,47 +418,46 @@ def _make_eval_ffn_mask(self, ff_width, B, dtype): mask = (jnp.arange(self.config.d_ff) < ff_width).astype(dtype) return jnp.broadcast_to(mask[None, :], (B, self.config.d_ff)) - def forward_with_aux(self, src, tgt, src_mask=None, tgt_mask=None, mat_ff_widths=None): - """Eval-only: separate per-width forwards for reporting per-width PPL. - - mat_ff_widths: list of FFN widths to evaluate (e.g. [1024, 512, 256]). - """ + def _eval_sub_models(self, encode_fn, src, tgt, src_mask, tgt_mask, B, dtype, mat_ff_widths, mat_ffn_masks): + """Run per-width forwards for matryoshka eval. Returns list of logit tensors.""" emb = self.embedding.embedding - B = src.shape[0] - - encoder_out = self.encode_text(src, src_mask=src_mask) - x_f32 = self._run_decoder(encoder_out, tgt, tgt_mask=tgt_mask) - logits = x_f32 @ emb.T - slot_div = self._slot_diversity(encoder_out) - + n_enc = self.config.num_encoder_layers + d_ff = self.config.d_ff mat_logits = [] - if mat_ff_widths is not None: + if mat_ffn_masks is not None: + for m in mat_ffn_masks: + if m.ndim == 2: + mask = jnp.broadcast_to(m[:, None, :], (m.shape[0], B, d_ff)) + enc_m, dec_m = mask[:n_enc], mask[n_enc:] + else: + enc_m = dec_m = jnp.broadcast_to(m[None, :], (B, d_ff)) + x_m = self._run_decoder(encode_fn(src, src_mask=src_mask, ffn_mask=enc_m), tgt, tgt_mask=tgt_mask, ffn_mask=dec_m) + mat_logits.append(x_m @ emb.T) + elif mat_ff_widths is not None: for ff_w in mat_ff_widths: - mask = self._make_eval_ffn_mask(ff_w, B, x_f32.dtype) - enc_m = self.encode_text(src, src_mask=src_mask, ffn_mask=mask) - x_m = self._run_decoder(enc_m, tgt, tgt_mask=tgt_mask, ffn_mask=mask) + mask = self._make_eval_ffn_mask(ff_w, B, dtype) + x_m = self._run_decoder(encode_fn(src, src_mask=src_mask, ffn_mask=mask), tgt, tgt_mask=tgt_mask, ffn_mask=mask) mat_logits.append(x_m @ emb.T) + return mat_logits + def forward_with_aux(self, src, tgt, src_mask=None, tgt_mask=None, cross_mask=None, mat_ff_widths=None, mat_ffn_masks=None): + """Eval-only: separate per-width forwards for reporting per-width PPL.""" + encoder_out = self.encode_text(src, src_mask=src_mask) + x_f32 = self._run_decoder(encoder_out, tgt, tgt_mask=tgt_mask) + logits = x_f32 @ self.embedding.embedding.T + slot_div = self._slot_diversity(encoder_out) + mat_logits = self._eval_sub_models(self.encode_text, src, tgt, src_mask, tgt_mask, src.shape[0], x_f32.dtype, mat_ff_widths, mat_ffn_masks) return logits, slot_div, mat_logits - def forward_speech_with_aux(self, mel, tgt, src_mask=None, tgt_mask=None, mat_ff_widths=None, deterministic=True): + def forward_speech_with_aux(self, mel, tgt, src_mask=None, tgt_mask=None, mat_ff_widths=None, mat_ffn_masks=None, deterministic=True): """Eval-only: separate per-width speech forwards for reporting per-width PPL.""" - emb = self.embedding.embedding - B = mel.shape[0] - - encoder_out = self.encode_speech(mel, src_mask=src_mask, deterministic=deterministic) + from functools import partial + encode_fn = partial(self.encode_speech, deterministic=deterministic) + encoder_out = encode_fn(mel, src_mask=src_mask) x_f32 = self._run_decoder(encoder_out, tgt, tgt_mask=tgt_mask) - logits = x_f32 @ emb.T + logits = x_f32 @ self.embedding.embedding.T slot_div = self._slot_diversity(encoder_out) - - mat_logits = [] - if mat_ff_widths is not None: - for ff_w in mat_ff_widths: - mask = self._make_eval_ffn_mask(ff_w, B, x_f32.dtype) - enc_m = self.encode_speech(mel, src_mask=src_mask, ffn_mask=mask, deterministic=deterministic) - x_m = self._run_decoder(enc_m, tgt, tgt_mask=tgt_mask, ffn_mask=mask) - mat_logits.append(x_m @ emb.T) - + mat_logits = self._eval_sub_models(encode_fn, mel, tgt, src_mask, tgt_mask, mel.shape[0], x_f32.dtype, mat_ff_widths, mat_ffn_masks) return logits, slot_div, mat_logits def init_all(self, src, tgt, mel): diff --git a/src/run.py b/src/run.py index a9a00ba..ed41692 100644 --- a/src/run.py +++ b/src/run.py @@ -21,10 +21,23 @@ def load_checkpoint(path): data = pickle.load(f) params = jax.tree.map(jnp.array, data["params"]) config = TransformerConfig(**data["config"]) - return params, config + ffn_mask = _build_ffn_mask(data, config) if "mat_mask_indices" in data else None + return params, config, ffn_mask -def generate(model, params, tokenizer, query, tools="[]", max_gen_len=512, seed=0, stream=True, task_token_id=None): +def _build_ffn_mask(data, config): + """Build (n_blocks, 1, d_ff) binary FFN mask from exported topk mask indices.""" + indices = data["mat_mask_indices"] + n_blocks = len(indices) + mask = np.zeros((n_blocks, 1, config.d_ff), dtype=np.float32) + for b, idx in enumerate(indices): + mask[b, 0, idx] = 1.0 + enc_mask = jnp.array(mask[:config.num_encoder_layers]) + dec_mask = jnp.array(mask[config.num_encoder_layers:]) + return {"encoder": enc_mask, "decoder": dec_mask} + + +def generate(model, params, tokenizer, query, tools="[]", max_gen_len=512, seed=0, stream=True, task_token_id=None, ffn_mask=None): """Generate tool-call output. Encoder: query only. @@ -37,9 +50,12 @@ def generate(model, params, tokenizer, query, tools="[]", max_gen_len=512, seed= eos_id = tokenizer.eos_token_id tool_call_id = task_token_id if task_token_id is not None else tokenizer.tool_call_token_id + enc_ffn = ffn_mask["encoder"] if ffn_mask else None + dec_ffn = ffn_mask["decoder"] if ffn_mask else None + src_mask = make_padding_mask(enc_input, pad_id) encoder_out = model.apply( - {"params": params}, enc_input, src_mask=src_mask, method="encode" + {"params": params}, enc_input, src_mask=src_mask, ffn_mask=enc_ffn, method="encode" ) # Build decoder prefix: [BOS, , tools_tokens...] @@ -61,6 +77,7 @@ def decode_step(dec_buffer, encoder_out): dec_buffer, encoder_out, self_mask=tgt_mask, + ffn_mask=dec_ffn, method="decode", ) return logits @@ -111,7 +128,7 @@ def load_audio(path, target_sr=16000): return audio, sr -def generate_from_audio(model, params, tokenizer, audio_array, sr=16000, tools="[]", max_gen_len=512, seed=0, stream=True): +def generate_from_audio(model, params, tokenizer, audio_array, sr=16000, tools="[]", max_gen_len=512, seed=0, stream=True, ffn_mask=None): """Generate tool-call output from audio using the speech encoder pathway. mel -> encode_speech -> decoder [BOS, , tools_tokens...] -> greedy decode. @@ -125,9 +142,12 @@ def generate_from_audio(model, params, tokenizer, audio_array, sr=16000, tools=" eos_id = tokenizer.eos_token_id tool_call_id = tokenizer.tool_call_token_id + enc_ffn = ffn_mask["encoder"] if ffn_mask else None + dec_ffn = ffn_mask["decoder"] if ffn_mask else None + src_mask = make_mel_padding_mask(mel_input) encoder_out = model.apply( - {"params": params}, mel_input, src_mask=src_mask, deterministic=True, method="encode_speech" + {"params": params}, mel_input, src_mask=src_mask, ffn_mask=enc_ffn, deterministic=True, method="encode_speech" ) # Build decoder prefix: [BOS, , tools_tokens...] @@ -147,6 +167,7 @@ def decode_step(dec_buffer, encoder_out): dec_buffer, encoder_out, self_mask=tgt_mask, + ffn_mask=dec_ffn, method="decode", ) return logits @@ -183,13 +204,15 @@ def decode_step(dec_buffer, encoder_out): def main(args): print(f"Loading checkpoint: {args.checkpoint}") - params, config = load_checkpoint(args.checkpoint) + params, config, ffn_mask = load_checkpoint(args.checkpoint) model = EncoderDecoderTransformer(config) tokenizer = get_tokenizer() param_count = sum(x.size for x in jax.tree.leaves(params)) print(f"Model parameters: {param_count:,}") + if ffn_mask: + print(f"TopK sub-model: per-layer FFN masking active") # --- Voice-to-tool-call mode --- audio_files = getattr(args, "audio", None) @@ -200,15 +223,11 @@ def main(args): print(f"Tools: {tools[:80]}{'...' if len(tools) > 80 else ''}") audio, sr = load_audio(audio_path) generate_from_audio( - model, - params, - tokenizer, - audio, - sr=sr, - tools=tools, + model, params, tokenizer, audio, + sr=sr, tools=tools, max_gen_len=args.max_len, seed=args.seed + i, - stream=True, + stream=True, ffn_mask=ffn_mask, ) return @@ -229,14 +248,9 @@ def main(args): print(f"\nQuery: {q}") print(f"Tools: {t[:80]}{'...' if len(t) > 80 else ''}") generate( - model, - params, - tokenizer, - q, - tools=t, - max_gen_len=args.max_len, - seed=args.seed + i, - stream=True, + model, params, tokenizer, q, + tools=t, max_gen_len=args.max_len, + seed=args.seed + i, stream=True, ffn_mask=ffn_mask, ) diff --git a/src/test.py b/src/test.py index b454a97..1e0e894 100644 --- a/src/test.py +++ b/src/test.py @@ -10,17 +10,19 @@ from .data import get_batches, get_tokenizer, load_tool_calls, prepare_tool_call_pairs, load_tool_call_audio from .model import ( EncoderDecoderTransformer, - TransformerConfig, make_causal_mask, make_padding_mask, ) from .run import load_checkpoint -def compute_perplexity(model, params, enc_inputs, dec_inputs, dec_targets, batch_size, pad_id, loss_mask=None): +def compute_perplexity(model, params, enc_inputs, dec_inputs, dec_targets, batch_size, pad_id, loss_mask=None, ffn_mask=None): total_loss = 0.0 total_tokens = 0 + enc_ffn = ffn_mask["encoder"] if ffn_mask else None + dec_ffn = ffn_mask["decoder"] if ffn_mask else None + for batch in get_batches(enc_inputs, dec_inputs, dec_targets, batch_size, shuffle=False, loss_mask=loss_mask): if loss_mask is not None: src, tgt_in, tgt_out, lm = batch @@ -34,9 +36,12 @@ def compute_perplexity(model, params, enc_inputs, dec_inputs, dec_targets, batch src_mask = make_padding_mask(src, pad_id) tgt_mask = make_causal_mask(tgt_in.shape[1]) & make_padding_mask(tgt_in, pad_id) + encoder_out = model.apply( + {"params": params}, src, src_mask=src_mask, ffn_mask=enc_ffn, method="encode", + ) logits = model.apply( - {"params": params}, src, tgt_in, - src_mask=src_mask, tgt_mask=tgt_mask, + {"params": params}, tgt_in, encoder_out, + self_mask=tgt_mask, ffn_mask=dec_ffn, method="decode", ) loss = optax.softmax_cross_entropy_with_integer_labels(logits, tgt_out) @@ -47,7 +52,7 @@ def compute_perplexity(model, params, enc_inputs, dec_inputs, dec_targets, batch return math.exp(avg_nll) -def measure_throughput(model, params, tokenizer, num_runs=10, prompt='What is the weather?', max_gen_len=64): +def measure_throughput(model, params, tokenizer, num_runs=10, prompt='What is the weather?', max_gen_len=64, ffn_mask=None): enc_tokens = tokenizer.encode(prompt) enc_input = jnp.array([enc_tokens]) pad_id = tokenizer.pad_token_id @@ -56,16 +61,19 @@ def measure_throughput(model, params, tokenizer, num_runs=10, prompt='What is th src_mask = make_padding_mask(enc_input, pad_id) tgt_mask = make_causal_mask(max_gen_len) + enc_ffn = ffn_mask["encoder"] if ffn_mask else None + dec_ffn = ffn_mask["decoder"] if ffn_mask else None + @jax.jit def decode_step(dec_buffer, encoder_out): logits = model.apply( {"params": params}, dec_buffer, encoder_out, - self_mask=tgt_mask, method="decode", + self_mask=tgt_mask, ffn_mask=dec_ffn, method="decode", ) return logits encoder_out = model.apply( - {"params": params}, enc_input, src_mask=src_mask, method="encode" + {"params": params}, enc_input, src_mask=src_mask, ffn_mask=enc_ffn, method="encode" ) dec_buffer = jnp.full((1, max_gen_len), pad_id, dtype=jnp.int32) dec_buffer = dec_buffer.at[0, 0].set(eos_id) @@ -81,7 +89,7 @@ def decode_step(dec_buffer, encoder_out): start = time.perf_counter() encoder_out = model.apply( - {"params": params}, enc_input, src_mask=src_mask, method="encode" + {"params": params}, enc_input, src_mask=src_mask, ffn_mask=enc_ffn, method="encode" ) logits = decode_step(dec_buffer, encoder_out) @@ -112,70 +120,7 @@ def decode_step(dec_buffer, encoder_out): } -def compute_repetition_rate(texts): - bigram_rep_rates = [] - for text in texts: - words = text.lower().split() - if len(words) < 2: - bigram_rep_rates.append(0.0) - continue - bigrams = [(words[i], words[i + 1]) for i in range(len(words) - 1)] - unique = len(set(bigrams)) - bigram_rep_rates.append(1.0 - unique / len(bigrams)) - return float(np.mean(bigram_rep_rates)) - - -def benchmark_generation_quality(model, params, tokenizer, prompts, max_gen_len=128, temperature=0.8): - from .run import generate - - generations = [] - for i, prompt in enumerate(prompts): - text = generate(model, params, tokenizer, prompt, max_gen_len, temperature, seed=i, stream=False) - generations.append(text) - - lengths = [len(tokenizer.encode(t)) for t in generations] - rep_rate = compute_repetition_rate(generations) - - return { - "avg_generation_length": float(np.mean(lengths)), - "min_generation_length": int(np.min(lengths)), - "max_generation_length": int(np.max(lengths)), - "bigram_repetition_rate": rep_rate, - "generations": list(zip(prompts, generations)), - } - - -def compute_wer(hypotheses, references): - """Compute word error rate using edit distance.""" - total_edits = 0 - total_ref_words = 0 - - for hyp, ref in zip(hypotheses, references): - hyp_words = hyp.lower().split() - ref_words = ref.lower().split() - n = len(ref_words) - m = len(hyp_words) - - # DP edit distance - d = [[0] * (m + 1) for _ in range(n + 1)] - for i in range(n + 1): - d[i][0] = i - for j in range(m + 1): - d[0][j] = j - for i in range(1, n + 1): - for j in range(1, m + 1): - if ref_words[i - 1] == hyp_words[j - 1]: - d[i][j] = d[i - 1][j - 1] - else: - d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1]) - - total_edits += d[n][m] - total_ref_words += n - - return total_edits / max(total_ref_words, 1) - - -def benchmark_tool_calls(model, params, tokenizer, num_samples=200, max_gen_len=512): +def benchmark_tool_calls(model, params, tokenizer, num_samples=200, max_gen_len=512, ffn_mask=None): """Generate tool-call predictions and compute structured metrics.""" import json from .run import generate @@ -203,6 +148,7 @@ def benchmark_tool_calls(model, params, tokenizer, num_samples=200, max_gen_len= pred_text = generate( model, params, tokenizer, ex["query"], tools=ex["tools"], max_gen_len=max_gen_len, seed=i, stream=False, + ffn_mask=ffn_mask, ).strip() try: @@ -275,7 +221,7 @@ def call_key(c): } -def benchmark_voice_tool_calls(model, params, tokenizer, num_samples=100, max_gen_len=512): +def benchmark_voice_tool_calls(model, params, tokenizer, num_samples=100, max_gen_len=512, ffn_mask=None): """Generate tool-call predictions from audio and compute structured metrics.""" import json from .run import generate_from_audio @@ -303,6 +249,7 @@ def benchmark_voice_tool_calls(model, params, tokenizer, num_samples=100, max_ge pred_text = generate_from_audio( model, params, tokenizer, audio, sr=sr, tools=pair["tools"], max_gen_len=max_gen_len, seed=i, stream=False, + ffn_mask=ffn_mask, ).strip() try: @@ -376,7 +323,7 @@ def call_key(c): def main(args): - params, config = load_checkpoint(args.checkpoint) + params, config, ffn_mask = load_checkpoint(args.checkpoint) model = EncoderDecoderTransformer(config) tokenizer = get_tokenizer() @@ -384,19 +331,21 @@ def main(args): print(f"\ncheckpoint: {args.checkpoint}") print(f"parameters: {param_count:,}") print(f"config: d={config.d_model}, heads={config.num_heads}, layers={config.num_encoder_layers}/{config.num_decoder_layers}") + if ffn_mask: + print(f"sub-model: topk FFN masking active") print(f"\nevaluating tool-call perplexity ({args.max_eval_samples} samples)...") ds = load_tool_calls("validation", max_samples=args.max_eval_samples) enc_inputs, dec_inputs, dec_targets, loss_mask_arr = prepare_tool_call_pairs( ds, tokenizer, max_enc_len=args.max_enc_len, max_dec_len=args.max_dec_len ) - ppl = compute_perplexity(model, params, enc_inputs, dec_inputs, dec_targets, args.batch_size, config.pad_token_id, loss_mask=loss_mask_arr) + ppl = compute_perplexity(model, params, enc_inputs, dec_inputs, dec_targets, args.batch_size, config.pad_token_id, loss_mask=loss_mask_arr, ffn_mask=ffn_mask) tc_samples = getattr(args, "tool_call_samples", 200) tc = None if tc_samples > 0: print(f"\nevaluating tool-call accuracy ({tc_samples} samples)...") - tc = benchmark_tool_calls(model, params, tokenizer, num_samples=tc_samples, max_gen_len=args.max_gen_len) + tc = benchmark_tool_calls(model, params, tokenizer, num_samples=tc_samples, max_gen_len=args.max_gen_len, ffn_mask=ffn_mask) print(f"\n ─────────────────────────────────────") print(f" Tool-Call Metrics") @@ -426,7 +375,7 @@ def main(args): voice_tc_samples = getattr(args, "voice_tc_samples", 50) if voice_tc_samples > 0: print(f"\nevaluating voice-to-tool-call ({voice_tc_samples} samples)...") - vtc = benchmark_voice_tool_calls(model, params, tokenizer, num_samples=voice_tc_samples, max_gen_len=args.max_gen_len) + vtc = benchmark_voice_tool_calls(model, params, tokenizer, num_samples=voice_tc_samples, max_gen_len=args.max_gen_len, ffn_mask=ffn_mask) print(f"\n ─── Voice-Tool-Call Metrics ─────────") print(f" JSON parse rate {vtc['json_parse_rate']:>10.1%}") print(f" Exact match {vtc['exact_match']:>10.1%}") @@ -448,7 +397,7 @@ def main(args): print() print(f"\nmeasuring throughput ({args.throughput_runs} runs)...") - throughput = measure_throughput(model, params, tokenizer, num_runs=args.throughput_runs) + throughput = measure_throughput(model, params, tokenizer, num_runs=args.throughput_runs, ffn_mask=ffn_mask) print(f"avg tokens: {throughput['avg_tokens_generated']:.1f}") print(f"avg latency: {throughput['avg_latency_s']:.3f}s") print(f"tokens/sec: {throughput['tokens_per_second']:.1f}") diff --git a/src/train.py b/src/train.py index 9e14141..1e8a458 100644 --- a/src/train.py +++ b/src/train.py @@ -1,4 +1,3 @@ -import argparse import math import os import pickle @@ -224,8 +223,85 @@ def _maybe_quantize(path, leaf): _GROUP_SIZE = 32 _MAT_FACTORS = () -_MAT_FF_WIDTHS = () -_D_FF = 2048 +_MAT_FF_WIDTHS = () # precomputed d_ff widths per factor +_D_FF = 2048 # set in train() +_N_BLOCKS = 12 # num_encoder_layers + num_decoder_layers, set in train() +_MAT_SPREAD_LAMBDA = 0.01 +_MAT_GUMBEL = False + + +def topk_mask(logits, k, tau, hard): + """Differentiable top-k mask. logits: (d_ff,), k: int, tau/hard: JAX scalars. + + Learning phase (hard=False): returns soft sigmoid mask for gradient flow. + Freeze phase (hard=True): returns stop_gradient hard mask — no STE. + Ties at threshold may select slightly more than k neurons; this is negligible + and export uses exact argsort top-k for the final mask. + """ + if k >= logits.shape[0]: + return jnp.ones_like(logits) + topk_vals = jax.lax.top_k(logits, k)[0] + threshold = topk_vals[-1] + y_soft = jax.nn.sigmoid((logits - threshold) / tau) + y_hard = (y_soft >= 0.5).astype(y_soft.dtype) + frozen = jax.lax.stop_gradient(y_hard) + return jnp.where(hard, frozen, y_soft) + + +def _gumbel_sample(rng, shape): + """Sample from Gumbel(0, 1) distribution.""" + u = jax.random.uniform(rng, shape, minval=1e-20, maxval=1.0) + return -jnp.log(-jnp.log(u)) + + +def _make_ffn_mask_topk(batch_size, d_ff, mask_logits, mat_ff_widths, tau, hard, step_rng): + """Build (n_blocks, batch, d_ff) per-layer topk mask. + + mask_logits: (n_mat, n_blocks, d_ff). tau/hard: JAX scalars. + step_rng: PRNGKey for Gumbel noise (used only if _MAT_GUMBEL is True). + Returns (n_blocks, batch, d_ff) stacked mask. + """ + n_blocks = mask_logits.shape[1] + n_widths = 1 + len(mat_ff_widths) + per_width = batch_size // n_widths + remainder = batch_size - per_width * n_widths + + block_masks = [] + for b in range(n_blocks): + rows = [jnp.ones((per_width, d_ff), dtype=jnp.bfloat16)] + for i, ff_w in enumerate(mat_ff_widths): + logits_b = mask_logits[i, b] # (d_ff,) + if _MAT_GUMBEL: + sub_rng = jax.random.fold_in(step_rng, b * 1000 + i) + noise = _gumbel_sample(sub_rng, (per_width, d_ff)) + # Zero noise in hard/freeze mode so all items converge + noise = noise * (1.0 - hard.astype(jnp.float32)) + noisy = logits_b[None, :] + noise # (per_width, d_ff) + m = jax.vmap(lambda l: topk_mask(l, k=ff_w, tau=tau, hard=hard))(noisy) + else: + m = topk_mask(logits_b, k=ff_w, tau=tau, hard=hard) + m = jnp.broadcast_to(m[None, :], (per_width, d_ff)) + rows.append(m.astype(jnp.bfloat16)) + if remainder > 0: + rows.append(jnp.ones((remainder, d_ff), dtype=jnp.bfloat16)) + block_masks.append(jnp.concatenate(rows, axis=0)) + return jnp.stack(block_masks) # (n_blocks, batch, d_ff) + + +def _compute_ce(logits, tgt_out, slot_div, loss_mask=None): + """Shared CE + z-loss + slot-div computation.""" + pad_id = 0 + logits_f32 = logits.astype(jnp.float32) + if loss_mask is not None: + mask = loss_mask + else: + mask = (tgt_out != pad_id).astype(jnp.float32) + ce_loss = jnp.sum( + optax.softmax_cross_entropy_with_integer_labels(logits_f32, tgt_out) * mask + ) / jnp.maximum(jnp.sum(mask), 1.0) + z_loss = 1e-4 * jnp.mean(jax.nn.logsumexp(logits_f32, axis=-1) ** 2) + div_loss = 1e-4 * slot_div + return ce_loss + z_loss + div_loss def _text_loss_fn(state, params, src, tgt_in, tgt_out, causal_mask, ffn_mask, loss_mask): @@ -238,14 +314,7 @@ def _text_loss_fn(state, params, src, tgt_in, tgt_out, causal_mask, ffn_mask, lo ffn_mask=ffn_mask, method="forward_masked", ) - logits_f32 = logits.astype(jnp.float32) - mask = loss_mask - ce_loss = jnp.sum( - optax.softmax_cross_entropy_with_integer_labels(logits_f32, tgt_out) * mask - ) / jnp.maximum(jnp.sum(mask), 1.0) - z_loss = 1e-4 * jnp.mean(jax.nn.logsumexp(logits_f32, axis=-1) ** 2) - div_loss = 1e-4 * slot_div - return ce_loss + z_loss + div_loss + return _compute_ce(logits, tgt_out, slot_div, loss_mask=loss_mask) def _speech_loss_fn(state, params, mel, tgt_in, tgt_out, causal_mask, ffn_mask, rng, loss_mask): @@ -260,14 +329,58 @@ def _speech_loss_fn(state, params, mel, tgt_in, tgt_out, causal_mask, ffn_mask, method="forward_speech_masked", rngs={"specaugment": rng}, ) - logits_f32 = logits.astype(jnp.float32) - mask = loss_mask - ce_loss = jnp.sum( - optax.softmax_cross_entropy_with_integer_labels(logits_f32, tgt_out) * mask - ) / jnp.maximum(jnp.sum(mask), 1.0) - z_loss = 1e-4 * jnp.mean(jax.nn.logsumexp(logits_f32, axis=-1) ** 2) - div_loss = 1e-4 * slot_div - return ce_loss + z_loss + div_loss + return _compute_ce(logits, tgt_out, slot_div, loss_mask=loss_mask) + + +def _topk_loss(state, params, mask_logits, src, tgt_in, tgt_out, causal_mask, + tau, hard, step_rng, is_speech=False, spec_rng=None, loss_mask=None): + """Topk loss for text or speech. Builds masks inside for gradient flow.""" + ffn_mask = _make_ffn_mask_topk(src.shape[0], _D_FF, mask_logits, _MAT_FF_WIDTHS, tau, hard, step_rng) + if is_speech: + src_mask = make_mel_padding_mask(src) + logits, slot_div = state.apply_fn( + {"params": _quantize_params(params, group_size=_GROUP_SIZE)}, + src, tgt_in, src_mask=src_mask, + tgt_mask=causal_mask & make_padding_mask(tgt_in, 0), + ffn_mask=ffn_mask, deterministic=False, + method="forward_speech_masked", rngs={"specaugment": spec_rng}, + ) + else: + src_mask = make_padding_mask(src, 0) + logits, slot_div = state.apply_fn( + {"params": _quantize_params(params, group_size=_GROUP_SIZE)}, + src, tgt_in, src_mask=src_mask, + tgt_mask=causal_mask & make_padding_mask(tgt_in, 0), + ffn_mask=ffn_mask, method="forward_masked", + ) + loss = _compute_ce(logits, tgt_out, slot_div, loss_mask=loss_mask) + if _MAT_SPREAD_LAMBDA > 0: + spread = jnp.mean(jnp.var(mask_logits, axis=-1)) + loss = loss - _MAT_SPREAD_LAMBDA * spread + return loss + + +def _warmup_loss(state, params, src, tgt_in, tgt_out, causal_mask, is_speech=False, spec_rng=None, loss_mask=None): + """Full-model-only loss (no matryoshka) for topk warmup phase.""" + ffn_mask = jnp.ones((src.shape[0], _D_FF), dtype=jnp.bfloat16) + if is_speech: + src_mask = make_mel_padding_mask(src) + logits, slot_div = state.apply_fn( + {"params": _quantize_params(params, group_size=_GROUP_SIZE)}, + src, tgt_in, src_mask=src_mask, + tgt_mask=causal_mask & make_padding_mask(tgt_in, 0), + ffn_mask=ffn_mask, deterministic=False, + method="forward_speech_masked", rngs={"specaugment": spec_rng}, + ) + else: + src_mask = make_padding_mask(src, 0) + logits, slot_div = state.apply_fn( + {"params": _quantize_params(params, group_size=_GROUP_SIZE)}, + src, tgt_in, src_mask=src_mask, + tgt_mask=causal_mask & make_padding_mask(tgt_in, 0), + ffn_mask=ffn_mask, method="forward_masked", + ) + return _compute_ce(logits, tgt_out, slot_div, loss_mask=loss_mask) def _make_ffn_mask(batch_size, d_ff, mat_ff_widths): @@ -288,62 +401,137 @@ def _make_ffn_mask(batch_size, d_ff, mat_ff_widths): return jnp.concatenate(rows, axis=0) -def _train_step_text(state, ema_params, src, tgt_in, tgt_out, causal_mask, ffn_mask, loss_mask): - ema_decay = 0.999 - loss, grads = jax.value_and_grad( - lambda p: _text_loss_fn(state, p, src, tgt_in, tgt_out, causal_mask, ffn_mask, loss_mask) - )(state.params) +def _grad_step(state, ema_params, loss_fn, prune_mask=None): + """Shared body for all non-topk train steps: grad, pmean, apply, ema.""" + loss, grads = jax.value_and_grad(loss_fn)(state.params) grads = jax.lax.pmean(grads, axis_name="batch") loss = jax.lax.pmean(loss, axis_name="batch") - grad_norm = optax.global_norm(grads) - state = state.apply_gradients(grads=grads) - ema_params = jax.tree.map(lambda e, p: ema_decay * e + (1 - ema_decay) * p, ema_params, state.params) - return state, ema_params, loss, grad_norm + state, ema_params = _apply_and_ema(state, ema_params, grads, prune_mask) + return state, ema_params, loss, optax.global_norm(grads) + + +def _train_step_text(state, ema_params, src, tgt_in, tgt_out, causal_mask, ffn_mask, loss_mask): + return _grad_step(state, ema_params, + lambda p: _text_loss_fn(state, p, src, tgt_in, tgt_out, causal_mask, ffn_mask, loss_mask)) def _train_step_text_masked(state, ema_params, src, tgt_in, tgt_out, causal_mask, prune_mask, ffn_mask, loss_mask): - """Text training step with fused prune mask application.""" - ema_decay = 0.999 - loss, grads = jax.value_and_grad( - lambda p: _text_loss_fn(state, p, src, tgt_in, tgt_out, causal_mask, ffn_mask, loss_mask) - )(state.params) - grads = jax.lax.pmean(grads, axis_name="batch") - loss = jax.lax.pmean(loss, axis_name="batch") - grad_norm = optax.global_norm(grads) - state = state.apply_gradients(grads=grads) - masked_params = jax.tree.map(lambda w, m: w * m, state.params, prune_mask) - state = state.replace(params=masked_params) - ema_params = jax.tree.map(lambda e, p: ema_decay * e + (1 - ema_decay) * p, ema_params, masked_params) - return state, ema_params, loss, grad_norm + return _grad_step(state, ema_params, + lambda p: _text_loss_fn(state, p, src, tgt_in, tgt_out, causal_mask, ffn_mask, loss_mask), + prune_mask=prune_mask) def _train_step_speech(state, ema_params, mel, tgt_in, tgt_out, causal_mask, ffn_mask, rng, loss_mask): - ema_decay = 0.999 + return _grad_step(state, ema_params, + lambda p: _speech_loss_fn(state, p, mel, tgt_in, tgt_out, causal_mask, ffn_mask, rng, loss_mask)) + + +def _train_step_speech_masked(state, ema_params, mel, tgt_in, tgt_out, causal_mask, prune_mask, ffn_mask, rng, loss_mask): + return _grad_step(state, ema_params, + lambda p: _speech_loss_fn(state, p, mel, tgt_in, tgt_out, causal_mask, ffn_mask, rng, loss_mask), + prune_mask=prune_mask) + + +def _apply_and_ema(state, ema_params, grads, prune_mask=None): + """Apply gradients, optional prune mask, and EMA update. Returns (state, ema).""" + state = state.apply_gradients(grads=grads) + if prune_mask is not None: + params = jax.tree.map(lambda w, m: w * m, state.params, prune_mask) + state = state.replace(params=params) + else: + params = state.params + ema = jax.tree.map(lambda e, p: 0.999 * e + 0.001 * p, ema_params, params) + return state, ema + + +def _train_step_text_warmup(state, ema_params, src, tgt_in, tgt_out, causal_mask, loss_mask): + """Text warmup step (no matryoshka). Returns grads for saliency accumulation.""" loss, grads = jax.value_and_grad( - lambda p: _speech_loss_fn(state, p, mel, tgt_in, tgt_out, causal_mask, ffn_mask, rng, loss_mask) + lambda p: _warmup_loss(state, p, src, tgt_in, tgt_out, causal_mask, loss_mask=loss_mask) )(state.params) grads = jax.lax.pmean(grads, axis_name="batch") loss = jax.lax.pmean(loss, axis_name="batch") - grad_norm = optax.global_norm(grads) - state = state.apply_gradients(grads=grads) - ema_params = jax.tree.map(lambda e, p: ema_decay * e + (1 - ema_decay) * p, ema_params, state.params) - return state, ema_params, loss, grad_norm + state, ema_params = _apply_and_ema(state, ema_params, grads) + return state, ema_params, loss, optax.global_norm(grads), grads -def _train_step_speech_masked(state, ema_params, mel, tgt_in, tgt_out, causal_mask, prune_mask, ffn_mask, rng, loss_mask): - """Speech training step with fused prune mask application.""" - ema_decay = 0.999 +def _train_step_speech_warmup(state, ema_params, mel, tgt_in, tgt_out, causal_mask, spec_rng, loss_mask): + """Speech warmup step (no matryoshka).""" loss, grads = jax.value_and_grad( - lambda p: _speech_loss_fn(state, p, mel, tgt_in, tgt_out, causal_mask, ffn_mask, rng, loss_mask) + lambda p: _warmup_loss(state, p, mel, tgt_in, tgt_out, causal_mask, is_speech=True, spec_rng=spec_rng, loss_mask=loss_mask) )(state.params) grads = jax.lax.pmean(grads, axis_name="batch") loss = jax.lax.pmean(loss, axis_name="batch") - grad_norm = optax.global_norm(grads) - state = state.apply_gradients(grads=grads) - masked_params = jax.tree.map(lambda w, m: w * m, state.params, prune_mask) - state = state.replace(params=masked_params) - ema_params = jax.tree.map(lambda e, p: ema_decay * e + (1 - ema_decay) * p, ema_params, masked_params) - return state, ema_params, loss, grad_norm + state, ema_params = _apply_and_ema(state, ema_params, grads) + return state, ema_params, loss, optax.global_norm(grads), grads + + +def _topk_grad_step(state, ema_params, mask_logits, loss_fn, prune_mask=None): + """Shared body for all topk train steps: grad, pmean, apply, ema.""" + (loss, (p_grads, ml_grads)) = jax.value_and_grad(loss_fn, argnums=(0, 1))(state.params, mask_logits) + p_grads = jax.lax.pmean(p_grads, axis_name="batch") + ml_grads = jax.lax.pmean(ml_grads, axis_name="batch") + loss = jax.lax.pmean(loss, axis_name="batch") + state, ema_params = _apply_and_ema(state, ema_params, p_grads, prune_mask) + return state, ema_params, ml_grads, loss, optax.global_norm(p_grads) + + +def _train_step_text_topk(state, ema_params, mask_logits, src, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, loss_mask): + return _topk_grad_step(state, ema_params, mask_logits, + lambda p, ml: _topk_loss(state, p, ml, src, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, loss_mask=loss_mask)) + + +def _train_step_text_topk_masked(state, ema_params, mask_logits, src, tgt_in, tgt_out, causal_mask, prune_mask, tau, hard, step_rng, loss_mask): + return _topk_grad_step(state, ema_params, mask_logits, + lambda p, ml: _topk_loss(state, p, ml, src, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, loss_mask=loss_mask), + prune_mask=prune_mask) + + +def _train_step_speech_topk(state, ema_params, mask_logits, mel, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, spec_rng, loss_mask): + return _topk_grad_step(state, ema_params, mask_logits, + lambda p, ml: _topk_loss(state, p, ml, mel, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, is_speech=True, spec_rng=spec_rng, loss_mask=loss_mask)) + + +def _train_step_speech_topk_masked(state, ema_params, mask_logits, mel, tgt_in, tgt_out, causal_mask, prune_mask, tau, hard, step_rng, spec_rng, loss_mask): + return _topk_grad_step(state, ema_params, mask_logits, + lambda p, ml: _topk_loss(state, p, ml, mel, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, is_speech=True, spec_rng=spec_rng, loss_mask=loss_mask), + prune_mask=prune_mask) + + +def _extract_ffn_saliency(grads, d_ff, n_enc, n_dec): + """Extract per-layer per-FFN-neuron saliency from param gradients. + + Returns (n_blocks, d_ff) array of neuron importance scores per block. + """ + n_blocks = n_enc + n_dec + saliency = np.zeros((n_blocks, d_ff), dtype=np.float32) + for path, leaf in jax.tree_util.tree_leaves_with_path(grads): + path_str = "/".join(p.key if hasattr(p, "key") else str(p) for p in path) + if "down_proj" not in path_str or "kernel" not in path_str or leaf.ndim != 2: + continue + g = np.array(leaf) + if g.shape[0] != d_ff: + continue + neuron_sal = np.sum(g ** 2, axis=1) + for part in path: + name = part.key if hasattr(part, "key") else str(part) + if name.startswith("block_"): + block_idx = int(name.split("_")[1]) + if "encoder" in path_str: + saliency[block_idx] += neuron_sal + elif "decoder" in path_str: + saliency[n_enc + block_idx] += neuron_sal + break + return saliency + + +def _update_mask_logits(ml_grads, mask_logits, mask_tx, mask_opt_state): + """Host-side mask logit optimizer step. Returns (updated mask_logits, opt_state).""" + ml_grads_np = np.array(jax_utils.unreplicate(ml_grads)) + ml_np = np.array(jax_utils.unreplicate(mask_logits)) + updates, mask_opt_state = mask_tx.update(ml_grads_np, mask_opt_state, ml_np) + ml_np = optax.apply_updates(ml_np, updates) + return jax_utils.replicate(jnp.array(ml_np)), mask_opt_state def _make_p_train_step(): @@ -362,6 +550,22 @@ def _make_p_train_step_speech_masked(): return jax.pmap(_train_step_speech_masked, axis_name="batch", donate_argnums=(0, 1)) +def _make_p_train_step_topk(): + return jax.pmap(_train_step_text_topk, axis_name="batch", donate_argnums=(0, 1)) + + +def _make_p_train_step_topk_masked(): + return jax.pmap(_train_step_text_topk_masked, axis_name="batch", donate_argnums=(0, 1)) + + +def _make_p_train_step_speech_topk(): + return jax.pmap(_train_step_speech_topk, axis_name="batch", donate_argnums=(0, 1)) + + +def _make_p_train_step_speech_topk_masked(): + return jax.pmap(_train_step_speech_topk_masked, axis_name="batch", donate_argnums=(0, 1)) + + def _make_val_loss_fn(apply_fn): @jax.jit def val_loss_batch(params, src, tgt_in, tgt_out, causal_mask, loss_mask): @@ -377,18 +581,23 @@ def val_loss_batch(params, src, tgt_in, tgt_out, causal_mask, loss_mask): return val_loss_batch -def _make_mat_val_loss_fn(apply_fn, ff_width): - """Val loss for matryoshka sub-model at given FFN width.""" +def _make_mat_val_loss_fn(apply_fn, ff_width=None, ffn_mask=None): + """Val loss for matryoshka sub-model at given FFN width or with a topk ffn_mask.""" @jax.jit def val_loss_batch(params, src, tgt_in, tgt_out, causal_mask, loss_mask): pad_id = 0 src_mask = make_padding_mask(src, pad_id) tgt_mask = causal_mask & make_padding_mask(tgt_in, pad_id) + kwargs = {} + if ffn_mask is not None: + kwargs["mat_ffn_masks"] = [ffn_mask] + else: + kwargs["mat_ff_widths"] = (ff_width,) logits, _, mat_logits = apply_fn( {"params": params}, src, tgt_in, src_mask=src_mask, tgt_mask=tgt_mask, - mat_ff_widths=(ff_width,), method="forward_with_aux", + **kwargs, ) trunc_logits = mat_logits[0].astype(jnp.float32) loss = optax.softmax_cross_entropy_with_integer_labels(trunc_logits, tgt_out) @@ -541,9 +750,10 @@ def train(args): n_mels=n_mels, ) - global _GROUP_SIZE, _MAT_FACTORS, _MAT_FF_WIDTHS, _D_FF + global _GROUP_SIZE, _MAT_FACTORS, _MAT_FF_WIDTHS, _D_FF, _N_BLOCKS, _MAT_SPREAD_LAMBDA, _MAT_GUMBEL _GROUP_SIZE = getattr(args, "group_size", 32) _D_FF = config.d_ff + _N_BLOCKS = config.num_encoder_layers + config.num_decoder_layers mat_factors_raw = getattr(args, "mat_factors", None) if mat_factors_raw: _MAT_FACTORS = tuple(f for f in mat_factors_raw if f > 1) @@ -552,6 +762,25 @@ def train(args): _MAT_FACTORS = () _MAT_FF_WIDTHS = () n_widths = 1 + len(_MAT_FF_WIDTHS) if _MAT_FF_WIDTHS else 1 + + # Fallback defaults here are for programmatic callers; CLI defaults are in cli.py + mat_method = getattr(args, "mat_method", "static-prefix") + use_topk = mat_method == "topk" and _MAT_FF_WIDTHS + mat_tau_start = getattr(args, "mat_tau_start", 0.5) + mat_tau_end = getattr(args, "mat_tau_end", 0.1) + mat_warmup_frac = getattr(args, "mat_warmup_frac", 0.15) + mat_freeze_frac = getattr(args, "mat_freeze_frac", 0.2) + # These globals are read inside pmap-traced functions; must be set before pmap creation below + _MAT_SPREAD_LAMBDA = getattr(args, "mat_spread_lambda", 0.01) + _MAT_GUMBEL = getattr(args, "mat_gumbel", False) + + if use_topk: + p_train_step_warmup = jax.pmap(_train_step_text_warmup, axis_name="batch", donate_argnums=(0, 1)) + p_train_step_warmup_speech = jax.pmap(_train_step_speech_warmup, axis_name="batch", donate_argnums=(0, 1)) + p_train_step_topk = _make_p_train_step_topk() + p_train_step_topk_masked = _make_p_train_step_topk_masked() + p_train_step_speech_topk = _make_p_train_step_speech_topk() + p_train_step_speech_topk_masked = _make_p_train_step_speech_topk_masked() p_train_step = _make_p_train_step() p_train_step_masked = _make_p_train_step_masked() p_train_step_speech = _make_p_train_step_speech() @@ -562,6 +791,10 @@ def train(args): rng, init_rng = jax.random.split(rng) mat_shared_input = getattr(args, "mat_shared_input", False) + if use_topk and mat_shared_input: + raise ValueError("--mat-shared-input is incompatible with --mat-method topk") + # With shared input, each unique sample is repeated n_widths times, + # so we fetch smaller batches but take more steps per epoch. unique_batch_size = effective_batch_size // n_widths if (mat_shared_input and n_widths > 1) else effective_batch_size text_batches_per_epoch = len(enc_inputs) // unique_batch_size if not no_speech and speech_audio_arrays is not None: @@ -583,8 +816,55 @@ def train(args): print(f" Loaded checkpoint params into train state") ema_params = jax.tree.map(jnp.copy, state.params) + + # --- TopK mask logit init --- + mask_logits = None + mask_opt_state = None + mask_tx = None + use_saliency = False + saliency_accum = None + saliency_steps = 0 + if use_topk: + n_mat = len(_MAT_FF_WIDTHS) + n_blocks = _N_BLOCKS + init_mode = getattr(args, "mat_init_mode", "normal") + init_value = getattr(args, "mat_init_value", 0.5) + saliency_scale = getattr(args, "mat_saliency_scale", 1.0) + rng, mask_rng = jax.random.split(rng) + d_ff = config.d_ff + use_saliency = init_mode == "saliency" + if init_mode == "prefix": + positions = jnp.arange(d_ff, dtype=jnp.float32) + ramp = init_value * (1.0 - 2.0 * positions / max(1, d_ff - 1)) + mask_logits = jnp.broadcast_to(ramp[None, None, :], (n_mat, n_blocks, d_ff)).copy() + elif init_mode == "shuffled_prefix": + positions = jnp.arange(d_ff, dtype=jnp.float32) + ramp = init_value * (1.0 - 2.0 * positions / max(1, d_ff - 1)) + rows = [] + for i in range(n_mat * n_blocks): + rng, perm_rng = jax.random.split(rng) + perm = jax.random.permutation(perm_rng, d_ff) + rows.append(ramp[perm]) + mask_logits = jnp.stack(rows).reshape(n_mat, n_blocks, d_ff) + elif init_mode == "saliency": + mask_logits = jnp.zeros((n_mat, n_blocks, d_ff)) + saliency_accum = np.zeros((n_blocks, d_ff), dtype=np.float32) + elif init_mode == "normal": + mask_logits = jax.random.normal(mask_rng, (n_mat, n_blocks, d_ff)) * init_value + else: + mask_logits = jnp.zeros((n_mat, n_blocks, d_ff)) + mask_lr = getattr(args, "mat_mask_lr", 3e-3) + mask_tx = optax.adam(learning_rate=mask_lr) + mask_opt_state = mask_tx.init(np.array(mask_logits)) + if resume_checkpoint and "mask_logits" in ckpt_data: + mask_logits = jnp.array(ckpt_data["mask_logits"]) + use_saliency = False + print(f" Loaded mask logits from checkpoint") + state = jax_utils.replicate(state) ema_params = jax_utils.replicate(ema_params) + if use_topk: + mask_logits = jax_utils.replicate(mask_logits) param_count = sum(x.size for x in jax.tree.leaves(jax_utils.unreplicate(state).params)) decay_steps = max(1, int(total_steps * 0.15)) @@ -604,6 +884,16 @@ def train(args): print(f" max_mel_len {max_mel_len:>12}") else: print(f" Speech disabled") + if use_topk: + print(f" Mat method topk (learned)") + print(f" Mat tau {mat_tau_start:.2f} -> {mat_tau_end:.2f}") + print(f" Mat warmup {mat_warmup_frac*100:.0f}% / freeze {mat_freeze_frac*100:.0f}%") + print(f" Mat spread λ {_MAT_SPREAD_LAMBDA}") + print(f" Mat per-layer {n_blocks} blocks") + if _MAT_GUMBEL: + print(f" Mat Gumbel enabled") + else: + print(f" Mat method static-prefix") print(f" ─────────────────────────────────────") print(f" Devices {num_devices:>12}") print(f" Batch {args.batch_size:>7} x {num_devices} = {effective_batch_size}") @@ -651,6 +941,9 @@ def train(args): prune_end_frac = getattr(args, "prune_end_frac", 0.67) weight_prune_epoch = 0 if sparsity_ratio > 0 else -1 + mat_warmup_end = int(total_steps * mat_warmup_frac) if use_topk else 0 + # freeze_frac=1.0 → freeze_start=0 → no learning phase (saliency-only mode, intentional) + mat_freeze_start = int(total_steps * (1 - mat_freeze_frac)) if use_topk else 0 for epoch in range(args.epochs): if epoch == weight_prune_epoch and not gradual_sparsify_done: @@ -678,6 +971,19 @@ def train(args): pbar = tqdm(range(steps_this_epoch), desc=f"Epoch {epoch + 1}/{args.epochs}") for step_i in pbar: + # --- TopK phase tracking --- + topk_active = False + cur_tau = mat_tau_start + cur_hard = False + if use_topk: + topk_active = global_step >= mat_warmup_end + cur_hard = global_step >= mat_freeze_start + if topk_active and not cur_hard: + progress = min(1.0, (global_step - mat_warmup_end) / max(1, mat_freeze_start - mat_warmup_end)) + cur_tau = mat_tau_start * (mat_tau_end / mat_tau_start) ** progress + elif cur_hard: + cur_tau = 0.001 + t0 = time.perf_counter() do_speech = (step_i % 2 == 1) and speech_idx < len(speech_batch_list) @@ -697,7 +1003,7 @@ def train(args): src, tgt_in, tgt_out, lm = text_batches[text_idx] text_idx += 1 - if n_widths > 1 and mat_shared_input: + if not use_topk and n_widths > 1 and mat_shared_input: per_width = args.batch_size // n_widths def _tile_for_mat(arr): s = arr.reshape(num_devices, per_width, *arr.shape[1:]) @@ -712,14 +1018,41 @@ def _tile_for_mat(arr): tgt_out_b = shard_batch(tgt_out, num_devices) lm_b = shard_batch(lm, num_devices) - if prune_mask is not None: - state, ema_params, loss, grad_norm = p_train_step_masked( - state, ema_params, src_b, tgt_in_b, tgt_out_b, causal_mask, prune_mask, text_ffn_mask, lm_b, + if topk_active: + tau_arr = jax_utils.replicate(jnp.float32(cur_tau)) + hard_arr = jax_utils.replicate(jnp.bool_(cur_hard)) + rng, step_rng = jax.random.split(rng) + step_rngs = jax.random.split(step_rng, num_devices) + if prune_mask is not None: + state, ema_params, ml_grads, loss, grad_norm = p_train_step_topk_masked( + state, ema_params, mask_logits, src_b, tgt_in_b, tgt_out_b, causal_mask, prune_mask, tau_arr, hard_arr, step_rngs, lm_b, + ) + else: + state, ema_params, ml_grads, loss, grad_norm = p_train_step_topk( + state, ema_params, mask_logits, src_b, tgt_in_b, tgt_out_b, causal_mask, tau_arr, hard_arr, step_rngs, lm_b, + ) + if not cur_hard: + mask_logits, mask_opt_state = _update_mask_logits(ml_grads, mask_logits, mask_tx, mask_opt_state) + elif use_topk and not topk_active: + state, ema_params, loss, grad_norm, warmup_grads = p_train_step_warmup( + state, ema_params, src_b, tgt_in_b, tgt_out_b, causal_mask, lm_b, ) + if use_saliency and saliency_accum is not None: + grads_unr = jax_utils.unreplicate(warmup_grads) + step_saliency = _extract_ffn_saliency(grads_unr, config.d_ff, config.num_encoder_layers, config.num_decoder_layers) + saliency_steps += 1 + beta = 0.99 + saliency_accum = beta * saliency_accum + (1 - beta) * step_saliency + del grads_unr else: - state, ema_params, loss, grad_norm = p_train_step( - state, ema_params, src_b, tgt_in_b, tgt_out_b, causal_mask, text_ffn_mask, lm_b, - ) + if prune_mask is not None: + state, ema_params, loss, grad_norm = p_train_step_masked( + state, ema_params, src_b, tgt_in_b, tgt_out_b, causal_mask, prune_mask, text_ffn_mask, lm_b, + ) + else: + state, ema_params, loss, grad_norm = p_train_step( + state, ema_params, src_b, tgt_in_b, tgt_out_b, causal_mask, text_ffn_mask, lm_b, + ) text_loss_val = float(loss[0]) text_losses.append(text_loss_val) @@ -730,7 +1063,7 @@ def _tile_for_mat(arr): mel_batch, sp_tgt_in, sp_tgt_out, sp_lm = speech_batch_list[speech_idx] speech_idx += 1 - if n_widths > 1 and mat_shared_input: + if not use_topk and n_widths > 1 and mat_shared_input: per_width = args.batch_size // n_widths def _tile_sp(arr): s = arr.reshape(num_devices, per_width, *arr.shape[1:]) @@ -745,23 +1078,60 @@ def _tile_sp(arr): sp_tgt_out_b = shard_batch(sp_tgt_out, num_devices) sp_lm_b = shard_batch(sp_lm, num_devices) - rng, spec_rng = jax.random.split(rng) - spec_rngs = jax.random.split(spec_rng, num_devices) - - if prune_mask is not None: - state, ema_params, sp_loss, sp_grad_norm = p_train_step_speech_masked( - state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, prune_mask, text_ffn_mask, spec_rngs, sp_lm_b, + if topk_active: + tau_arr = jax_utils.replicate(jnp.float32(cur_tau)) + hard_arr = jax_utils.replicate(jnp.bool_(cur_hard)) + rng, step_rng, spec_rng = jax.random.split(rng, 3) + step_rngs = jax.random.split(step_rng, num_devices) + spec_rngs = jax.random.split(spec_rng, num_devices) + if prune_mask is not None: + state, ema_params, ml_grads, sp_loss, sp_grad_norm = p_train_step_speech_topk_masked( + state, ema_params, mask_logits, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, prune_mask, tau_arr, hard_arr, step_rngs, spec_rngs, sp_lm_b, + ) + else: + state, ema_params, ml_grads, sp_loss, sp_grad_norm = p_train_step_speech_topk( + state, ema_params, mask_logits, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, tau_arr, hard_arr, step_rngs, spec_rngs, sp_lm_b, + ) + if not cur_hard: + mask_logits, mask_opt_state = _update_mask_logits(ml_grads, mask_logits, mask_tx, mask_opt_state) + elif use_topk and not topk_active: + # Speech warmup grads not used for saliency, text is the primary task + rng, spec_rng = jax.random.split(rng) + spec_rngs = jax.random.split(spec_rng, num_devices) + state, ema_params, sp_loss, sp_grad_norm, _ = p_train_step_warmup_speech( + state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, spec_rngs, sp_lm_b, ) else: - state, ema_params, sp_loss, sp_grad_norm = p_train_step_speech( - state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, text_ffn_mask, spec_rngs, sp_lm_b, - ) + speech_ffn_mask = text_ffn_mask + rng, spec_rng = jax.random.split(rng) + spec_rngs = jax.random.split(spec_rng, num_devices) + if prune_mask is not None: + state, ema_params, sp_loss, sp_grad_norm = p_train_step_speech_masked( + state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, prune_mask, speech_ffn_mask, spec_rngs, sp_lm_b, + ) + else: + state, ema_params, sp_loss, sp_grad_norm = p_train_step_speech( + state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, speech_ffn_mask, spec_rngs, sp_lm_b, + ) speech_loss_val = float(sp_loss[0]) speech_losses.append(speech_loss_val) step_grad_norm = float(sp_grad_norm[0]) text_loss_val = text_losses[-1] if text_losses else float("nan") global_step += 1 + # Saliency init: promote accumulated saliency into mask_logits at warmup end + if use_saliency and saliency_accum is not None and global_step >= mat_warmup_end: + print(f"\n Saliency init: {saliency_steps} warmup steps accumulated") + logits_np = np.zeros_like(saliency_accum) + for bl in range(n_blocks): + ranks = np.argsort(np.argsort(-saliency_accum[bl])).astype(np.float32) + logits_np[bl] = saliency_scale * (1.0 - 2.0 * ranks / max(1, config.d_ff - 1)) + mask_logits_np = np.broadcast_to(logits_np[None, :, :], (n_mat, n_blocks, config.d_ff)).copy() + mask_logits = jax_utils.replicate(jnp.array(mask_logits_np)) + mask_opt_state = mask_tx.init(mask_logits_np) + saliency_accum = None + print(f" Saliency logit range: [{logits_np.min():.3f}, {logits_np.max():.3f}]") + if epoch == weight_prune_epoch and not gradual_sparsify_done: epoch_step += 1 current_sparsity = _cubic_sparsity_schedule(epoch_step, t_start, t_end, sparsity_ratio) @@ -796,6 +1166,9 @@ def _tile_sp(arr): postfix["sparsification"] = f"{current_sparsity*100:.1f}%" else: postfix["sparsification"] = "done" + if use_topk and topk_active: + phase = "freeze" if cur_hard else "learn" + postfix["mat"] = f"{phase} τ={cur_tau:.3f}" pbar.set_postfix(**postfix) if use_wandb: @@ -811,6 +1184,9 @@ def _tile_sp(arr): log_dict["train/speech_loss"] = speech_loss_val if epoch == weight_prune_epoch and not gradual_sparsify_done: log_dict["train/scheduled_sparsity"] = current_sparsity + if use_topk: + log_dict["train/mat_tau"] = cur_tau + log_dict["train/mat_topk_active"] = int(topk_active) if global_step % eval_every == 0 or global_step == total_steps: log_dict["val/text_ppl"] = last_val_ppl wandb.log(log_dict) @@ -870,8 +1246,21 @@ def _tile_sp(arr): mat_results = {} if _MAT_FACTORS: apply_fn = jax_utils.unreplicate(state).apply_fn + topk_hard_masks = {} + if use_topk: + ml_unr = jax_utils.unreplicate(mask_logits) + n_blocks_eval = ml_unr.shape[1] + for i, ff_w in enumerate(_MAT_FF_WIDTHS): + # Per-layer hard masks: (n_blocks, d_ff) + block_masks = [] + for b in range(n_blocks_eval): + block_masks.append(topk_mask(ml_unr[i, b], k=ff_w, tau=jnp.float32(0.001), hard=jnp.bool_(True))) + topk_hard_masks[ff_w] = jnp.stack(block_masks) # (n_blocks, d_ff) for factor, ff_w in zip(_MAT_FACTORS, _MAT_FF_WIDTHS): - mat_vl_fn = _make_mat_val_loss_fn(apply_fn, ff_w) + if ff_w in topk_hard_masks: + mat_vl_fn = _make_mat_val_loss_fn(apply_fn, ffn_mask=topk_hard_masks[ff_w]) + else: + mat_vl_fn = _make_mat_val_loss_fn(apply_fn, ff_width=ff_w) mat_total_loss, mat_total_toks = 0.0, 0.0 for vb in get_batches(val_enc, val_dec_in, val_dec_tgt, args.batch_size, shuffle=False, loss_mask=val_loss_mask): vl, vt = mat_vl_fn(eval_params, vb[0], vb[1], vb[2], val_causal, vb[3]) @@ -892,16 +1281,34 @@ def _tile_sp(arr): ckpt_name = f"needle_{args.num_layers}_{args.d_model}_{global_step}.pkl" ckpt_path = os.path.join(args.checkpoint_dir, ckpt_name) + ckpt_data_out = {"params": params_np, "config": config.__dict__} + if use_topk: + ml_np = np.array(jax_utils.unreplicate(mask_logits)) + ckpt_data_out["mask_logits"] = ml_np + ckpt_data_out["mat_method"] = "topk" + ckpt_data_out["mat_factors"] = list(_MAT_FACTORS) with open(ckpt_path, "wb") as f: - pickle.dump({"params": params_np, "config": config.__dict__}, f) + pickle.dump(ckpt_data_out, f) from .test import measure_throughput, benchmark_tool_calls from .run import generate_from_audio eval_params_jnp = jax.tree.map(jnp.array, params_np) - del params_np + del params_np model = EncoderDecoderTransformer(config) - tp = measure_throughput(model, eval_params_jnp, tokenizer, num_runs=5) - tc_metrics = benchmark_tool_calls(model, eval_params_jnp, tokenizer, num_samples=20, max_gen_len=128) + # Build ffn_mask dict from topk_hard_masks for sub-model eval during training + eval_ffn_mask = None + if use_topk and topk_hard_masks: + # Use the smallest (most compressed) sub-model for generation eval + smallest_ff_w = _MAT_FF_WIDTHS[-1] + if smallest_ff_w in topk_hard_masks: + hm = topk_hard_masks[smallest_ff_w] # (n_blocks, d_ff) + n_enc = config.num_encoder_layers + eval_ffn_mask = { + "encoder": hm[:n_enc, None, :], # (n_enc, 1, d_ff) + "decoder": hm[n_enc:, None, :], # (n_dec, 1, d_ff) + } + tp = measure_throughput(model, eval_params_jnp, tokenizer, num_runs=5, ffn_mask=eval_ffn_mask) + tc_metrics = benchmark_tool_calls(model, eval_params_jnp, tokenizer, num_samples=20, max_gen_len=128, ffn_mask=eval_ffn_mask) voice_tc_samples = [] if not no_speech and val_speech_audio_arrays is not None: @@ -912,6 +1319,7 @@ def _tile_sp(arr): pred_text = generate_from_audio( model, eval_params_jnp, tokenizer, audio, sr=sr, tools=pair["tools"], max_gen_len=128, seed=i, stream=False, + ffn_mask=eval_ffn_mask, ).strip() voice_tc_samples.append((pair["query"][:80], pair["answers"][:120], pred_text[:120])) From 7d67e4b31554674655860a73a53b9bf6bc42995a Mon Sep 17 00:00:00 2001 From: Noah Cylich Date: Sun, 8 Mar 2026 20:08:23 -0700 Subject: [PATCH 2/3] cleaned the code --- src/export.py | 30 +++++--- src/model.py | 54 +++++++------ src/train.py | 205 ++++++++++++++++++-------------------------------- 3 files changed, 125 insertions(+), 164 deletions(-) diff --git a/src/export.py b/src/export.py index 60c6f36..59be10d 100644 --- a/src/export.py +++ b/src/export.py @@ -20,6 +20,19 @@ _FFN_KERNEL_NAMES = {"gate_proj", "up_proj", "down_proj"} +def _to_numpy(tree): + """Convert all JAX arrays in a pytree to numpy arrays.""" + return jax.tree.map( + lambda x: np.asarray(x) if isinstance(x, jnp.ndarray) else x, tree + ) + + +def _param_stats(tree): + """Return (param_count, total_bytes) for a pytree of arrays.""" + leaves = jax.tree.leaves(tree) + return sum(x.size for x in leaves), sum(x.nbytes for x in leaves) + + def export_submodel(checkpoint_path, factor, output_path): """Export a matryoshka sub-model from a full checkpoint. @@ -50,9 +63,7 @@ def _export_topk(data, params, config, factor, output_path): ff_w = config.d_ff // factor per_layer_indices = [np.sort(np.argsort(-factor_logits[b])[:ff_w]) for b in range(factor_logits.shape[0])] - params_np = jax.tree.map( - lambda x: np.asarray(x) if isinstance(x, jnp.ndarray) else x, params - ) + params_np = _to_numpy(params) os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) with open(output_path, "wb") as f: @@ -65,8 +76,7 @@ def _export_topk(data, params, config, factor, output_path): }, f) n_blocks = len(per_layer_indices) - orig_count = sum(x.size for x in jax.tree.leaves(params)) - orig_bytes = sum(x.nbytes for x in jax.tree.leaves(params)) + orig_count, orig_bytes = _param_stats(params) print(f"\n TopK export: {output_path}") print(f" ─────────────────────────────────────") @@ -118,18 +128,14 @@ def slice_leaf(key_path, leaf): new_config = replace(config, d_ff=d_ff_new) - sliced_np = jax.tree.map( - lambda x: np.asarray(x) if isinstance(x, jnp.ndarray) else x, sliced - ) + sliced_np = _to_numpy(sliced) os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) with open(output_path, "wb") as f: pickle.dump({"params": sliced_np, "config": new_config.__dict__}, f) - orig_count = sum(x.size for x in jax.tree.leaves(params)) - new_count = sum(x.size for x in jax.tree.leaves(sliced_np)) - orig_bytes = sum(x.nbytes for x in jax.tree.leaves(params)) - new_bytes = sum(x.nbytes for x in jax.tree.leaves(sliced_np)) + orig_count, orig_bytes = _param_stats(params) + new_count, new_bytes = _param_stats(sliced_np) print(f"\n Export complete: {output_path}") print(f" ─────────────────────────────────────") diff --git a/src/model.py b/src/model.py index 688f190..2d50f95 100644 --- a/src/model.py +++ b/src/model.py @@ -202,6 +202,13 @@ def __call__(self, x, s, mask=None, rope=None, ffn_mask=None): return x, s +def _get_block_mask(ffn_mask, block_idx): + """Extract per-block FFN mask: index into first dim if 3D, otherwise pass through.""" + if ffn_mask is not None and ffn_mask.ndim == 3: + return ffn_mask[block_idx] + return ffn_mask + + class MemoryMixerEncoder(nn.Module): """Encoder using MemoryMixer blocks. Output is the final memory slots S.""" config: TransformerConfig @@ -231,7 +238,7 @@ def __call__(self, x, mask=None, rope=None, ffn_mask=None): mask = mask[..., :T_new] for i in range(cfg.num_encoder_layers): - block_mask = ffn_mask[i] if (ffn_mask is not None and ffn_mask.ndim == 3) else ffn_mask + block_mask = _get_block_mask(ffn_mask, i) x, s = nn.remat(MemoryMixerBlock)( cfg.num_heads, cfg.num_kv_heads, cfg.d_model, cfg.d_ff, cfg.num_memory_slots, cfg.total_layers, dt, cfg.activation, name=f"block_{i}" @@ -285,7 +292,7 @@ def __call__(self, x, encoder_out, self_mask=None, cross_mask=None, rope=None, f x = x.astype(dt) for i in range(cfg.num_decoder_layers): - block_mask = ffn_mask[i] if (ffn_mask is not None and ffn_mask.ndim == 3) else ffn_mask + block_mask = _get_block_mask(ffn_mask, i) x = nn.remat(DecoderBlock)( cfg.num_heads, cfg.num_kv_heads, cfg.d_model, cfg.d_ff, cfg.total_layers, dt, cfg.activation, name=f"block_{i}" )(x, encoder_out, self_mask=self_mask, cross_mask=cross_mask, rope=rope, ffn_mask=block_mask) @@ -395,23 +402,24 @@ def _split_ffn_mask(self, ffn_mask): return ffn_mask[:n_enc], ffn_mask[n_enc:] return ffn_mask, ffn_mask - def forward_masked(self, src, tgt, src_mask=None, tgt_mask=None, ffn_mask=None): - """Single forward with per-batch-item FFN masking. Returns (logits, slot_div).""" - enc_mask, dec_mask = self._split_ffn_mask(ffn_mask) - encoder_out = self.encode_text(src, src_mask=src_mask, ffn_mask=enc_mask) + def _forward_masked_impl(self, encoder_out, tgt, tgt_mask=None, dec_mask=None): + """Shared masked forward: decoder + logits + slot diversity.""" x_f32 = self._run_decoder(encoder_out, tgt, tgt_mask=tgt_mask, ffn_mask=dec_mask) logits = x_f32 @ self.embedding.embedding.T slot_div = self._slot_diversity(encoder_out) return logits, slot_div + def forward_masked(self, src, tgt, src_mask=None, tgt_mask=None, ffn_mask=None): + """Single forward with per-batch-item FFN masking. Returns (logits, slot_div).""" + enc_mask, dec_mask = self._split_ffn_mask(ffn_mask) + encoder_out = self.encode_text(src, src_mask=src_mask, ffn_mask=enc_mask) + return self._forward_masked_impl(encoder_out, tgt, tgt_mask=tgt_mask, dec_mask=dec_mask) + def forward_speech_masked(self, mel, tgt, src_mask=None, tgt_mask=None, ffn_mask=None, deterministic=True): """Single speech forward with per-batch-item FFN masking. Returns (logits, slot_div).""" enc_mask, dec_mask = self._split_ffn_mask(ffn_mask) encoder_out = self.encode_speech(mel, src_mask=src_mask, ffn_mask=enc_mask, deterministic=deterministic) - x_f32 = self._run_decoder(encoder_out, tgt, tgt_mask=tgt_mask, ffn_mask=dec_mask) - logits = x_f32 @ self.embedding.embedding.T - slot_div = self._slot_diversity(encoder_out) - return logits, slot_div + return self._forward_masked_impl(encoder_out, tgt, tgt_mask=tgt_mask, dec_mask=dec_mask) def _make_eval_ffn_mask(self, ff_width, B, dtype): """Default prefix FFN mask for eval: first ff_width neurons active.""" @@ -440,25 +448,24 @@ def _eval_sub_models(self, encode_fn, src, tgt, src_mask, tgt_mask, B, dtype, ma mat_logits.append(x_m @ emb.T) return mat_logits - def forward_with_aux(self, src, tgt, src_mask=None, tgt_mask=None, cross_mask=None, mat_ff_widths=None, mat_ffn_masks=None): - """Eval-only: separate per-width forwards for reporting per-width PPL.""" - encoder_out = self.encode_text(src, src_mask=src_mask) + def _forward_with_aux_impl(self, encode_fn, src, tgt, src_mask=None, tgt_mask=None, mat_ff_widths=None, mat_ffn_masks=None): + """Eval-only: full forward + per-width sub-model forwards for reporting per-width PPL.""" + encoder_out = encode_fn(src, src_mask=src_mask) x_f32 = self._run_decoder(encoder_out, tgt, tgt_mask=tgt_mask) logits = x_f32 @ self.embedding.embedding.T slot_div = self._slot_diversity(encoder_out) - mat_logits = self._eval_sub_models(self.encode_text, src, tgt, src_mask, tgt_mask, src.shape[0], x_f32.dtype, mat_ff_widths, mat_ffn_masks) + mat_logits = self._eval_sub_models(encode_fn, src, tgt, src_mask, tgt_mask, src.shape[0], x_f32.dtype, mat_ff_widths, mat_ffn_masks) return logits, slot_div, mat_logits + def forward_with_aux(self, src, tgt, src_mask=None, tgt_mask=None, cross_mask=None, mat_ff_widths=None, mat_ffn_masks=None): + return self._forward_with_aux_impl(self.encode_text, src, tgt, src_mask=src_mask, tgt_mask=tgt_mask, + mat_ff_widths=mat_ff_widths, mat_ffn_masks=mat_ffn_masks) + def forward_speech_with_aux(self, mel, tgt, src_mask=None, tgt_mask=None, mat_ff_widths=None, mat_ffn_masks=None, deterministic=True): - """Eval-only: separate per-width speech forwards for reporting per-width PPL.""" from functools import partial encode_fn = partial(self.encode_speech, deterministic=deterministic) - encoder_out = encode_fn(mel, src_mask=src_mask) - x_f32 = self._run_decoder(encoder_out, tgt, tgt_mask=tgt_mask) - logits = x_f32 @ self.embedding.embedding.T - slot_div = self._slot_diversity(encoder_out) - mat_logits = self._eval_sub_models(encode_fn, mel, tgt, src_mask, tgt_mask, mel.shape[0], x_f32.dtype, mat_ff_widths, mat_ffn_masks) - return logits, slot_div, mat_logits + return self._forward_with_aux_impl(encode_fn, mel, tgt, src_mask=src_mask, tgt_mask=tgt_mask, + mat_ff_widths=mat_ff_widths, mat_ffn_masks=mat_ffn_masks) def init_all(self, src, tgt, mel): """Dummy forward through both text and speech pathways to initialize all params.""" @@ -472,6 +479,11 @@ def init_all(self, src, tgt, mel): return jnp.zeros(()) +def count_params(params): + """Count total number of parameters in a pytree.""" + return sum(x.size for x in jax.tree.leaves(params)) + + def make_causal_mask(seq_len): mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) return mask[None, None, :, :] diff --git a/src/train.py b/src/train.py index 1e8a458..181e8f2 100644 --- a/src/train.py +++ b/src/train.py @@ -20,6 +20,7 @@ from .model import ( EncoderDecoderTransformer, TransformerConfig, + count_params, make_causal_mask, make_padding_mask, make_mel_padding_mask, @@ -332,27 +333,31 @@ def _speech_loss_fn(state, params, mel, tgt_in, tgt_out, causal_mask, ffn_mask, return _compute_ce(logits, tgt_out, slot_div, loss_mask=loss_mask) -def _topk_loss(state, params, mask_logits, src, tgt_in, tgt_out, causal_mask, - tau, hard, step_rng, is_speech=False, spec_rng=None, loss_mask=None): - """Topk loss for text or speech. Builds masks inside for gradient flow.""" - ffn_mask = _make_ffn_mask_topk(src.shape[0], _D_FF, mask_logits, _MAT_FF_WIDTHS, tau, hard, step_rng) +def _forward_masked(state, params, src, tgt_in, causal_mask, ffn_mask, is_speech=False, spec_rng=None): + """Dispatch forward_masked or forward_speech_masked based on is_speech.""" + q_params = _quantize_params(params, group_size=_GROUP_SIZE) + tgt_mask = causal_mask & make_padding_mask(tgt_in, 0) if is_speech: - src_mask = make_mel_padding_mask(src) - logits, slot_div = state.apply_fn( - {"params": _quantize_params(params, group_size=_GROUP_SIZE)}, - src, tgt_in, src_mask=src_mask, - tgt_mask=causal_mask & make_padding_mask(tgt_in, 0), + return state.apply_fn( + {"params": q_params}, src, tgt_in, + src_mask=make_mel_padding_mask(src), tgt_mask=tgt_mask, ffn_mask=ffn_mask, deterministic=False, method="forward_speech_masked", rngs={"specaugment": spec_rng}, ) else: - src_mask = make_padding_mask(src, 0) - logits, slot_div = state.apply_fn( - {"params": _quantize_params(params, group_size=_GROUP_SIZE)}, - src, tgt_in, src_mask=src_mask, - tgt_mask=causal_mask & make_padding_mask(tgt_in, 0), + return state.apply_fn( + {"params": q_params}, src, tgt_in, + src_mask=make_padding_mask(src, 0), tgt_mask=tgt_mask, ffn_mask=ffn_mask, method="forward_masked", ) + + +def _topk_loss(state, params, mask_logits, src, tgt_in, tgt_out, causal_mask, + tau, hard, step_rng, is_speech=False, spec_rng=None, loss_mask=None): + """Topk loss for text or speech. Builds masks inside for gradient flow.""" + ffn_mask = _make_ffn_mask_topk(src.shape[0], _D_FF, mask_logits, _MAT_FF_WIDTHS, tau, hard, step_rng) + logits, slot_div = _forward_masked(state, params, src, tgt_in, causal_mask, ffn_mask, + is_speech=is_speech, spec_rng=spec_rng) loss = _compute_ce(logits, tgt_out, slot_div, loss_mask=loss_mask) if _MAT_SPREAD_LAMBDA > 0: spread = jnp.mean(jnp.var(mask_logits, axis=-1)) @@ -363,23 +368,8 @@ def _topk_loss(state, params, mask_logits, src, tgt_in, tgt_out, causal_mask, def _warmup_loss(state, params, src, tgt_in, tgt_out, causal_mask, is_speech=False, spec_rng=None, loss_mask=None): """Full-model-only loss (no matryoshka) for topk warmup phase.""" ffn_mask = jnp.ones((src.shape[0], _D_FF), dtype=jnp.bfloat16) - if is_speech: - src_mask = make_mel_padding_mask(src) - logits, slot_div = state.apply_fn( - {"params": _quantize_params(params, group_size=_GROUP_SIZE)}, - src, tgt_in, src_mask=src_mask, - tgt_mask=causal_mask & make_padding_mask(tgt_in, 0), - ffn_mask=ffn_mask, deterministic=False, - method="forward_speech_masked", rngs={"specaugment": spec_rng}, - ) - else: - src_mask = make_padding_mask(src, 0) - logits, slot_div = state.apply_fn( - {"params": _quantize_params(params, group_size=_GROUP_SIZE)}, - src, tgt_in, src_mask=src_mask, - tgt_mask=causal_mask & make_padding_mask(tgt_in, 0), - ffn_mask=ffn_mask, method="forward_masked", - ) + logits, slot_div = _forward_masked(state, params, src, tgt_in, causal_mask, ffn_mask, + is_speech=is_speech, spec_rng=spec_rng) return _compute_ce(logits, tgt_out, slot_div, loss_mask=loss_mask) @@ -534,36 +524,9 @@ def _update_mask_logits(ml_grads, mask_logits, mask_tx, mask_opt_state): return jax_utils.replicate(jnp.array(ml_np)), mask_opt_state -def _make_p_train_step(): - return jax.pmap(_train_step_text, axis_name="batch", donate_argnums=(0, 1)) - - -def _make_p_train_step_masked(): - return jax.pmap(_train_step_text_masked, axis_name="batch", donate_argnums=(0, 1)) - - -def _make_p_train_step_speech(): - return jax.pmap(_train_step_speech, axis_name="batch", donate_argnums=(0, 1)) - - -def _make_p_train_step_speech_masked(): - return jax.pmap(_train_step_speech_masked, axis_name="batch", donate_argnums=(0, 1)) - - -def _make_p_train_step_topk(): - return jax.pmap(_train_step_text_topk, axis_name="batch", donate_argnums=(0, 1)) - - -def _make_p_train_step_topk_masked(): - return jax.pmap(_train_step_text_topk_masked, axis_name="batch", donate_argnums=(0, 1)) - - -def _make_p_train_step_speech_topk(): - return jax.pmap(_train_step_speech_topk, axis_name="batch", donate_argnums=(0, 1)) - - -def _make_p_train_step_speech_topk_masked(): - return jax.pmap(_train_step_speech_topk_masked, axis_name="batch", donate_argnums=(0, 1)) +def _pmap_train_step(step_fn): + """Create a pmap'd train step from any step function.""" + return jax.pmap(step_fn, axis_name="batch", donate_argnums=(0, 1)) def _make_val_loss_fn(apply_fn): @@ -623,6 +586,16 @@ def val_loss_batch(params, mel, tgt_in, tgt_out, causal_mask, loss_mask): return val_loss_batch +def _eval_val_ppl(loss_fn, params, batches, causal_mask): + """Evaluate validation perplexity over batches. Returns PPL (capped at exp(20)).""" + total_loss, total_toks = 0.0, 0.0 + for vb in batches: + vl, vt = loss_fn(params, vb[0], vb[1], vb[2], causal_mask, vb[3]) + total_loss += float(vl) + total_toks += float(vt) + return float(math.exp(min(total_loss / max(total_toks, 1), 20))) + + def _estimate_mat_params(config, matryoshka_factor): """Estimate parameter count of a sub-model at a given matryoshka factor. @@ -653,6 +626,12 @@ def shard_batch(batch, num_devices): return batch.reshape(num_devices, -1, *batch.shape[1:]) +def _tile_for_mat(arr, num_devices, per_width, n_widths): + """Tile a batch for shared-input matryoshka: repeat each sample across all widths.""" + s = arr.reshape(num_devices, per_width, *arr.shape[1:]) + return np.tile(s, (1, n_widths) + (1,) * (arr.ndim - 1)) + + def train(args): num_devices = jax.local_device_count() no_speech = getattr(args, "no_speech", False) @@ -775,16 +754,16 @@ def train(args): _MAT_GUMBEL = getattr(args, "mat_gumbel", False) if use_topk: - p_train_step_warmup = jax.pmap(_train_step_text_warmup, axis_name="batch", donate_argnums=(0, 1)) - p_train_step_warmup_speech = jax.pmap(_train_step_speech_warmup, axis_name="batch", donate_argnums=(0, 1)) - p_train_step_topk = _make_p_train_step_topk() - p_train_step_topk_masked = _make_p_train_step_topk_masked() - p_train_step_speech_topk = _make_p_train_step_speech_topk() - p_train_step_speech_topk_masked = _make_p_train_step_speech_topk_masked() - p_train_step = _make_p_train_step() - p_train_step_masked = _make_p_train_step_masked() - p_train_step_speech = _make_p_train_step_speech() - p_train_step_speech_masked = _make_p_train_step_speech_masked() + p_train_step_warmup = _pmap_train_step(_train_step_text_warmup) + p_train_step_warmup_speech = _pmap_train_step(_train_step_speech_warmup) + p_train_step_topk = _pmap_train_step(_train_step_text_topk) + p_train_step_topk_masked = _pmap_train_step(_train_step_text_topk_masked) + p_train_step_speech_topk = _pmap_train_step(_train_step_speech_topk) + p_train_step_speech_topk_masked = _pmap_train_step(_train_step_speech_topk_masked) + p_train_step = _pmap_train_step(_train_step_text) + p_train_step_masked = _pmap_train_step(_train_step_text_masked) + p_train_step_speech = _pmap_train_step(_train_step_speech) + p_train_step_speech_masked = _pmap_train_step(_train_step_speech_masked) np.random.seed(args.seed) rng = jax.random.PRNGKey(args.seed) @@ -866,7 +845,7 @@ def train(args): if use_topk: mask_logits = jax_utils.replicate(mask_logits) - param_count = sum(x.size for x in jax.tree.leaves(jax_utils.unreplicate(state).params)) + param_count = count_params(jax_utils.unreplicate(state).params) decay_steps = max(1, int(total_steps * 0.15)) stable_steps = total_steps - warmup_steps - decay_steps @@ -1004,14 +983,8 @@ def train(args): text_idx += 1 if not use_topk and n_widths > 1 and mat_shared_input: - per_width = args.batch_size // n_widths - def _tile_for_mat(arr): - s = arr.reshape(num_devices, per_width, *arr.shape[1:]) - return np.tile(s, (1, n_widths) + (1,) * (arr.ndim - 1)) - src_b = _tile_for_mat(src) - tgt_in_b = _tile_for_mat(tgt_in) - tgt_out_b = _tile_for_mat(tgt_out) - lm_b = _tile_for_mat(lm) + tile = lambda arr: _tile_for_mat(arr, num_devices, args.batch_size // n_widths, n_widths) + src_b, tgt_in_b, tgt_out_b, lm_b = tile(src), tile(tgt_in), tile(tgt_out), tile(lm) else: src_b = shard_batch(src, num_devices) tgt_in_b = shard_batch(tgt_in, num_devices) @@ -1064,14 +1037,8 @@ def _tile_for_mat(arr): speech_idx += 1 if not use_topk and n_widths > 1 and mat_shared_input: - per_width = args.batch_size // n_widths - def _tile_sp(arr): - s = arr.reshape(num_devices, per_width, *arr.shape[1:]) - return np.tile(s, (1, n_widths) + (1,) * (arr.ndim - 1)) - mel_b = _tile_sp(mel_batch) - sp_tgt_in_b = _tile_sp(sp_tgt_in) - sp_tgt_out_b = _tile_sp(sp_tgt_out) - sp_lm_b = _tile_sp(sp_lm) + tile = lambda arr: _tile_for_mat(arr, num_devices, args.batch_size // n_widths, n_widths) + mel_b, sp_tgt_in_b, sp_tgt_out_b, sp_lm_b = tile(mel_batch), tile(sp_tgt_in), tile(sp_tgt_out), tile(sp_lm) else: mel_b = shard_batch(mel_batch, num_devices) sp_tgt_in_b = shard_batch(sp_tgt_in, num_devices) @@ -1147,13 +1114,9 @@ def _tile_sp(arr): if global_step % eval_every == 0 or global_step == total_steps: _eval_params = jax_utils.unreplicate(ema_params) val_causal = make_causal_mask(args.max_dec_len) - total_loss, total_toks = 0.0, 0.0 - for vb in get_batches(val_enc, val_dec_in, val_dec_tgt, args.batch_size, shuffle=False, loss_mask=val_loss_mask): - vl, vt = val_loss_fn(_eval_params, vb[0], vb[1], vb[2], val_causal, vb[3]) - total_loss += float(vl) - total_toks += float(vt) - last_val_ppl = float(math.exp(min(total_loss / max(total_toks, 1), 20))) - + last_val_ppl = _eval_val_ppl(val_loss_fn, _eval_params, + get_batches(val_enc, val_dec_in, val_dec_tgt, args.batch_size, shuffle=False, loss_mask=val_loss_mask), + val_causal) del _eval_params postfix = { @@ -1203,7 +1166,7 @@ def _tile_sp(arr): params=jax.tree.map(lambda w, m: w * m, state.params, prune_mask)) ema_params = jax.tree.map(lambda w, m: w * m, ema_params, prune_mask) final_pruned = jax.tree.map(np.array, jax_utils.unreplicate(ema_params)) - total_p = sum(x.size for x in jax.tree.leaves(final_pruned)) + total_p = count_params(final_pruned) zero_p = sum(int(np.sum(np.abs(x) < 1e-6)) for x in jax.tree.leaves(final_pruned)) print(f"\n Gradual sparsification complete — mask locked.") print(f" Final sparsity: {zero_p/total_p*100:.2f}% ({zero_p:,}/{total_p:,} near-zero)") @@ -1215,33 +1178,21 @@ def _tile_sp(arr): eval_params = jax_utils.unreplicate(ema_params) val_causal = make_causal_mask(args.max_dec_len) - total_loss, total_toks = 0.0, 0.0 - for vb in get_batches(val_enc, val_dec_in, val_dec_tgt, args.batch_size, shuffle=False, loss_mask=val_loss_mask): - vl, vt = val_loss_fn(eval_params, vb[0], vb[1], vb[2], val_causal, vb[3]) - total_loss += float(vl) - total_toks += float(vt) - last_val_ppl = float(math.exp(min(total_loss / max(total_toks, 1), 20))) + val_batches = lambda: get_batches(val_enc, val_dec_in, val_dec_tgt, args.batch_size, shuffle=False, loss_mask=val_loss_mask) + last_val_ppl = _eval_val_ppl(val_loss_fn, eval_params, val_batches(), val_causal) speech_val_ppl = None if speech_vl_fn is not None and val_speech_audio_arrays is not None: - sp_total_loss, sp_total_toks = 0.0, 0.0 - for sp_batch in get_speech_batches(val_speech_audio_arrays, val_speech_dec_in, val_speech_dec_tgt, args.batch_size, - shuffle=False, loss_mask=val_speech_loss_mask, - n_mels=n_mels, max_mel_len=max_mel_len, augmenter=None): - vl, vt = speech_vl_fn(eval_params, sp_batch[0], sp_batch[1], sp_batch[2], val_causal, sp_batch[3]) - sp_total_loss += float(vl) - sp_total_toks += float(vt) - speech_val_loss = sp_total_loss / max(sp_total_toks, 1) - speech_val_ppl = float(math.exp(min(speech_val_loss, 20))) + speech_val_ppl = _eval_val_ppl( + speech_vl_fn, eval_params, + get_speech_batches(val_speech_audio_arrays, val_speech_dec_in, val_speech_dec_tgt, args.batch_size, + shuffle=False, loss_mask=val_speech_loss_mask, + n_mels=n_mels, max_mel_len=max_mel_len, augmenter=None), + val_causal) q_params = _quantize_params(eval_params, group_size=_GROUP_SIZE) - q_total_loss, q_total_toks = 0.0, 0.0 - for vb in get_batches(val_enc, val_dec_in, val_dec_tgt, args.batch_size, shuffle=False, loss_mask=val_loss_mask): - vl, vt = val_loss_fn(q_params, vb[0], vb[1], vb[2], val_causal, vb[3]) - q_total_loss += float(vl) - q_total_toks += float(vt) - quant_val_ppl = float(math.exp(min(q_total_loss / max(q_total_toks, 1), 20))) - del q_params + quant_val_ppl = _eval_val_ppl(val_loss_fn, q_params, val_batches(), val_causal) + del q_params mat_results = {} if _MAT_FACTORS: @@ -1251,30 +1202,22 @@ def _tile_sp(arr): ml_unr = jax_utils.unreplicate(mask_logits) n_blocks_eval = ml_unr.shape[1] for i, ff_w in enumerate(_MAT_FF_WIDTHS): - # Per-layer hard masks: (n_blocks, d_ff) - block_masks = [] - for b in range(n_blocks_eval): - block_masks.append(topk_mask(ml_unr[i, b], k=ff_w, tau=jnp.float32(0.001), hard=jnp.bool_(True))) - topk_hard_masks[ff_w] = jnp.stack(block_masks) # (n_blocks, d_ff) + block_masks = [topk_mask(ml_unr[i, b], k=ff_w, tau=jnp.float32(0.001), hard=jnp.bool_(True)) + for b in range(n_blocks_eval)] + topk_hard_masks[ff_w] = jnp.stack(block_masks) for factor, ff_w in zip(_MAT_FACTORS, _MAT_FF_WIDTHS): if ff_w in topk_hard_masks: mat_vl_fn = _make_mat_val_loss_fn(apply_fn, ffn_mask=topk_hard_masks[ff_w]) else: mat_vl_fn = _make_mat_val_loss_fn(apply_fn, ff_width=ff_w) - mat_total_loss, mat_total_toks = 0.0, 0.0 - for vb in get_batches(val_enc, val_dec_in, val_dec_tgt, args.batch_size, shuffle=False, loss_mask=val_loss_mask): - vl, vt = mat_vl_fn(eval_params, vb[0], vb[1], vb[2], val_causal, vb[3]) - mat_total_loss += float(vl) - mat_total_toks += float(vt) - avg_loss = mat_total_loss / max(mat_total_toks, 1) - mat_ppl = float(math.exp(min(avg_loss, 20))) + mat_ppl = _eval_val_ppl(mat_vl_fn, eval_params, val_batches(), val_causal) mat_params = _estimate_mat_params(config, factor) mat_results[factor] = (mat_ppl, mat_params, ff_w) del apply_fn params_np = jax.tree.map(np.array, eval_params) del eval_params - total_params = sum(x.size for x in jax.tree.leaves(params_np)) + total_params = count_params(params_np) near_zero = sum(int(np.sum(np.abs(x) < 1e-6)) for x in jax.tree.leaves(params_np)) sparsity = near_zero / total_params * 100 From f92c639932c6256304bff45982ecd0171e88b988 Mon Sep 17 00:00:00 2001 From: Noah Cylich Date: Mon, 9 Mar 2026 21:36:53 -0700 Subject: [PATCH 3/3] cleaned git diff --- src/cli.py | 87 ++----------------------------------------------- src/evaluate.py | 2 +- src/model.py | 4 +-- src/run.py | 23 +++++++++---- src/test.py | 63 +++++++++++++++++++++++++++++++++++ src/train.py | 76 +++++++++++++++++++++++------------------- 6 files changed, 127 insertions(+), 128 deletions(-) diff --git a/src/cli.py b/src/cli.py index b83b6a0..fac5aea 100644 --- a/src/cli.py +++ b/src/cli.py @@ -1,90 +1,7 @@ import argparse import sys -HELP = """ - ┌───────────────────────────────────────────────────────────────────┐ - │ │ - │ ┌─┐┌─┐┌─┐┌┬┐┬ ┬┌─┐ ┌┐┌┌─┐┌─┐┌┬┐┬ ┌─┐ │ - │ │ ├─┤│ │ │ │└─┐ │││├┤ ├┤ │││ ├┤ │ - │ └─┘┴ ┴└─┘ ┴ └─┘└─┘ ┘└┘└─┘└─┘─┴┘┴─┘└─┘ │ - │ ...the tiny model to rule them all... │ - │ │ - │ train │ - │ --full Use full 1B config (~1.17B params) │ - │ --epochs INT Training epochs (default: 1) │ - │ --batch-size INT Batch size (default: 32) │ - │ --lr FLOAT AdamW learning rate (default: 3e-4) │ - │ --muon-lr FLOAT Muon learning rate (default: 0.02) │ - │ --d-model INT Model dim (default: 512) │ - │ --num-heads INT Attention heads (default: 8) │ - │ --num-kv-heads INT KV heads for GQA (default: num-heads)│ - │ --num-layers INT Encoder layers (default: 8) │ - │ --num-dec-layers INT Decoder layers (default: 4) │ - │ --max-enc-len INT Max encoder seq length (default: 256)│ - │ --max-dec-len INT Max decoder seq length (default: 256)│ - │ --max-samples INT Training samples (default: all) │ - │ --mat-factors INT [...] FFN shrink factors (default: 2 4 8) │ - │ --mat-method STR static-prefix|topk (default: topk) │ - │ --mat-init-mode STR saliency|prefix|normal (def: sal.) │ - │ --mat-warmup-frac FL Saliency warmup fraction (def: 0.4) │ - │ --mat-freeze-frac FL Mask freeze fraction (default: 1.0) │ - │ --mat-tau-start FLOAT TopK tau start (default: 0.5) │ - │ --mat-tau-end FLOAT TopK tau end (default: 0.1) │ - │ --mat-mask-lr FLOAT Mask logit LR (default: 3e-3) │ - │ --sparsity-ratio FLOAT Block prune ratio (default: 0.5) │ - │ --group-size INT Quant/prune group size (default: 32) │ - │ --prune-interval INT Steps between mask updates (def: 100)│ - │ --prune-start-frac FL Start pruning at frac (def: 0.33) │ - │ --prune-end-frac FL Lock mask at this frac (def: 0.67) │ - │ --activation STR drelu|swiglu|geglu (default: drelu) │ - │ --warmup-ratio FLOAT LR warmup ratio (default: 0.05) │ - │ --eval-every INT Val eval interval (default: 1000) │ - │ --wandb Enable W&B logging │ - │ --checkpoint PATH Resume from checkpoint │ - │ --checkpoint-dir DIR Checkpoint directory │ - │ --seed INT Random seed (default: 42) │ - │ --no-speech Disable speech (text-only training) │ - │ --speech-every INT Speech step every N text (default: 3) │ - │ --max-mel-len INT Max mel frames (default: 1024) │ - │ --n-mels INT Mel frequency bins (default: 80) │ - │ --max-speech-samples INT Max LibriSpeech samples │ - │ │ - │ run │ - │ --checkpoint PATH Path to model checkpoint (required) │ - │ --query STR Query text for tool-call generation │ - │ --tools STR Tools JSON for tool-call generation │ - │ --audio PATH [...] Audio files for voice-to-tool-call │ - │ --max-len INT Max tokens to generate (default: 512) │ - │ --seed INT Random seed (default: 0) │ - │ │ - │ test │ - │ --checkpoint PATH Path to model checkpoint (required) │ - │ --batch-size INT Batch size (default: 32) │ - │ --max-eval-samples INT Evaluation samples (default: 1000) │ - │ --max-gen-len INT Max generation length (default: 512) │ - │ --tool-call-samples INT Tool-call accuracy samples (def: 200) │ - │ --voice-tc-samples INT Voice-tool-call samples (default: 50) │ - │ --throughput-runs INT Throughput runs (default: 10) │ - │ │ - │ evaluate │ - │ --checkpoint PATH Path to model checkpoint (required) │ - │ --benchmarks [...] wikitext2 lambada hellaswag arc_easy │ - │ --max-samples INT Samples per benchmark (default: 500) │ - │ │ - │ tpu │ - │ create NAME Create TPU (auto-finds zone) │ - │ --type STR Accelerator (default: v6e-8) │ - │ --version STR TPU OS (auto-detected from --type) │ - │ connect NAME SSH config + connect (auto-zone) │ - │ claude NAME Install Claude Code on instance │ - │ stop NAME Stop instance (auto-zone) │ - │ start NAME Start stopped instance (auto-zone) │ - │ delete NAME Delete instance (auto-zone) │ - │ list List all TPU instances │ - │ --zone ZONE Override auto-detected zone │ - │ │ - └───────────────────────────────────────────────────────────────────┘ -""" +HELP = """Check the readme""" def main(): if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help", "help"): @@ -147,7 +64,7 @@ def main(): p.add_argument("--mat-gumbel", action="store_true", help="Use Gumbel noise for per-item mask diversity during topk learning") p.add_argument("--dropout", type=float, default=0.0, - help="Dropout rate for residual connections (default: 0.1)") + help="Dropout rate for residual connections (default: 0.0)") p.add_argument("--no-speech", action="store_true", help="Disable speech training (text-only)") p.add_argument("--max-mel-len", type=int, default=1024, help="Max mel spectrogram frames (default: 1024)") diff --git a/src/evaluate.py b/src/evaluate.py index a30dbc1..8640e24 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -37,7 +37,7 @@ def _decode(params, dec_input, encoder_out, tgt_mask, _unused_cross_mask): def _shard_single(x, num_devices): """Replicate a single-sample batch across all devices for pmap.""" - return jnp.broadcast_to(x[None], (num_devices, *x.shape)) + return jnp.broadcast_to(x, (num_devices, *x.shape[1:])) def score_sequence(model, params, enc_tokens, dec_tokens, pad_id, sos_id=None, diff --git a/src/model.py b/src/model.py index 6592540..f1d498c 100644 --- a/src/model.py +++ b/src/model.py @@ -46,7 +46,7 @@ class TransformerConfig: activation: str = "drelu" num_memory_slots: int = 64 n_mels: int = 80 - dropout_rate: float = 0.1 + dropout_rate: float = 0.0 @property def jax_dtype(self): @@ -460,7 +460,7 @@ def _forward_with_aux_impl(self, encode_fn, src, tgt, src_mask=None, tgt_mask=No mat_logits = self._eval_sub_models(encode_fn, src, tgt, src_mask, tgt_mask, src.shape[0], x_f32.dtype, mat_ff_widths, mat_ffn_masks) return logits, slot_div, mat_logits - def forward_with_aux(self, src, tgt, src_mask=None, tgt_mask=None, cross_mask=None, mat_ff_widths=None, mat_ffn_masks=None): + def forward_with_aux(self, src, tgt, src_mask=None, tgt_mask=None, mat_ff_widths=None, mat_ffn_masks=None): return self._forward_with_aux_impl(self.encode_text, src, tgt, src_mask=src_mask, tgt_mask=tgt_mask, mat_ff_widths=mat_ff_widths, mat_ffn_masks=mat_ffn_masks) diff --git a/src/run.py b/src/run.py index e6edbfd..f01f9b0 100644 --- a/src/run.py +++ b/src/run.py @@ -218,11 +218,16 @@ def main(args): print(f"Tools: {tools[:80]}{'...' if len(tools) > 80 else ''}") audio, sr = load_audio(audio_path) generate_from_audio( - model, params, tokenizer, audio, - sr=sr, tools=tools, + model, + params, + tokenizer, + audio, + sr=sr, + tools=tools, max_gen_len=args.max_len, seed=args.seed + i, - stream=True, ffn_mask=ffn_mask, + stream=True, + ffn_mask=ffn_mask, ) return @@ -242,9 +247,15 @@ def main(args): print(f"\nQuery: {q}") print(f"Tools: {t[:80]}{'...' if len(t) > 80 else ''}") generate( - model, params, tokenizer, q, - tools=t, max_gen_len=args.max_len, - seed=args.seed + i, stream=True, ffn_mask=ffn_mask, + model, + params, + tokenizer, + q, + tools=t, + max_gen_len=args.max_len, + seed=args.seed + i, + stream=True, + ffn_mask=ffn_mask, ) diff --git a/src/test.py b/src/test.py index ce86ea2..d97219d 100644 --- a/src/test.py +++ b/src/test.py @@ -113,6 +113,69 @@ def measure_throughput(model, params, tokenizer, num_runs=10, prompt='What is th } +def compute_repetition_rate(texts): + bigram_rep_rates = [] + for text in texts: + words = text.lower().split() + if len(words) < 2: + bigram_rep_rates.append(0.0) + continue + bigrams = [(words[i], words[i + 1]) for i in range(len(words) - 1)] + unique = len(set(bigrams)) + bigram_rep_rates.append(1.0 - unique / len(bigrams)) + return float(np.mean(bigram_rep_rates)) + + +def benchmark_generation_quality(model, params, tokenizer, prompts, max_gen_len=128, temperature=0.8): + from .run import generate + + generations = [] + for i, prompt in enumerate(prompts): + text = generate(model, params, tokenizer, prompt, max_gen_len, temperature, seed=i, stream=False) + generations.append(text) + + lengths = [len(tokenizer.encode(t)) for t in generations] + rep_rate = compute_repetition_rate(generations) + + return { + "avg_generation_length": float(np.mean(lengths)), + "min_generation_length": int(np.min(lengths)), + "max_generation_length": int(np.max(lengths)), + "bigram_repetition_rate": rep_rate, + "generations": list(zip(prompts, generations)), + } + + +def compute_wer(hypotheses, references): + """Compute word error rate using edit distance.""" + total_edits = 0 + total_ref_words = 0 + + for hyp, ref in zip(hypotheses, references): + hyp_words = hyp.lower().split() + ref_words = ref.lower().split() + n = len(ref_words) + m = len(hyp_words) + + # DP edit distance + d = [[0] * (m + 1) for _ in range(n + 1)] + for i in range(n + 1): + d[i][0] = i + for j in range(m + 1): + d[0][j] = j + for i in range(1, n + 1): + for j in range(1, m + 1): + if ref_words[i - 1] == hyp_words[j - 1]: + d[i][j] = d[i - 1][j - 1] + else: + d[i][j] = 1 + min(d[i - 1][j], d[i][j - 1], d[i - 1][j - 1]) + + total_edits += d[n][m] + total_ref_words += n + + return total_edits / max(total_ref_words, 1) + + def benchmark_tool_calls(model, params, tokenizer, num_samples=200, max_gen_len=512, ffn_mask=None): """Generate tool-call predictions and compute structured metrics.""" import json diff --git a/src/train.py b/src/train.py index a639878..3045bba 100644 --- a/src/train.py +++ b/src/train.py @@ -337,31 +337,38 @@ def _speech_loss_fn(state, params, mel, tgt_in, tgt_out, causal_mask, ffn_mask, return _compute_ce(logits, tgt_out, slot_div, loss_mask=loss_mask) -def _forward_masked(state, params, src, tgt_in, causal_mask, ffn_mask, is_speech=False, spec_rng=None): +def _forward_masked(state, params, src, tgt_in, causal_mask, ffn_mask, is_speech=False, spec_rng=None, drop_rng=None): """Dispatch forward_masked or forward_speech_masked based on is_speech.""" q_params = _quantize_params(params, group_size=_GROUP_SIZE) tgt_mask = causal_mask & make_padding_mask(tgt_in, 0) + deterministic = drop_rng is None + rngs = {} + if drop_rng is not None: + rngs["dropout"] = drop_rng + if spec_rng is not None: + rngs["specaugment"] = spec_rng if is_speech: return state.apply_fn( {"params": q_params}, src, tgt_in, src_mask=make_mel_padding_mask(src), tgt_mask=tgt_mask, - ffn_mask=ffn_mask, deterministic=False, - method="forward_speech_masked", rngs={"specaugment": spec_rng}, + ffn_mask=ffn_mask, deterministic=deterministic, + method="forward_speech_masked", rngs=rngs, ) else: return state.apply_fn( {"params": q_params}, src, tgt_in, src_mask=make_padding_mask(src, 0), tgt_mask=tgt_mask, - ffn_mask=ffn_mask, method="forward_masked", + ffn_mask=ffn_mask, deterministic=deterministic, + method="forward_masked", rngs=rngs, ) def _topk_loss(state, params, mask_logits, src, tgt_in, tgt_out, causal_mask, - tau, hard, step_rng, is_speech=False, spec_rng=None, loss_mask=None): + tau, hard, step_rng, is_speech=False, spec_rng=None, loss_mask=None, drop_rng=None): """Topk loss for text or speech. Builds masks inside for gradient flow.""" ffn_mask = _make_ffn_mask_topk(src.shape[0], _D_FF, mask_logits, _MAT_FF_WIDTHS, tau, hard, step_rng) logits, slot_div = _forward_masked(state, params, src, tgt_in, causal_mask, ffn_mask, - is_speech=is_speech, spec_rng=spec_rng) + is_speech=is_speech, spec_rng=spec_rng, drop_rng=drop_rng) loss = _compute_ce(logits, tgt_out, slot_div, loss_mask=loss_mask) if _MAT_SPREAD_LAMBDA > 0: spread = jnp.mean(jnp.var(mask_logits, axis=-1)) @@ -369,11 +376,11 @@ def _topk_loss(state, params, mask_logits, src, tgt_in, tgt_out, causal_mask, return loss -def _warmup_loss(state, params, src, tgt_in, tgt_out, causal_mask, is_speech=False, spec_rng=None, loss_mask=None): +def _warmup_loss(state, params, src, tgt_in, tgt_out, causal_mask, is_speech=False, spec_rng=None, loss_mask=None, drop_rng=None): """Full-model-only loss (no matryoshka) for topk warmup phase.""" ffn_mask = jnp.ones((src.shape[0], _D_FF), dtype=jnp.bfloat16) logits, slot_div = _forward_masked(state, params, src, tgt_in, causal_mask, ffn_mask, - is_speech=is_speech, spec_rng=spec_rng) + is_speech=is_speech, spec_rng=spec_rng, drop_rng=drop_rng) return _compute_ce(logits, tgt_out, slot_div, loss_mask=loss_mask) @@ -438,10 +445,10 @@ def _apply_and_ema(state, ema_params, grads, prune_mask=None): return state, ema -def _train_step_text_warmup(state, ema_params, src, tgt_in, tgt_out, causal_mask, loss_mask): +def _train_step_text_warmup(state, ema_params, src, tgt_in, tgt_out, causal_mask, drop_rng, loss_mask): """Text warmup step (no matryoshka). Returns grads for saliency accumulation.""" loss, grads = jax.value_and_grad( - lambda p: _warmup_loss(state, p, src, tgt_in, tgt_out, causal_mask, loss_mask=loss_mask) + lambda p: _warmup_loss(state, p, src, tgt_in, tgt_out, causal_mask, loss_mask=loss_mask, drop_rng=drop_rng) )(state.params) grads = jax.lax.pmean(grads, axis_name="batch") loss = jax.lax.pmean(loss, axis_name="batch") @@ -449,10 +456,10 @@ def _train_step_text_warmup(state, ema_params, src, tgt_in, tgt_out, causal_mask return state, ema_params, loss, optax.global_norm(grads), grads -def _train_step_speech_warmup(state, ema_params, mel, tgt_in, tgt_out, causal_mask, spec_rng, loss_mask): +def _train_step_speech_warmup(state, ema_params, mel, tgt_in, tgt_out, causal_mask, spec_rng, drop_rng, loss_mask): """Speech warmup step (no matryoshka).""" loss, grads = jax.value_and_grad( - lambda p: _warmup_loss(state, p, mel, tgt_in, tgt_out, causal_mask, is_speech=True, spec_rng=spec_rng, loss_mask=loss_mask) + lambda p: _warmup_loss(state, p, mel, tgt_in, tgt_out, causal_mask, is_speech=True, spec_rng=spec_rng, loss_mask=loss_mask, drop_rng=drop_rng) )(state.params) grads = jax.lax.pmean(grads, axis_name="batch") loss = jax.lax.pmean(loss, axis_name="batch") @@ -470,25 +477,25 @@ def _topk_grad_step(state, ema_params, mask_logits, loss_fn, prune_mask=None): return state, ema_params, ml_grads, loss, optax.global_norm(p_grads) -def _train_step_text_topk(state, ema_params, mask_logits, src, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, loss_mask): +def _train_step_text_topk(state, ema_params, mask_logits, src, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, drop_rng, loss_mask): return _topk_grad_step(state, ema_params, mask_logits, - lambda p, ml: _topk_loss(state, p, ml, src, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, loss_mask=loss_mask)) + lambda p, ml: _topk_loss(state, p, ml, src, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, loss_mask=loss_mask, drop_rng=drop_rng)) -def _train_step_text_topk_masked(state, ema_params, mask_logits, src, tgt_in, tgt_out, causal_mask, prune_mask, tau, hard, step_rng, loss_mask): +def _train_step_text_topk_masked(state, ema_params, mask_logits, src, tgt_in, tgt_out, causal_mask, prune_mask, tau, hard, step_rng, drop_rng, loss_mask): return _topk_grad_step(state, ema_params, mask_logits, - lambda p, ml: _topk_loss(state, p, ml, src, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, loss_mask=loss_mask), + lambda p, ml: _topk_loss(state, p, ml, src, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, loss_mask=loss_mask, drop_rng=drop_rng), prune_mask=prune_mask) -def _train_step_speech_topk(state, ema_params, mask_logits, mel, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, spec_rng, loss_mask): +def _train_step_speech_topk(state, ema_params, mask_logits, mel, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, spec_rng, drop_rng, loss_mask): return _topk_grad_step(state, ema_params, mask_logits, - lambda p, ml: _topk_loss(state, p, ml, mel, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, is_speech=True, spec_rng=spec_rng, loss_mask=loss_mask)) + lambda p, ml: _topk_loss(state, p, ml, mel, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, is_speech=True, spec_rng=spec_rng, loss_mask=loss_mask, drop_rng=drop_rng)) -def _train_step_speech_topk_masked(state, ema_params, mask_logits, mel, tgt_in, tgt_out, causal_mask, prune_mask, tau, hard, step_rng, spec_rng, loss_mask): +def _train_step_speech_topk_masked(state, ema_params, mask_logits, mel, tgt_in, tgt_out, causal_mask, prune_mask, tau, hard, step_rng, spec_rng, drop_rng, loss_mask): return _topk_grad_step(state, ema_params, mask_logits, - lambda p, ml: _topk_loss(state, p, ml, mel, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, is_speech=True, spec_rng=spec_rng, loss_mask=loss_mask), + lambda p, ml: _topk_loss(state, p, ml, mel, tgt_in, tgt_out, causal_mask, tau, hard, step_rng, is_speech=True, spec_rng=spec_rng, loss_mask=loss_mask, drop_rng=drop_rng), prune_mask=prune_mask) @@ -705,7 +712,7 @@ def train(args): activation=getattr(args, "activation", "drelu"), num_memory_slots=getattr(args, "num_memory_slots", 64), n_mels=n_mels, - dropout_rate=getattr(args, "dropout", 0.1), + dropout_rate=getattr(args, "dropout", 0.0), ) global _GROUP_SIZE, _MAT_FACTORS, _MAT_FF_WIDTHS, _D_FF, _N_BLOCKS, _MAT_SPREAD_LAMBDA, _MAT_GUMBEL @@ -983,17 +990,17 @@ def train(args): step_rngs = jax.random.split(step_rng, num_devices) if prune_mask is not None: state, ema_params, ml_grads, loss, grad_norm = p_train_step_topk_masked( - state, ema_params, mask_logits, src_b, tgt_in_b, tgt_out_b, causal_mask, prune_mask, tau_arr, hard_arr, step_rngs, lm_b, + state, ema_params, mask_logits, src_b, tgt_in_b, tgt_out_b, causal_mask, prune_mask, tau_arr, hard_arr, step_rngs, text_rngs, lm_b, ) else: state, ema_params, ml_grads, loss, grad_norm = p_train_step_topk( - state, ema_params, mask_logits, src_b, tgt_in_b, tgt_out_b, causal_mask, tau_arr, hard_arr, step_rngs, lm_b, + state, ema_params, mask_logits, src_b, tgt_in_b, tgt_out_b, causal_mask, tau_arr, hard_arr, step_rngs, text_rngs, lm_b, ) if not cur_hard: mask_logits, mask_opt_state = _update_mask_logits(ml_grads, mask_logits, mask_tx, mask_opt_state) elif use_topk and not topk_active: state, ema_params, loss, grad_norm, warmup_grads = p_train_step_warmup( - state, ema_params, src_b, tgt_in_b, tgt_out_b, causal_mask, lm_b, + state, ema_params, src_b, tgt_in_b, tgt_out_b, causal_mask, text_rngs, lm_b, ) if use_saliency and saliency_accum is not None: grads_unr = jax_utils.unreplicate(warmup_grads) @@ -1033,37 +1040,38 @@ def train(args): if topk_active: tau_arr = jax_utils.replicate(jnp.float32(cur_tau)) hard_arr = jax_utils.replicate(jnp.bool_(cur_hard)) - rng, step_rng, spec_rng = jax.random.split(rng, 3) + rng, step_rng, spec_rng, drop_rng = jax.random.split(rng, 4) step_rngs = jax.random.split(step_rng, num_devices) spec_rngs = jax.random.split(spec_rng, num_devices) + drop_rngs = jax.random.split(drop_rng, num_devices) if prune_mask is not None: state, ema_params, ml_grads, sp_loss, sp_grad_norm = p_train_step_speech_topk_masked( - state, ema_params, mask_logits, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, prune_mask, tau_arr, hard_arr, step_rngs, spec_rngs, sp_lm_b, + state, ema_params, mask_logits, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, prune_mask, tau_arr, hard_arr, step_rngs, spec_rngs, drop_rngs, sp_lm_b, ) else: state, ema_params, ml_grads, sp_loss, sp_grad_norm = p_train_step_speech_topk( - state, ema_params, mask_logits, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, tau_arr, hard_arr, step_rngs, spec_rngs, sp_lm_b, + state, ema_params, mask_logits, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, tau_arr, hard_arr, step_rngs, spec_rngs, drop_rngs, sp_lm_b, ) if not cur_hard: mask_logits, mask_opt_state = _update_mask_logits(ml_grads, mask_logits, mask_tx, mask_opt_state) elif use_topk and not topk_active: - # Speech warmup grads not used for saliency, text is the primary task - rng, spec_rng = jax.random.split(rng) + rng, spec_rng, drop_rng = jax.random.split(rng, 3) spec_rngs = jax.random.split(spec_rng, num_devices) + drop_rngs = jax.random.split(drop_rng, num_devices) state, ema_params, sp_loss, sp_grad_norm, _ = p_train_step_warmup_speech( - state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, spec_rngs, sp_lm_b, + state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, spec_rngs, drop_rngs, sp_lm_b, ) else: speech_ffn_mask = text_ffn_mask - rng, spec_rng = jax.random.split(rng) - spec_rngs = jax.random.split(spec_rng, num_devices) + rng, speech_rng = jax.random.split(rng) + speech_rngs = jax.random.split(speech_rng, num_devices) if prune_mask is not None: state, ema_params, sp_loss, sp_grad_norm = p_train_step_speech_masked( - state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, prune_mask, speech_ffn_mask, spec_rngs, sp_lm_b, + state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, prune_mask, speech_ffn_mask, speech_rngs, sp_lm_b, ) else: state, ema_params, sp_loss, sp_grad_norm = p_train_step_speech( - state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, speech_ffn_mask, spec_rngs, sp_lm_b, + state, ema_params, mel_b, sp_tgt_in_b, sp_tgt_out_b, causal_mask, speech_ffn_mask, speech_rngs, sp_lm_b, ) speech_loss_val = float(sp_loss[0]) speech_losses.append(speech_loss_val)