-
Notifications
You must be signed in to change notification settings - Fork 615
add configuration for float8 with rowwise scaling, via recipe lookup #808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
3ebdf05 to
8ffb32c
Compare
|
I'm looking into the unexpectedly high peak memory. As @vkuzo mentioned it is partially mitigated after applying pytorch/ao#1629 but still at ~58GB while we expect around ~47GB. Experiment (rowwise float8 + compile) with pytorch/ao#1629 appliedbf16 + compile command: row-wise fp8 + compile command: This reproduced the peak memory descrepancy described in the PR. Looking at the memory snapshots side by side, one thing that stands out immediately is a big spike in memory at the start of the backward pass, when some buffers are allocated for logsoftmax backward. However, since this spike exists for both runs, I think it is expected: Next, I can see the peak memory of the forward pass, before backward even begins, is much higher for fp8 + compile than bf16 + compile. The bf16 forward pass seemst o peak around 32GB, vs the float8 forward pass seems to peak around 43GB. This difference aligns with the roughly 11GB difference in peak memory we are investigating (47GB vs 58GB). Looking closer, the discrepancy seems to be related to activations, not the model parameters or optimizer states. This is because I can see at the beginning of each forward pass before anything is allocated for activations, both bf16 and fp8 have about 12GB of model params and optimizer state allocated. To confirm if the discrepancy is in the activations, as a next step I'll use the memory profiler to categorize the allocations. |
|
Side note, I think I just discovered another bug with fp8 row-wise scaling + torch.compile (cc @vkuzo): When NGPU=2 or more, things run smoothly. When NGPU=1, we get the following error: Error is thrown here, looks like where _scaled_mm dispatches to cuBLAS. Perhaps when NGPU=2+ torch.compile is using an autogenerated triton kernel for the fp8 gemm but when NGPU=1 it's using cubLAS, which we have this guard in place for? |
|
I did a quick round of experiments with the debug model with greatly reduced size (dim=32, n_heads=16, layers=32), comparing bf16+compile and fp8+compile for all AC settings (none, full, selective per layer, selective per op) and found the ONLY configuration where peak memory was higher for fp8 than bf16 was selective per op (see table below). To me this indicates this issue may be related to #786 |
|
Ran experiments for a much smaller model (2 layers, dim 16, 1 head) and found memory discrepancy of 0.58GB vs 0.59 GB in bf16 vs fp8, so I increased the model size (32 layers, dim 4096, 16 heads) and found the memory discrepancy grew (73GB vs 93GB). Looking at the memory snapshots for the tiny model, I found a discrepancy in allocations during the beginning of the forward pass between bf16 + compile vs fp8 + compile, so I looked the stack trace for each of these tensors. On the left we see 3 empty_strided_cuda calls in the Triton kernel generated by inductor for bf16. As you can see even with 2 layers, there is already a memory discrepancy in these initial allocations at the beginning of each forward pass after the initial series of empty_strided_cuda calls, right under where the giant block above starts (71MB vs 72MB). I'm wondering if the size of the discrepancy will grow/compound with larger models. |
My guess on what this is:
|
|
Did some quick additional experiments comparing eager vs compile and confirmed the discrepancy applies to both - further narrowing down the issue to an interaction between selective per op AC and fp8 quantization. Currently these tests are using row-wise scaling, but given we saw a similar issue with tensor-wise scaling in #786 I'm inclined to think it isn't specific to row-wise scaling right now. I will confirm this with additional tests tomorrow. |
|
per-op SAC, layers==2, dim==1024, n_heads==1 is an interesting repro. Does this combination also reproduce with tensorwise scaling? How about layers==1? |
|
can we repro this outside of torchtitan, to narrow it down further? https://github.com/pytorch/ao/blob/main/benchmarks/float8/profile_linear_float8.py has some pre-canned launcher scripts you can quickly try on layers of various complexity (linear, norm -> linear -> activation, FFN, etc) |
|
although that script does not yet integrate with per-op SAC |
|
nm, 1 transformer block is already pretty small. At this point maybe we can jump into the logs and see if anything stands out? |
|
I wrote a script to produce minimal repro (with exported memory snapshot) using different configurations (N linear layers, FFN, etc), and optionally applying fsdp/compile/float8 rowwise quantization: #815 So far with 1 GPU:
Per op AC discards every 2nd matmul activation, but stores all outputs from these other ops. The 2 ops specific to low precision training seem to be |
|
I documented tests doing all permutations of the following configurations using a repro script I wrote #815: AC
Compile
Test types:
Distributed settings:
Model types:
Quantization
For ALL tests, the float8 run had strictly less peak memory than the bf16 run, EXCEPT:
The attention layer increase is negligible so i think the problem is the FFN, which is the most minimal layer that repros. |
|
AOT graphs for FFN + bf16 + fsdp + compile The fp8 aot graph logs are way longer (~600 lines vs ~150 lines) and it's hard to tell what could be problematic, it is just a long list of ops performed in each fwd/bwd pass. Naively, the 4x increase in operations seems it could indicate that more operator fusions were performed for the bf16 run than float8, which would lead to more activations being stored. |
The 1.11x memory increase in FFN isn't as high as the 1.23x increase in the full Llama3 8b model, so I ran a test with 10 transformer blocks instead of just 1, to see if the discrpeancy increases with model size, and it did not (18035.55 bf16 vs 19991.45 MB fp8 row-wise = ~1.11x increase) |
|
I analyzed the FFN + FSDP2 Triton kernels for bf16 and fp8 forward passes, and annotated them, tracking the total current cumulative size of buffers as they are allocated/deallocated. I also analyzed the total size of the buffers returned, which I assume includes what is saved for backward. Total size of buffers returned:
Peak memory utilized by buffers during FFN forward:
Next I plan to analyze the fp8 kernel to determine exactly what these huge tensors are that we are saving for backward. |
I analyzed the forward kernels for bf16 and fp8 row wise w/ FSDP, and I think the culprit is # RETURNED FROM FP8 rowwise forward
return (
reinterpret_tensor(
buf21, (1, 16, 4096), (65536, 4096, 1), 0
), # buf21:bf16=(16, 4096) => small (FFN output)
primals_1,
primals_3,
primals_4,
buf0, # buf0:bf16=(1, 16, 4096) => abs(input) => small
buf3, # buf3:fp32=(16, 1) => rowwise scales for inputs
buf7, # buf7:bf16=(16, 16384) => W1(x) => small
buf8, # buf8:bf16=(4096, 16384) => abs(W3) => huge (134,217,728 bytes)
)buf8 has shape (4096, 16384) so with dtype bf16 it is 134,217,728 bytes, which is 99.5% of the total memory of these extra buffers returned to save for backward (buf0, buf3, buf7, buf8). This is in contrast to the bf16 kernel, which does NOT return/store the abs(W3): # RETURNED FROM BF16 forward
return (
reinterpret_tensor(
buf3, (1, 16, 4096), (65536, 4096, 1), 0
), # FFN output => (1,16,4096) in bf16 = 131,072 bytes
primals_1,
primals_3,
primals_4,
buf0, # W1(x) => (16,16384) in bf16 = 524,288 bytes
)This seems to be what is causing the memory usage discrepancy, however, I don't yet know why autograd engine / inductor decide this needs to be saved for fp8 row-wise + FSDP. |
|
Looking at the AOT graphs for fp8 kernel, we can see Specifically, in the backward graph we can see abs4 is going through amax and other quantization steps here. |
|
In my repro script, commenting out the abs op from the Saving abs op: Not saving abs op |
Taking this one step further to see if the improvement translates from FFN to the full transformer model, I tried commenting out the abs op from torchtitan per op AC save list and rerunning the original experiment from the PR description and it virtually eliminated the extra peak memory, without negatively impacting TPS: bf16: fp8 row-wise WITH storing abs() op output: fp8 row-wise WITHOUT storing abs() op output: I also discussed this with @soulitzer and he believes the RCA seems correct. He said when compile is enabled we can only save max() and not abs(), versus in eager mode saving abs() can make sense since there's no dead code elimination and thus it would need to be recomputed. I think one potential solution would be to conditionally decide whether to save abs() or not, depending on if compile is enabled or not. Need to look into how feasible that is. |
|
awesome analysis! I think me adding |
Sure, will do |
… increase in peak memory usage for float8 training with row-wise scaling (#820) This fixes a bug in how per op AC interacts with float8 training. The bug surfaced through unexpected increase in peak memory usage when integrating row-wise scaling into float8 training in #808. The full RCA can be in found in this thread starting [here](#808 (comment)). ## TL;DR - I [tested](#808 (comment)) different model layers and permutations of configurations and found the minimal repro was a FFN with per op AC + FSDP + compile + float8 row-wise scaling. - I [analyzed](#808 (comment)) the triton kernels generated for the compiled FFN forward pass in bf16 vs fp8, tracking peak cumulative memory allocated and what was saved for backward. I found the fp8 forward kernel was saving a huge buffer/tensor for backward which the bf16 kernel was not saving. - I [tracked](#808 (comment)) how the various buffers were used and determined this huge buffer saved for backwards was holding `abs(W3)`, where W3 is an **unsharded weight tensor.** - I [tested](#808 (comment)) removing `abs()` from the save list for per op AC, and confirmed it resolved the memory issue without impacting TPS. ## Benchmarks **float8 row-wise WITH storing abs() op output:** ``` [rank0]:2025-02-04 14:13:52,324 - root - INFO - step: 10 loss: 9.9765 memory: 58.55GiB(61.63%) tps: 6,379 mfu: 37.35% [rank0]:2025-02-04 14:14:05,145 - root - INFO - step: 20 loss: 8.3610 memory: 58.55GiB(61.63%) tps: 6,390 mfu: 37.42% [rank0]:2025-02-04 14:14:17,975 - root - INFO - step: 30 loss: 7.6860 memory: 58.55GiB(61.63%) tps: 6,386 mfu: 37.39% ``` **float8 row-wise WITHOUT storing abs() op output:** ``` [rank0]:2025-02-04 14:19:22,446 - root - INFO - step: 10 loss: 10.1946 memory: 47.77GiB(50.28%) tps: 6,293 mfu: 36.85% [rank0]:2025-02-04 14:19:35,190 - root - INFO - step: 20 loss: 8.4865 memory: 47.77GiB(50.28%) tps: 6,429 mfu: 37.64% [rank0]:2025-02-04 14:19:47,951 - root - INFO - step: 30 loss: 7.6982 memory: 47.77GiB(50.28%) tps: 6,420 mfu: 37.60% ``` I also confirmed this also does not cause a regression in memory usage or TPS with tensorwise scaling: **float8 tensorwise WITH storing abs() output:** ``` [rank0]:2025-02-04 15:06:53,589 - root - INFO - step: 10 loss: 9.8912 memory: 47.77GiB(50.28%) tps: 6,881 mfu: 40.30% [rank0]:2025-02-04 15:07:05,499 - root - INFO - step: 20 loss: 8.4239 memory: 47.77GiB(50.28%) tps: 6,879 mfu: 40.28% [rank0]:2025-02-04 15:07:17,412 - root - INFO - step: 30 loss: 7.6615 memory: 47.77GiB(50.28%) tps: 6,877 mfu: 40.27% ``` **float8 tensorwise WITHOUT storing abs() output:** ``` [rank0]:2025-02-04 15:08:32,527 - root - INFO - step: 10 loss: 9.9628 memory: 47.77GiB(50.28%) tps: 6,865 mfu: 40.20% [rank0]:2025-02-04 15:08:44,450 - root - INFO - step: 20 loss: 8.5451 memory: 47.77GiB(50.28%) tps: 6,871 mfu: 40.24% [rank0]:2025-02-04 15:08:56,383 - root - INFO - step: 30 loss: 7.8286 memory: 47.77GiB(50.28%) tps: 6,865 mfu: 40.20% ```
8ffb32c to
b08786a
Compare
In general, the feature should be ready or almost ready (put under the
I'd like to understand better what you need beyond 1 for this change. Is it because:
|
makes sense, thank you! There are two features in this PR:
I think a good way to go with this PR is If we need this landed before (b) for some reason, we can chat and figure it out then. |
That's very thoughtful! Looking forward to the API! |
Summary: 1. add a test for toy model + TP + float8 rowwise scaling training 2. fix underlying issues to make the test pass: a. add fast path for tensor view where the new shape is the same as old shape, for rowwise scaled float8 (this is needed for DTensor) b. modify the fake grad dependency workaround to work when grad is a DTensor Test Plan: 1. ./test/float8/test_everything.sh (one transient failure: https://www.internalfb.com/phabricator/paste/view/P1733103301) 2. verified that float8 rowwise scaling behaves sanely in torchtitan on LLaMa 3 8B on 8 H100s, with tp 2: ``` // requires pytorch/torchtitan#808 // baseline - bfloat16 + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:41:16,175 - root - INFO - step: 40 loss: 7.4240 memory: 35.56GiB(37.43%) tps: 1,669 mfu: 9.77% // float8 baseline - float8 tensorwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:44:07,806 - root - INFO - step: 40 loss: 7.4993 memory: 35.57GiB(37.44%) tps: 2,141 mfu: 12.54% // float8 rowwise without zero fake dep (for sanity) + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:47:51,400 - root - INFO - step: 40 loss: 7.3472 memory: 35.55GiB(37.42%) tps: 1,858 mfu: 10.88% // float8 rowwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:51:20,864 - root - INFO - step: 40 loss: 9.4211 memory: 35.55GiB(37.42%) tps: 1,820 mfu: 10.66% ``` Reviewers: Subscribers: Tasks: Tags:
fc86c6f to
8c41cba
Compare
Summary: 1. add a test for toy model + TP + float8 rowwise scaling training 2. fix underlying issues to make the test pass: a. add fast path for tensor view where the new shape is the same as old shape, for rowwise scaled float8 (this is needed for DTensor) b. modify the fake grad dependency workaround to work when grad is a DTensor Test Plan: 1. ./test/float8/test_everything.sh (one transient failure: https://www.internalfb.com/phabricator/paste/view/P1733103301) 2. verified that float8 rowwise scaling behaves sanely in torchtitan on LLaMa 3 8B on 8 H100s, with tp 2: ``` // requires pytorch/torchtitan#808 // baseline - bfloat16 + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:41:16,175 - root - INFO - step: 40 loss: 7.4240 memory: 35.56GiB(37.43%) tps: 1,669 mfu: 9.77% // float8 baseline - float8 tensorwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:44:07,806 - root - INFO - step: 40 loss: 7.4993 memory: 35.57GiB(37.44%) tps: 2,141 mfu: 12.54% // float8 rowwise without zero fake dep (for sanity) + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:47:51,400 - root - INFO - step: 40 loss: 7.3472 memory: 35.55GiB(37.42%) tps: 1,858 mfu: 10.88% // float8 rowwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:51:20,864 - root - INFO - step: 40 loss: 9.4211 memory: 35.55GiB(37.42%) tps: 1,820 mfu: 10.66% ``` Reviewers: Subscribers: Tasks: Tags:
Summary: 1. add a test for toy model + TP + float8 rowwise scaling training 2. fix underlying issues to make the test pass: a. add fast path for tensor view where the new shape is the same as old shape, for rowwise scaled float8 (this is needed for DTensor) b. modify the fake grad dependency workaround to work when grad is a DTensor Test Plan: 1. ./test/float8/test_everything.sh (one transient failure: https://www.internalfb.com/phabricator/paste/view/P1733103301) 2. verified that float8 rowwise scaling behaves sanely in torchtitan on LLaMa 3 8B on 8 H100s, with tp 2: ``` // requires pytorch/torchtitan#808 // baseline - bfloat16 + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:41:16,175 - root - INFO - step: 40 loss: 7.4240 memory: 35.56GiB(37.43%) tps: 1,669 mfu: 9.77% // float8 baseline - float8 tensorwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:44:07,806 - root - INFO - step: 40 loss: 7.4993 memory: 35.57GiB(37.44%) tps: 2,141 mfu: 12.54% // float8 rowwise without zero fake dep (for sanity) + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:47:51,400 - root - INFO - step: 40 loss: 7.3472 memory: 35.55GiB(37.42%) tps: 1,858 mfu: 10.88% // float8 rowwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:51:20,864 - root - INFO - step: 40 loss: 9.4211 memory: 35.55GiB(37.42%) tps: 1,820 mfu: 10.66% ``` Reviewers: Subscribers: Tasks: Tags:
8c41cba to
8edf33b
Compare
|
@tianyu-l we cleaned up the recipe lookup in torchao and I updated this PR to match, would love your review again |
| "rowwise", | ||
| "rowwise_with_gw_hp", | ||
| ) | ||
| enable_float8_tp = enable_float8_linear and not float8_is_rowwise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had some comments. Could you also rebase to main? We recently moved float8.py to torchtitan/components/float8.py
| float8_is_rowwise = job_config.float8.recipe_name in ( | ||
| "rowwise", | ||
| "rowwise_with_gw_hp", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since they are the only two valid options, can we do assertion instead of using in, unless you plan to add other options soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you explain what you mean by that? not sure I follow, foo in ("a", "b") seems like a clean way to enforce that foo is either a or b me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I was saying does it make sense to do
float8_is_rowwise = job_config.float8.recipe_name is not None
# and if you don't use `choices` to guard, we should
assert job_config.float8.recipe_name in ("rowwise", "rowwise_with_gw_hp")
if we only supports two rowwise recipes.
NVM if torchao plans to support other recipes soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, got it. Currently this PR is written as "support all recipes, including future ones", looks like you are asking for "only support recipes which have been tested in torchtitan". I agree with that feedback, let me change this PR to accomodate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float8_is_rowwise = job_config.float8.recipe_name is not None
I would prefer not do this though, as this forces the code reader to go look at "which recipes are valid" to understand what's going on.
assert job_config.float8.recipe_name in ("rowwise", "rowwise_with_gw_hp")
I like this as is displays the valid recipes here without making the code reader go look for them.
torchtitan/config_manager.py
Outdated
| "--float8.recipe_name", | ||
| type=str, | ||
| default=None, | ||
| help="If specified, creates float8 config from recipe name, valid choices are `rowwise` and `rowwise_with_gw_hp`", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add something like "if a valid recipe is specified, other float8 configs will be ignored"
| "rowwise", | ||
| "rowwise_with_gw_hp", | ||
| ) | ||
| enable_float8_tp = enable_float8_linear and not float8_is_rowwise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you share more on what happens / should happen when TP and rowwise scaling are used together. It looks from the change that the communication would still be in bf16, but matmul will be in rowwise fp8?
7a04053 to
abbc6f1
Compare
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please address inline comments before merging.
one question I have:
What changes if we use rowwise instead of tensorwise -- do FSDP param communications and TP activation communications happen in BF16? Is this because FP8 communication in those scenarios is temporarily not supported, or are there fundamental limitations?
| assert ( | ||
| not float8_config.force_recompute_fp8_weight_in_bwd | ||
| ), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported" | ||
| self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if recipe_name is not in ("rowwise", "rowwise_with_gw_hp")?
shall we use choices in config_manager.py to guard, similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L191
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will throw a detailed error message from torchao (https://github.com/pytorch/ao/blob/1ab1b77ad744115c3fee62716e7d2083c57d80a1/torchao/float8/config.py#L301) with valid choices displayed in the error message - this code is getting that for free
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
never mind, moved to "choices" like you suggested
| float8_is_rowwise = job_config.float8.recipe_name in ( | ||
| "rowwise", | ||
| "rowwise_with_gw_hp", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I was saying does it make sense to do
float8_is_rowwise = job_config.float8.recipe_name is not None
# and if you don't use `choices` to guard, we should
assert job_config.float8.recipe_name in ("rowwise", "rowwise_with_gw_hp")
if we only supports two rowwise recipes.
NVM if torchao plans to support other recipes soon.
abbc6f1 to
9773de6
Compare
Summary: This is an example of how to call float8 training with rowwise scaling from torchao. TODO: finalize API in torchao, and finalize how we want to expose it in torchtitan, and optimize performance. ``` // baseline (bf16 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile ... step: 20 loss: 8.4931 memory: 47.65GiB(50.16%) tps: 5,760 mfu: 33.73% // experiment (rowwise float8 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile ... // torchao main branch step: 40 loss: 7.3818 memory: 66.81GiB(70.33%) tps: 6,412 mfu: 37.55% // torchao with pytorch/ao#1629 step: 20 loss: 8.3823 memory: 58.55GiB(61.63%) tps: 6,424 mfu: 37.62% // for comparison, tensorwise float8 with float8 all-gather (on main branch) with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp ... step: 20 loss: 8.4258 memory: 47.32GiB(49.81%) tps: 7,186 mfu: 42.08% ``` Test Plan: Reviewers: Subscribers: Tasks: Tags:
9773de6 to
3537666
Compare
It's a fundamental limitation of rowwise: b is not relevant for tensorwise (the same scale can be applied on dim0 or dim1), so we can do float8 all-gather easily with a single scale + single float8 raw data tensor. With rowwise scaling, the benefit is a lot less clear as we'd need to send both dim0 and dim1 scaled float8 tensors, which is actually more bytes (including scales) than sending the bfloat16 tensor and casting to float8 after the all-gather. There could be some smaller wins here, but overall it's just not a clear+easy win with rowwise as it is with tensorwise. |
…ytorch#808) Summary: Exposes the float8 config from recipe name API from torchao, so we can use it to easily configure float8 with rowwise scaling from torchtitan. Usage: ``` with-proxy CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh \ --model.converters float8 \ --training.compile \ --float8.recipe_name rowwise ``` Example, pretraining LLaMa 3 8B on 8 H100s: ``` // baseline (bf16 + compile) > with-proxy CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --training.compile ... step: 20 loss: 8.4931 memory: 47.65GiB(50.16%) tps: 5,760 mfu: 33.73% // experiment (rowwise float8 + compile) > with-proxy CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --model.converters float8 --training.compile --float8.recipe_name rowwise ... step: 30 loss: 7.7109 memory: 47.77GiB(50.28%) tps: 6,468 mfu: 37.88% // for comparison, tensorwise float8 with float8 all-gather (on main branch) with-proxy CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --model.converters float8 --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp ... step: 20 loss: 8.4258 memory: 47.32GiB(49.81%) tps: 7,186 mfu: 42.08% ``` Test Plan: Reviewers: Subscribers: Tasks: Tags:
Summary: 1. add a test for toy model + TP + float8 rowwise scaling training 2. fix underlying issues to make the test pass: a. add fast path for tensor view where the new shape is the same as old shape, for rowwise scaled float8 (this is needed for DTensor) b. modify the fake grad dependency workaround to work when grad is a DTensor Test Plan: 1. ./test/float8/test_everything.sh (one transient failure: https://www.internalfb.com/phabricator/paste/view/P1733103301) 2. verified that float8 rowwise scaling behaves sanely in torchtitan on LLaMa 3 8B on 8 H100s, with tp 2: ``` // requires pytorch/torchtitan#808 // baseline - bfloat16 + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:41:16,175 - root - INFO - step: 40 loss: 7.4240 memory: 35.56GiB(37.43%) tps: 1,669 mfu: 9.77% // float8 baseline - float8 tensorwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile [rank0]:2025-02-14 13:44:07,806 - root - INFO - step: 40 loss: 7.4993 memory: 35.57GiB(37.44%) tps: 2,141 mfu: 12.54% // float8 rowwise without zero fake dep (for sanity) + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:47:51,400 - root - INFO - step: 40 loss: 7.3472 memory: 35.55GiB(37.42%) tps: 1,858 mfu: 10.88% // float8 rowwise + compile + tp 2 > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise [rank0]:2025-02-14 13:51:20,864 - root - INFO - step: 40 loss: 9.4211 memory: 35.55GiB(37.42%) tps: 1,820 mfu: 10.66% ``` Reviewers: Subscribers: Tasks: Tags:







Summary:
Exposes the float8 config from recipe name API from torchao, so we can use it to easily configure float8 with rowwise scaling from torchtitan.
Usage:
Example, pretraining LLaMa 3 8B on 8 H100s:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: