Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove abs from save list for per op AC to fix bug causing unexpected…
… increase in peak memory usage for float8 training with row-wise scaling (#820) This fixes a bug in how per op AC interacts with float8 training. The bug surfaced through unexpected increase in peak memory usage when integrating row-wise scaling into float8 training in #808. The full RCA can be in found in this thread starting [here](#808 (comment)). ## TL;DR - I [tested](#808 (comment)) different model layers and permutations of configurations and found the minimal repro was a FFN with per op AC + FSDP + compile + float8 row-wise scaling. - I [analyzed](#808 (comment)) the triton kernels generated for the compiled FFN forward pass in bf16 vs fp8, tracking peak cumulative memory allocated and what was saved for backward. I found the fp8 forward kernel was saving a huge buffer/tensor for backward which the bf16 kernel was not saving. - I [tracked](#808 (comment)) how the various buffers were used and determined this huge buffer saved for backwards was holding `abs(W3)`, where W3 is an **unsharded weight tensor.** - I [tested](#808 (comment)) removing `abs()` from the save list for per op AC, and confirmed it resolved the memory issue without impacting TPS. ## Benchmarks **float8 row-wise WITH storing abs() op output:** ``` [rank0]:2025-02-04 14:13:52,324 - root - INFO - step: 10 loss: 9.9765 memory: 58.55GiB(61.63%) tps: 6,379 mfu: 37.35% [rank0]:2025-02-04 14:14:05,145 - root - INFO - step: 20 loss: 8.3610 memory: 58.55GiB(61.63%) tps: 6,390 mfu: 37.42% [rank0]:2025-02-04 14:14:17,975 - root - INFO - step: 30 loss: 7.6860 memory: 58.55GiB(61.63%) tps: 6,386 mfu: 37.39% ``` **float8 row-wise WITHOUT storing abs() op output:** ``` [rank0]:2025-02-04 14:19:22,446 - root - INFO - step: 10 loss: 10.1946 memory: 47.77GiB(50.28%) tps: 6,293 mfu: 36.85% [rank0]:2025-02-04 14:19:35,190 - root - INFO - step: 20 loss: 8.4865 memory: 47.77GiB(50.28%) tps: 6,429 mfu: 37.64% [rank0]:2025-02-04 14:19:47,951 - root - INFO - step: 30 loss: 7.6982 memory: 47.77GiB(50.28%) tps: 6,420 mfu: 37.60% ``` I also confirmed this also does not cause a regression in memory usage or TPS with tensorwise scaling: **float8 tensorwise WITH storing abs() output:** ``` [rank0]:2025-02-04 15:06:53,589 - root - INFO - step: 10 loss: 9.8912 memory: 47.77GiB(50.28%) tps: 6,881 mfu: 40.30% [rank0]:2025-02-04 15:07:05,499 - root - INFO - step: 20 loss: 8.4239 memory: 47.77GiB(50.28%) tps: 6,879 mfu: 40.28% [rank0]:2025-02-04 15:07:17,412 - root - INFO - step: 30 loss: 7.6615 memory: 47.77GiB(50.28%) tps: 6,877 mfu: 40.27% ``` **float8 tensorwise WITHOUT storing abs() output:** ``` [rank0]:2025-02-04 15:08:32,527 - root - INFO - step: 10 loss: 9.9628 memory: 47.77GiB(50.28%) tps: 6,865 mfu: 40.20% [rank0]:2025-02-04 15:08:44,450 - root - INFO - step: 20 loss: 8.5451 memory: 47.77GiB(50.28%) tps: 6,871 mfu: 40.24% [rank0]:2025-02-04 15:08:56,383 - root - INFO - step: 30 loss: 7.8286 memory: 47.77GiB(50.28%) tps: 6,865 mfu: 40.20% ```
- Loading branch information