Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 99 additions & 32 deletions megatron/core/optimizer/emerging_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,21 @@
try:
from emerging_optimizers import registry
from emerging_optimizers.orthogonalized_optimizers import (
AdaptiveMuon,
OrthogonalizedOptimizer,
get_muon_scale_factor,
)
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz_tp
from emerging_optimizers.scalar_optimizers import Lion # pylint: disable=unused-import

# It is necessary to import SOAP for the registry to work.
# It is necessary to import optimizers for the registry to work.
from emerging_optimizers.soap import SOAP # pylint: disable=unused-import

HAVE_EMERGING_OPTIMIZERS = True
except ImportError:
HAVE_EMERGING_OPTIMIZERS = False
OrthogonalizedOptimizer = object
AdaptiveMuon = object


logger = logging.getLogger(__name__)
Expand All @@ -46,6 +49,22 @@
# ===========================================================================


def _eopt_init_state_fn(opt, config=None):
"""Initialize emerging optimizer state for torch_dist checkpoint format."""
for group in opt.param_groups:
# Checkpoint init needs state for all parameters, including those without grads yet.
opt._init_group(group, skip_non_grad_params=False)


def _default_param_overrides_factory() -> Dict[ParamKey, Dict[str, Any]]:
"""Default param overrides: route non-linear/embedding params to Adam."""
return {
ParamKey(
predicate=ParamPredicate(name="nonlinear_or_embedding", fn=_is_nonlinear_or_embedding)
): {'optimizer': 'adam'}
}


@dataclass
class EmergingOptimizerEntry:
"""Everything needed to create and configure an emerging optimizer.
Expand All @@ -59,9 +78,11 @@ class EmergingOptimizerEntry:
"""

optimizer_cls: type
init_state_fn: Callable
config_to_kwargs: Callable | None
default_param_overrides: Dict[ParamKey, Dict[str, Any]] = field(default_factory=dict)
init_state_fn: Callable = _eopt_init_state_fn
config_to_kwargs: Callable | None = None
default_param_overrides: Dict[ParamKey, Dict[str, Any]] = field(
default_factory=_default_param_overrides_factory
)


