Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "optimizers/flash-muon"]
path = optimizers/flash-muon
url = https://github.com/nil0x9/flash-muon.git
1 change: 1 addition & 0 deletions configs/moe_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class MoEModelConfig:
grad_clip: float = 1.0

# Technical
use_flash_muon = True
use_amp: bool = True
vocab_size: Optional[int] = None
log_milestones: Tuple[int, ...] = (2000, 5000, 10000)
Expand Down
8 changes: 7 additions & 1 deletion optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent / 'flash-muon'))

from flash_muon import Muon as FlashMuon
from .muon import Muon, zeropower_via_newtonschulz5

__all__ = ['Muon', 'zeropower_via_newtonschulz5']
__all__ = ["FlashMuon", "Muon", "zeropower_via_newtonschulz5"]
1 change: 1 addition & 0 deletions optimizers/flash-muon
Submodule flash-muon added at 80ac87
6 changes: 3 additions & 3 deletions training/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.utils.data import DataLoader
from torch.amp import autocast
from configs.moe_config import MoEModelConfig

from utils.helpers import unpack_batch

def evaluate_model(model: nn.Module, val_loader: DataLoader, config: MoEModelConfig):
"""Evaluate model performance"""
Expand All @@ -17,10 +17,10 @@ def evaluate_model(model: nn.Module, val_loader: DataLoader, config: MoEModelCon
device = next(model.parameters()).device

with torch.no_grad():
for i, (x, y) in enumerate(val_loader):
for i, batch in enumerate(val_loader):
if i >= config.eval_steps:
break
x, y = x.to(device), y.to(device)
x, y = unpack_batch(batch, device)

with autocast('cuda', dtype=torch.float16, enabled=config.use_amp):
# MoE model evaluation
Expand Down
36 changes: 18 additions & 18 deletions training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,32 @@
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.amp import autocast, GradScaler
from functools import partial
from tqdm import tqdm
from configs.moe_config import MoEModelConfig
from models.moe_llm import MoEMinimalLLM
from optimizers.muon import Muon
from optimizers import Muon, FlashMuon
from training.evaluation import evaluate_model
from utils.helpers import set_seed
from utils.helpers import set_seed, unpack_batch

def _is_muon_param(name: str, p: nn.Parameter) -> bool:
"""Muon rule: 2-D weight matrix, no embeddings, no norm layers."""
return (
p.ndim == 2
and p.requires_grad
and "token_embedding" not in name
and "norm" not in name
)

def setup_muon_optimizer(model: nn.Module, config: MoEModelConfig):
"""Setup Muon optimizer with hybrid approach"""
muon_params = []
adamw_params = []

for name, param in model.named_parameters():
if (param.ndim == 2 and
'token_embedding' not in name and
'norm' not in name and
param.requires_grad):
muon_params.append(param)
else:
adamw_params.append(param)

muon_params, adamw_params = [], []
for name, p in model.named_parameters():
(muon_params if _is_muon_param(name, p) else adamw_params).append(p)
print(f" Muon parameters: {sum(p.numel() for p in muon_params):,}")
print(f" AdamW parameters: {sum(p.numel() for p in adamw_params):,}")

muon_optimizer = Muon(muon_params, lr=config.muon_lr, momentum=0.95)
MuonCls = config.use_flash_muon and partial(FlashMuon, world_size=getattr(config, 'world_size', 1), rank=getattr(config, 'rank', 0)) or Muon
muon_optimizer = MuonCls(muon_params, lr=config.muon_lr, momentum=0.95)
adamw_optimizer = torch.optim.AdamW(adamw_params, lr=config.muon_lr*0.1, weight_decay=config.weight_decay)

return [muon_optimizer, adamw_optimizer]
Expand Down Expand Up @@ -88,11 +88,11 @@ def lr_lambda(step):
eval_times = []

while step < config.max_steps:
for batch_idx, (x, y) in enumerate(train_loader):
for batch_idx, batch in enumerate(train_loader):
if step >= config.max_steps:
break

x, y = x.to(device), y.to(device)
x, y = unpack_batch(batch, device)

# Forward pass
if config.use_amp:
Expand Down
9 changes: 9 additions & 0 deletions utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ def set_seed(seed: int = 42):
def count_parameters(model):
"""Count the number of parameters in a model"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)

def unpack_batch(batch, device):
"""Always return (input_ids, labels) no matter what the loader gives us."""
if isinstance(batch, dict):
return batch["input_ids"].to(device), batch["labels"].to(device)
elif isinstance(batch, (list, tuple)):
return batch[0].to(device), batch[1].to(device)
else:
raise TypeError("Unknown batch format")