-
Notifications
You must be signed in to change notification settings - Fork 270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add configuration for float8 with rowwise scaling, via recipe lookup #808
base: main
Are you sure you want to change the base?
Conversation
3ebdf05
to
8ffb32c
Compare
I'm looking into the unexpectedly high peak memory. As @vkuzo mentioned it is partially mitigated after applying pytorch/ao#1629 but still at ~58GB while we expect around ~47GB. Experiment (rowwise float8 + compile) with pytorch/ao#1629 appliedbf16 + compile command:
row-wise fp8 + compile command:
This reproduced the peak memory descrepancy described in the PR. Looking at the memory snapshots side by side, one thing that stands out immediately is a big spike in memory at the start of the backward pass, when some buffers are allocated for logsoftmax backward. However, since this spike exists for both runs, I think it is expected: Next, I can see the peak memory of the forward pass, before backward even begins, is much higher for fp8 + compile than bf16 + compile. The bf16 forward pass seemst o peak around 32GB, vs the float8 forward pass seems to peak around 43GB. This difference aligns with the roughly 11GB difference in peak memory we are investigating (47GB vs 58GB). Looking closer, the discrepancy seems to be related to activations, not the model parameters or optimizer states. This is because I can see at the beginning of each forward pass before anything is allocated for activations, both bf16 and fp8 have about 12GB of model params and optimizer state allocated. To confirm if the discrepancy is in the activations, as a next step I'll use the memory profiler to categorize the allocations. |
Side note, I think I just discovered another bug with fp8 row-wise scaling + torch.compile (cc @vkuzo): When NGPU=2 or more, things run smoothly.
When NGPU=1, we get the following error:
Error is thrown here, looks like where _scaled_mm dispatches to cuBLAS. Perhaps when NGPU=2+ torch.compile is using an autogenerated triton kernel for the fp8 gemm but when NGPU=1 it's using cubLAS, which we have this guard in place for? |
I did a quick round of experiments with the debug model with greatly reduced size (dim=32, n_heads=16, layers=32), comparing bf16+compile and fp8+compile for all AC settings (none, full, selective per layer, selective per op) and found the ONLY configuration where peak memory was higher for fp8 than bf16 was selective per op (see table below). To me this indicates this issue may be related to #786 |
Ran experiments for a much smaller model (2 layers, dim 16, 1 head) and found memory discrepancy of 0.58GB vs 0.59 GB in bf16 vs fp8, so I increased the model size (32 layers, dim 4096, 16 heads) and found the memory discrepancy grew (73GB vs 93GB). Looking at the memory snapshots for the tiny model, I found a discrepancy in allocations during the beginning of the forward pass between bf16 + compile vs fp8 + compile, so I looked the stack trace for each of these tensors. On the left we see 3 empty_strided_cuda calls in the Triton kernel generated by inductor for bf16. As you can see even with 2 layers, there is already a memory discrepancy in these initial allocations at the beginning of each forward pass after the initial series of empty_strided_cuda calls, right under where the giant block above starts (71MB vs 72MB). I'm wondering if the size of the discrepancy will grow/compound with larger models. |
My guess on what this is:
|
Did some quick additional experiments comparing eager vs compile and confirmed the discrepancy applies to both - further narrowing down the issue to an interaction between selective per op AC and fp8 quantization. Currently these tests are using row-wise scaling, but given we saw a similar issue with tensor-wise scaling in #786 I'm inclined to think it isn't specific to row-wise scaling right now. I will confirm this with additional tests tomorrow. |
per-op SAC, layers==2, dim==1024, n_heads==1 is an interesting repro. Does this combination also reproduce with tensorwise scaling? How about layers==1? |
can we repro this outside of torchtitan, to narrow it down further? https://github.com/pytorch/ao/blob/main/benchmarks/float8/profile_linear_float8.py has some pre-canned launcher scripts you can quickly try on layers of various complexity (linear, norm -> linear -> activation, FFN, etc) |
although that script does not yet integrate with per-op SAC |
nm, 1 transformer block is already pretty small. At this point maybe we can jump into the logs and see if anything stands out? |
I wrote a script to produce minimal repro (with exported memory snapshot) using different configurations (N linear layers, FFN, etc), and optionally applying fsdp/compile/float8 rowwise quantization: #815 So far with 1 GPU:
Per op AC discards every 2nd matmul activation, but stores all outputs from these other ops. The 2 ops specific to low precision training seem to be |
I documented tests doing all permutations of the following configurations using a repro script I wrote #815: AC
Compile
Test types:
Distributed settings:
Model types:
Quantization
For ALL tests, the float8 run had strictly less peak memory than the bf16 run, EXCEPT:
The attention layer increase is negligible so i think the problem is the FFN, which is the most minimal layer that repros. |
AOT graphs for FFN + bf16 + fsdp + compile The fp8 aot graph logs are way longer (~600 lines vs ~150 lines) and it's hard to tell what could be problematic, it is just a long list of ops performed in each fwd/bwd pass. Naively, the 4x increase in operations seems it could indicate that more operator fusions were performed for the bf16 run than float8, which would lead to more activations being stored. |
The 1.11x memory increase in FFN isn't as high as the 1.23x increase in the full Llama3 8b model, so I ran a test with 10 transformer blocks instead of just 1, to see if the discrpeancy increases with model size, and it did not (18035.55 bf16 vs 19991.45 MB fp8 row-wise = ~1.11x increase) |
I analyzed the FFN + FSDP2 Triton kernels for bf16 and fp8 forward passes, and annotated them, tracking the total current cumulative size of buffers as they are allocated/deallocated. I also analyzed the total size of the buffers returned, which I assume includes what is saved for backward. Total size of buffers returned:
Peak memory utilized by buffers during FFN forward:
Next I plan to analyze the fp8 kernel to determine exactly what these huge tensors are that we are saving for backward. |
I analyzed the forward kernels for bf16 and fp8 row wise w/ FSDP, and I think the culprit is # RETURNED FROM FP8 rowwise forward
return (
reinterpret_tensor(
buf21, (1, 16, 4096), (65536, 4096, 1), 0
), # buf21:bf16=(16, 4096) => small (FFN output)
primals_1,
primals_3,
primals_4,
buf0, # buf0:bf16=(1, 16, 4096) => abs(input) => small
buf3, # buf3:fp32=(16, 1) => rowwise scales for inputs
buf7, # buf7:bf16=(16, 16384) => W1(x) => small
buf8, # buf8:bf16=(4096, 16384) => abs(W3) => huge (134,217,728 bytes)
) buf8 has shape (4096, 16384) so with dtype bf16 it is 134,217,728 bytes, which is 99.5% of the total memory of these extra buffers returned to save for backward (buf0, buf3, buf7, buf8). This is in contrast to the bf16 kernel, which does NOT return/store the abs(W3): # RETURNED FROM BF16 forward
return (
reinterpret_tensor(
buf3, (1, 16, 4096), (65536, 4096, 1), 0
), # FFN output => (1,16,4096) in bf16 = 131,072 bytes
primals_1,
primals_3,
primals_4,
buf0, # W1(x) => (16,16384) in bf16 = 524,288 bytes
) This seems to be what is causing the memory usage discrepancy, however, I don't yet know why autograd engine / inductor decide this needs to be saved for fp8 row-wise + FSDP. |
Looking at the AOT graphs for fp8 kernel, we can see Specifically, in the backward graph we can see abs4 is going through amax and other quantization steps here. |
In my repro script, commenting out the abs op from the Saving abs op: Not saving abs op |
Taking this one step further to see if the improvement translates from FFN to the full transformer model, I tried commenting out the abs op from torchtitan per op AC save list and rerunning the original experiment from the PR description and it virtually eliminated the extra peak memory, without negatively impacting TPS: bf16:
fp8 row-wise WITH storing abs() op output:
fp8 row-wise WITHOUT storing abs() op output:
I also discussed this with @soulitzer and he believes the RCA seems correct. He said when compile is enabled we can only save max() and not abs(), versus in eager mode saving abs() can make sense since there's no dead code elimination and thus it would need to be recomputed. I think one potential solution would be to conditionally decide whether to save abs() or not, depending on if compile is enabled or not. Need to look into how feasible that is. |
awesome analysis! I think me adding |
Sure, will do |
… increase in peak memory usage for float8 training with row-wise scaling (#820) This fixes a bug in how per op AC interacts with float8 training. The bug surfaced through unexpected increase in peak memory usage when integrating row-wise scaling into float8 training in #808. The full RCA can be in found in this thread starting [here](#808 (comment)). ## TL;DR - I [tested](#808 (comment)) different model layers and permutations of configurations and found the minimal repro was a FFN with per op AC + FSDP + compile + float8 row-wise scaling. - I [analyzed](#808 (comment)) the triton kernels generated for the compiled FFN forward pass in bf16 vs fp8, tracking peak cumulative memory allocated and what was saved for backward. I found the fp8 forward kernel was saving a huge buffer/tensor for backward which the bf16 kernel was not saving. - I [tracked](#808 (comment)) how the various buffers were used and determined this huge buffer saved for backwards was holding `abs(W3)`, where W3 is an **unsharded weight tensor.** - I [tested](#808 (comment)) removing `abs()` from the save list for per op AC, and confirmed it resolved the memory issue without impacting TPS. ## Benchmarks **float8 row-wise WITH storing abs() op output:** ``` [rank0]:2025-02-04 14:13:52,324 - root - INFO - step: 10 loss: 9.9765 memory: 58.55GiB(61.63%) tps: 6,379 mfu: 37.35% [rank0]:2025-02-04 14:14:05,145 - root - INFO - step: 20 loss: 8.3610 memory: 58.55GiB(61.63%) tps: 6,390 mfu: 37.42% [rank0]:2025-02-04 14:14:17,975 - root - INFO - step: 30 loss: 7.6860 memory: 58.55GiB(61.63%) tps: 6,386 mfu: 37.39% ``` **float8 row-wise WITHOUT storing abs() op output:** ``` [rank0]:2025-02-04 14:19:22,446 - root - INFO - step: 10 loss: 10.1946 memory: 47.77GiB(50.28%) tps: 6,293 mfu: 36.85% [rank0]:2025-02-04 14:19:35,190 - root - INFO - step: 20 loss: 8.4865 memory: 47.77GiB(50.28%) tps: 6,429 mfu: 37.64% [rank0]:2025-02-04 14:19:47,951 - root - INFO - step: 30 loss: 7.6982 memory: 47.77GiB(50.28%) tps: 6,420 mfu: 37.60% ``` I also confirmed this also does not cause a regression in memory usage or TPS with tensorwise scaling: **float8 tensorwise WITH storing abs() output:** ``` [rank0]:2025-02-04 15:06:53,589 - root - INFO - step: 10 loss: 9.8912 memory: 47.77GiB(50.28%) tps: 6,881 mfu: 40.30% [rank0]:2025-02-04 15:07:05,499 - root - INFO - step: 20 loss: 8.4239 memory: 47.77GiB(50.28%) tps: 6,879 mfu: 40.28% [rank0]:2025-02-04 15:07:17,412 - root - INFO - step: 30 loss: 7.6615 memory: 47.77GiB(50.28%) tps: 6,877 mfu: 40.27% ``` **float8 tensorwise WITHOUT storing abs() output:** ``` [rank0]:2025-02-04 15:08:32,527 - root - INFO - step: 10 loss: 9.9628 memory: 47.77GiB(50.28%) tps: 6,865 mfu: 40.20% [rank0]:2025-02-04 15:08:44,450 - root - INFO - step: 20 loss: 8.5451 memory: 47.77GiB(50.28%) tps: 6,871 mfu: 40.24% [rank0]:2025-02-04 15:08:56,383 - root - INFO - step: 30 loss: 7.8286 memory: 47.77GiB(50.28%) tps: 6,865 mfu: 40.20% ```
8ffb32c
to
b08786a
Compare
Summary: This is an example of how to call float8 training with rowwise scaling from torchao. TODO: finalize API in torchao, and finalize how we want to expose it in torchtitan, and optimize performance. ``` // baseline (bf16 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile ... step: 20 loss: 8.4931 memory: 47.65GiB(50.16%) tps: 5,760 mfu: 33.73% // experiment (rowwise float8 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile ... // torchao main branch step: 40 loss: 7.3818 memory: 66.81GiB(70.33%) tps: 6,412 mfu: 37.55% // torchao with pytorch/ao#1629 step: 20 loss: 8.3823 memory: 58.55GiB(61.63%) tps: 6,424 mfu: 37.62% // for comparison, tensorwise float8 with float8 all-gather (on main branch) with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp ... step: 20 loss: 8.4258 memory: 47.32GiB(49.81%) tps: 7,186 mfu: 42.08% ``` Test Plan: Reviewers: Subscribers: Tasks: Tags:
b08786a
to
fc86c6f
Compare
thank you, fixed lint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add more description on the recipes in the PR summary? E.g. what are all_axiswise
and lw_axiswise_with_gw_hp
, are they both rowwise scaled but different flavor?
could you also update the comments here
https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L145-L147
"--float8.recipe_name", | ||
type=str, | ||
default=None, | ||
help="If specified, creates float8 config from recipe name, choices are `all_axiswise` and `lw_axiswise_with_gw_hp`", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please:
- add a
choices
arg, similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L292 - add more info on the recipes, including references, e.g. if users are interested, where they can learn more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related - imo, this naming should be made more intuitive for users. I understand you want exact breakout within your code but users won't necessarily grasp it vs calling it in more direct terms.
@@ -42,29 +42,43 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |||
return | |||
try: | |||
from torchao.float8 import Float8LinearConfig | |||
|
|||
# we should update this code after torchao exposes this publically |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# we should update this code after torchao exposes this publically | |
# TODO: we should update this code after torchao exposes this publically |
to be honest, we don't have good writeups for this, and we hope to land something here to unblock us from running experiments to create this information. Is there a good way to land experimental features here (without such writeups), or is the recommendation to use unlanded code and only land ready-for-wide-usage features? |
In general, the feature should be ready or almost ready (put under the
I'd like to understand better what you need beyond 1 for this change. Is it because:
|
makes sense, thank you! There are two features in this PR:
I think a good way to go with this PR is If we need this landed before (b) for some reason, we can chat and figure it out then. |
That's very thoughtful! Looking forward to the API! |
Summary:
Exposes the float8 config from recipe name API from torchao, so we can use it to easily configure float8 with rowwise scaling from torchtitan.
Usage:
Example, pretraining LLaMa 3 8B on 8 H100s:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: