Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 306 additions & 0 deletions docs/learned_matryoshka_plan.md

Large diffs are not rendered by default.

18 changes: 17 additions & 1 deletion src/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,24 @@ def main():
help="Matryoshka FFN shrink factors, e.g. 2=half width (default: 2 4 8)")
p.add_argument("--mat-shared-input", action="store_true",
help="Each unique input is repeated across all mat widths (default: unique input per width)")
p.add_argument("--mat-method", choices=["static-prefix", "topk"], default="topk",
help="Matryoshka method: 'static-prefix' (fixed first-N masks), 'topk' (saliency-based masks, default)")
p.add_argument("--mat-tau-start", type=float, default=0.5)
p.add_argument("--mat-tau-end", type=float, default=0.1)
p.add_argument("--mat-init-mode", choices=["prefix", "shuffled_prefix", "saliency", "normal", "zeros"], default="saliency")
p.add_argument("--mat-init-value", type=float, default=0.5)
p.add_argument("--mat-spread-lambda", type=float, default=0.0)
p.add_argument("--mat-warmup-frac", type=float, default=0.4,
help="Fraction of total steps for vanilla warmup (no masks)")
p.add_argument("--mat-freeze-frac", type=float, default=1.0,
help="Fraction of total steps at end with frozen hard masks")
p.add_argument("--mat-mask-lr", type=float, default=3e-3,
help="Mask logit optimizer learning rate (default: 3e-3)")
p.add_argument("--mat-saliency-scale", type=float, default=1.0)
p.add_argument("--mat-gumbel", action="store_true",
help="Use Gumbel noise for per-item mask diversity during topk learning")
p.add_argument("--dropout", type=float, default=0.0,
help="Dropout rate for residual connections (default: 0.1)")
help="Dropout rate for residual connections (default: 0.0)")
p.add_argument("--no-speech", action="store_true", help="Disable speech training (text-only)")
p.add_argument("--max-mel-len", type=int, default=1024,
help="Max mel spectrogram frames (default: 1024)")
Expand Down
54 changes: 33 additions & 21 deletions src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,27 @@
from .data import get_tokenizer
from .model import (
EncoderDecoderTransformer,
TransformerConfig,
make_causal_mask,
make_padding_mask,
)
from .run import load_checkpoint


def _make_p_encode(model):
def _make_p_encode(model, enc_ffn=None):
"""Create a pmap'd encode function."""
def _encode(params, src, src_mask):
return model.apply(
{"params": params}, src, src_mask=src_mask, method="encode",
{"params": params}, src, src_mask=src_mask, ffn_mask=enc_ffn, method="encode",
)
return jax.pmap(_encode, axis_name="batch")


def _make_p_decode(model):
def _make_p_decode(model, dec_ffn=None):
"""Create a pmap'd decode function."""
def _decode(params, dec_input, encoder_out, tgt_mask, _unused_cross_mask):
return model.apply(
{"params": params}, dec_input, encoder_out,
self_mask=tgt_mask, method="decode",
self_mask=tgt_mask, ffn_mask=dec_ffn, method="decode",
)
return jax.pmap(_decode, axis_name="batch")

Expand All @@ -42,20 +41,23 @@ def _shard_single(x, num_devices):


def score_sequence(model, params, enc_tokens, dec_tokens, pad_id, sos_id=None,
p_encode=None, p_decode=None, num_devices=1):
p_encode=None, p_decode=None, num_devices=1, ffn_mask=None):
"""Compute average negative log-likelihood of dec_tokens given enc_tokens."""
sos = sos_id if sos_id is not None else pad_id
enc_input = jnp.array([enc_tokens])
src_mask = make_padding_mask(enc_input, pad_id)

enc_ffn = ffn_mask["encoder"] if ffn_mask else None
dec_ffn = ffn_mask["decoder"] if ffn_mask else None

if p_encode is not None and num_devices > 1:
enc_s = _shard_single(enc_input, num_devices)
src_mask_s = _shard_single(src_mask, num_devices)
encoder_out = p_encode(params, enc_s, src_mask_s)[0:1]
else:
p = params if num_devices <= 1 else jax_utils.unreplicate(params)
encoder_out = model.apply(
{"params": p}, enc_input, src_mask=src_mask, method="encode",
{"params": p}, enc_input, src_mask=src_mask, ffn_mask=enc_ffn, method="encode",
)

