Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/megatron/bridge/training/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Optional, Union

from megatron.core.optimizer import (
MegatronOptimizer,
OptimizerConfig,
get_megatron_optimizer,
get_mup_config_overrides,
)
from megatron.core.optimizer.muon import get_megatron_muon_optimizer
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.module import MegatronModule
from megatron.core.utils import get_model_config

from megatron.bridge.training.config import (
OptimizerConfigOverrideProvider,
Expand All @@ -31,6 +34,9 @@
)


G_LOGGER = logging.getLogger(__name__)


def setup_optimizer(
optimizer_config: OptimizerConfig,
scheduler_config: SchedulerConfig,
Expand Down Expand Up @@ -59,6 +65,22 @@ def setup_optimizer(
OptimizerConfigOverrideProviderContext(scheduler_config, optimizer_config, model)
)

# Apply μP optimizer scaling if enabled on the model config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to do this in config container. validate()? all the overriding logic better put in the same place. If not possible, okay to leave it here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I think this can't live in ConfigContainer.validate() because the μP optimizer overrides require the post-DDP-wrapped model — specifically, get_model_config() needs to unwrap the DDP/FSDP shell to reach the underlying TransformerConfig. At validate() time the model hasn't been constructed yet (the config container is built before setup_model() is called). By the time setup_optimizer() is called, the model is fully wrapped and get_model_config() can safely retrieve use_mup and mup_width_mult.

model_chunks = model if isinstance(model, list) else [model]
model_config = get_model_config(model_chunks[0])
if getattr(model_config, "use_mup", False):
mup_overrides = get_mup_config_overrides(
config=optimizer_config,
mup_width_mult=model_config.mup_width_mult,
optimizer_type=optimizer_config.optimizer,
)
if mup_overrides:
config_overrides = {**(config_overrides or {}), **mup_overrides}
G_LOGGER.info(
f"μP enabled (width_mult={model_config.mup_width_mult:.4g}): "
f"applied {len(mup_overrides)} optimizer param-group override(s)."
)

if hasattr(optimizer_config, "provide"):
optimizer = optimizer_config.provide(
model_chunks=model,
Expand Down
21 changes: 21 additions & 0 deletions tests/functional_tests/training/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest


@pytest.fixture(scope="session", autouse=True)
def ensure_test_data(tmp_path_factory):
"""Override parent conftest fixture: training tests use MockGPTDatasetConfig and need no data."""
yield tmp_path_factory.mktemp("test_data")
131 changes: 131 additions & 0 deletions tests/functional_tests/training/test_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

import pytest
Expand Down Expand Up @@ -168,6 +169,136 @@ def test_pretrain_with_checkpoint(self, tmp_path):
# Clean up manually.
clear_directories(tmp_path)

@pytest.mark.run_only_on("GPU")
def test_pretrain_with_mup(self, tmp_path, caplog):
"""
Test end to end training with μP (Maximal Update Parameterization) enabled.

Verifies that use_mup=True flows through the full training stack: the model
config's mup_width_mult is computed by finalize(), get_model_config() on the
DDP-wrapped model still returns use_mup=True, and setup_optimizer applies the
per-parameter-class LR overrides without error.

Uses mup_base_hidden_size=1024 with hidden_size=2048 (width_mult=2.0) so that
the LR scaling is non-trivial and any failure to apply overrides would be visible.
"""
initialize_distributed()
shared_base_dir = broadcast_path(tmp_path)

tensorboard_dir = os.path.join(shared_base_dir, "tensorboard")

if torch.distributed.get_rank() == 0:
os.makedirs(tensorboard_dir, exist_ok=True)

torch.distributed.barrier()

try:
global_batch_size = 8
micro_batch_size = 1
seq_length = 512
total_iters = 5

model_cfg = Llama32ModelProvider1B(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
context_parallel_size=1,
sequence_parallel=False,
attention_softmax_in_fp32=True,
pipeline_dtype=torch.bfloat16,
bf16=True,
seq_length=seq_length,
make_vocab_size_divisible_by=128,
vocab_size=None,
num_layers=1,
use_mup=True,
mup_base_hidden_size=1024, # width_mult = 2048/1024 = 2.0
)

cfg = ConfigContainer(
model=model_cfg,
train=TrainingConfig(
train_iters=total_iters,
global_batch_size=global_batch_size,
micro_batch_size=micro_batch_size,
exit_signal_handler=True,
),
validation=ValidationConfig(
eval_interval=5,
eval_iters=2,
),
optimizer=OptimizerConfig(
optimizer="adam",
bf16=True,
fp16=False,
adam_beta1=0.9,
adam_beta2=0.95,
adam_eps=1e-8,
use_distributed_optimizer=True,
clip_grad=1.0,
lr=3e-3,
weight_decay=0.01,
min_lr=1e-6,
),
scheduler=SchedulerConfig(
start_weight_decay=0.033,
end_weight_decay=0.033,
weight_decay_incr_style="constant",
lr_decay_style="cosine",
lr_warmup_iters=1,
lr_warmup_init=0.0,
lr_decay_iters=total_iters,
override_opt_param_scheduler=True,
),
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=True,
overlap_param_gather=True,
average_in_collective=True,
use_distributed_optimizer=True,
),
dataset=MockGPTDatasetConfig(
random_seed=1234,
reset_attention_mask=False,
reset_position_ids=False,
eod_mask_loss=False,
seq_length=seq_length,
num_dataset_builder_threads=1,
data_sharding=True,
dataloader_type="single",
num_workers=1,
),
logger=LoggerConfig(
log_interval=5,
tensorboard_dir=tensorboard_dir,
),
tokenizer=TokenizerConfig(
tokenizer_type="NullTokenizer",
vocab_size=10000,
),
checkpoint=CheckpointConfig(
save_interval=100,
ckpt_format="torch_dist",
),
rng=RNGConfig(seed=1234),
)

with caplog.at_level(logging.INFO, logger="megatron.bridge.training.optim"):
pretrain(cfg, forward_step)

# Assert μP optimizer overrides were applied (not just a smoke test)
mup_log_messages = [r.message for r in caplog.records if "μP enabled" in r.message]
assert mup_log_messages, (
"Expected μP optimizer override log message but found none. "
"Check that use_mup=True flows through setup_optimizer."
)
assert "width_mult=2" in mup_log_messages[0], (
f"Expected width_mult=2 in μP log, got: {mup_log_messages[0]}"
)

finally:
clear_directories(tmp_path)

Comment on lines +172 to +301
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

test_pretrain_with_mup needs an explicit μP assertion.

Right now this is a smoke test only. Even if μP overrides stop being applied, the test can still pass as long as training does not crash. Please add at least one explicit check that μP-specific logic executed (for example, assert the μP optimizer override signal/log appears, or assert an observable μP-derived optimizer-group property).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/functional_tests/training/test_pretrain.py` around lines 171 - 289, Add
an explicit μP assertion in test_pretrain_with_mup to verify μP logic executed:
after constructing model_cfg with use_mup=True and mup_base_hidden_size=1024 and
before/after calling pretrain(forward_step), inspect the optimizer or model
config returned by the training setup (e.g., from pretrain or the object created
by Llama32ModelProvider1B) and assert a μP-specific signal—such as an optimizer
param-group lr scaling or a flag indicating mup overrides were applied (look for
methods/names like setup_optimizer, get_model_config, use_mup, mup_width_mult,
or the optimizer param_groups returned by pretrain)—so the test fails if no μP
overrides were applied.

@pytest.mark.run_only_on("GPU")
def test_pretrain_vpp(self, tmp_path):
"""
Expand Down
148 changes: 148 additions & 0 deletions tests/unit_tests/training/test_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for setup_optimizer in optim.py."""

from unittest.mock import MagicMock, patch

from megatron.core.optimizer import OptimizerConfig, ParamGroupOverride, ParamKey

from megatron.bridge.training.config import SchedulerConfig


class TestSetupOptimizerMuP:
"""Tests for μP optimizer scaling in setup_optimizer."""

def _make_optimizer_config(self, lr=1e-3, min_lr=1e-5, optimizer="adam"):
return OptimizerConfig(optimizer=optimizer, lr=lr, min_lr=min_lr, bf16=True)

def _make_scheduler_config(self):
cfg = SchedulerConfig(lr_decay_iters=1000, lr_decay_style="cosine")
cfg.lr_warmup_steps = 0
cfg.lr_decay_steps = 1000
cfg.wsd_decay_steps = None
return cfg

def _make_model_mock(self, use_mup=False, mup_width_mult=1.0):
model = MagicMock()
model_config = MagicMock()
model_config.use_mup = use_mup
model_config.mup_width_mult = mup_width_mult
return model, model_config

def _make_param_key(self):
"""Create a simple ParamKey instance for use in fake overrides."""
return ParamKey(name="*.weight")

@patch("megatron.bridge.training.optim._get_scheduler")
@patch("megatron.bridge.training.optim.get_megatron_optimizer")
@patch("megatron.bridge.training.optim.get_model_config")
def test_mup_disabled_skips_overrides(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler):
"""When use_mup=False, get_mup_config_overrides is not called."""
from megatron.bridge.training.optim import setup_optimizer

model, model_config = self._make_model_mock(use_mup=False)
mock_get_model_config.return_value = model_config
mock_get_optimizer.return_value = MagicMock()

with patch("megatron.bridge.training.optim.get_mup_config_overrides") as mock_mup:
setup_optimizer(
optimizer_config=self._make_optimizer_config(),
scheduler_config=self._make_scheduler_config(),
model=model,
)
mock_mup.assert_not_called()

@patch("megatron.bridge.training.optim._get_scheduler")
@patch("megatron.bridge.training.optim.get_megatron_optimizer")
@patch("megatron.bridge.training.optim.get_model_config")
def test_mup_enabled_calls_overrides(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler):
"""When use_mup=True, get_mup_config_overrides is called with correct args."""
from megatron.bridge.training.optim import setup_optimizer

model, model_config = self._make_model_mock(use_mup=True, mup_width_mult=2.0)
mock_get_model_config.return_value = model_config
mock_get_optimizer.return_value = MagicMock()

fake_overrides = {self._make_param_key(): ParamGroupOverride(lr_mult=0.5)}

with patch("megatron.bridge.training.optim.get_mup_config_overrides", return_value=fake_overrides) as mock_mup:
optimizer_config = self._make_optimizer_config(lr=1e-3, optimizer="adam")
setup_optimizer(
optimizer_config=optimizer_config,
scheduler_config=self._make_scheduler_config(),
model=model,
)
mock_mup.assert_called_once_with(
config=optimizer_config,
mup_width_mult=2.0,
optimizer_type="adam",
)

@patch("megatron.bridge.training.optim._get_scheduler")
@patch("megatron.bridge.training.optim.get_megatron_optimizer")
@patch("megatron.bridge.training.optim.get_model_config")
def test_mup_overrides_merged_with_existing(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler):
"""μP overrides are merged with existing config_overrides."""
from megatron.bridge.training.optim import setup_optimizer

model, model_config = self._make_model_mock(use_mup=True, mup_width_mult=4.0)
mock_get_model_config.return_value = model_config

mup_key = ParamKey(name="*.weight")
existing_key = ParamKey(name="*.bias")
mup_overrides = {mup_key: ParamGroupOverride(lr_mult=0.25)}
existing_overrides = {existing_key: ParamGroupOverride(wd_mult=0.0)}

captured_overrides = {}

def capture_optimizer_call(**kwargs):
captured_overrides.update(kwargs.get("config_overrides") or {})
return MagicMock()

mock_get_optimizer.side_effect = capture_optimizer_call

with patch("megatron.bridge.training.optim.get_mup_config_overrides", return_value=mup_overrides):
with patch(
"megatron.bridge.training.optim.OptimizerConfigOverrideProvider.build_config_overrides",
return_value=existing_overrides,
):
setup_optimizer(
optimizer_config=self._make_optimizer_config(),
scheduler_config=self._make_scheduler_config(),
model=model,
)

assert mup_key in captured_overrides
assert existing_key in captured_overrides

@patch("megatron.bridge.training.optim._get_scheduler")
@patch("megatron.bridge.training.optim.get_megatron_optimizer")
@patch("megatron.bridge.training.optim.get_model_config")
def test_mup_model_list_uses_first_chunk(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler):
"""When model is a list, get_model_config is called on the first chunk."""
from megatron.bridge.training.optim import setup_optimizer

model1, model_config = self._make_model_mock(use_mup=False)
model2 = MagicMock()
mock_get_model_config.return_value = model_config
mock_get_optimizer.return_value = MagicMock()

setup_optimizer(
optimizer_config=self._make_optimizer_config(),
scheduler_config=self._make_scheduler_config(),
model=[model1, model2],
)

mock_get_model_config.assert_called_once_with(model1)
Loading