Commit 8824727
authored
Remove abs from save list for per op AC to fix bug causing unexpected increase in peak memory usage for float8 training with row-wise scaling (#820)
This fixes a bug in how per op AC interacts with float8 training. The
bug surfaced through unexpected increase in peak memory usage when
integrating row-wise scaling into float8 training in
#808. The full RCA can be in
found in this thread starting
[here](#808 (comment)).
## TL;DR
- I
[tested](#808 (comment))
different model layers and permutations of configurations and found the
minimal repro was a FFN with per op AC + FSDP + compile + float8
row-wise scaling.
- I
[analyzed](#808 (comment))
the triton kernels generated for the compiled FFN forward pass in bf16
vs fp8, tracking peak cumulative memory allocated and what was saved for
backward. I found the fp8 forward kernel was saving a huge buffer/tensor
for backward which the bf16 kernel was not saving.
- I
[tracked](#808 (comment))
how the various buffers were used and determined this huge buffer saved
for backwards was holding `abs(W3)`, where W3 is an **unsharded weight
tensor.**
- I
[tested](#808 (comment))
removing `abs()` from the save list for per op AC, and confirmed it
resolved the memory issue without impacting TPS.
## Benchmarks
**float8 row-wise WITH storing abs() op output:**
```
[rank0]:2025-02-04 14:13:52,324 - root - INFO - step: 10 loss: 9.9765 memory: 58.55GiB(61.63%) tps: 6,379 mfu: 37.35%
[rank0]:2025-02-04 14:14:05,145 - root - INFO - step: 20 loss: 8.3610 memory: 58.55GiB(61.63%) tps: 6,390 mfu: 37.42%
[rank0]:2025-02-04 14:14:17,975 - root - INFO - step: 30 loss: 7.6860 memory: 58.55GiB(61.63%) tps: 6,386 mfu: 37.39%
```
**float8 row-wise WITHOUT storing abs() op output:**
```
[rank0]:2025-02-04 14:19:22,446 - root - INFO - step: 10 loss: 10.1946 memory: 47.77GiB(50.28%) tps: 6,293 mfu: 36.85%
[rank0]:2025-02-04 14:19:35,190 - root - INFO - step: 20 loss: 8.4865 memory: 47.77GiB(50.28%) tps: 6,429 mfu: 37.64%
[rank0]:2025-02-04 14:19:47,951 - root - INFO - step: 30 loss: 7.6982 memory: 47.77GiB(50.28%) tps: 6,420 mfu: 37.60%
```
I also confirmed this also does not cause a regression in memory usage
or TPS with tensorwise scaling:
**float8 tensorwise WITH storing abs() output:**
```
[rank0]:2025-02-04 15:06:53,589 - root - INFO - step: 10 loss: 9.8912 memory: 47.77GiB(50.28%) tps: 6,881 mfu: 40.30%
[rank0]:2025-02-04 15:07:05,499 - root - INFO - step: 20 loss: 8.4239 memory: 47.77GiB(50.28%) tps: 6,879 mfu: 40.28%
[rank0]:2025-02-04 15:07:17,412 - root - INFO - step: 30 loss: 7.6615 memory: 47.77GiB(50.28%) tps: 6,877 mfu: 40.27%
```
**float8 tensorwise WITHOUT storing abs() output:**
```
[rank0]:2025-02-04 15:08:32,527 - root - INFO - step: 10 loss: 9.9628 memory: 47.77GiB(50.28%) tps: 6,865 mfu: 40.20%
[rank0]:2025-02-04 15:08:44,450 - root - INFO - step: 20 loss: 8.5451 memory: 47.77GiB(50.28%) tps: 6,871 mfu: 40.24%
[rank0]:2025-02-04 15:08:56,383 - root - INFO - step: 30 loss: 7.8286 memory: 47.77GiB(50.28%) tps: 6,865 mfu: 40.20%
```1 parent 690f299 commit 8824727
1 file changed
+2
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
217 | 217 | | |
218 | 218 | | |
219 | 219 | | |
220 | | - | |
221 | | - | |
| 220 | + | |
| 221 | + | |
222 | 222 | | |
223 | 223 | | |
224 | 224 | | |
| |||
0 commit comments