dec_in = [sos] + list(dec_tokens[:-1])
Expand All @@ -71,7 +73,7 @@ def score_sequence(model, params, enc_tokens, dec_tokens, pad_id, sos_id=None,
p = params if num_devices <= 1 else jax_utils.unreplicate(params)
logits = model.apply(
{"params": p}, dec_input, encoder_out,
self_mask=tgt_mask, method="decode",
self_mask=tgt_mask, ffn_mask=dec_ffn, method="decode",
)[0]

log_probs = jax.nn.log_softmax(logits if logits.ndim == 2 else logits[0])
Expand All @@ -81,7 +83,7 @@ def score_sequence(model, params, enc_tokens, dec_tokens, pad_id, sos_id=None,


def eval_wikitext2(model, params, tokenizer, max_samples=500, max_len=256,
num_devices=1, p_encode=None, p_decode=None):
num_devices=1, p_encode=None, p_decode=None, ffn_mask=None):
"""Perplexity on WikiText-2 test split."""
ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test")

Expand All @@ -91,6 +93,8 @@ def eval_wikitext2(model, params, tokenizer, max_samples=500, max_len=256,
total_tokens = 0
evaluated = 0

enc_ffn = ffn_mask["encoder"] if ffn_mask else None
dec_ffn = ffn_mask["decoder"] if ffn_mask else None
single_params = jax_utils.unreplicate(params) if num_devices > 1 else params

for example in ds:
Expand All @@ -116,7 +120,7 @@ def eval_wikitext2(model, params, tokenizer, max_samples=500, max_len=256,
encoder_out = p_encode(params, enc_s, src_mask_s)[0:1]
else:
encoder_out = model.apply(
{"params": single_params}, enc_input, src_mask=src_mask, method="encode",
{"params": single_params}, enc_input, src_mask=src_mask, ffn_mask=enc_ffn, method="encode",
)

dec_in = [sos_id] + list(dec_tokens[:-1])
Expand All @@ -131,7 +135,7 @@ def eval_wikitext2(model, params, tokenizer, max_samples=500, max_len=256,
else:
logits = model.apply(
{"params": single_params}, dec_input, encoder_out,
self_mask=tgt_mask, method="decode",
self_mask=tgt_mask, ffn_mask=dec_ffn, method="decode",
)

log_probs = jax.nn.log_softmax(logits[0] if logits.ndim == 3 else logits[0])
Expand All @@ -150,7 +154,7 @@ def eval_wikitext2(model, params, tokenizer, max_samples=500, max_len=256,


def eval_lambada(model, params, tokenizer, max_samples=500,
num_devices=1, p_encode=None, p_decode=None):
num_devices=1, p_encode=None, p_decode=None, ffn_mask=None):
"""Accuracy of predicting the final word on LAMBADA."""
ds = load_dataset("EleutherAI/lambada_openai", "default", split="test")

Expand All @@ -159,6 +163,8 @@ def eval_lambada(model, params, tokenizer, max_samples=500,
correct = 0
total = 0

enc_ffn = ffn_mask["encoder"] if ffn_mask else None
dec_ffn = ffn_mask["decoder"] if ffn_mask else None
single_params = jax_utils.unreplicate(params) if num_devices > 1 else params

for example in ds:
Expand All @@ -183,7 +189,7 @@ def eval_lambada(model, params, tokenizer, max_samples=500,
encoder_out = p_encode(params, enc_s, src_mask_s)[0:1]
else:
encoder_out = model.apply(
{"params": single_params}, enc_input, src_mask=src_mask, method="encode",
{"params": single_params}, enc_input, src_mask=src_mask, ffn_mask=enc_ffn, method="encode",
)

dec_in = jnp.array([[sos_id]])
Expand All @@ -197,7 +203,7 @@ def eval_lambada(model, params, tokenizer, max_samples=500,
else:
logits = model.apply(
{"params": single_params}, dec_in, encoder_out,
self_mask=tgt_mask, method="decode",
self_mask=tgt_mask, ffn_mask=dec_ffn, method="decode",
)

predicted = int(jnp.argmax(logits[0, 0] if logits.ndim == 3 else logits[0, 0]))
Expand All @@ -213,7 +219,7 @@ def eval_lambada(model, params, tokenizer, max_samples=500,


def eval_hellaswag(model, params, tokenizer, max_samples=500,
num_devices=1, p_encode=None, p_decode=None):
num_devices=1, p_encode=None, p_decode=None, ffn_mask=None):
"""Accuracy on HellaSwag by scoring each candidate ending."""
ds = load_dataset("Rowan/hellaswag", split="validation")

Expand Down Expand Up @@ -242,7 +248,7 @@ def eval_hellaswag(model, params, tokenizer, max_samples=500,
dec_tokens = dec_tokens[:64]
score = score_sequence(model, params, enc_tokens, dec_tokens, pad_id,
sos_id=sos_id, p_encode=p_encode, p_decode=p_decode,
num_devices=num_devices)
num_devices=num_devices, ffn_mask=ffn_mask)
scores.append(score)

predicted = int(np.argmax(scores))
Expand All @@ -258,7 +264,7 @@ def eval_hellaswag(model, params, tokenizer, max_samples=500,


def eval_arc_easy(model, params, tokenizer, max_samples=500,
num_devices=1, p_encode=None, p_decode=None):
num_devices=1, p_encode=None, p_decode=None, ffn_mask=None):
"""Accuracy on ARC-Easy by scoring each answer choice."""
ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="test")

Expand Down Expand Up @@ -288,7 +294,7 @@ def eval_arc_easy(model, params, tokenizer, max_samples=500,
dec_tokens = dec_tokens[:64]
score = score_sequence(model, params, enc_tokens, dec_tokens, pad_id,
sos_id=sos_id, p_encode=p_encode, p_decode=p_decode,
num_devices=num_devices)
num_devices=num_devices, ffn_mask=ffn_mask)
scores.append(score)

predicted_idx = int(np.argmax(scores))
Expand Down Expand Up @@ -317,18 +323,23 @@ def main(args):
print(f"Detected {num_devices} device(s) for data-parallel evaluation")

print(f"Loading checkpoint: {args.checkpoint}")
params, config = load_checkpoint(args.checkpoint)
params, config, ffn_mask = load_checkpoint(args.checkpoint)
model = EncoderDecoderTransformer(config)
tokenizer = get_tokenizer()

param_count = sum(x.size for x in jax.tree.leaves(params))
print(f"Model parameters: {param_count:,}")

# Replicate params across devices for pmap
enc_ffn = ffn_mask["encoder"] if ffn_mask else None
dec_ffn = ffn_mask["decoder"] if ffn_mask else None
if ffn_mask:
print(f"sub-model: topk FFN masking active")

if num_devices > 1:
params = jax_utils.replicate(params)
p_encode = _make_p_encode(model)
p_decode = _make_p_decode(model)
p_encode = _make_p_encode(model, enc_ffn=enc_ffn)
p_decode = _make_p_decode(model, dec_ffn=dec_ffn)
print(f"Params replicated across {num_devices} devices")
else:
p_encode = None
Expand All @@ -349,6 +360,7 @@ def main(args):
result = BENCHMARKS[name](
model, params, tokenizer, max_samples=args.max_samples,
num_devices=num_devices, p_encode=p_encode, p_decode=p_decode,
ffn_mask=ffn_mask,
)
results[name] = result

Expand Down
79 changes: 69 additions & 10 deletions src/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

With FFN interior matryoshka, d_model stays constant — only FFN intermediate
dimensions (gate_proj, up_proj, down_proj) are sliced.

For topk-trained models, exports full model + mask indices per factor.
"""

import os
Expand All @@ -18,18 +20,79 @@
_FFN_KERNEL_NAMES = {"gate_proj", "up_proj", "down_proj"}


def _to_numpy(tree):
"""Convert all JAX arrays in a pytree to numpy arrays."""
return jax.tree.map(
lambda x: np.asarray(x) if isinstance(x, jnp.ndarray) else x, tree
)


def _param_stats(tree):
"""Return (param_count, total_bytes) for a pytree of arrays."""
leaves = jax.tree.leaves(tree)
return sum(x.size for x in leaves), sum(x.nbytes for x in leaves)


def export_submodel(checkpoint_path, factor, output_path):
"""Slice a full matryoshka checkpoint to a sub-model at given shrink factor.
"""Export a matryoshka sub-model from a full checkpoint.

factor: how many times smaller the FFN width (e.g. 2 = half, 4 = quarter).
Attention, embeddings, and norms are unchanged.
For prefix-trained models: slices FFN weights to create a smaller d_ff.
For topk-trained models: saves full model + binary mask indices per factor.
"""

with open(checkpoint_path, "rb") as f:
data = pickle.load(f)
params = data["params"]
config = TransformerConfig(**data["config"])

mat_method = data.get("mat_method", "static-prefix")
if mat_method == "topk" and "mask_logits" in data:
return _export_topk(data, params, config, factor, output_path)
else:
return _export_prefix(params, config, factor, output_path)


def _export_topk(data, params, config, factor, output_path):
"""TopK export: full model + per-layer binary mask indices for FFN masking."""
mask_logits = np.asarray(data["mask_logits"]) # (n_mat, n_blocks, d_ff)
mat_factors = data.get("mat_factors", [])
if factor not in mat_factors:
raise ValueError(f"factor={factor} not found in mat_factors={mat_factors}")

factor_logits = mask_logits[mat_factors.index(factor)] # (n_blocks, d_ff)
ff_w = config.d_ff // factor
per_layer_indices = [np.sort(np.argsort(-factor_logits[b])[:ff_w]) for b in range(factor_logits.shape[0])]

params_np = _to_numpy(params)

os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
with open(output_path, "wb") as f:
pickle.dump({
"params": params_np,
"config": config.__dict__,
"mat_mask_indices": per_layer_indices,
"mat_factor": factor,
"mat_ff_width": ff_w,
}, f)

n_blocks = len(per_layer_indices)
orig_count, orig_bytes = _param_stats(params)

print(f"\n TopK export: {output_path}")
print(f" ─────────────────────────────────────")
print(f" d_ff (full) {config.d_ff:>12d}")
print(f" d_ff (masked) {ff_w:>12d}")
print(f" factor {str(factor)+'x':>12s}")
print(f" blocks {n_blocks:>12d} (per-layer masks)")
print(f" neurons/layer {ff_w:>12d}")
print(f" params (full) {orig_count:>12,d}")
print(f" size (MB) {orig_bytes / 1e6:>12.1f}")
print(f" Note: full weights kept; per-layer mask applied at FFN level")
print()


def _export_prefix(params, config, factor, output_path):
"""Prefix export: slice FFN weights to a smaller d_ff."""
d_ff_new = config.d_ff // factor
if d_ff_new == 0:
raise ValueError(f"factor={factor} too large: would give d_ff=0")
Expand Down Expand Up @@ -65,18 +128,14 @@ def slice_leaf(key_path, leaf):

new_config = replace(config, d_ff=d_ff_new)

sliced_np = jax.tree.map(
lambda x: np.asarray(x) if isinstance(x, jnp.ndarray) else x, sliced
)
sliced_np = _to_numpy(sliced)

os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
with open(output_path, "wb") as f:
pickle.dump({"params": sliced_np, "config": new_config.__dict__}, f)

orig_count = sum(x.size for x in jax.tree.leaves(params))
new_count = sum(x.size for x in jax.tree.leaves(sliced_np))
orig_bytes = sum(x.nbytes for x in jax.tree.leaves(params))
new_bytes = sum(x.nbytes for x in jax.tree.leaves(sliced_np))
orig_count, orig_bytes = _param_stats(params)
new_count, new_bytes = _param_stats(sliced_np)

print(f"\n Export complete: {output_path}")
print(f" ─────────────────────────────────────")
Expand Down
Loading