Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 1 addition & 19 deletions src/megatron/bridge/models/common/unimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import logging


logger = logging.getLogger(__name__)

from typing import TYPE_CHECKING, Any, Callable
from typing import Any, Callable

import torch
from megatron.core import tensor_parallel
Expand All @@ -36,9 +34,6 @@
from megatron.core.utils import get_model_config


if TYPE_CHECKING:
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig

try:
from megatron.core.fp8_utils import correct_amax_history_if_needed
except ImportError:
Expand All @@ -58,7 +53,6 @@ def unimodal_build_distributed_models(
mixed_precision_wrapper: Callable[[Any, MegatronModule], MegatronModule] | None = Float16Module,
pre_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None = None,
model_type: ModelType = ModelType.encoder_or_decoder,
mixed_precision_config: MixedPrecisionConfig | None = None,
) -> list[MegatronModule]:
"""Build model stages and wrap for distributed training.

Expand Down Expand Up @@ -87,7 +81,6 @@ def unimodal_build_distributed_models(
Pass ``None`` to skip.
pre_wrap_hook: Hook applied to the model stage list before any wrapping.
model_type: Deprecated flag, only used for backwards compatibility.
mixed_precision_config: Mixed-precision config for DDP wrapper.

Returns:
List of model stages, wrapped and ready for distributed training.
Expand Down Expand Up @@ -151,7 +144,6 @@ def unimodal_build_distributed_models(
use_megatron_fsdp=use_megatron_fsdp,
use_torch_fsdp2=use_torch_fsdp2,
pg_collection=pg_collection,
mixed_precision_config=mixed_precision_config,
)

return model_list
Expand Down Expand Up @@ -204,7 +196,6 @@ def _ddp_wrap(
overlap_param_gather_with_optimizer_step: bool,
use_megatron_fsdp: bool = False,
use_torch_fsdp2: bool = False,
mixed_precision_config: MixedPrecisionConfig | None = None,
*,
pg_collection: ProcessGroupCollection,
) -> list[MegatronModule]:
Expand All @@ -218,23 +209,15 @@ def _ddp_wrap(
for overlapping parameter gathering with optimizer step
use_megatron_fsdp: Whether to use Megatron FSDP.
use_torch_fsdp2: Whether to use PyTorch FSDP v2 instead of DDP
mixed_precision_config: Mixed-precision config for DDP wrapper.
pg_collection: Model communication process groups.

Returns:
list[MegatronModule]: List of DDP/FSDP wrapped model modules
"""
ddp_init_kwargs = {}
if use_megatron_fsdp:
DP = FullyShardedDataParallel
if use_torch_fsdp2:
raise ValueError("Using use_megatron_fsdp and use_torch_fsdp2 at the same time is not supported.")
if mixed_precision_config is not None:
# Also pass the mixed-precision arguments for Megatron-FSDP only.
mixed_precision_config.finalize()
ddp_init_kwargs["main_params_dtype"] = mixed_precision_config.megatron_fsdp_main_params_dtype
ddp_init_kwargs["main_grads_dtype"] = mixed_precision_config.megatron_fsdp_main_grads_dtype
ddp_init_kwargs["grad_comm_dtype"] = mixed_precision_config.megatron_fsdp_grad_comm_dtype
elif use_torch_fsdp2:
DP = TorchFullyShardedDataParallel
else:
Expand All @@ -255,7 +238,6 @@ def _ddp_wrap(
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step,
pg_collection=pg_collection,
**ddp_init_kwargs,
)
for (model_chunk_idx, model_chunk) in enumerate(model)
]
Expand Down
16 changes: 1 addition & 15 deletions src/megatron/bridge/models/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import abc
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Generic, TypedDict, TypeVar, Union
from typing import Any, Callable, Generic, TypedDict, TypeVar, Union

from megatron.bridge.models.common.unimodal import _ddp_wrap, _print_num_params

