Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions megatron/core/distributed/distributed_data_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import Optional

import torch


@dataclass
class DistributedDataParallelConfig:
Expand Down Expand Up @@ -162,6 +164,33 @@ class DistributedDataParallelConfig:
delay_wgrad_compute: bool = False
"""Delay the weight gradient computation to improve batch-level communication overlapping"""

megatron_fsdp_main_params_dtype: Optional[torch.dtype] = torch.float32
"""Data type for the main weight buffer utilized for distributed optimization
and quantization with Megatron-FSDP. If set to None, the model compute weight
buffer will take the role of the main weights, or when no sharding is applied,
the native model weights become the main weights. Defaults to torch.float32.
"""

megatron_fsdp_main_grads_dtype: Optional[torch.dtype] = None
"""Data type for the main gradient buffer utilized for distributed optimization with
Megatron-FSDP. If set to None, main gradients will match the dtype of the model
compute parameters specified by the user model. Defaults to None.
"""

megatron_fsdp_grad_comm_dtype: Optional[torch.dtype] = None
"""Data type for gradient gather / scatter communications. Can be utilized to reduce
communication latency, but adds overhead for type-casting and copy operations.
If using NCCL UBR v2.27+, gradient reduction may be performed in high-precision
depending on the network domain (NVLink or IB), and can enable mixed-precision
communication and accumulation, e.g. setting grad_comm_dtype to `BF16` can support
`FP32` reduction even though we have `BF16` input and output communication buffers.
If set to None, the `main_grads_dtype` is used. If using HSDP (either DP-Replicate
or DP-Outer in `outer_dp_sharding_strategy`), `no_shard`, `optim`, or a
`FixedPoolAllocator` (`fsdp_double_buffer`), allocating `dtype`-custom gradient
communication buffers (per FSDP group) adds memory overhead. Defaults to None.
No additional memory is allocated when `grad_comm_dtype == main_grads_dtype`.
"""

def __post_init__(self):
import os

