|
7 | 7 | from collections.abc import Sequence
|
8 | 8 | from contextlib import contextmanager
|
9 | 9 | from dataclasses import dataclass
|
10 |
| -from typing import List, Optional |
| 10 | +from typing import cast, List, Optional, Union |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 | import torch.nn as nn
|
|
20 | 20 | Shard,
|
21 | 21 | )
|
22 | 22 | from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
| 23 | +from torch.distributed.distributed_c10d import ReduceOp |
| 24 | +from torch.distributed.fsdp._fully_shard._fsdp_collectives import _div_if_needed |
23 | 25 | from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
24 | 26 | from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
25 | 27 | from torch.distributed.tensor.placement_types import _StridedShard, Placement
|
@@ -49,6 +51,82 @@ class MixedPrecisionPolicy:
|
49 | 51 | reduce_dtype: Optional[torch.dtype] = None
|
50 | 52 |
|
51 | 53 |
|
| 54 | +@dataclass(frozen=True) |
| 55 | +class SimpleFSDPPartial(Partial): |
| 56 | + gradient_divide_factor: Optional[float] = None |
| 57 | + reduce_dtype: Optional[torch.dtype] = None |
| 58 | + |
| 59 | + def _get_gradient_divide_factors( |
| 60 | + self, |
| 61 | + ) -> tuple[ |
| 62 | + Optional[float], |
| 63 | + Optional[float], |
| 64 | + str, |
| 65 | + str, |
| 66 | + ]: |
| 67 | + """ |
| 68 | + the logic follows |
| 69 | + https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688 |
| 70 | + """ |
| 71 | + if self.gradient_divide_factor is None: |
| 72 | + return None, None, None, None |
| 73 | + overflow_risk = self.reduce_dtype not in (torch.float32, torch.bfloat16) |
| 74 | + pre_factor: Optional[float] = None |
| 75 | + post_factor: Optional[float] = None |
| 76 | + # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid |
| 77 | + # overflow/underflow. For N data parallel workers, each worker computes |
| 78 | + # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid |
| 79 | + # overflow/underflow, we divide by ~sqrt(N) before/after the reduction. |
| 80 | + pre_factor = 1 |
| 81 | + while ( |
| 82 | + self.gradient_divide_factor % pre_factor == 0 |
| 83 | + and self.gradient_divide_factor / pre_factor > pre_factor |
| 84 | + ): |
| 85 | + pre_factor *= 2 |
| 86 | + post_factor = self.gradient_divide_factor / pre_factor |
| 87 | + reduce_scatter_op, all_reduce_op = "sum", "sum" |
| 88 | + return pre_factor, post_factor, reduce_scatter_op, all_reduce_op |
| 89 | + |
| 90 | + def _reduce_value( |
| 91 | + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int |
| 92 | + ) -> torch.Tensor: |
| 93 | + ( |
| 94 | + pre_factor, |
| 95 | + post_factor, |
| 96 | + reduce_scatter_op, |
| 97 | + all_reduce_op, |
| 98 | + ) = self._get_gradient_divide_factors() |
| 99 | + if pre_factor is not None: |
| 100 | + _div_if_needed(tensor, pre_factor) |
| 101 | + reduced = super()._reduce_value(tensor, mesh, mesh_dim) |
| 102 | + if post_factor is not None: |
| 103 | + _div_if_needed(reduced, post_factor) |
| 104 | + return reduced |
| 105 | + |
| 106 | + def _reduce_shard_value( |
| 107 | + self, |
| 108 | + tensor: torch.Tensor, |
| 109 | + mesh: DeviceMesh, |
| 110 | + mesh_dim: int, |
| 111 | + shard_spec: Placement, |
| 112 | + ) -> torch.Tensor: |
| 113 | + ( |
| 114 | + pre_factor, |
| 115 | + post_factor, |
| 116 | + reduce_scatter_op, |
| 117 | + all_reduce_op, |
| 118 | + ) = self._get_gradient_divide_factors() |
| 119 | + |
| 120 | + if pre_factor is not None: |
| 121 | + _div_if_needed(tensor, pre_factor) |
| 122 | + shard_spec = cast(Shard, shard_spec) |
| 123 | + reduced = shard_spec._reduce_shard_tensor_test(tensor, mesh, reduce_scatter_op, mesh_dim) |
| 124 | + |
| 125 | + if post_factor is not None: |
| 126 | + _div_if_needed(reduced, post_factor) |
| 127 | + return reduced |
| 128 | + |
| 129 | + |
52 | 130 | def _distribute_dtensor(
|
53 | 131 | tensor: DTensor,
|
54 | 132 | device_mesh: DeviceMesh,
|
@@ -192,18 +270,26 @@ def __init__(
|
192 | 270 | mode,
|
193 | 271 | regional_ac,
|
194 | 272 | mp_policy,
|
| 273 | + gradient_divide_factor, |
195 | 274 | ):
|
196 | 275 | super().__init__()
|
197 | 276 | self.device_mesh = device_mesh
|
198 | 277 | self.param_sharding = param_sharding
|
199 | 278 | self.mode = mode
|
200 | 279 | self.compute_placements = [Replicate()] * self.device_mesh.ndim
|
201 |
| - self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim |
| 280 | + self.grad_placements = [ |
| 281 | + SimpleFSDPPartial( |
| 282 | + reduce_op="avg", |
| 283 | + gradient_divide_factor=gradient_divide_factor, |
| 284 | + reduce_dtype=mp_policy.reduce_dtype, |
| 285 | + ) |
| 286 | + if gradient_divide_factor is not None |
| 287 | + else Partial(reduce_op="avg") |
| 288 | + ] * self.device_mesh.ndim |
202 | 289 | self.regional_ac = regional_ac
|
203 | 290 | mp_policy = mp_policy or MixedPrecisionPolicy()
|
204 | 291 | self.param_dtype = mp_policy.param_dtype
|
205 | 292 | self.reduce_dtype = mp_policy.reduce_dtype
|
206 |
| - self.ep_mesh_name, self.tp_mesh_name = "ep", "tp" |
207 | 293 |
|
208 | 294 | def replicate_compute(self, x):
|
209 | 295 | # data parallel runtime replicate parameters and do local compute
|
@@ -286,6 +372,7 @@ def data_parallel(
|
286 | 372 | ac_mode: str = "none",
|
287 | 373 | mp_policy: Optional[MixedPrecisionPolicy] = None,
|
288 | 374 | shard_dim: int = 0,
|
| 375 | + gradient_divide_factor: Optional[float] = None, |
289 | 376 | ):
|
290 | 377 | if mode == "replicate":
|
291 | 378 | param_sharding = (Replicate(),)
|
@@ -348,6 +435,7 @@ def data_parallel(
|
348 | 435 | mode,
|
349 | 436 | regional_ac,
|
350 | 437 | mp_policy=mp_policy,
|
| 438 | + gradient_divide_factor=gradient_divide_factor, |
351 | 439 | ),
|
352 | 440 | )
|
353 | 441 | return model
|
0 commit comments