Skip to content

Commit 8824727

Browse files
Remove abs from save list for per op AC to fix bug causing unexpected increase in peak memory usage for float8 training with row-wise scaling (#820)
This fixes a bug in how per op AC interacts with float8 training. The bug surfaced through unexpected increase in peak memory usage when integrating row-wise scaling into float8 training in #808. The full RCA can be in found in this thread starting [here](#808 (comment)). ## TL;DR - I [tested](#808 (comment)) different model layers and permutations of configurations and found the minimal repro was a FFN with per op AC + FSDP + compile + float8 row-wise scaling. - I [analyzed](#808 (comment)) the triton kernels generated for the compiled FFN forward pass in bf16 vs fp8, tracking peak cumulative memory allocated and what was saved for backward. I found the fp8 forward kernel was saving a huge buffer/tensor for backward which the bf16 kernel was not saving. - I [tracked](#808 (comment)) how the various buffers were used and determined this huge buffer saved for backwards was holding `abs(W3)`, where W3 is an **unsharded weight tensor.** - I [tested](#808 (comment)) removing `abs()` from the save list for per op AC, and confirmed it resolved the memory issue without impacting TPS. ## Benchmarks **float8 row-wise WITH storing abs() op output:** ``` [rank0]:2025-02-04 14:13:52,324 - root - INFO - step: 10 loss: 9.9765 memory: 58.55GiB(61.63%) tps: 6,379 mfu: 37.35% [rank0]:2025-02-04 14:14:05,145 - root - INFO - step: 20 loss: 8.3610 memory: 58.55GiB(61.63%) tps: 6,390 mfu: 37.42% [rank0]:2025-02-04 14:14:17,975 - root - INFO - step: 30 loss: 7.6860 memory: 58.55GiB(61.63%) tps: 6,386 mfu: 37.39% ``` **float8 row-wise WITHOUT storing abs() op output:** ``` [rank0]:2025-02-04 14:19:22,446 - root - INFO - step: 10 loss: 10.1946 memory: 47.77GiB(50.28%) tps: 6,293 mfu: 36.85% [rank0]:2025-02-04 14:19:35,190 - root - INFO - step: 20 loss: 8.4865 memory: 47.77GiB(50.28%) tps: 6,429 mfu: 37.64% [rank0]:2025-02-04 14:19:47,951 - root - INFO - step: 30 loss: 7.6982 memory: 47.77GiB(50.28%) tps: 6,420 mfu: 37.60% ``` I also confirmed this also does not cause a regression in memory usage or TPS with tensorwise scaling: **float8 tensorwise WITH storing abs() output:** ``` [rank0]:2025-02-04 15:06:53,589 - root - INFO - step: 10 loss: 9.8912 memory: 47.77GiB(50.28%) tps: 6,881 mfu: 40.30% [rank0]:2025-02-04 15:07:05,499 - root - INFO - step: 20 loss: 8.4239 memory: 47.77GiB(50.28%) tps: 6,879 mfu: 40.28% [rank0]:2025-02-04 15:07:17,412 - root - INFO - step: 30 loss: 7.6615 memory: 47.77GiB(50.28%) tps: 6,877 mfu: 40.27% ``` **float8 tensorwise WITHOUT storing abs() output:** ``` [rank0]:2025-02-04 15:08:32,527 - root - INFO - step: 10 loss: 9.9628 memory: 47.77GiB(50.28%) tps: 6,865 mfu: 40.20% [rank0]:2025-02-04 15:08:44,450 - root - INFO - step: 20 loss: 8.5451 memory: 47.77GiB(50.28%) tps: 6,871 mfu: 40.24% [rank0]:2025-02-04 15:08:56,383 - root - INFO - step: 30 loss: 7.8286 memory: 47.77GiB(50.28%) tps: 6,865 mfu: 40.20% ```
1 parent 690f299 commit 8824727

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ def apply_tp(
217217
torch.ops.aten._scaled_dot_product_flash_attention.default,
218218
torch.ops._c10d_functional.reduce_scatter_tensor.default,
219219
# for low precision training, it's useful to always save
220-
# the result of max(abs(tensor))
221-
torch.ops.aten.abs.default,
220+
# the result of max, since the absolute maximum is
221+
# used to compute the scaling factor for quantization.
222222
torch.ops.aten.max.default,
223223
}
224224

0 commit comments

Comments
 (0)