-
Notifications
You must be signed in to change notification settings - Fork 28
Add model merge #721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model merge #721
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| #!/usr/bin/env bash | ||
| set -euo pipefail | ||
|
|
||
| # Demo commands for model_merge.py. Replace paths with real checkpoint dirs. | ||
| CKPT_A="out/run_a" | ||
| CKPT_B="out/run_b" | ||
|
|
||
| echo "==> L2-normalized merge (default)" | ||
| python3 model_merge.py "${CKPT_A}" "${CKPT_B}" --out_dir out/merge_l2 | ||
|
|
||
| echo "==> L2-normalized merge but skip final norm for wte/lm_head" | ||
| python3 model_merge.py "${CKPT_A}" "${CKPT_B}" \ | ||
| --out_dir out/merge_skip_final_norm \ | ||
| --skip_final_norm_wte_lm_head | ||
|
|
||
| echo "==> Simple averaging without any L2 normalization" | ||
| python3 model_merge.py "${CKPT_A}" "${CKPT_B}" \ | ||
| --out_dir out/merge_simple_avg \ | ||
| --no_l2_normalize \ | ||
| --simple_divisor 2.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| import argparse | ||
| import datetime | ||
| import os | ||
| import shutil | ||
| from typing import Dict, Any | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| L2_NORM_DIM = -1 | ||
| EPS = 1e-12 | ||
|
|
||
|
|
||
| def parse_args() -> argparse.Namespace: | ||
| parser = argparse.ArgumentParser( | ||
| description=( | ||
| "Merge two nanoGPT checkpoints by L2-normalizing vectors, " | ||
| "adding them, and L2-normalizing again." | ||
| ) | ||
| ) | ||
| parser.add_argument( | ||
| "ckpt_dir_a", | ||
| type=str, | ||
| help="Directory containing ckpt.pt from the first training run", | ||
| ) | ||
| parser.add_argument( | ||
| "ckpt_dir_b", | ||
| type=str, | ||
| help="Directory containing ckpt.pt from the second training run", | ||
| ) | ||
| parser.add_argument( | ||
| "--out_dir", | ||
| type=str, | ||
| default=None, | ||
| help="Directory to write the merged checkpoint (defaults to <ckpt_dir_a>_merge)", | ||
| ) | ||
| parser.add_argument( | ||
| "--skip_final_norm_wte_lm_head", | ||
| action="store_true", | ||
| help="Skip the final L2 normalization for wte/lm_head weights", | ||
| ) | ||
| parser.add_argument( | ||
| "--no_l2_normalize", | ||
| action="store_true", | ||
| help=( | ||
| "Disable all L2 normalizations and instead add and divide by --simple_divisor." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--simple_divisor", | ||
| type=float, | ||
| default=2.0, | ||
| help=( | ||
| "Divisor for simple merging (used only when --no_l2_normalize is set)." | ||
| ), | ||
| ) | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| def l2_normalize(tensor: torch.Tensor, dim: int = L2_NORM_DIM) -> torch.Tensor: | ||
| if tensor.ndim == 0: | ||
| return tensor | ||
| if tensor.ndim == 1: | ||
| dim = 0 | ||
| norm = tensor.norm(dim=dim, keepdim=True).clamp_min(EPS) | ||
| return tensor / norm | ||
|
|
||
|
|
||
| def is_wte_or_lm_head(key: str) -> bool: | ||
| parts = key.split(".") | ||
| for part in parts: | ||
| if part == "wte" or part.startswith("wte_"): | ||
| return True | ||
| if part == "lm_head" or part.startswith("lm_head_"): | ||
| return True | ||
| return False | ||
|
|
||
|
|
||
| def load_checkpoint(ckpt_dir: str) -> Dict[str, Any]: | ||
| ckpt_path = os.path.join(ckpt_dir, "ckpt.pt") | ||
| if not os.path.exists(ckpt_path): | ||
| raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") | ||
| return torch.load(ckpt_path, map_location="cpu", weights_only=True) | ||
|
|
||
|
|
||
| def main() -> None: | ||
| args = parse_args() | ||
|
|
||
| checkpoint_a = load_checkpoint(args.ckpt_dir_a) | ||
| checkpoint_b = load_checkpoint(args.ckpt_dir_b) | ||
|
|
||
| state_dict_a = checkpoint_a.get("model", checkpoint_a) | ||
| state_dict_b = checkpoint_b.get("model", checkpoint_b) | ||
|
|
||
| if state_dict_a.keys() != state_dict_b.keys(): | ||
| missing_a = sorted(set(state_dict_b.keys()) - set(state_dict_a.keys())) | ||
| missing_b = sorted(set(state_dict_a.keys()) - set(state_dict_b.keys())) | ||
| raise ValueError( | ||
| "Checkpoint parameter keys do not match. " | ||
| f"Missing in A: {missing_a[:5]}{'...' if len(missing_a) > 5 else ''}. " | ||
| f"Missing in B: {missing_b[:5]}{'...' if len(missing_b) > 5 else ''}." | ||
| ) | ||
|
|
||
| merged_state_dict = {} | ||
| for key, tensor_a in state_dict_a.items(): | ||
| tensor_b = state_dict_b[key] | ||
| if not torch.is_floating_point(tensor_a): | ||
| if tensor_a.shape != tensor_b.shape or not torch.equal(tensor_a, tensor_b): | ||
| raise ValueError(f"Non-floating tensor mismatch for key {key}") | ||
| merged_state_dict[key] = tensor_a | ||
| continue | ||
|
|
||
| if tensor_a.shape != tensor_b.shape: | ||
| raise ValueError( | ||
| f"Shape mismatch for {key}: {tensor_a.shape} vs {tensor_b.shape}" | ||
| ) | ||
|
|
||
| if args.no_l2_normalize: | ||
| merged = (tensor_a + tensor_b) / args.simple_divisor | ||
| merged_state_dict[key] = merged | ||
| continue | ||
|
|
||
| norm_a = l2_normalize(tensor_a) | ||
| norm_b = l2_normalize(tensor_b) | ||
| merged = norm_a + norm_b | ||
|
|
||
| if args.skip_final_norm_wte_lm_head and is_wte_or_lm_head(key): | ||
| merged_state_dict[key] = merged | ||
| else: | ||
| merged_state_dict[key] = l2_normalize(merged) | ||
|
|
||
| if isinstance(checkpoint_a, dict) and "model" in checkpoint_a: | ||
| checkpoint_a["model"] = merged_state_dict | ||
| else: | ||
| checkpoint_a = merged_state_dict | ||
|
|
||
| checkpoint_a.pop("optimizer", None) | ||
| checkpoint_a.pop("scheduler", None) | ||
|
Comment on lines
+137
to
+138
|
||
| if isinstance(checkpoint_a, dict): | ||
| checkpoint_a["iter_num"] = 0 | ||
| checkpoint_a["best_val_loss"] = 1e9 | ||
| checkpoint_a["best_iter"] = 0 | ||
gkielian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| checkpoint_a["best_tokens"] = 0 | ||
|
|
||
| out_dir = args.out_dir or f"{args.ckpt_dir_a.rstrip('/').rstrip(os.sep)}_merge" | ||
| os.makedirs(out_dir, exist_ok=True) | ||
| torch.save(checkpoint_a, os.path.join(out_dir, "ckpt.pt")) | ||
|
|
||
| meta_path = os.path.join(args.ckpt_dir_a, "meta.pkl") | ||
| if os.path.exists(meta_path): | ||
| shutil.copy2(meta_path, os.path.join(out_dir, "meta.pkl")) | ||
|
|
||
| print( | ||
| "✔ Merged checkpoint written to " | ||
| f"{out_dir} at {datetime.datetime.now().isoformat(timespec='seconds')}" | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
l2_normalizefunction has special handling for scalar (0-dim) and vector (1-dim) tensors but lacks documentation explaining this behavior. Add a docstring documenting that: (1) scalar tensors are returned unchanged, (2) 1-dim tensors normalize along dim=0, and (3) higher-dim tensors use the provided dim parameter (default=-1).