@@ -54,6 +54,59 @@ def pretty_print_buckets(buckets: List[Bucket]):
54
54
55
55
56
56
class DDPOptimizer :
57
+ """
58
+ DDPOptimizer applies when dynamo compiles models wrapped in DistributedDataParallel (DDP),
59
+ breaking the dynamo graph into chunks to compile separately, with the breaks aligning to
60
+ the boundaries of gradient-allreduce buckets chosen by DDP.
61
+
62
+ Background/Motivation
63
+ - DDP uses allreduce collectives to synchronize partial gradients computed on different workers
64
+ - DDP groups gradient allreduces into 'buckets' to optimize communication efficiency of all-reduce
65
+ - Parameters grouped into buckets are assumed to be adjacent in time, so they become ready
66
+ at around the same time during backward and thus can share the same allreduce efficently
67
+ - Allreduces must overlap with backward compute for optimal training performance
68
+ - DDP schedules allreduces using 'hooks' fired from the c++ autograd engine in pytorch, which
69
+ operates when individual grads become 'ready'
70
+ - Dynamo+AOTAutograd produces a single fused graph that runs 'atomically' from the perspective of the
71
+ autograd engine, such that all gradients become 'ready' at the same time. Hooks fire after the whole
72
+ fused backward function executes, preventing any overlap of compute and communication
73
+
74
+ Algorithm
75
+ - DDPOptimizer starts off with an FX graph traced by dynamo which represents forward. It can traverse
76
+ this graph in reverse order to determine the true order that gradients will become ready during backward.
77
+ - Parameter sizes are counted in reverse order, up to a bucket size limit, at which point a new bucket is started
78
+ and a graph break introduced
79
+ - Each of the subgraphs is compiled by the compiler provided to dynamo by the user, and then fused back together
80
+ into an outer module that is returned to the user
81
+
82
+ Notes
83
+ - It would be better to enforce (by adding an API to DDP) that the bucket splits chosen here are used by DDP,
84
+ and that DDP does not need to detect or optimize bucket order by observing execution at runtime, as it does
85
+ in eager.
86
+ - If Dynamo can't capture a whole graph for the portion of the model wrapped by DDP, this algorithm will currently
87
+ produce splits that do not necessarily align with the buckets used by DDP. This should result in performance
88
+ degradation approaching the baseline case where graph-splits are not used, but not worse.
89
+ - If the backend compiler fails to compile a single subgraph, it will execute eagerly despite the rest of the
90
+ subgraphs being compiled
91
+ - DDP has a 'parameters_and_buffers_to_ignore' field, which DDPOptimizer attempts to honor by reading markers
92
+ left by DDP on individual parameters. In cases where other transformations, such as reparameterization, are
93
+ also used, the ignore markers could be lost. If DDPOptimizer fails to ignore a parameter ignored by DDP,
94
+ it is not catastrophic but could impact performance by choosing sub-optimal bucket splits.
95
+ - DDPOptimizer always ignores all buffers, regardless of their ignore flag, since buffers do not require gradients,
96
+ and therefore aren't allreduced by DDP. (They are broadcast during forward, but this is not covered by
97
+ DDPOptimizer)
98
+
99
+ Args:
100
+ bucket_bytes_cap (int): Controls the size of buckets, in bytes, used to determine graphbreaks. Should be
101
+ set to match the equivalent parameter on the original DDP module.
102
+
103
+ backend_compile_fn (callable): A dynamo compiler function, to be invoked to compile each subgraph.
104
+
105
+ first_bucket_cap (int): Controls the size of the first bucket. Should match DDP's first bucket cap. DDP
106
+ special-cases the first bucket size since it is sometimes optimal to start a small allreduce early.
107
+
108
+ """
109
+
57
110
def __init__ (
58
111
self ,
59
112
bucket_bytes_cap : int ,
0 commit comments