diff --git a/src/megatron/bridge/models/common/unimodal.py b/src/megatron/bridge/models/common/unimodal.py index c12c175c6f..027151bb75 100644 --- a/src/megatron/bridge/models/common/unimodal.py +++ b/src/megatron/bridge/models/common/unimodal.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import logging logger = logging.getLogger(__name__) -from typing import TYPE_CHECKING, Any, Callable +from typing import Any, Callable import torch from megatron.core import tensor_parallel @@ -36,9 +34,6 @@ from megatron.core.utils import get_model_config -if TYPE_CHECKING: - from megatron.bridge.training.mixed_precision import MixedPrecisionConfig - try: from megatron.core.fp8_utils import correct_amax_history_if_needed except ImportError: @@ -58,7 +53,6 @@ def unimodal_build_distributed_models( mixed_precision_wrapper: Callable[[Any, MegatronModule], MegatronModule] | None = Float16Module, pre_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None = None, model_type: ModelType = ModelType.encoder_or_decoder, - mixed_precision_config: MixedPrecisionConfig | None = None, ) -> list[MegatronModule]: """Build model stages and wrap for distributed training. @@ -87,7 +81,6 @@ def unimodal_build_distributed_models( Pass ``None`` to skip. pre_wrap_hook: Hook applied to the model stage list before any wrapping. model_type: Deprecated flag, only used for backwards compatibility. - mixed_precision_config: Mixed-precision config for DDP wrapper. Returns: List of model stages, wrapped and ready for distributed training. @@ -151,7 +144,6 @@ def unimodal_build_distributed_models( use_megatron_fsdp=use_megatron_fsdp, use_torch_fsdp2=use_torch_fsdp2, pg_collection=pg_collection, - mixed_precision_config=mixed_precision_config, ) return model_list @@ -204,7 +196,6 @@ def _ddp_wrap( overlap_param_gather_with_optimizer_step: bool, use_megatron_fsdp: bool = False, use_torch_fsdp2: bool = False, - mixed_precision_config: MixedPrecisionConfig | None = None, *, pg_collection: ProcessGroupCollection, ) -> list[MegatronModule]: @@ -218,23 +209,15 @@ def _ddp_wrap( for overlapping parameter gathering with optimizer step use_megatron_fsdp: Whether to use Megatron FSDP. use_torch_fsdp2: Whether to use PyTorch FSDP v2 instead of DDP - mixed_precision_config: Mixed-precision config for DDP wrapper. pg_collection: Model communication process groups. Returns: list[MegatronModule]: List of DDP/FSDP wrapped model modules """ - ddp_init_kwargs = {} if use_megatron_fsdp: DP = FullyShardedDataParallel if use_torch_fsdp2: raise ValueError("Using use_megatron_fsdp and use_torch_fsdp2 at the same time is not supported.") - if mixed_precision_config is not None: - # Also pass the mixed-precision arguments for Megatron-FSDP only. - mixed_precision_config.finalize() - ddp_init_kwargs["main_params_dtype"] = mixed_precision_config.megatron_fsdp_main_params_dtype - ddp_init_kwargs["main_grads_dtype"] = mixed_precision_config.megatron_fsdp_main_grads_dtype - ddp_init_kwargs["grad_comm_dtype"] = mixed_precision_config.megatron_fsdp_grad_comm_dtype elif use_torch_fsdp2: DP = TorchFullyShardedDataParallel else: @@ -255,7 +238,6 @@ def _ddp_wrap( # model chunks is overlapped with compute anyway. disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step, pg_collection=pg_collection, - **ddp_init_kwargs, ) for (model_chunk_idx, model_chunk) in enumerate(model) ] diff --git a/src/megatron/bridge/models/model_provider.py b/src/megatron/bridge/models/model_provider.py index 1a196daf05..f9ae1ccbff 100644 --- a/src/megatron/bridge/models/model_provider.py +++ b/src/megatron/bridge/models/model_provider.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - import abc import os from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Generic, TypedDict, TypeVar, Union +from typing import Any, Callable, Generic, TypedDict, TypeVar, Union from megatron.bridge.models.common.unimodal import _ddp_wrap, _print_num_params @@ -57,10 +55,6 @@ from megatron.bridge.utils.instantiate_utils import InstantiationMode -if TYPE_CHECKING: - from megatron.bridge.training.mixed_precision import MixedPrecisionConfig - - try: from megatron.core.fp8_utils import correct_amax_history_if_needed except ImportError: @@ -129,7 +123,6 @@ def provide_distributed_model( | None = None, post_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None = None, mixed_precision_wrapper: Callable[[Any, MegatronModule], MegatronModule] | None = Float16Module, - mixed_precision_config: MixedPrecisionConfig | None = None, pg_collection: ProcessGroupCollection | None = None, ) -> list[ModelT]: """Instantiate and wrap the model for distributed training. @@ -158,8 +151,6 @@ def provide_distributed_model( this will override all hooks registered via `register_post_wrap_hook`. mixed_precision_wrapper: A module wrapper (e.g., `Float16Module`) applied when fp16/bf16 is enabled. If None, no mixed precision wrapper is applied. - mixed_precision_config: Optional MixedPrecisionConfig. Used to configure mixed-precision - features on models during initialization. pg_collection: Optional pre-initialized ProcessGroupCollection. If provided, skips model parallel initialization and uses the provided collection directly. This is used when `use_decentralized_pg=True` in the distributed config. @@ -216,7 +207,6 @@ def composed_pre_wrap_hook(model: list[MegatronModule]) -> list[MegatronModule]: init_model_with_meta_device=init_model_with_meta_device, pre_wrap_hook=final_pre_wrap_hook, mixed_precision_wrapper=mixed_precision_wrapper, - mixed_precision_config=mixed_precision_config, pg_collection=pg_collection, ) @@ -501,7 +491,6 @@ def get_model( ] | None = None, mixed_precision_wrapper: Callable[[Any, MegatronModule], MegatronModule] | None = Float16Module, - mixed_precision_config: MixedPrecisionConfig | None = None, *, pg_collection: ProcessGroupCollection, ) -> list[MegatronModule]: @@ -536,8 +525,6 @@ def get_model( hooks will be executed in order. mixed_precision_wrapper: Wrapper class/function applied when fp16/bf16 is enabled. Defaults to Megatron-Core's `Float16Module`. If None, the wrapper is not applied. - mixed_precision_config: MixedPrecisionConfig that controls mixed-precision features - for the created model. Returns: list[MegatronModule]: List of model modules. Contains multiple modules @@ -620,7 +607,6 @@ def get_model( overlap_param_gather_with_optimizer_step, use_megatron_fsdp=use_megatron_fsdp, use_torch_fsdp2=use_torch_fsdp2, - mixed_precision_config=mixed_precision_config, pg_collection=pg_collection, ) diff --git a/src/megatron/bridge/training/mixed_precision.py b/src/megatron/bridge/training/mixed_precision.py index c30ad943b6..284f29117a 100644 --- a/src/megatron/bridge/training/mixed_precision.py +++ b/src/megatron/bridge/training/mixed_precision.py @@ -66,17 +66,6 @@ class MixedPrecisionConfig: num_layers_at_start_in_bf16: int = 0 num_layers_at_end_in_bf16: int = 0 reuse_grad_buf_for_mxfp8_param_ag: bool = False - # Megatron-FSDP MixedPrecisionPolicy - megatron_fsdp_main_params_dtype: str | Optional[torch.dtype] = torch.float32 - megatron_fsdp_main_grads_dtype: str | Optional[torch.dtype] = None - megatron_fsdp_grad_comm_dtype: str | Optional[torch.dtype] = None - - def __post_init__(self): - if self.grad_reduce_in_fp32: - if self.megatron_fsdp_main_grads_dtype is None: - object.__setattr__(self, "megatron_fsdp_main_grads_dtype", torch.float32) - if self.megatron_fsdp_grad_comm_dtype is None: - object.__setattr__(self, "megatron_fsdp_grad_comm_dtype", torch.float32) def __setattr__(self, name: str, value) -> None: # Use object.__setattr__ to avoid recursion @@ -89,35 +78,6 @@ def __setattr__(self, name: str, value) -> None: elif name == "fp8_param" and hasattr(self, "fp8_param_gather"): if self.fp8_param_gather != value: object.__setattr__(self, "fp8_param_gather", value) - if ( - name == "grad_reduce_in_fp32" - and hasattr(self, "megatron_fsdp_main_grads_dtype") - and hasattr(self, "megatron_fsdp_grad_comm_dtype") - ): - if value: - # Legacy argument for Megatron-FSDP - Gradients used to be reduced in - # the same data-type as the main gradient data-type. Recommend using - # the new Megatron-FSDP mixed-precision arguments to control this! - object.__setattr__(self, "megatron_fsdp_main_grads_dtype", torch.float32) - object.__setattr__(self, "megatron_fsdp_grad_comm_dtype", torch.float32) - else: - # Default back to "auto". - object.__setattr__(self, "megatron_fsdp_main_grads_dtype", None) - object.__setattr__(self, "megatron_fsdp_grad_comm_dtype", None) - if name in ( - "megatron_fsdp_main_params_dtype", - "megatron_fsdp_main_grads_dtype", - "megatron_fsdp_grad_comm_dtype", - ) and isinstance(value, str): - # Map string options to torch.dtype or None. - if value == "fp32": - object.__setattr__(self, name, torch.float32) - elif value == "bf16": - object.__setattr__(self, name, torch.bfloat16) - elif value == "fp16": - object.__setattr__(self, name, torch.float16) - elif value == "auto": - object.__setattr__(self, name, None) def finalize(self): # If fp8_param is None, initialize it from fp8_param_gather @@ -137,23 +97,6 @@ def finalize(self): if self.fp4 and not is_te_min_version("2.7.0.dev0"): raise ValueError("fp4 requires Transformer Engine >= 2.7.0.dev0 for NVFP4BlockScaling support.") - if self.grad_reduce_in_fp32: - self.megatron_fsdp_main_grads_dtype = torch.float32 - self.megatron_fsdp_grad_comm_dtype = torch.float32 - else: - self.megatron_fsdp_main_grads_dtype = None - self.megatron_fsdp_grad_comm_dtype = None - for mfsdp_mp_arg in ( - self.megatron_fsdp_main_params_dtype, - self.megatron_fsdp_main_grads_dtype, - self.megatron_fsdp_grad_comm_dtype, - ): - if isinstance(mfsdp_mp_arg, str): - raise ValueError( - f"[MixedPrecisionConfig] Could not map {mfsdp_mp_arg} to torch.dtype or 'auto'. " - "Options: 'fp32', 'fp16', 'bf16', 'auto'" - ) - def setup( self, model_config: GPTModelProvider | T5ModelProvider, diff --git a/src/megatron/bridge/training/setup.py b/src/megatron/bridge/training/setup.py index 1b5d81cf31..671de9b692 100644 --- a/src/megatron/bridge/training/setup.py +++ b/src/megatron/bridge/training/setup.py @@ -225,7 +225,6 @@ def modelopt_pre_wrap_hook(model): overlap_param_gather_with_optimizer_step=cfg.optimizer.overlap_param_gather_with_optimizer_step, data_parallel_random_init=cfg.rng.data_parallel_random_init, pg_collection=pg_collection, - mixed_precision_config=cfg.mixed_precision, ) cfg.model.timers = timers diff --git a/tests/unit_tests/models/common/test_unimodal.py b/tests/unit_tests/models/common/test_unimodal.py index f50335f5ee..896bbd2d6f 100644 --- a/tests/unit_tests/models/common/test_unimodal.py +++ b/tests/unit_tests/models/common/test_unimodal.py @@ -497,70 +497,6 @@ def test_returns_list_of_wrapped_modules( assert isinstance(result, list) assert len(result) == 2 - @patch("megatron.bridge.models.common.unimodal.TorchFullyShardedDataParallel") - @patch("megatron.bridge.models.common.unimodal.FullyShardedDataParallel") - @patch("megatron.bridge.models.common.unimodal.DistributedDataParallel") - @patch("megatron.bridge.models.common.unimodal.get_model_config") - @patch("torch.cuda.stream", new_callable=MagicMock) - @patch("torch.cuda.current_stream") - @patch("torch.cuda.Stream") - def test_mixed_precision_config_forwarded_to_megatron_fsdp( - self, mock_stream, mock_curr, mock_ctx, mock_cfg, mock_ddp, mock_fsdp, mock_torch_fsdp - ): - mock_ctx.return_value.__enter__ = Mock(return_value=None) - mock_ctx.return_value.__exit__ = Mock(return_value=False) - - mp_cfg = Mock() - mp_cfg.megatron_fsdp_main_params_dtype = torch.float32 - mp_cfg.megatron_fsdp_main_grads_dtype = torch.bfloat16 - mp_cfg.megatron_fsdp_grad_comm_dtype = torch.bfloat16 - - _ddp_wrap( - self.model, - False, - self.ddp_config, - False, - use_megatron_fsdp=True, - pg_collection=self.pg, - mixed_precision_config=mp_cfg, - ) - - mp_cfg.finalize.assert_called_once() - for call in mock_fsdp.call_args_list: - assert call.kwargs["main_params_dtype"] == torch.float32 - assert call.kwargs["main_grads_dtype"] == torch.bfloat16 - assert call.kwargs["grad_comm_dtype"] == torch.bfloat16 - - @patch("megatron.bridge.models.common.unimodal.TorchFullyShardedDataParallel") - @patch("megatron.bridge.models.common.unimodal.FullyShardedDataParallel") - @patch("megatron.bridge.models.common.unimodal.DistributedDataParallel") - @patch("megatron.bridge.models.common.unimodal.get_model_config") - @patch("torch.cuda.stream", new_callable=MagicMock) - @patch("torch.cuda.current_stream") - @patch("torch.cuda.Stream") - def test_mixed_precision_config_ignored_for_standard_ddp( - self, mock_stream, mock_curr, mock_ctx, mock_cfg, mock_ddp, mock_fsdp, mock_torch_fsdp - ): - mock_ctx.return_value.__enter__ = Mock(return_value=None) - mock_ctx.return_value.__exit__ = Mock(return_value=False) - - mp_cfg = Mock() - _ddp_wrap( - self.model, - False, - self.ddp_config, - False, - use_megatron_fsdp=False, - pg_collection=self.pg, - mixed_precision_config=mp_cfg, - ) - - mp_cfg.finalize.assert_not_called() - for call in mock_ddp.call_args_list: - assert "main_params_dtype" not in call.kwargs - assert "main_grads_dtype" not in call.kwargs - assert "grad_comm_dtype" not in call.kwargs - # ============================================================================= # Section 6 — TestUnimodalBuildDistributedModels @@ -900,37 +836,3 @@ def test_returns_final_model_list(self): assert result is ddp_result finally: self._stop_patches() - - def test_mixed_precision_config_forwarded_to_ddp_wrap(self): - mocks = self._standard_patches() - try: - ddp_config = Mock() - mp_cfg = Mock() - unimodal_build_distributed_models( - Mock(), - self.transformer_config, - self.pg, - ddp_config=ddp_config, - wrap_with_ddp=True, - mixed_precision_config=mp_cfg, - ) - ddp_call_kwargs = mocks["ddp"].call_args.kwargs - assert ddp_call_kwargs["mixed_precision_config"] is mp_cfg - finally: - self._stop_patches() - - def test_mixed_precision_config_none_by_default(self): - mocks = self._standard_patches() - try: - ddp_config = Mock() - unimodal_build_distributed_models( - Mock(), - self.transformer_config, - self.pg, - ddp_config=ddp_config, - wrap_with_ddp=True, - ) - ddp_call_kwargs = mocks["ddp"].call_args.kwargs - assert ddp_call_kwargs["mixed_precision_config"] is None - finally: - self._stop_patches() diff --git a/tests/unit_tests/training/test_mixed_precision.py b/tests/unit_tests/training/test_mixed_precision.py index a4bed29323..c495523308 100644 --- a/tests/unit_tests/training/test_mixed_precision.py +++ b/tests/unit_tests/training/test_mixed_precision.py @@ -730,222 +730,6 @@ def test_recipe_with_setup(self): assert ddp_config.bf16 is True -class TestGradReduceInFp32AffectsFsdpDtypes: - """Tests that grad_reduce_in_fp32 propagates to Megatron-FSDP dtype fields.""" - - def test_grad_reduce_in_fp32_sets_fsdp_grads_and_comm_to_fp32_on_init(self): - config = MixedPrecisionConfig(grad_reduce_in_fp32=True) - assert config.megatron_fsdp_main_grads_dtype == torch.float32 - assert config.megatron_fsdp_grad_comm_dtype == torch.float32 - - def test_grad_reduce_in_fp32_sets_fsdp_grads_and_comm_to_fp32_via_setattr(self): - config = MixedPrecisionConfig(grad_reduce_in_fp32=False) - config.megatron_fsdp_main_grads_dtype = torch.bfloat16 - config.megatron_fsdp_grad_comm_dtype = torch.bfloat16 - - config.grad_reduce_in_fp32 = True - - assert config.megatron_fsdp_main_grads_dtype == torch.float32 - assert config.megatron_fsdp_grad_comm_dtype == torch.float32 - - def test_grad_reduce_in_fp32_false_preserves_custom_fsdp_dtypes(self): - config = MixedPrecisionConfig( - grad_reduce_in_fp32=False, - megatron_fsdp_main_grads_dtype=torch.bfloat16, - megatron_fsdp_grad_comm_dtype=torch.bfloat16, - ) - assert config.megatron_fsdp_main_grads_dtype == torch.bfloat16 - assert config.megatron_fsdp_grad_comm_dtype == torch.bfloat16 - - def test_finalize_enforces_grad_reduce_in_fp32(self): - config = MixedPrecisionConfig(grad_reduce_in_fp32=False) - config.megatron_fsdp_main_grads_dtype = torch.bfloat16 - config.megatron_fsdp_grad_comm_dtype = torch.bfloat16 - - config.grad_reduce_in_fp32 = True - config.finalize() - - assert config.megatron_fsdp_main_grads_dtype == torch.float32 - assert config.megatron_fsdp_grad_comm_dtype == torch.float32 - - def test_grad_reduce_in_fp32_false_sets_fsdp_grads_and_comm_to_none_on_init(self): - """When grad_reduce_in_fp32=False at init, fsdp grad dtypes default to None (auto).""" - config = MixedPrecisionConfig(grad_reduce_in_fp32=False) - assert config.megatron_fsdp_main_grads_dtype is None - assert config.megatron_fsdp_grad_comm_dtype is None - - def test_grad_reduce_in_fp32_false_sets_fsdp_grads_and_comm_to_none_via_setattr(self): - """Switching grad_reduce_in_fp32 from True to False resets fsdp grad dtypes to None.""" - config = MixedPrecisionConfig(grad_reduce_in_fp32=True) - assert config.megatron_fsdp_main_grads_dtype == torch.float32 - assert config.megatron_fsdp_grad_comm_dtype == torch.float32 - - config.grad_reduce_in_fp32 = False - - assert config.megatron_fsdp_main_grads_dtype is None - assert config.megatron_fsdp_grad_comm_dtype is None - - def test_finalize_enforces_grad_reduce_in_fp32_false(self): - """finalize() overrides custom fsdp grad dtypes to None when grad_reduce_in_fp32=False.""" - config = MixedPrecisionConfig( - grad_reduce_in_fp32=False, - megatron_fsdp_main_grads_dtype=torch.bfloat16, - megatron_fsdp_grad_comm_dtype=torch.bfloat16, - ) - assert config.megatron_fsdp_main_grads_dtype == torch.bfloat16 - assert config.megatron_fsdp_grad_comm_dtype == torch.bfloat16 - - config.finalize() - - assert config.megatron_fsdp_main_grads_dtype is None - assert config.megatron_fsdp_grad_comm_dtype is None - - def test_grad_reduce_in_fp32_toggle_round_trip(self): - """Toggling grad_reduce_in_fp32 True->False->True updates fsdp dtypes correctly.""" - config = MixedPrecisionConfig(grad_reduce_in_fp32=True) - assert config.megatron_fsdp_main_grads_dtype == torch.float32 - assert config.megatron_fsdp_grad_comm_dtype == torch.float32 - - config.grad_reduce_in_fp32 = False - assert config.megatron_fsdp_main_grads_dtype is None - assert config.megatron_fsdp_grad_comm_dtype is None - - config.grad_reduce_in_fp32 = True - assert config.megatron_fsdp_main_grads_dtype == torch.float32 - assert config.megatron_fsdp_grad_comm_dtype == torch.float32 - - def test_grad_reduce_in_fp32_does_not_touch_main_params_dtype(self): - config = MixedPrecisionConfig( - grad_reduce_in_fp32=True, - megatron_fsdp_main_params_dtype=torch.bfloat16, - ) - assert config.megatron_fsdp_main_params_dtype == torch.bfloat16 - - def test_grad_reduce_in_fp32_false_does_not_touch_main_params_dtype(self): - config = MixedPrecisionConfig( - grad_reduce_in_fp32=False, - megatron_fsdp_main_params_dtype=torch.bfloat16, - ) - assert config.megatron_fsdp_main_params_dtype == torch.bfloat16 - - -class TestFsdpDtypeStringConversion: - """Tests that string values for megatron_fsdp_* fields are converted to torch.dtype.""" - - @pytest.mark.parametrize( - "str_val, expected_dtype", - [ - ("fp32", torch.float32), - ("bf16", torch.bfloat16), - ("fp16", torch.float16), - ("auto", None), - ], - ) - def test_main_params_dtype_string_conversion(self, str_val, expected_dtype): - config = MixedPrecisionConfig(megatron_fsdp_main_params_dtype=str_val) - assert config.megatron_fsdp_main_params_dtype is expected_dtype - - @pytest.mark.parametrize( - "str_val, expected_dtype", - [ - ("fp32", torch.float32), - ("bf16", torch.bfloat16), - ("fp16", torch.float16), - ("auto", None), - ], - ) - def test_main_grads_dtype_string_conversion(self, str_val, expected_dtype): - config = MixedPrecisionConfig(grad_reduce_in_fp32=False, megatron_fsdp_main_grads_dtype=str_val) - assert config.megatron_fsdp_main_grads_dtype is expected_dtype - - @pytest.mark.parametrize( - "str_val, expected_dtype", - [ - ("fp32", torch.float32), - ("bf16", torch.bfloat16), - ("fp16", torch.float16), - ("auto", None), - ], - ) - def test_grad_comm_dtype_string_conversion(self, str_val, expected_dtype): - config = MixedPrecisionConfig(grad_reduce_in_fp32=False, megatron_fsdp_grad_comm_dtype=str_val) - assert config.megatron_fsdp_grad_comm_dtype is expected_dtype - - def test_setattr_string_conversion(self): - config = MixedPrecisionConfig(grad_reduce_in_fp32=False) - config.megatron_fsdp_main_params_dtype = "bf16" - config.megatron_fsdp_main_grads_dtype = "fp16" - config.megatron_fsdp_grad_comm_dtype = "auto" - - assert config.megatron_fsdp_main_params_dtype == torch.bfloat16 - assert config.megatron_fsdp_main_grads_dtype == torch.float16 - assert config.megatron_fsdp_grad_comm_dtype is None - - def test_finalize_rejects_unknown_string(self): - config = MixedPrecisionConfig(grad_reduce_in_fp32=False) - object.__setattr__(config, "megatron_fsdp_main_params_dtype", "int8") - - with pytest.raises(ValueError, match="Could not map int8 to torch.dtype"): - config.finalize() - - def test_torch_dtype_values_are_not_modified(self): - config = MixedPrecisionConfig( - megatron_fsdp_main_params_dtype=torch.bfloat16, - megatron_fsdp_main_grads_dtype=torch.float16, - megatron_fsdp_grad_comm_dtype=torch.float32, - ) - assert config.megatron_fsdp_main_params_dtype == torch.bfloat16 - assert config.megatron_fsdp_main_grads_dtype == torch.float16 - assert config.megatron_fsdp_grad_comm_dtype == torch.float32 - - -class TestProvideDistributedModelMixedPrecisionConfig: - """Test that mixed_precision_config is forwarded through provide_distributed_model.""" - - @patch("megatron.bridge.models.model_provider.ProcessGroupCollection.use_mpu_process_groups") - @patch("megatron.bridge.models.model_provider.get_model") - @patch("megatron.bridge.models.model_provider.torch.distributed") - @patch("megatron.bridge.models.model_provider.parallel_state.is_initialized", return_value=True) - def test_mixed_precision_config_forwarded_to_get_model(self, mock_ps_init, mock_dist, mock_get_model, mock_use_pg): - from megatron.core.transformer.module import MegatronModule - from megatron.core.transformer.transformer_config import TransformerConfig - - from megatron.bridge.models.model_provider import ModelProviderMixin - - class _Module(MegatronModule): - def __init__(self): - super().__init__(TransformerConfig(num_layers=1, hidden_size=1, num_attention_heads=1)) - - class _Provider(ModelProviderMixin): - def provide(self, pre_process=None, post_process=None, vp_stage=None): - return _Module() - - mock_dist.is_initialized.return_value = True - mock_get_model.return_value = [_Module()] - - class _PG: - def __init__(self): - self.pp = object() - self.tp = object() - self.cp = object() - self.dp = object() - self.dp_cp = object() - self.expt_dp = object() - - mock_use_pg.return_value = _PG() - - mp_config = MixedPrecisionConfig(bf16=True, params_dtype=torch.bfloat16) - provider = _Provider() - provider.provide_distributed_model( - ddp_config=DistributedDataParallelConfig(), - mixed_precision_config=mp_config, - wrap_with_ddp=False, - ) - - mock_get_model.assert_called_once() - assert mock_get_model.call_args.kwargs["mixed_precision_config"] is mp_config - - class TestRegisterAndGetMixedPrecisionConfig: """Tests for the `register` decorator and `get_mixed_precision_config` helper."""