def _create_emerging_optimizer(config, param_groups, eopt_name, model_chunks, pg_collection):
Expand Down Expand Up @@ -166,7 +187,11 @@ def scaled_orthogonalize_fn(
self.qkv_split_shapes = qkv_split_shapes

weight_decay_method = "decoupled" if use_decoupled_weight_decay else "l2"
super().__init__(
# Use explicit class call instead of super() so that subclasses with
# multiple inheritance (e.g. TensorParallelAdaptiveMuon) don't route
# through an intermediate class that doesn't accept scaled_orthogonalize_fn.
OrthogonalizedOptimizer.__init__(
self,
params,
lr,
momentum,
Expand Down Expand Up @@ -229,11 +254,60 @@ def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> t
return grad


def _eopt_init_state_fn(opt, config=None):
"""Initialize emerging optimizer state for torch_dist checkpoint format."""
for group in opt.param_groups:
# Checkpoint init needs state for all parameters, including those without grads yet.
opt._init_group(group, skip_non_grad_params=False)
class TensorParallelAdaptiveMuon(TensorParallelMuon, AdaptiveMuon):
"""Tensor Parallel Adaptive Muon optimizer."""

def __init__(
self,
params: ParamsT,
lr: float = 3e-4,
momentum: float = 0.95,
nesterov: bool = True,
weight_decay: float = 0.01,
use_decoupled_weight_decay: bool = True,
split_qkv: bool = False,
is_qkv_fn: Callable[[torch.Tensor], bool] | None = None,
qkv_split_shapes: tuple[int, int, int] | None = None,
fp32_matmul_prec: str = "medium",
coefficient_type: str = "quintic",
num_ns_steps: int = 5,
scale_mode: str = "spectral",
extra_scale_factor: float = 1.0,
pg_collection: Optional[ProcessGroupCollection] = None,
tp_mode: Literal["blockwise", "duplicated", "distributed"] = "duplicated",
moment2_method: Literal["adamuon", "normuon"] = "adamuon",
beta2: float = 0.95,
eps: float = 1e-8,
) -> None:
TensorParallelMuon.__init__(
self,
params,
lr=lr,
momentum=momentum,
nesterov=nesterov,
weight_decay=weight_decay,
use_decoupled_weight_decay=use_decoupled_weight_decay,
split_qkv=split_qkv,
is_qkv_fn=is_qkv_fn,
qkv_split_shapes=qkv_split_shapes,
fp32_matmul_prec=fp32_matmul_prec,
coefficient_type=coefficient_type,
num_ns_steps=num_ns_steps,
scale_mode=scale_mode,
extra_scale_factor=extra_scale_factor,
pg_collection=pg_collection,
tp_mode=tp_mode,
)
self.moment2_method = moment2_method

for group in self.param_groups:
group.setdefault("beta2", beta2)
group.setdefault("eps", eps)

@torch.no_grad() # type: ignore[misc]
def step(self, closure: Optional[Callable] = None) -> Optional[float]:
"""Step function"""
return AdaptiveMuon.step(self, closure)


def _kwargs_from_config(optimizer_cls: type, prefix: str, config) -> Dict[str, Any]:
Expand Down Expand Up @@ -266,6 +340,13 @@ def _muon_config_to_kwargs(config, model_chunks, pg_collection) -> Dict[str, Any
return kwargs


def _adaptive_muon_config_to_kwargs(config, model_chunks, pg_collection) -> Dict[str, Any]:
"""Convert OptimizerConfig to TensorParallelAdaptiveMuon constructor kwargs."""
kwargs = _muon_config_to_kwargs(config, model_chunks, pg_collection)
kwargs.update(_kwargs_from_config(TensorParallelAdaptiveMuon, "adaptive_muon", config))
return kwargs


def _default_adam_based_eopt_config_to_kwargs(
eopt_name, config, model_chunks, pg_collection
) -> Dict[str, Any]:
Expand All @@ -280,34 +361,20 @@ def _default_adam_based_eopt_config_to_kwargs(
# -----------------------------------------------------------------------
_EMERGING_OPTIMIZERS = {
'muon': EmergingOptimizerEntry(
optimizer_cls=TensorParallelMuon,
init_state_fn=_eopt_init_state_fn,
config_to_kwargs=_muon_config_to_kwargs,
default_param_overrides={
ParamKey(
predicate=ParamPredicate(
name="nonlinear_or_embedding", fn=_is_nonlinear_or_embedding
)
): {'optimizer': 'adam'}
},
)
optimizer_cls=TensorParallelMuon, config_to_kwargs=_muon_config_to_kwargs
),
"adaptive_muon": EmergingOptimizerEntry(
optimizer_cls=TensorParallelAdaptiveMuon, config_to_kwargs=_adaptive_muon_config_to_kwargs
),
}

# Register soap with default config
# TODO(skyw): register all emerging optimizers.
if HAVE_EMERGING_OPTIMIZERS:
for eopt_name in ["soap"]:
for eopt_name in registry.get_optimizer_name_list():
if eopt_name in _EMERGING_OPTIMIZERS:
# skip already registered local versions, e.g. TensorParallel versions.
continue
_EMERGING_OPTIMIZERS[eopt_name] = EmergingOptimizerEntry(
optimizer_cls=registry.get_optimizer_cls(eopt_name),
init_state_fn=_eopt_init_state_fn,
config_to_kwargs=None,
default_param_overrides={
ParamKey(
predicate=ParamPredicate(
name="nonlinear_or_embedding", fn=_is_nonlinear_or_embedding
)
): {'optimizer': 'adam'}
},
optimizer_cls=registry.get_optimizer_cls(eopt_name)
)
9 changes: 9 additions & 0 deletions megatron/core/optimizer/optimizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,15 @@ class OptimizerConfig:
soap_use_kl_shampoo: bool = True
"""Whether to use the KL-Shampoo preconditioner."""

adaptive_muon_moment2_method: str = "adamuon"
"""The method to use for the moment2 update in Adaptive Muon optimizer."""

adaptive_muon_beta2: float = 0.95
"""The beta2 parameter for the Adaptive Muon optimizer."""

adaptive_muon_eps: float = 1e-8
"""The eps parameter for the Adaptive Muon optimizer."""

#######################
# Distributed optimizer
#######################
Expand Down
2 changes: 1 addition & 1 deletion megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2256,7 +2256,7 @@ def _add_training_args(parser):
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd', 'muon', 'dist_muon', 'soap'],
choices=['adam', 'sgd', 'muon', 'dist_muon', 'soap', "adaptive_muon", "lion"],
help='Optimizer function. '
'Note: dist_muon is deprecated; use --optimizer muon '
'with --use-distributed-optimizer instead.')
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ flash_mla = [
]
transformer-engine = { git = "https://github.com/NVIDIA/TransformerEngine.git", rev = "5671fd3675906cda1ade26c24a65d3dedd88eb89" }
nemo-run = { git = "https://github.com/NVIDIA-NeMo/Run.git", rev = "01a9a8ba360f7b2908728ad0516e0ad9d936966d" }
emerging_optimizers = { git = "https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git", rev = "bc634ff8c0cf4fb5dbae0a531081281b499be3a0" }
emerging_optimizers = { git = "https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git", rev = "v0.2.0" }
fast-hadamard-transform = { git = "https://github.com/Dao-AILab/fast-hadamard-transform.git", rev = "f134af63deb2df17e1171a9ec1ea4a7d8604d5ca" }

[tool.isort]
Expand Down
Loading
Loading