-
Notifications
You must be signed in to change notification settings - Fork 199
Mup scaling #2666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Mup scaling #2666
Changes from all commits
4d3cedd
f674b72
43c0eca
dcd56f5
180e9ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import pytest | ||
|
|
||
|
|
||
| @pytest.fixture(scope="session", autouse=True) | ||
| def ensure_test_data(tmp_path_factory): | ||
| """Override parent conftest fixture: training tests use MockGPTDatasetConfig and need no data.""" | ||
| yield tmp_path_factory.mktemp("test_data") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import logging | ||
| import os | ||
|
|
||
| import pytest | ||
|
|
@@ -168,6 +169,136 @@ def test_pretrain_with_checkpoint(self, tmp_path): | |
| # Clean up manually. | ||
| clear_directories(tmp_path) | ||
|
|
||
| @pytest.mark.run_only_on("GPU") | ||
| def test_pretrain_with_mup(self, tmp_path, caplog): | ||
| """ | ||
| Test end to end training with μP (Maximal Update Parameterization) enabled. | ||
|
|
||
| Verifies that use_mup=True flows through the full training stack: the model | ||
| config's mup_width_mult is computed by finalize(), get_model_config() on the | ||
| DDP-wrapped model still returns use_mup=True, and setup_optimizer applies the | ||
| per-parameter-class LR overrides without error. | ||
|
|
||
| Uses mup_base_hidden_size=1024 with hidden_size=2048 (width_mult=2.0) so that | ||
| the LR scaling is non-trivial and any failure to apply overrides would be visible. | ||
| """ | ||
| initialize_distributed() | ||
| shared_base_dir = broadcast_path(tmp_path) | ||
|
|
||
| tensorboard_dir = os.path.join(shared_base_dir, "tensorboard") | ||
|
|
||
| if torch.distributed.get_rank() == 0: | ||
| os.makedirs(tensorboard_dir, exist_ok=True) | ||
|
|
||
| torch.distributed.barrier() | ||
|
|
||
| try: | ||
| global_batch_size = 8 | ||
| micro_batch_size = 1 | ||
| seq_length = 512 | ||
| total_iters = 5 | ||
|
|
||
| model_cfg = Llama32ModelProvider1B( | ||
| tensor_model_parallel_size=1, | ||
| pipeline_model_parallel_size=1, | ||
| context_parallel_size=1, | ||
| sequence_parallel=False, | ||
| attention_softmax_in_fp32=True, | ||
| pipeline_dtype=torch.bfloat16, | ||
| bf16=True, | ||
| seq_length=seq_length, | ||
| make_vocab_size_divisible_by=128, | ||
| vocab_size=None, | ||
| num_layers=1, | ||
| use_mup=True, | ||
| mup_base_hidden_size=1024, # width_mult = 2048/1024 = 2.0 | ||
| ) | ||
|
|
||
| cfg = ConfigContainer( | ||
| model=model_cfg, | ||
| train=TrainingConfig( | ||
| train_iters=total_iters, | ||
| global_batch_size=global_batch_size, | ||
| micro_batch_size=micro_batch_size, | ||
| exit_signal_handler=True, | ||
| ), | ||
| validation=ValidationConfig( | ||
| eval_interval=5, | ||
| eval_iters=2, | ||
| ), | ||
| optimizer=OptimizerConfig( | ||
| optimizer="adam", | ||
| bf16=True, | ||
| fp16=False, | ||
| adam_beta1=0.9, | ||
| adam_beta2=0.95, | ||
| adam_eps=1e-8, | ||
| use_distributed_optimizer=True, | ||
| clip_grad=1.0, | ||
| lr=3e-3, | ||
| weight_decay=0.01, | ||
| min_lr=1e-6, | ||
| ), | ||
| scheduler=SchedulerConfig( | ||
| start_weight_decay=0.033, | ||
| end_weight_decay=0.033, | ||
| weight_decay_incr_style="constant", | ||
| lr_decay_style="cosine", | ||
| lr_warmup_iters=1, | ||
| lr_warmup_init=0.0, | ||
| lr_decay_iters=total_iters, | ||
| override_opt_param_scheduler=True, | ||
| ), | ||
| ddp=DistributedDataParallelConfig( | ||
| check_for_nan_in_grad=True, | ||
| grad_reduce_in_fp32=True, | ||
| overlap_grad_reduce=True, | ||
| overlap_param_gather=True, | ||
| average_in_collective=True, | ||
| use_distributed_optimizer=True, | ||
| ), | ||
| dataset=MockGPTDatasetConfig( | ||
| random_seed=1234, | ||
| reset_attention_mask=False, | ||
| reset_position_ids=False, | ||
| eod_mask_loss=False, | ||
| seq_length=seq_length, | ||
| num_dataset_builder_threads=1, | ||
| data_sharding=True, | ||
| dataloader_type="single", | ||
| num_workers=1, | ||
| ), | ||
| logger=LoggerConfig( | ||
| log_interval=5, | ||
| tensorboard_dir=tensorboard_dir, | ||
| ), | ||
| tokenizer=TokenizerConfig( | ||
| tokenizer_type="NullTokenizer", | ||
| vocab_size=10000, | ||
| ), | ||
| checkpoint=CheckpointConfig( | ||
| save_interval=100, | ||
| ckpt_format="torch_dist", | ||
| ), | ||
| rng=RNGConfig(seed=1234), | ||
| ) | ||
|
|
||
| with caplog.at_level(logging.INFO, logger="megatron.bridge.training.optim"): | ||
| pretrain(cfg, forward_step) | ||
|
|
||
| # Assert μP optimizer overrides were applied (not just a smoke test) | ||
| mup_log_messages = [r.message for r in caplog.records if "μP enabled" in r.message] | ||
| assert mup_log_messages, ( | ||
| "Expected μP optimizer override log message but found none. " | ||
| "Check that use_mup=True flows through setup_optimizer." | ||
| ) | ||
| assert "width_mult=2" in mup_log_messages[0], ( | ||
| f"Expected width_mult=2 in μP log, got: {mup_log_messages[0]}" | ||
| ) | ||
|
|
||
| finally: | ||
| clear_directories(tmp_path) | ||
|
|
||
|
Comment on lines
+172
to
+301
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Right now this is a smoke test only. Even if μP overrides stop being applied, the test can still pass as long as training does not crash. Please add at least one explicit check that μP-specific logic executed (for example, assert the μP optimizer override signal/log appears, or assert an observable μP-derived optimizer-group property). 🤖 Prompt for AI Agents |
||
| @pytest.mark.run_only_on("GPU") | ||
| def test_pretrain_vpp(self, tmp_path): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Tests for setup_optimizer in optim.py.""" | ||
|
|
||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| from megatron.core.optimizer import OptimizerConfig, ParamGroupOverride, ParamKey | ||
|
|
||
| from megatron.bridge.training.config import SchedulerConfig | ||
|
|
||
|
|
||
| class TestSetupOptimizerMuP: | ||
| """Tests for μP optimizer scaling in setup_optimizer.""" | ||
|
|
||
| def _make_optimizer_config(self, lr=1e-3, min_lr=1e-5, optimizer="adam"): | ||
| return OptimizerConfig(optimizer=optimizer, lr=lr, min_lr=min_lr, bf16=True) | ||
|
|
||
| def _make_scheduler_config(self): | ||
| cfg = SchedulerConfig(lr_decay_iters=1000, lr_decay_style="cosine") | ||
| cfg.lr_warmup_steps = 0 | ||
| cfg.lr_decay_steps = 1000 | ||
| cfg.wsd_decay_steps = None | ||
| return cfg | ||
|
|
||
| def _make_model_mock(self, use_mup=False, mup_width_mult=1.0): | ||
| model = MagicMock() | ||
| model_config = MagicMock() | ||
| model_config.use_mup = use_mup | ||
| model_config.mup_width_mult = mup_width_mult | ||
| return model, model_config | ||
|
|
||
| def _make_param_key(self): | ||
| """Create a simple ParamKey instance for use in fake overrides.""" | ||
| return ParamKey(name="*.weight") | ||
|
|
||
| @patch("megatron.bridge.training.optim._get_scheduler") | ||
| @patch("megatron.bridge.training.optim.get_megatron_optimizer") | ||
| @patch("megatron.bridge.training.optim.get_model_config") | ||
| def test_mup_disabled_skips_overrides(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler): | ||
| """When use_mup=False, get_mup_config_overrides is not called.""" | ||
| from megatron.bridge.training.optim import setup_optimizer | ||
|
|
||
| model, model_config = self._make_model_mock(use_mup=False) | ||
| mock_get_model_config.return_value = model_config | ||
| mock_get_optimizer.return_value = MagicMock() | ||
|
|
||
| with patch("megatron.bridge.training.optim.get_mup_config_overrides") as mock_mup: | ||
| setup_optimizer( | ||
| optimizer_config=self._make_optimizer_config(), | ||
| scheduler_config=self._make_scheduler_config(), | ||
| model=model, | ||
| ) | ||
| mock_mup.assert_not_called() | ||
|
|
||
| @patch("megatron.bridge.training.optim._get_scheduler") | ||
| @patch("megatron.bridge.training.optim.get_megatron_optimizer") | ||
| @patch("megatron.bridge.training.optim.get_model_config") | ||
| def test_mup_enabled_calls_overrides(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler): | ||
| """When use_mup=True, get_mup_config_overrides is called with correct args.""" | ||
| from megatron.bridge.training.optim import setup_optimizer | ||
|
|
||
| model, model_config = self._make_model_mock(use_mup=True, mup_width_mult=2.0) | ||
| mock_get_model_config.return_value = model_config | ||
| mock_get_optimizer.return_value = MagicMock() | ||
|
|
||
| fake_overrides = {self._make_param_key(): ParamGroupOverride(lr_mult=0.5)} | ||
|
|
||
| with patch("megatron.bridge.training.optim.get_mup_config_overrides", return_value=fake_overrides) as mock_mup: | ||
| optimizer_config = self._make_optimizer_config(lr=1e-3, optimizer="adam") | ||
| setup_optimizer( | ||
| optimizer_config=optimizer_config, | ||
| scheduler_config=self._make_scheduler_config(), | ||
| model=model, | ||
| ) | ||
| mock_mup.assert_called_once_with( | ||
| config=optimizer_config, | ||
| mup_width_mult=2.0, | ||
| optimizer_type="adam", | ||
| ) | ||
|
|
||
| @patch("megatron.bridge.training.optim._get_scheduler") | ||
| @patch("megatron.bridge.training.optim.get_megatron_optimizer") | ||
| @patch("megatron.bridge.training.optim.get_model_config") | ||
| def test_mup_overrides_merged_with_existing(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler): | ||
| """μP overrides are merged with existing config_overrides.""" | ||
| from megatron.bridge.training.optim import setup_optimizer | ||
|
|
||
| model, model_config = self._make_model_mock(use_mup=True, mup_width_mult=4.0) | ||
| mock_get_model_config.return_value = model_config | ||
|
|
||
| mup_key = ParamKey(name="*.weight") | ||
| existing_key = ParamKey(name="*.bias") | ||
| mup_overrides = {mup_key: ParamGroupOverride(lr_mult=0.25)} | ||
| existing_overrides = {existing_key: ParamGroupOverride(wd_mult=0.0)} | ||
|
|
||
| captured_overrides = {} | ||
|
|
||
| def capture_optimizer_call(**kwargs): | ||
| captured_overrides.update(kwargs.get("config_overrides") or {}) | ||
| return MagicMock() | ||
|
|
||
| mock_get_optimizer.side_effect = capture_optimizer_call | ||
|
|
||
| with patch("megatron.bridge.training.optim.get_mup_config_overrides", return_value=mup_overrides): | ||
| with patch( | ||
| "megatron.bridge.training.optim.OptimizerConfigOverrideProvider.build_config_overrides", | ||
| return_value=existing_overrides, | ||
| ): | ||
| setup_optimizer( | ||
| optimizer_config=self._make_optimizer_config(), | ||
| scheduler_config=self._make_scheduler_config(), | ||
| model=model, | ||
| ) | ||
|
|
||
| assert mup_key in captured_overrides | ||
| assert existing_key in captured_overrides | ||
|
|
||
| @patch("megatron.bridge.training.optim._get_scheduler") | ||
| @patch("megatron.bridge.training.optim.get_megatron_optimizer") | ||
| @patch("megatron.bridge.training.optim.get_model_config") | ||
| def test_mup_model_list_uses_first_chunk(self, mock_get_model_config, mock_get_optimizer, _mock_get_scheduler): | ||
| """When model is a list, get_model_config is called on the first chunk.""" | ||
| from megatron.bridge.training.optim import setup_optimizer | ||
|
|
||
| model1, model_config = self._make_model_mock(use_mup=False) | ||
| model2 = MagicMock() | ||
| mock_get_model_config.return_value = model_config | ||
| mock_get_optimizer.return_value = MagicMock() | ||
|
|
||
| setup_optimizer( | ||
| optimizer_config=self._make_optimizer_config(), | ||
| scheduler_config=self._make_scheduler_config(), | ||
| model=[model1, model2], | ||
| ) | ||
|
|
||
| mock_get_model_config.assert_called_once_with(model1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to do this in config container. validate()? all the overriding logic better put in the same place. If not possible, okay to leave it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I think this can't live in
ConfigContainer.validate()because the μP optimizer overrides require the post-DDP-wrapped model — specifically,get_model_config()needs to unwrap the DDP/FSDP shell to reach the underlyingTransformerConfig. Atvalidate()time the model hasn't been constructed yet (the config container is built beforesetup_model()is called). By the timesetup_optimizer()is called, the model is fully wrapped andget_model_config()can safely retrieveuse_mupandmup_width_mult.