From d3bc705555cdf911d72ab781e95606307cc81f01 Mon Sep 17 00:00:00 2001 From: Cory Ye Date: Tue, 17 Mar 2026 12:34:11 -0700 Subject: [PATCH] Move Megatron-FSDP MixedPrecisionPolicy arguments from FSDP adapter to DDPConfig. Signed-off-by: Cory Ye --- .../distributed_data_parallel_config.py | 29 +++++++++++++++++++ .../distributed/fsdp/mcore_fsdp_adapter.py | 17 +++++++---- .../distributed_data_parallel_config.py | 29 +++++++++++++++++++ megatron/training/training.py | 9 ------ .../test_mcore_fully_sharded_data_parallel.py | 6 ++-- 5 files changed, 72 insertions(+), 18 deletions(-) diff --git a/megatron/core/distributed/distributed_data_parallel_config.py b/megatron/core/distributed/distributed_data_parallel_config.py index 80118bd6ce1..807b122b0e6 100644 --- a/megatron/core/distributed/distributed_data_parallel_config.py +++ b/megatron/core/distributed/distributed_data_parallel_config.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Optional +import torch + @dataclass class DistributedDataParallelConfig: @@ -162,6 +164,33 @@ class DistributedDataParallelConfig: delay_wgrad_compute: bool = False """Delay the weight gradient computation to improve batch-level communication overlapping""" + megatron_fsdp_main_params_dtype: Optional[torch.dtype] = torch.float32 + """Data type for the main weight buffer utilized for distributed optimization + and quantization with Megatron-FSDP. If set to None, the model compute weight + buffer will take the role of the main weights, or when no sharding is applied, + the native model weights become the main weights. Defaults to torch.float32. + """ + + megatron_fsdp_main_grads_dtype: Optional[torch.dtype] = None + """Data type for the main gradient buffer utilized for distributed optimization with + Megatron-FSDP. If set to None, main gradients will match the dtype of the model + compute parameters specified by the user model. Defaults to None. + """ + + megatron_fsdp_grad_comm_dtype: Optional[torch.dtype] = None + """Data type for gradient gather / scatter communications. Can be utilized to reduce + communication latency, but adds overhead for type-casting and copy operations. + If using NCCL UBR v2.27+, gradient reduction may be performed in high-precision + depending on the network domain (NVLink or IB), and can enable mixed-precision + communication and accumulation, e.g. setting grad_comm_dtype to `BF16` can support + `FP32` reduction even though we have `BF16` input and output communication buffers. + If set to None, the `main_grads_dtype` is used. If using HSDP (either DP-Replicate + or DP-Outer in `outer_dp_sharding_strategy`), `no_shard`, `optim`, or a + `FixedPoolAllocator` (`fsdp_double_buffer`), allocating `dtype`-custom gradient + communication buffers (per FSDP group) adds memory overhead. Defaults to None. + No additional memory is allocated when `grad_comm_dtype == main_grads_dtype`. + """ + def __post_init__(self): import os diff --git a/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py b/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py index 78fe63130f3..8993620c779 100644 --- a/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py +++ b/megatron/core/distributed/fsdp/mcore_fsdp_adapter.py @@ -70,9 +70,6 @@ def __init__( ddp_config: DistributedDataParallelConfig, module: torch.nn.Module, fsdp_unit_modules: Optional[List[torch.nn.Module]] = None, - main_params_dtype: Optional[torch.dtype] = torch.float32, - main_grads_dtype: Optional[torch.dtype] = None, - grad_comm_dtype: Optional[torch.dtype] = None, disable_bucketing: bool = False, device: Optional[torch.device] = None, pg_collection: Optional[ProcessGroupCollection] = None, @@ -90,10 +87,18 @@ def __init__( f'Setting up DistributedDataParallel with config {self.ddp_config}', ) self.mp_policy = MixedPrecisionPolicy( - main_params_dtype=main_params_dtype, + main_params_dtype=ddp_config.megatron_fsdp_main_params_dtype, # Grandfathered Argument: grad_reduce_in_fp32 - main_grads_dtype=torch.float32 if ddp_config.grad_reduce_in_fp32 else main_grads_dtype, - grad_comm_dtype=torch.float32 if ddp_config.grad_reduce_in_fp32 else grad_comm_dtype, + main_grads_dtype=( + torch.float32 + if ddp_config.grad_reduce_in_fp32 + else ddp_config.megatron_fsdp_main_grads_dtype + ), + grad_comm_dtype=( + torch.float32 + if ddp_config.grad_reduce_in_fp32 + else ddp_config.megatron_fsdp_grad_comm_dtype + ), ) log_single_rank( logger, diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py index aae16b2f57d..4ad5a8dddac 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/distributed_data_parallel_config.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Optional +import torch + @dataclass class DistributedDataParallelConfig: @@ -116,6 +118,33 @@ class DistributedDataParallelConfig: to minimize the registration time. """ + megatron_fsdp_main_params_dtype: Optional[torch.dtype] = torch.float32 + """Data type for the main weight buffer utilized for distributed optimization + and quantization with Megatron-FSDP. If set to None, the model compute weight + buffer will take the role of the main weights, or when no sharding is applied, + the native model weights become the main weights. Defaults to torch.float32. + """ + + megatron_fsdp_main_grads_dtype: Optional[torch.dtype] = None + """Data type for the main gradient buffer utilized for distributed optimization with + Megatron-FSDP. If set to None, main gradients will match the dtype of the model + compute parameters specified by the user model. Defaults to None. + """ + + megatron_fsdp_grad_comm_dtype: Optional[torch.dtype] = None + """Data type for gradient gather / scatter communications. Can be utilized to reduce + communication latency, but adds overhead for type-casting and copy operations. + If using NCCL UBR v2.27+, gradient reduction may be performed in high-precision + depending on the network domain (NVLink or IB), and can enable mixed-precision + communication and accumulation, e.g. setting grad_comm_dtype to `BF16` can support + `FP32` reduction even though we have `BF16` input and output communication buffers. + If set to None, the `main_grads_dtype` is used. If using HSDP (either DP-Replicate + or DP-Outer in `outer_dp_sharding_strategy`), `no_shard`, `optim`, or a + `FixedPoolAllocator` (`fsdp_double_buffer`), allocating `dtype`-custom gradient + communication buffers (per FSDP group) adds memory overhead. Defaults to None. + No additional memory is allocated when `grad_comm_dtype == main_grads_dtype`. + """ + def __post_init__(self): import os diff --git a/megatron/training/training.py b/megatron/training/training.py index e319617d6d5..0ad2c8373f9 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1378,14 +1378,6 @@ def build_model(): ddp_stream.wait_stream(torch.cuda.current_stream()) # Make ddp_stream start after whatever the default stream already queued with torch.cuda.stream(ddp_stream): - # To pass kwargs unique to specific DDP classes. - ddp_init_kwargs = {} - if args.use_megatron_fsdp: - # Also pass the mixed-precision arguments for Megatron-FSDP only. - ddp_init_kwargs["main_params_dtype"] = args.megatron_fsdp_main_params_dtype - ddp_init_kwargs["main_grads_dtype"] = args.megatron_fsdp_main_grads_dtype - ddp_init_kwargs["grad_comm_dtype"] = args.megatron_fsdp_grad_comm_dtype - model = [ DP( config=config, @@ -1394,7 +1386,6 @@ def build_model(): # Turn off bucketing for model_chunk 2 onwards, since communication # for these model chunks is overlapped with compute anyway. disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step, - **ddp_init_kwargs, ) for (model_chunk_idx, model_chunk) in enumerate(model) ] diff --git a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py index 313119ad864..0271da1fed9 100644 --- a/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py +++ b/tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py @@ -102,6 +102,9 @@ def _build_fsdp_model( bucket_size=10000, use_megatron_fsdp=True, grad_reduce_in_fp32=grad_reduce_in_fp32, + megatron_fsdp_main_params_dtype=main_params_dtype, + megatron_fsdp_main_grads_dtype=main_grads_dtype, + megatron_fsdp_grad_comm_dtype=grad_comm_dtype, ) model = TestModel(input_dim=13, output_dim=17).cuda() transformer_config = TransformerConfig( @@ -112,9 +115,6 @@ def _build_fsdp_model( ddp_config=fsdp_config, module=model, fsdp_unit_modules=[torch.nn.Linear], - main_params_dtype=main_params_dtype, - main_grads_dtype=main_grads_dtype, - grad_comm_dtype=grad_comm_dtype, ) return fsdp_model