Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,22 @@ def parallelize_deepseekv3(
):
experts_shard_dim = 1

# when EP is enable, the routed experts' gradient reduction is done over
# dp_mod_ep_mesh instead of whole dp_mesh.
# we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh
# to be consistent with data.
# TODO (ruisizhang123): update the logic following the link below instead
# of using a reduction_divide_factor
# https://github.com/pytorch/torchtitan/pull/1803#discussion_r2415190883
transformer_block.moe.experts = data_parallel(
transformer_block.moe.experts,
dp_mod_ep_mesh,
dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
shard_dim=experts_shard_dim,
reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future we probably should deprecate this logic anyway See related PR #1803 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add a todo for it

)
# TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp
# transformer_block.moe.experts.set_gradient_divide_factor(
# parallel_dims.fsdp_gradient_divide_factor,
# )

model = data_parallel(
model,
Expand Down
43 changes: 41 additions & 2 deletions torchtitan/experiments/simple_fsdp/simple_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,37 @@ class MixedPrecisionPolicy:
reduce_dtype: Optional[torch.dtype] = None


class _ScaledPartial(Partial):
# A subclass of Partial placement that allows user to perform reduction with a custom
# factor (reduction_divide_factor) other than the default world size.
def __init__(
self,
reduction_divide_factor: float,
):
self.reduction_divide_factor = reduction_divide_factor
super().__init__(reduce_op="sum")

def _reduce_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# for all_reduce in DDP
tensor.div_(self.reduction_divide_factor)
reduced = super()._reduce_value(tensor, mesh, mesh_dim)
return reduced

def _reduce_shard_value(
self,
tensor: torch.Tensor,
mesh: DeviceMesh,
mesh_dim: int,
shard_spec: Placement,
) -> torch.Tensor:
# for reduce_scatter in FSDP
tensor.div_(self.reduction_divide_factor)
reduced = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec)
return reduced


def _distribute_dtensor(
tensor: DTensor,
device_mesh: DeviceMesh,
Expand Down Expand Up @@ -192,18 +223,24 @@ def __init__(
mode,
regional_ac,
mp_policy,
reduction_divide_factor,
):
super().__init__()
self.device_mesh = device_mesh
self.param_sharding = param_sharding
self.mode = mode
self.compute_placements = [Replicate()] * self.device_mesh.ndim
self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim
self.grad_placements = [
_ScaledPartial(
reduction_divide_factor=reduction_divide_factor,
)
if reduction_divide_factor is not None
else Partial(reduce_op="avg")
] * self.device_mesh.ndim
self.regional_ac = regional_ac
mp_policy = mp_policy or MixedPrecisionPolicy()
self.param_dtype = mp_policy.param_dtype
self.reduce_dtype = mp_policy.reduce_dtype
self.ep_mesh_name, self.tp_mesh_name = "ep", "tp"

def replicate_compute(self, x):
# data parallel runtime replicate parameters and do local compute
Expand Down Expand Up @@ -286,6 +323,7 @@ def data_parallel(
ac_mode: str = "none",
mp_policy: Optional[MixedPrecisionPolicy] = None,
shard_dim: int = 0,
reduction_divide_factor: Optional[float] = None,
):
if mode == "replicate":
param_sharding = (Replicate(),)
Expand Down Expand Up @@ -348,6 +386,7 @@ def data_parallel(
mode,
regional_ac,
mp_policy=mp_policy,
reduction_divide_factor=reduction_divide_factor,
),
)
return model