Expand Down
17 changes: 11 additions & 6 deletions megatron/core/distributed/fsdp/mcore_fsdp_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ def __init__(
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
fsdp_unit_modules: Optional[List[torch.nn.Module]] = None,
main_params_dtype: Optional[torch.dtype] = torch.float32,
main_grads_dtype: Optional[torch.dtype] = None,
grad_comm_dtype: Optional[torch.dtype] = None,
disable_bucketing: bool = False,
device: Optional[torch.device] = None,
pg_collection: Optional[ProcessGroupCollection] = None,
Expand All @@ -90,10 +87,18 @@ def __init__(
f'Setting up DistributedDataParallel with config {self.ddp_config}',
)
self.mp_policy = MixedPrecisionPolicy(
main_params_dtype=main_params_dtype,
main_params_dtype=ddp_config.megatron_fsdp_main_params_dtype,
# Grandfathered Argument: grad_reduce_in_fp32
main_grads_dtype=torch.float32 if ddp_config.grad_reduce_in_fp32 else main_grads_dtype,
grad_comm_dtype=torch.float32 if ddp_config.grad_reduce_in_fp32 else grad_comm_dtype,
main_grads_dtype=(
torch.float32
if ddp_config.grad_reduce_in_fp32
else ddp_config.megatron_fsdp_main_grads_dtype
),
grad_comm_dtype=(
torch.float32
if ddp_config.grad_reduce_in_fp32
else ddp_config.megatron_fsdp_grad_comm_dtype
),
)
log_single_rank(
logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import Optional

import torch


@dataclass
class DistributedDataParallelConfig:
Expand Down Expand Up @@ -116,6 +118,33 @@ class DistributedDataParallelConfig:
to minimize the registration time.
"""

megatron_fsdp_main_params_dtype: Optional[torch.dtype] = torch.float32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just add a MixedPrecisionPolicy field instead of duplicating?

Copy link
Member Author

@cspades cspades Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! MixedPrecisionPolicy also works without Megatron, i.e. the fully_shard API that native Torch users and NeMo Automodel can use without installing anything. PyPI: https://pypi.org/project/megatron-fsdp/

Nesting that dataclass in the DDPConfig as well as using it as a standalone config can be quite confusing. The only reason I'm even including these arguments into DDPConfig is because of Megatron-Bridge conventions, and prefer not to make it more complicated. One complexity is obvious - that users will need to import Megatron-FSDP sub-modules (megatron_fsdp.MixedPrecisionPolicy) just to use the DDPConfig, not needed and possibly circular.

"""Data type for the main weight buffer utilized for distributed optimization
and quantization with Megatron-FSDP. If set to None, the model compute weight
buffer will take the role of the main weights, or when no sharding is applied,
the native model weights become the main weights. Defaults to torch.float32.
"""

megatron_fsdp_main_grads_dtype: Optional[torch.dtype] = None
"""Data type for the main gradient buffer utilized for distributed optimization with
Megatron-FSDP. If set to None, main gradients will match the dtype of the model
compute parameters specified by the user model. Defaults to None.
"""

megatron_fsdp_grad_comm_dtype: Optional[torch.dtype] = None
"""Data type for gradient gather / scatter communications. Can be utilized to reduce
communication latency, but adds overhead for type-casting and copy operations.
If using NCCL UBR v2.27+, gradient reduction may be performed in high-precision
depending on the network domain (NVLink or IB), and can enable mixed-precision
communication and accumulation, e.g. setting grad_comm_dtype to `BF16` can support
`FP32` reduction even though we have `BF16` input and output communication buffers.
If set to None, the `main_grads_dtype` is used. If using HSDP (either DP-Replicate
or DP-Outer in `outer_dp_sharding_strategy`), `no_shard`, `optim`, or a
`FixedPoolAllocator` (`fsdp_double_buffer`), allocating `dtype`-custom gradient
communication buffers (per FSDP group) adds memory overhead. Defaults to None.
No additional memory is allocated when `grad_comm_dtype == main_grads_dtype`.
"""

def __post_init__(self):
import os

Expand Down
9 changes: 0 additions & 9 deletions megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,14 +1378,6 @@ def build_model():
ddp_stream.wait_stream(torch.cuda.current_stream())
# Make ddp_stream start after whatever the default stream already queued
with torch.cuda.stream(ddp_stream):
# To pass kwargs unique to specific DDP classes.
ddp_init_kwargs = {}
if args.use_megatron_fsdp:
# Also pass the mixed-precision arguments for Megatron-FSDP only.
ddp_init_kwargs["main_params_dtype"] = args.megatron_fsdp_main_params_dtype
ddp_init_kwargs["main_grads_dtype"] = args.megatron_fsdp_main_grads_dtype
ddp_init_kwargs["grad_comm_dtype"] = args.megatron_fsdp_grad_comm_dtype

model = [
DP(
config=config,
Expand All @@ -1394,7 +1386,6 @@ def build_model():
# Turn off bucketing for model_chunk 2 onwards, since communication
# for these model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step,
**ddp_init_kwargs,
)
for (model_chunk_idx, model_chunk) in enumerate(model)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def _build_fsdp_model(
bucket_size=10000,
use_megatron_fsdp=True,
grad_reduce_in_fp32=grad_reduce_in_fp32,
megatron_fsdp_main_params_dtype=main_params_dtype,
megatron_fsdp_main_grads_dtype=main_grads_dtype,
megatron_fsdp_grad_comm_dtype=grad_comm_dtype,
)
model = TestModel(input_dim=13, output_dim=17).cuda()
transformer_config = TransformerConfig(
Expand All @@ -112,9 +115,6 @@ def _build_fsdp_model(
ddp_config=fsdp_config,
module=model,
fsdp_unit_modules=[torch.nn.Linear],
main_params_dtype=main_params_dtype,
main_grads_dtype=main_grads_dtype,
grad_comm_dtype=grad_comm_dtype,
)
return fsdp_model

Expand Down
Loading