Skip to content

Commit 4bb5c2c

Browse files
wconstabpytorchmergebot
authored andcommitted
Add docstring to DDPOptimizer (pytorch#88521)
Pull Request resolved: pytorch#88521 Approved by: https://github.com/aazzolini
1 parent 1f32c3c commit 4bb5c2c

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

torch/_dynamo/optimizations/distributed.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,59 @@ def pretty_print_buckets(buckets: List[Bucket]):
5454

5555

5656
class DDPOptimizer:
57+
"""
58+
DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
59+
breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
60+
the boundaries of gradient-allreduce buckets chosen by DDP.
61+
62+
Background/Motivation
63+
- DDP uses allreduce collectives to synchronize partial gradients computed on different workers
64+
- DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
65+
- Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
66+
at around the same time during backward and thus can share the same allreduce efficently
67+
- Allreduces must overlap with backward compute for optimal training performance
68+
- DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
69+
operates when individual grads become 'ready'
70+
- Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
71+
autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole
72+
fused backward function executes, preventing any overlap of compute and communication
73+
74+
Algorithm
75+
- DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse
76+
this graph in reverse order to determine the true order that gradients will become ready during backward.
77+
- Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
78+
and a graph break introduced
79+
- Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
80+
into an outer module that is returned to the user
81+
82+
Notes
83+
- It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
84+
and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
85+
in eager.
86+
- If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
87+
produce splits that do not necessarily align with the buckets used by DDP. This should result in performance
88+
degradation approaching the baseline case where graph-splits are not used, but not worse.
89+
- If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
90+
subgraphs being compiled
91+
- DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
92+
left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are
93+
also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP,
94+
it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
95+
- DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
96+
and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by
97+
DDPOptimizer)
98+
99+
Args:
100+
bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be
101+
set to match the equivalent parameter on the original DDP module.
102+
103+
backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
104+
105+
first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP
106+
special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
107+
108+
"""
109+
57110
def __init__(
58111
self,
59112
bucket_bytes_cap: int,

0 commit comments

Comments
 (0)