diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..1645593c --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "optimizers/flash-muon"] + path = optimizers/flash-muon + url = https://github.com/nil0x9/flash-muon.git diff --git a/configs/moe_config.py b/configs/moe_config.py index 7d316fa3..e0a1d864 100644 --- a/configs/moe_config.py +++ b/configs/moe_config.py @@ -33,6 +33,7 @@ class MoEModelConfig: grad_clip: float = 1.0 # Technical + use_flash_muon = True use_amp: bool = True vocab_size: Optional[int] = None log_milestones: Tuple[int, ...] = (2000, 5000, 10000) diff --git a/optimizers/__init__.py b/optimizers/__init__.py index 5dcfceef..8757135c 100644 --- a/optimizers/__init__.py +++ b/optimizers/__init__.py @@ -1,3 +1,9 @@ +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent / 'flash-muon')) + +from flash_muon import Muon as FlashMuon from .muon import Muon, zeropower_via_newtonschulz5 -__all__ = ['Muon', 'zeropower_via_newtonschulz5'] +__all__ = ["FlashMuon", "Muon", "zeropower_via_newtonschulz5"] diff --git a/optimizers/flash-muon b/optimizers/flash-muon new file mode 160000 index 00000000..80ac87fb --- /dev/null +++ b/optimizers/flash-muon @@ -0,0 +1 @@ +Subproject commit 80ac87fb49afc792b84eccb393d051b1ed8eee32 diff --git a/training/evaluation.py b/training/evaluation.py index afe197e1..cf4665b5 100644 --- a/training/evaluation.py +++ b/training/evaluation.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader from torch.amp import autocast from configs.moe_config import MoEModelConfig - +from utils.helpers import unpack_batch def evaluate_model(model: nn.Module, val_loader: DataLoader, config: MoEModelConfig): """Evaluate model performance""" @@ -17,10 +17,10 @@ def evaluate_model(model: nn.Module, val_loader: DataLoader, config: MoEModelCon device = next(model.parameters()).device with torch.no_grad(): - for i, (x, y) in enumerate(val_loader): + for i, batch in enumerate(val_loader): if i >= config.eval_steps: break - x, y = x.to(device), y.to(device) + x, y = unpack_batch(batch, device) with autocast('cuda', dtype=torch.float16, enabled=config.use_amp): # MoE model evaluation diff --git a/training/trainer.py b/training/trainer.py index 81cadff2..22a5e457 100644 --- a/training/trainer.py +++ b/training/trainer.py @@ -6,32 +6,32 @@ import matplotlib.pyplot as plt from torch.utils.data import DataLoader from torch.amp import autocast, GradScaler +from functools import partial from tqdm import tqdm from configs.moe_config import MoEModelConfig from models.moe_llm import MoEMinimalLLM -from optimizers.muon import Muon +from optimizers import Muon, FlashMuon from training.evaluation import evaluate_model -from utils.helpers import set_seed +from utils.helpers import set_seed, unpack_batch +def _is_muon_param(name: str, p: nn.Parameter) -> bool: + """Muon rule: 2-D weight matrix, no embeddings, no norm layers.""" + return ( + p.ndim == 2 + and p.requires_grad + and "token_embedding" not in name + and "norm" not in name + ) def setup_muon_optimizer(model: nn.Module, config: MoEModelConfig): """Setup Muon optimizer with hybrid approach""" - muon_params = [] - adamw_params = [] - - for name, param in model.named_parameters(): - if (param.ndim == 2 and - 'token_embedding' not in name and - 'norm' not in name and - param.requires_grad): - muon_params.append(param) - else: - adamw_params.append(param) - + muon_params, adamw_params = [], [] + for name, p in model.named_parameters(): + (muon_params if _is_muon_param(name, p) else adamw_params).append(p) print(f" Muon parameters: {sum(p.numel() for p in muon_params):,}") print(f" AdamW parameters: {sum(p.numel() for p in adamw_params):,}") - - muon_optimizer = Muon(muon_params, lr=config.muon_lr, momentum=0.95) + MuonCls = config.use_flash_muon and partial(FlashMuon, world_size=getattr(config, 'world_size', 1), rank=getattr(config, 'rank', 0)) or Muon + muon_optimizer = MuonCls(muon_params, lr=config.muon_lr, momentum=0.95) adamw_optimizer = torch.optim.AdamW(adamw_params, lr=config.muon_lr*0.1, weight_decay=config.weight_decay) return [muon_optimizer, adamw_optimizer] @@ -88,11 +88,11 @@ def lr_lambda(step): eval_times = [] while step < config.max_steps: - for batch_idx, (x, y) in enumerate(train_loader): + for batch_idx, batch in enumerate(train_loader): if step >= config.max_steps: break - x, y = x.to(device), y.to(device) + x, y = unpack_batch(batch, device) # Forward pass if config.use_amp: diff --git a/utils/helpers.py b/utils/helpers.py index 03e4cf2d..c1cf5f9a 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -17,3 +17,12 @@ def set_seed(seed: int = 42): def count_parameters(model): """Count the number of parameters in a model""" return sum(p.numel() for p in model.parameters() if p.requires_grad) + +def unpack_batch(batch, device): + """Always return (input_ids, labels) no matter what the loader gives us.""" + if isinstance(batch, dict): + return batch["input_ids"].to(device), batch["labels"].to(device) + elif isinstance(batch, (list, tuple)): + return batch[0].to(device), batch[1].to(device) + else: + raise TypeError("Unknown batch format")