Distributed Data Parallel
(DDP
) is the most used parallelism strategy for distributed training. It partitions the input data batch across multiple devices, replicates the model on each device, and synchronizes gradient (e.g. with AllReduce
) in the background.
veScale DDP
is primarily inherited from Megatron-LM's DDP for its performance and compatibility with ZeRO optimizer. We extend and enhance the original DDP with extra features surrounding veScale DTensor
and DModule
:
-
conversion between
DTensor
andTensor
gradients -
support nested gradient synchronization with
DModule
(for Sequence Parallel) -
support gradient synchronization for dynamic control flow
DDP
is a module (DModule
) wrapper that creates a flattened Gradient Buffer that stores the gradients produced by the model backward.
(This is achieved by adding a hook to the grad_fn
of the model parameters, which fills DTensor
gradient outputed by PyTorch Autograd engine to the pre-allocated grad buffer.)
The purpose of Gradient Buffer is to both accelerate gradient synchronization and reduce memory fragmentation, as it only needs to be performed once for the entire buffer, rather than once per parameter.
For extreme performance, the Gradient Buffer is further divided into multiple _Bucket_s such that the backward compute and gradient synchronization of each Bucket can be overlapped. As soon as all gradients in a Bucket are generated, we can immediately trigger the gradient synchronization rather than waiting until the whole Gradient Buffer is ready.
The gradient synchronization can be either AllReduce
or ReduceScatter
under the DDP hood:
-
AllReduce
is used when no ZeRO optimizer -
ReduceScatter
is used when ZeRO optimizer (e.g.,DistributedOptimizer
) exists
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.dmodule.api import parallelize_module
from vescale.dtensor.device_mesh import DeviceMesh
# create an torch-native model
mlp = MLP()
# create 2-dim DeviceMesh, the first for data-parallel, while the second for tensor-parallel.
device_mesh = DeviceMesh("cuda", [[0, 1], [2, 3]], mesh_dim_names=("DP", "TP"))
# parallelize torch-native model into TP model
tp_mlp = parallelize_module(mlp, device_mesh["TP"], sharding_plan)
# wrap TP model with `DDP`
dp_tp_mlp = DDP(
# feed the TP model
tp_mlp,
# feed DP's sub-mesh or just `device_mesh` (i.e., by default we treat the first dim of devicemesh as data-parallelism).
device_mesh["DP"],
# choose whether overlap gradient all-reduce with backward
overlap_grad_reduce=True or False,
# choose whether used `DistributedOptimizer`
# if True, `DDP` will be used with `DistributedOptimizer`, so `DDP` reduce-scatter the gradient along data-parallel ranks.
# if False, `DDP` just all-reduce the gradient along data-parallel ranks.
use_distributed_optimizer=True or False
)
# train model
dp_tp_mlp(torch.rand(...)).sum().bakward()
# all-reduce / reduce-scatter the gradient across the DP world.
dp_tp_mlp.finish_grad_sync()
-
APIs can be found in
<repo>/vescale/ddp/distributed_data_parallel.py
-
More examples can be found in
<repo>/test/parallel/ddp_optim/test_ddp.py