Expand Down Expand Up @@ -57,10 +55,6 @@
from megatron.bridge.utils.instantiate_utils import InstantiationMode


if TYPE_CHECKING:
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig


try:
from megatron.core.fp8_utils import correct_amax_history_if_needed
except ImportError:
Expand Down Expand Up @@ -129,7 +123,6 @@ def provide_distributed_model(
| None = None,
post_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None = None,
mixed_precision_wrapper: Callable[[Any, MegatronModule], MegatronModule] | None = Float16Module,
mixed_precision_config: MixedPrecisionConfig | None = None,
pg_collection: ProcessGroupCollection | None = None,
) -> list[ModelT]:
"""Instantiate and wrap the model for distributed training.
Expand Down Expand Up @@ -158,8 +151,6 @@ def provide_distributed_model(
this will override all hooks registered via `register_post_wrap_hook`.
mixed_precision_wrapper: A module wrapper (e.g., `Float16Module`) applied when fp16/bf16
is enabled. If None, no mixed precision wrapper is applied.
mixed_precision_config: Optional MixedPrecisionConfig. Used to configure mixed-precision
features on models during initialization.
pg_collection: Optional pre-initialized ProcessGroupCollection. If provided, skips
model parallel initialization and uses the provided collection directly.
This is used when `use_decentralized_pg=True` in the distributed config.
Expand Down Expand Up @@ -216,7 +207,6 @@ def composed_pre_wrap_hook(model: list[MegatronModule]) -> list[MegatronModule]:
init_model_with_meta_device=init_model_with_meta_device,
pre_wrap_hook=final_pre_wrap_hook,
mixed_precision_wrapper=mixed_precision_wrapper,
mixed_precision_config=mixed_precision_config,
pg_collection=pg_collection,
)

Expand Down Expand Up @@ -501,7 +491,6 @@ def get_model(
]
| None = None,
mixed_precision_wrapper: Callable[[Any, MegatronModule], MegatronModule] | None = Float16Module,
mixed_precision_config: MixedPrecisionConfig | None = None,
*,
pg_collection: ProcessGroupCollection,
) -> list[MegatronModule]:
Expand Down Expand Up @@ -536,8 +525,6 @@ def get_model(
hooks will be executed in order.
mixed_precision_wrapper: Wrapper class/function applied when fp16/bf16 is enabled. Defaults
to Megatron-Core's `Float16Module`. If None, the wrapper is not applied.
mixed_precision_config: MixedPrecisionConfig that controls mixed-precision features
for the created model.

Returns:
list[MegatronModule]: List of model modules. Contains multiple modules
Expand Down Expand Up @@ -620,7 +607,6 @@ def get_model(
overlap_param_gather_with_optimizer_step,
use_megatron_fsdp=use_megatron_fsdp,
use_torch_fsdp2=use_torch_fsdp2,
mixed_precision_config=mixed_precision_config,
pg_collection=pg_collection,
)

Expand Down
57 changes: 0 additions & 57 deletions src/megatron/bridge/training/mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,6 @@ class MixedPrecisionConfig:
num_layers_at_start_in_bf16: int = 0
num_layers_at_end_in_bf16: int = 0
reuse_grad_buf_for_mxfp8_param_ag: bool = False
# Megatron-FSDP MixedPrecisionPolicy
megatron_fsdp_main_params_dtype: str | Optional[torch.dtype] = torch.float32
megatron_fsdp_main_grads_dtype: str | Optional[torch.dtype] = None
megatron_fsdp_grad_comm_dtype: str | Optional[torch.dtype] = None

def __post_init__(self):
if self.grad_reduce_in_fp32:
if self.megatron_fsdp_main_grads_dtype is None:
object.__setattr__(self, "megatron_fsdp_main_grads_dtype", torch.float32)
if self.megatron_fsdp_grad_comm_dtype is None:
object.__setattr__(self, "megatron_fsdp_grad_comm_dtype", torch.float32)

def __setattr__(self, name: str, value) -> None:
# Use object.__setattr__ to avoid recursion
Expand All @@ -89,35 +78,6 @@ def __setattr__(self, name: str, value) -> None:
elif name == "fp8_param" and hasattr(self, "fp8_param_gather"):
if self.fp8_param_gather != value:
object.__setattr__(self, "fp8_param_gather", value)
if (
name == "grad_reduce_in_fp32"
and hasattr(self, "megatron_fsdp_main_grads_dtype")
and hasattr(self, "megatron_fsdp_grad_comm_dtype")
):
if value:
# Legacy argument for Megatron-FSDP - Gradients used to be reduced in
# the same data-type as the main gradient data-type. Recommend using
# the new Megatron-FSDP mixed-precision arguments to control this!
object.__setattr__(self, "megatron_fsdp_main_grads_dtype", torch.float32)
object.__setattr__(self, "megatron_fsdp_grad_comm_dtype", torch.float32)
else:
# Default back to "auto".
object.__setattr__(self, "megatron_fsdp_main_grads_dtype", None)
object.__setattr__(self, "megatron_fsdp_grad_comm_dtype", None)
if name in (
"megatron_fsdp_main_params_dtype",
"megatron_fsdp_main_grads_dtype",
"megatron_fsdp_grad_comm_dtype",
) and isinstance(value, str):
# Map string options to torch.dtype or None.
if value == "fp32":
object.__setattr__(self, name, torch.float32)
elif value == "bf16":
object.__setattr__(self, name, torch.bfloat16)
elif value == "fp16":
object.__setattr__(self, name, torch.float16)
elif value == "auto":
object.__setattr__(self, name, None)

def finalize(self):
# If fp8_param is None, initialize it from fp8_param_gather
Expand All @@ -137,23 +97,6 @@ def finalize(self):
if self.fp4 and not is_te_min_version("2.7.0.dev0"):
raise ValueError("fp4 requires Transformer Engine >= 2.7.0.dev0 for NVFP4BlockScaling support.")

if self.grad_reduce_in_fp32:
self.megatron_fsdp_main_grads_dtype = torch.float32
self.megatron_fsdp_grad_comm_dtype = torch.float32
else:
self.megatron_fsdp_main_grads_dtype = None
self.megatron_fsdp_grad_comm_dtype = None
for mfsdp_mp_arg in (
self.megatron_fsdp_main_params_dtype,
self.megatron_fsdp_main_grads_dtype,
self.megatron_fsdp_grad_comm_dtype,
):
if isinstance(mfsdp_mp_arg, str):
raise ValueError(
f"[MixedPrecisionConfig] Could not map {mfsdp_mp_arg} to torch.dtype or 'auto'. "
"Options: 'fp32', 'fp16', 'bf16', 'auto'"
)

def setup(
self,
model_config: GPTModelProvider | T5ModelProvider,
Expand Down
1 change: 0 additions & 1 deletion src/megatron/bridge/training/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def modelopt_pre_wrap_hook(model):
overlap_param_gather_with_optimizer_step=cfg.optimizer.overlap_param_gather_with_optimizer_step,
data_parallel_random_init=cfg.rng.data_parallel_random_init,
pg_collection=pg_collection,
mixed_precision_config=cfg.mixed_precision,
)

cfg.model.timers = timers
Expand Down
98 changes: 0 additions & 98 deletions tests/unit_tests/models/common/test_unimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,70 +497,6 @@ def test_returns_list_of_wrapped_modules(
assert isinstance(result, list)
assert len(result) == 2

@patch("megatron.bridge.models.common.unimodal.TorchFullyShardedDataParallel")
@patch("megatron.bridge.models.common.unimodal.FullyShardedDataParallel")
@patch("megatron.bridge.models.common.unimodal.DistributedDataParallel")
@patch("megatron.bridge.models.common.unimodal.get_model_config")
@patch("torch.cuda.stream", new_callable=MagicMock)
@patch("torch.cuda.current_stream")
@patch("torch.cuda.Stream")
def test_mixed_precision_config_forwarded_to_megatron_fsdp(
self, mock_stream, mock_curr, mock_ctx, mock_cfg, mock_ddp, mock_fsdp, mock_torch_fsdp
):
mock_ctx.return_value.__enter__ = Mock(return_value=None)
mock_ctx.return_value.__exit__ = Mock(return_value=False)

mp_cfg = Mock()
mp_cfg.megatron_fsdp_main_params_dtype = torch.float32
mp_cfg.megatron_fsdp_main_grads_dtype = torch.bfloat16
mp_cfg.megatron_fsdp_grad_comm_dtype = torch.bfloat16

_ddp_wrap(
self.model,
False,
self.ddp_config,
False,
use_megatron_fsdp=True,
pg_collection=self.pg,
mixed_precision_config=mp_cfg,
)

mp_cfg.finalize.assert_called_once()
for call in mock_fsdp.call_args_list:
assert call.kwargs["main_params_dtype"] == torch.float32
assert call.kwargs["main_grads_dtype"] == torch.bfloat16
assert call.kwargs["grad_comm_dtype"] == torch.bfloat16

@patch("megatron.bridge.models.common.unimodal.TorchFullyShardedDataParallel")
@patch("megatron.bridge.models.common.unimodal.FullyShardedDataParallel")
@patch("megatron.bridge.models.common.unimodal.DistributedDataParallel")
@patch("megatron.bridge.models.common.unimodal.get_model_config")
@patch("torch.cuda.stream", new_callable=MagicMock)
@patch("torch.cuda.current_stream")
@patch("torch.cuda.Stream")
def test_mixed_precision_config_ignored_for_standard_ddp(
self, mock_stream, mock_curr, mock_ctx, mock_cfg, mock_ddp, mock_fsdp, mock_torch_fsdp
):
mock_ctx.return_value.__enter__ = Mock(return_value=None)
mock_ctx.return_value.__exit__ = Mock(return_value=False)

mp_cfg = Mock()
_ddp_wrap(
self.model,
False,
self.ddp_config,
False,
use_megatron_fsdp=False,
pg_collection=self.pg,
mixed_precision_config=mp_cfg,
)

mp_cfg.finalize.assert_not_called()
for call in mock_ddp.call_args_list:
assert "main_params_dtype" not in call.kwargs
assert "main_grads_dtype" not in call.kwargs
assert "grad_comm_dtype" not in call.kwargs


# =============================================================================
# Section 6 — TestUnimodalBuildDistributedModels
Expand Down Expand Up @@ -900,37 +836,3 @@ def test_returns_final_model_list(self):
assert result is ddp_result
finally:
self._stop_patches()

def test_mixed_precision_config_forwarded_to_ddp_wrap(self):
mocks = self._standard_patches()
try:
ddp_config = Mock()
mp_cfg = Mock()
unimodal_build_distributed_models(
Mock(),
self.transformer_config,
self.pg,
ddp_config=ddp_config,
wrap_with_ddp=True,
mixed_precision_config=mp_cfg,
)
ddp_call_kwargs = mocks["ddp"].call_args.kwargs
assert ddp_call_kwargs["mixed_precision_config"] is mp_cfg
finally:
self._stop_patches()

def test_mixed_precision_config_none_by_default(self):
mocks = self._standard_patches()
try:
ddp_config = Mock()
unimodal_build_distributed_models(
Mock(),
self.transformer_config,
self.pg,
ddp_config=ddp_config,
wrap_with_ddp=True,
)
ddp_call_kwargs = mocks["ddp"].call_args.kwargs
assert ddp_call_kwargs["mixed_precision_config"] is None
finally:
self._stop_patches()
Loading
Loading