Skip to content

Commit

Permalink
Remove abs from save list for per op AC to fix bug causing unexpected…
Browse files Browse the repository at this point in the history
… increase in peak memory usage for float8 training with row-wise scaling (#820)

This fixes a bug in how per op AC interacts with float8 training. The
bug surfaced through unexpected increase in peak memory usage when
integrating row-wise scaling into float8 training in
#808. The full RCA can be in
found in this thread starting
[here](#808 (comment)).

## TL;DR
- I
[tested](#808 (comment))
different model layers and permutations of configurations and found the
minimal repro was a FFN with per op AC + FSDP + compile + float8
row-wise scaling.
- I
[analyzed](#808 (comment))
the triton kernels generated for the compiled FFN forward pass in bf16
vs fp8, tracking peak cumulative memory allocated and what was saved for
backward. I found the fp8 forward kernel was saving a huge buffer/tensor
for backward which the bf16 kernel was not saving.
- I
[tracked](#808 (comment))
how the various buffers were used and determined this huge buffer saved
for backwards was holding `abs(W3)`, where W3 is an **unsharded weight
tensor.**
- I
[tested](#808 (comment))
removing `abs()` from the save list for per op AC, and confirmed it
resolved the memory issue without impacting TPS.


## Benchmarks 

**float8 row-wise WITH storing abs() op output:**
```
[rank0]:2025-02-04 14:13:52,324 - root - INFO - step: 10  loss:  9.9765  memory: 58.55GiB(61.63%)  tps: 6,379  mfu: 37.35%
[rank0]:2025-02-04 14:14:05,145 - root - INFO - step: 20  loss:  8.3610  memory: 58.55GiB(61.63%)  tps: 6,390  mfu: 37.42%
[rank0]:2025-02-04 14:14:17,975 - root - INFO - step: 30  loss:  7.6860  memory: 58.55GiB(61.63%)  tps: 6,386  mfu: 37.39%
```

**float8 row-wise WITHOUT storing abs() op output:**
```
[rank0]:2025-02-04 14:19:22,446 - root - INFO - step: 10  loss: 10.1946  memory: 47.77GiB(50.28%)  tps: 6,293  mfu: 36.85%
[rank0]:2025-02-04 14:19:35,190 - root - INFO - step: 20  loss:  8.4865  memory: 47.77GiB(50.28%)  tps: 6,429  mfu: 37.64%
[rank0]:2025-02-04 14:19:47,951 - root - INFO - step: 30  loss:  7.6982  memory: 47.77GiB(50.28%)  tps: 6,420  mfu: 37.60%
```

I also confirmed this also does not cause a regression in memory usage
or TPS with tensorwise scaling:

**float8 tensorwise WITH storing abs() output:**
```
[rank0]:2025-02-04 15:06:53,589 - root - INFO - step: 10  loss:  9.8912  memory: 47.77GiB(50.28%)  tps: 6,881  mfu: 40.30%
[rank0]:2025-02-04 15:07:05,499 - root - INFO - step: 20  loss:  8.4239  memory: 47.77GiB(50.28%)  tps: 6,879  mfu: 40.28%
[rank0]:2025-02-04 15:07:17,412 - root - INFO - step: 30  loss:  7.6615  memory: 47.77GiB(50.28%)  tps: 6,877  mfu: 40.27%
```

**float8 tensorwise WITHOUT storing abs() output:**

```
[rank0]:2025-02-04 15:08:32,527 - root - INFO - step: 10  loss:  9.9628  memory: 47.77GiB(50.28%)  tps: 6,865  mfu: 40.20%
[rank0]:2025-02-04 15:08:44,450 - root - INFO - step: 20  loss:  8.5451  memory: 47.77GiB(50.28%)  tps: 6,871  mfu: 40.24%
[rank0]:2025-02-04 15:08:56,383 - root - INFO - step: 30  loss:  7.8286  memory: 47.77GiB(50.28%)  tps: 6,865  mfu: 40.20%
```
  • Loading branch information
danielvegamyhre authored Feb 6, 2025
1 parent 690f299 commit 8824727
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def apply_tp(
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
# for low precision training, it's useful to always save
# the result of max(abs(tensor))
torch.ops.aten.abs.default,
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
}

Expand Down

0 comments on commit 8824727

Please sign in to comment.