Skip to content

Commit 8a414cf

Browse files
committed
fix simplefsdp gradient_divide_factor
1 parent 5d8e2d5 commit 8a414cf

File tree

2 files changed

+92
-7
lines changed

2 files changed

+92
-7
lines changed

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,8 @@ def parallelize_deepseekv3(
132132
ac_mode=job_config.activation_checkpoint.mode,
133133
mp_policy=mp_policy,
134134
shard_dim=experts_shard_dim,
135+
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
135136
)
136-
# TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp
137-
# transformer_block.moe.experts.set_gradient_divide_factor(
138-
# parallel_dims.fsdp_gradient_divide_factor,
139-
# )
140137

141138
model = data_parallel(
142139
model,

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Sequence
88
from contextlib import contextmanager
99
from dataclasses import dataclass
10-
from typing import List, Optional
10+
from typing import cast, List, Optional, Union
1111

1212
import torch
1313
import torch.nn as nn
@@ -20,6 +20,8 @@
2020
Shard,
2121
)
2222
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
23+
from torch.distributed.distributed_c10d import ReduceOp
24+
from torch.distributed.fsdp._fully_shard._fsdp_collectives import _div_if_needed
2325
from torch.distributed.tensor._dtensor_spec import DTensorSpec
2426
from torch.distributed.tensor._redistribute import redistribute_local_tensor
2527
from torch.distributed.tensor.placement_types import _StridedShard, Placement
@@ -49,6 +51,82 @@ class MixedPrecisionPolicy:
4951
reduce_dtype: Optional[torch.dtype] = None
5052

5153

54+
@dataclass(frozen=True)
55+
class SimpleFSDPPartial(Partial):
56+
gradient_divide_factor: Optional[float] = None
57+
reduce_dtype: Optional[torch.dtype] = None
58+
59+
def _get_gradient_divide_factors(
60+
self,
61+
) -> tuple[
62+
Optional[float],
63+
Optional[float],
64+
str,
65+
str,
66+
]:
67+
"""
68+
the logic follows
69+
https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688
70+
"""
71+
if self.gradient_divide_factor is None:
72+
return None, None, None, None
73+
overflow_risk = self.reduce_dtype not in (torch.float32, torch.bfloat16)
74+
pre_factor: Optional[float] = None
75+
post_factor: Optional[float] = None
76+
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
77+
# overflow/underflow. For N data parallel workers, each worker computes
78+
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
79+
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
80+
pre_factor = 1
81+
while (
82+
self.gradient_divide_factor % pre_factor == 0
83+
and self.gradient_divide_factor / pre_factor > pre_factor
84+
):
85+
pre_factor *= 2
86+
post_factor = self.gradient_divide_factor / pre_factor
87+
reduce_scatter_op, all_reduce_op = "sum", "sum"
88+
return pre_factor, post_factor, reduce_scatter_op, all_reduce_op
89+
90+
def _reduce_value(
91+
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
92+
) -> torch.Tensor:
93+
(
94+
pre_factor,
95+
post_factor,
96+
reduce_scatter_op,
97+
all_reduce_op,
98+
) = self._get_gradient_divide_factors()
99+
if pre_factor is not None:
100+
_div_if_needed(tensor, pre_factor)
101+
reduced = super()._reduce_value(tensor, mesh, mesh_dim)
102+
if post_factor is not None:
103+
_div_if_needed(reduced, post_factor)
104+
return reduced
105+
106+
def _reduce_shard_value(
107+
self,
108+
tensor: torch.Tensor,
109+
mesh: DeviceMesh,
110+
mesh_dim: int,
111+
shard_spec: Placement,
112+
) -> torch.Tensor:
113+
(
114+
pre_factor,
115+
post_factor,
116+
reduce_scatter_op,
117+
all_reduce_op,
118+
) = self._get_gradient_divide_factors()
119+
120+
if pre_factor is not None:
121+
_div_if_needed(tensor, pre_factor)
122+
shard_spec = cast(Shard, shard_spec)
123+
reduced = shard_spec._reduce_shard_tensor_test(tensor, mesh, reduce_scatter_op, mesh_dim)
124+
125+
if post_factor is not None:
126+
_div_if_needed(reduced, post_factor)
127+
return reduced
128+
129+
52130
def _distribute_dtensor(
53131
tensor: DTensor,
54132
device_mesh: DeviceMesh,
@@ -192,18 +270,26 @@ def __init__(
192270
mode,
193271
regional_ac,
194272
mp_policy,
273+
gradient_divide_factor,
195274
):
196275
super().__init__()
197276
self.device_mesh = device_mesh
198277
self.param_sharding = param_sharding
199278
self.mode = mode
200279
self.compute_placements = [Replicate()] * self.device_mesh.ndim
201-
self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim
280+
self.grad_placements = [
281+
SimpleFSDPPartial(
282+
reduce_op="avg",
283+
gradient_divide_factor=gradient_divide_factor,
284+
reduce_dtype=mp_policy.reduce_dtype,
285+
)
286+
if gradient_divide_factor is not None
287+
else Partial(reduce_op="avg")
288+
] * self.device_mesh.ndim
202289
self.regional_ac = regional_ac
203290
mp_policy = mp_policy or MixedPrecisionPolicy()
204291
self.param_dtype = mp_policy.param_dtype
205292
self.reduce_dtype = mp_policy.reduce_dtype
206-
self.ep_mesh_name, self.tp_mesh_name = "ep", "tp"
207293

208294
def replicate_compute(self, x):
209295
# data parallel runtime replicate parameters and do local compute
@@ -286,6 +372,7 @@ def data_parallel(
286372
ac_mode: str = "none",
287373
mp_policy: Optional[MixedPrecisionPolicy] = None,
288374
shard_dim: int = 0,
375+
gradient_divide_factor: Optional[float] = None,
289376
):
290377
if mode == "replicate":
291378
param_sharding = (Replicate(),)
@@ -348,6 +435,7 @@ def data_parallel(
348435
mode,
349436
regional_ac,
350437
mp_policy=mp_policy,
438+
gradient_divide_factor=gradient_divide_factor,
351439
),
352440
)
353441
return model

0 commit comments

Comments
 (0)