Skip to content

Commit 3e1b843

Browse files
authored
fix: pp grad accumulation is broken (#1732)
[problem] Using gradient accumulation is incompatible with PipleineSchedule(..., scale_grads=True) option, which defaults to True. When this option is set, at each step, all gradients are scaled by the micro-batch size. This works fine for a single gradient accumulation step, but when using multiple steps, this will rescale the total gradient by this factor, not just at the end of gradient accumulation. The result is that the accumulated gradient is an exponential moving average, rather than a sum. Overall, the resulting gradients are much smaller than they should be and using gradient accumulation with PP is not equivalent to using it without PP -- the loss curves diverge substantially, as well as the gradient-norms are way off. A secondary consequence of is that at every step, it divides the gradients by n_microbatches, which is computationally expensive when applied to a large model. [solution] Set "scale_grads=False" when creating the scheduler instance. Compute "n_microbatches" in the constructor and apply this factor, along with gradient_accumulation_steps, to the scale factor in "rescale_accumulated_loss()". This will cause the loss to be scaled, rather than the gradients, at each step by the correct factor. A secondary benifit of this approach is that it avoids having to modify all of the gradients. It's much cheaper, computationally than modifying all of the gradients -- and it's correct, which it is not, without the change. A side effect of the previous change is that the loss values returned by the pipeline have been scaled by this factor, which makes them too small by a factor of n_microbatches. We can correct this by rescaling the returned loss by the same factor. [testing] Witout these changes, a baseline run, with 10 gradient accumulation steps, on a single GPU is compared against a run (without the changes) to a 2 GPU pipeline, using 1F1B. The effective batch size is 320 in both cases, with all other variables controlled. The result is a substantial divergence between the loss curves and gradient-norm of the two runs. With this change applied, the results are nearly identical, ignoring minor differences from non-determinism. [references] scale_grads option: https://github.com/pytorch/pytorch/blob/281bb56cc50073159c8418c5c99c7459c914c4db/torch/distributed/pipelining/schedules.py#L286 scale_grads implementation: https://github.com/pytorch/pytorch/blob/281bb56cc50073159c8418c5c99c7459c914c4db/torch/distributed/pipelining/stage.py#L567 Test code for reproduction of the issue and the testing the fix: https://github.com/jdinalt/forgather/tree/main/examples/torchtitan/test_parallelisms
1 parent 0f34257 commit 3e1b843

File tree

5 files changed

+19
-5
lines changed

5 files changed

+19
-5
lines changed

torchtitan/components/validate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ def validate(
147147
# accumulate losses across pipeline microbatches
148148
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
149149
loss = (
150-
torch.mean(torch.stack(losses)).to(device_type)
150+
# using sum instead of mean because we already rescale the
151+
# loss_fn down by a factor of n_microbatches in
152+
# torchtitan/distributed/pipeline_parallel.py
153+
torch.sum(torch.stack(losses)).to(device_type)
151154
if self.pp_has_last_stage
152155
else torch.tensor([-1.0], device=device_type)
153156
)

torchtitan/distributed/pipeline_parallel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ScheduleZBVZeroBubble,
2323
)
2424

25+
from torchtitan.components.loss import rescale_accumulated_loss
2526
from torchtitan.config import JobConfig
2627
from torchtitan.tools.logging import logger
2728

@@ -82,7 +83,8 @@ def build_pipeline_schedule(
8283
schedule = schedule_class(
8384
stages if looped_schedule else stages[0],
8485
n_microbatches=n_microbatches,
85-
loss_fn=loss_fn,
86+
loss_fn=rescale_accumulated_loss(loss_fn, n_microbatches),
87+
scale_grads=False,
8688
)
8789
logger.info(
8890
f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} "

torchtitan/experiments/deepseek_v3/train_ds_dev.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,10 @@ def run_full_model(
126126
y = pp_schedule.step(x)
127127
elif pp_rank == pp_size - 1:
128128
y = pp_schedule.step(target=label, losses=losses)
129-
loss = torch.mean(torch.stack(losses))
129+
# using sum instead of mean because we already rescale the
130+
# loss_fn down by a factor of n_microbatches in
131+
# torchtitan/distributed/pipeline_parallel.py
132+
loss = torch.sum(torch.stack(losses))
130133
else:
131134
pp_schedule.step()
132135
else:

torchtitan/experiments/forge/example_train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,10 @@ def forward_backward_step(
197197
# accumulate losses across pipeline microbatches
198198
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
199199
loss = (
200-
torch.mean(torch.stack(losses)).to(self.device)
200+
# using sum instead of mean because we already rescale the
201+
# loss_fn down by a factor of n_microbatches in
202+
# torchtitan/distributed/pipeline_parallel.py
203+
torch.sum(torch.stack(losses)).to(self.device)
201204
if self.pp_has_last_stage
202205
else torch.tensor([-1.0], device=self.device)
203206
)

torchtitan/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,10 @@ def forward_backward_step(
457457
# accumulate losses across pipeline microbatches
458458
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
459459
loss = (
460-
torch.mean(torch.stack(losses)).to(self.device)
460+
# using sum instead of mean because we already rescale the
461+
# loss_fn down by a factor of n_microbatches in
462+
# torchtitan/distributed/pipeline_parallel.py
463+
torch.sum(torch.stack(losses)).to(self.device)
461464
if self.pp_has_last_stage
462465
else torch.tensor([-1.0], device=self.device)
463466
)

0 commit comments

Comments
 (0)