Commit 3e1b843
authored
fix: pp grad accumulation is broken (#1732)
[problem]
Using gradient accumulation is incompatible with PipleineSchedule(...,
scale_grads=True) option, which defaults to True.
When this option is set, at each step, all gradients are scaled by the
micro-batch size. This works fine for a single gradient accumulation
step, but when using multiple steps, this will rescale the total
gradient by this factor, not just at the end of gradient accumulation.
The result is that the accumulated gradient is an exponential moving
average, rather than a sum. Overall, the resulting gradients are much
smaller than they should be and using gradient accumulation with PP is
not equivalent to using it without PP -- the loss curves diverge
substantially, as well as the gradient-norms are way off.
A secondary consequence of is that at every step, it divides the
gradients by n_microbatches, which is computationally expensive when
applied to a large model.
[solution]
Set "scale_grads=False" when creating the scheduler instance.
Compute "n_microbatches" in the constructor and apply this factor, along
with gradient_accumulation_steps, to the scale factor in
"rescale_accumulated_loss()". This will cause the loss to be scaled,
rather than the gradients, at each step by the correct factor.
A secondary benifit of this approach is that it avoids having to modify
all of the gradients. It's much cheaper, computationally than modifying
all of the gradients -- and it's correct, which it is not, without the
change.
A side effect of the previous change is that the loss values returned by
the pipeline have been scaled by this factor, which makes them too small
by a factor of n_microbatches. We can correct this by rescaling the
returned loss by the same factor.
[testing]
Witout these changes, a baseline run, with 10 gradient accumulation
steps, on a single GPU is compared against a run (without the changes)
to a 2 GPU pipeline, using 1F1B. The effective batch size is 320 in both
cases, with all other variables controlled. The result is a substantial
divergence between the loss curves and gradient-norm of the two runs.
With this change applied, the results are nearly identical, ignoring
minor differences from non-determinism.
[references]
scale_grads option:
https://github.com/pytorch/pytorch/blob/281bb56cc50073159c8418c5c99c7459c914c4db/torch/distributed/pipelining/schedules.py#L286
scale_grads implementation:
https://github.com/pytorch/pytorch/blob/281bb56cc50073159c8418c5c99c7459c914c4db/torch/distributed/pipelining/stage.py#L567
Test code for reproduction of the issue and the testing the fix:
https://github.com/jdinalt/forgather/tree/main/examples/torchtitan/test_parallelisms1 parent 0f34257 commit 3e1b843
File tree
5 files changed
+19
-5
lines changed- torchtitan
- components
- distributed
- experiments
- deepseek_v3
- forge
5 files changed
+19
-5
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
150 | | - | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
151 | 154 | | |
152 | 155 | | |
153 | 156 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
| |||
82 | 83 | | |
83 | 84 | | |
84 | 85 | | |
85 | | - | |
| 86 | + | |
| 87 | + | |
86 | 88 | | |
87 | 89 | | |
88 | 90 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
126 | 126 | | |
127 | 127 | | |
128 | 128 | | |
129 | | - | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
130 | 133 | | |
131 | 134 | | |
132 | 135 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
197 | 197 | | |
198 | 198 | | |
199 | 199 | | |
200 | | - | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
201 | 204 | | |
202 | 205 | | |
203 | 206 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
457 | 457 | | |
458 | 458 | | |
459 | 459 | | |
460 | | - | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
461 | 464 | | |
462 | 465 | | |
463 | 466 | | |
| |||
0 commit comments