Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add configuration for float8 with rowwise scaling, via recipe lookup #808

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jan 27, 2025

Summary:

Exposes the float8 config from recipe name API from torchao, so we can use it to easily configure float8 with rowwise scaling from torchtitan.

Usage:

with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh \
    --float8.enable_float8_linear \
    --training.compile \
    --float8.recipe_name all_axiswise

Example, pretraining LLaMa 3 8B on 8 H100s:

// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.recipe_name all_axiswise
...
step: 30  loss:  7.7109  memory: 47.77GiB(50.28%)  tps: 6,468  mfu: 37.88%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 27, 2025
@vkuzo vkuzo force-pushed the 20250127_float8_rowwise_example branch from 3ebdf05 to 8ffb32c Compare January 27, 2025 21:54
@danielvegamyhre
Copy link
Contributor

I'm looking into the unexpectedly high peak memory. As @vkuzo mentioned it is partially mitigated after applying pytorch/ao#1629 but still at ~58GB while we expect around ~47GB.

Experiment (rowwise float8 + compile) with pytorch/ao#1629 applied

bf16 + compile command:

CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile --profiling.enable_memory_snapshot --profiling.save_memory_snapshot_folder memory_snapshot

...
[rank0]:2025-01-30 15:11:14,277 - root - INFO - step:  1  loss: 12.2750  memory: 40.15GiB(42.26%)  tps: 572  mfu: 3.35%
[rank0]:2025-01-30 15:11:14,277 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-01-30 15:11:27,104 - root - INFO - step: 10  loss:  9.9733  memory: 47.65GiB(50.16%)  tps: 5,748  mfu: 33.66%
[rank0]:2025-01-30 15:11:41,347 - root - INFO - step: 20  loss:  8.4896  memory: 47.65GiB(50.16%)  tps: 5,752  mfu: 33.68%

row-wise fp8 + compile command:

CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --profiling.enable_memory_snapshot --profiling.save_memory_snapshot_folder memory_snapshot

...

[rank0]:2025-01-30 14:12:48,407 - root - INFO - step:  1  loss: 12.2584  memory: 51.03GiB(53.71%)  tps: 639  mfu: 3.74%
[rank0]:2025-01-30 14:12:48,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-01-30 14:12:59,952 - root - INFO - step: 10  loss:  9.8484  memory: 58.55GiB(61.63%)  tps: 6,387  mfu: 37.40%
[rank0]:2025-01-30 14:13:12,775 - root - INFO - step: 20  loss:  8.4263  memory: 58.55GiB(61.63%)  tps: 6,388  mfu: 37.41%
[rank0]:2025-01-30 14:13:25,617 - root - INFO - step: 30  loss:  7.6880  memory: 58.55GiB(61.63%)  tps: 6,380  mfu: 37.36%
[rank0]:2025-01-30 14:13:38,479 - root - INFO - step: 40  loss:  7.3205  memory: 58.55GiB(61.63%)  tps: 6,370  mfu: 37.30%
[rank0]:2025-01-30 14:13:51,346 - root - INFO - step: 50  loss:  7.1534  memory: 58.55GiB(61.63%)  tps: 6,367  mfu: 37.28%
[rank0]:2025-01-30 14:14:04,206 - root - INFO - step: 60  loss:  6.9532  memory: 58.55GiB(61.63%)  tps: 6,371  mfu: 37.31%
[rank0]:2025-01-30 14:14:17,074 - root - INFO - step: 70  loss:  6.9997  memory: 58.55GiB(61.63%)  tps: 6,367  mfu: 37.28%
[rank0]:2025-01-30 14:14:29,950 - root - INFO - step: 80  loss:  6.6942  memory: 58.55GiB(61.63%)  tps: 6,363  mfu: 37.26%
[rank0]:2025-01-30 14:14:42,834 - root - INFO - step: 90  loss:  6.7342  memory: 58.55GiB(61.63%)  tps: 6,359  mfu: 37.24%
[rank0]:2025-01-30 14:14:57,323 - root - INFO - step: 100  loss:  6.4553  memory: 58.55GiB(61.63%)  tps: 5,654  mfu: 33.11%

This reproduced the peak memory descrepancy described in the PR.

Looking at the memory snapshots side by side, one thing that stands out immediately is a big spike in memory at the start of the backward pass, when some buffers are allocated for logsoftmax backward. However, since this spike exists for both runs, I think it is expected:

peaks

Next, I can see the peak memory of the forward pass, before backward even begins, is much higher for fp8 + compile than bf16 + compile. The bf16 forward pass seemst o peak around 32GB, vs the float8 forward pass seems to peak around 43GB. This difference aligns with the roughly 11GB difference in peak memory we are investigating (47GB vs 58GB).

top_of_mountain

Looking closer, the discrepancy seems to be related to activations, not the model parameters or optimizer states. This is because I can see at the beginning of each forward pass before anything is allocated for activations, both bf16 and fp8 have about 12GB of model params and optimizer state allocated.

12gb_start

To confirm if the discrepancy is in the activations, as a next step I'll use the memory profiler to categorize the allocations.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Jan 31, 2025

Side note, I think I just discovered another bug with fp8 row-wise scaling + torch.compile (cc @vkuzo):

When NGPU=2 or more, things run smoothly.

NGPU=2 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
[rank0]:2025-01-30 17:01:22,822 - root - INFO - step:  1  loss: 12.2662  memory: 77.85GiB(81.94%)  tps: 986  mfu: 5.78%
[rank0]:2025-01-30 17:01:22,823 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-01-30 17:01:39,017 - root - WARNING - 18 CUDA memory allocation retries.
[rank0]:2025-01-30 17:01:39,017 - root - INFO - step: 10  loss: 10.1890  memory: 93.32GiB(98.22%)  tps: 4,553  mfu: 26.66%

When NGPU=1, we get the following error:

NGPU=1 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile

...
    File "/home/danvm/.conda/envs/ao/lib/python3.13/site-packages/torch/_inductor/utils.py", line 2204, in run
      return model(new_inputs)
    File "/tmp/torchinductor_danvm/id/cido3o7oxjlzgvco4rlw2yfnalncojwlzvsfowmxsjejj5ppiduc.py", line 1331, in call
      extern_kernels._scaled_mm(buf5, buf6, buf4, buf7, out_dtype=torch.float32, use_fast_accum=True, out=buf8)
      ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  RuntimeError: Only bf16 high precsion output types are supported for row-wise scaling.

Error is thrown here, looks like where _scaled_mm dispatches to cuBLAS.

Perhaps when NGPU=2+ torch.compile is using an autogenerated triton kernel for the fp8 gemm but when NGPU=1 it's using cubLAS, which we have this guard in place for?

@danielvegamyhre
Copy link
Contributor

I did a quick round of experiments with the debug model with greatly reduced size (dim=32, n_heads=16, layers=32), comparing bf16+compile and fp8+compile for all AC settings (none, full, selective per layer, selective per op) and found the ONLY configuration where peak memory was higher for fp8 than bf16 was selective per op (see table below).

To me this indicates this issue may be related to #786

Screenshot 2025-01-30 at 5 13 09 PM

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Jan 31, 2025

Ran experiments for a much smaller model (2 layers, dim 16, 1 head) and found memory discrepancy of 0.58GB vs 0.59 GB in bf16 vs fp8, so I increased the model size (32 layers, dim 4096, 16 heads) and found the memory discrepancy grew (73GB vs 93GB).

Looking at the memory snapshots for the tiny model, I found a discrepancy in allocations during the beginning of the forward pass between bf16 + compile vs fp8 + compile, so I looked the stack trace for each of these tensors.

On the left we see 3 empty_strided_cuda calls in the Triton kernel generated by inductor for bf16.
On the right we see 6 empty_stride_cuda calls in the Triton kernel generated by inductor for fp8.

As you can see even with 2 layers, there is already a memory discrepancy in these initial allocations at the beginning of each forward pass after the initial series of empty_strided_cuda calls, right under where the giant block above starts (71MB vs 72MB). I'm wondering if the size of the discrepancy will grow/compound with larger models.

2layers

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 31, 2025

Side note, I think I just discovered another bug with fp8 row-wise scaling + torch.compile (cc @vkuzo):
RuntimeError: Only bf16 high precsion output types are supported for row-wise scaling.

My guess on what this is:

  1. the rowwise scaling kernel only support bfloat16 output (not float32)
  2. in torchtitan, the default logic is to turn on bfloat16 mixed precision scaling if FSDP is on, for >2 GPUs. For one GPU, it's possible that the weights + activations are in float32. If they are, this wouldn't work with the rowwise scaling kernel that we have.

@danielvegamyhre
Copy link
Contributor

Did some quick additional experiments comparing eager vs compile and confirmed the discrepancy applies to both - further narrowing down the issue to an interaction between selective per op AC and fp8 quantization. Currently these tests are using row-wise scaling, but given we saw a similar issue with tensor-wise scaling in #786 I'm inclined to think it isn't specific to row-wise scaling right now. I will confirm this with additional tests tomorrow.

Screenshot 2025-01-30 at 10 08 04 PM

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 31, 2025

per-op SAC, layers==2, dim==1024, n_heads==1 is an interesting repro. Does this combination also reproduce with tensorwise scaling? How about layers==1?

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Jan 31, 2025

per-op SAC, layers==2, dim==1024, n_heads==1 is an interesting repro. Does this combination also reproduce with tensorwise scaling? How about layers==1?

It reproduces with 1 layer. With tensorwise scaling, there is only minor differences, so this issue seems specific to per op SAC + fp8 row-wise.

Screenshot 2025-01-31 at 9 01 09 AM

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 31, 2025

can we repro this outside of torchtitan, to narrow it down further?

https://github.com/pytorch/ao/blob/main/benchmarks/float8/profile_linear_float8.py has some pre-canned launcher scripts you can quickly try on layers of various complexity (linear, norm -> linear -> activation, FFN, etc)

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 31, 2025

although that script does not yet integrate with per-op SAC

@vkuzo
Copy link
Contributor Author

vkuzo commented Jan 31, 2025

nm, 1 transformer block is already pretty small. At this point maybe we can jump into the logs and see if anything stands out?

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 1, 2025

I wrote a script to produce minimal repro (with exported memory snapshot) using different configurations (N linear layers, FFN, etc), and optionally applying fsdp/compile/float8 rowwise quantization: #815

So far with 1 GPU:

  • only linear layers did not repro
  • FFN with silu activation did not repro

Per op AC discards every 2nd matmul activation, but stores all outputs from these other ops.

The 2 ops specific to low precision training seem to be abs and max (source).

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 3, 2025

I documented tests doing all permutations of the following configurations using a repro script I wrote #815:

AC

  • All tests used selective per op AC

Compile

  • All tests used torch.compile

Test types:

  • e2e training loops with optimizer steps
  • single fwd+bwd without optimizer step

Distributed settings:

  • None (1 gpu)
  • FSDP2 on 2 GPUs

Model types:

  • single linear layer
  • FFN
  • attention layer
  • single transformer decoder block
  • 2 transformer decoder blocks

Quantization

  • None (bf16)
  • float8 with row-wise scaling

For ALL tests, the float8 run had strictly less peak memory than the bf16 run, EXCEPT:

  • FFN with distributed training / FSDP2 (1409 MB bf16 vs 1569 MB float8 row-wise) = 1.11x increase
  • Transformer blocks with distributed training / FSDP2 (1862.29 MB bf16 vs 2086.45 MB float8 row-wise) = 1.12x increase
  • Attention layer with distributed training / FSDP2 (517.13 MB bf16 vs 533.58 MB float8 row-wise) = 1.03x increase

The attention layer increase is negligible so i think the problem is the FFN, which is the most minimal layer that repros.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 3, 2025

AOT graphs for FFN + bf16 + fsdp + compile
AOT graphs for FFN + fp8 + fsdp + compile

The fp8 aot graph logs are way longer (~600 lines vs ~150 lines) and it's hard to tell what could be problematic, it is just a long list of ops performed in each fwd/bwd pass.

Naively, the 4x increase in operations seems it could indicate that more operator fusions were performed for the bf16 run than float8, which would lead to more activations being stored.

@danielvegamyhre
Copy link
Contributor

I documented tests doing all permutations of the following configurations using the following repro script #815:

AC

  • All tests used selective per op AC

Compile

  • All tests used torch.compile

Test types:

  • e2e training loops with optimizer steps
  • single fwd+bwd without optimizer step

Distributed settings:

  • None (1 gpu)
  • FSDP2 on 2 GPUs

Model types:

  • single linear layer
  • FFN (bf16 + fp8)
  • attention layer (bf16 + fp8)
  • single transformer decoder block (bf16 + fp8)
  • 2 transformer decoder blocks (bf16 + fp8)

Quantization

  • None (bf16)
  • float8 with row-wise scaling

For ALL tests, the float8 run had strictly less peak memory than the bf16 run, EXCEPT:

  • FFN with distributed training / FSDP2 (1409 MB bf16 vs 1569 MB float8 row-wise) = 1.11x increase
  • Transformer blocks with distributed training / FSDP2 (1862.29 MB bf16 vs 2086.45 MB float8 row-wise) = 1.12x increase
  • Attention layer with distributed training / FSDP2 (517.13 MB bf16 vs 533.58 MB float8 row-wise) = 1.03x increase

The attention layer increase is negligible so i think the problem is the FFN, which is the most minimal layer that repros.

The 1.11x memory increase in FFN isn't as high as the 1.23x increase in the full Llama3 8b model, so I ran a test with 10 transformer blocks instead of just 1, to see if the discrpeancy increases with model size, and it did not (18035.55 bf16 vs 19991.45 MB fp8 row-wise = ~1.11x increase)

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 4, 2025

I analyzed the FFN + FSDP2 Triton kernels for bf16 and fp8 forward passes, and annotated them, tracking the total current cumulative size of buffers as they are allocated/deallocated.

I also analyzed the total size of the buffers returned, which I assume includes what is saved for backward.

Total size of buffers returned:

  1. bf16: 655,360 bytes
  2. fp8 row-wise: 134,873,152 bytes (205x larger than bf16).

Peak memory utilized by buffers during FFN forward:

  1. bf16: 655,360 bytes
  2. fp8 row-wise: 269,353,088 bytes (411x larger than bf16).

Next I plan to analyze the fp8 kernel to determine exactly what these huge tensors are that we are saving for backward.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 4, 2025

I analyzed the FFN + FSDP2 Triton kernels for bf16 and fp8 forward passes, and annotated them, tracking the total current cumulative size of buffers as they are allocated/deallocated.

I also analyzed the total size of the buffers returned, which I assume includes what is saved for backward.

Total size of buffers returned:

  1. bf16: 655,360 bytes
  2. fp8 row-wise: 134,873,152 bytes (205x larger than bf16).

Peak memory utilized by buffers during FFN forward:

  1. bf16: 655,360 bytes
  2. fp8 row-wise: 269,353,088 bytes (411x larger than bf16).

Next I plan to analyze the fp8 kernel to determine exactly what these huge tensors are that we are saving for backward.

I analyzed the forward kernels for bf16 and fp8 row wise w/ FSDP, and I think the culprit is buf8 which holds abs(W3), where W3 is an unsharded weight matrix. This huge tensor is being returned and presumably stored for backward, and this is not the case for the bf16 kernel.

# RETURNED FROM FP8 rowwise forward

    return (
        reinterpret_tensor(
            buf21, (1, 16, 4096), (65536, 4096, 1), 0
        ),  # buf21:bf16=(16, 4096) => small (FFN output)
        primals_1,
        primals_3,
        primals_4,
        buf0,  # buf0:bf16=(1, 16, 4096)    => abs(input) => small
        buf3,  # buf3:fp32=(16, 1)          => rowwise scales for inputs
        buf7,  # buf7:bf16=(16, 16384)      => W1(x) => small
        buf8,  # buf8:bf16=(4096, 16384)    => abs(W3) =>  huge (134,217,728  bytes)
    )

buf8 has shape (4096, 16384) so with dtype bf16 it is 134,217,728 bytes, which is 99.5% of the total memory of these extra buffers returned to save for backward (buf0, buf3, buf7, buf8).

This is in contrast to the bf16 kernel, which does NOT return/store the abs(W3):

# RETURNED FROM BF16 forward

    return (
        reinterpret_tensor(
            buf3, (1, 16, 4096), (65536, 4096, 1), 0
        ),  # FFN output => (1,16,4096) in bf16 = 131,072 bytes
        primals_1,
        primals_3,
        primals_4,
        buf0,  # W1(x) => (16,16384) in bf16 = 524,288 bytes
    )

This seems to be what is causing the memory usage discrepancy, however, I don't yet know why autograd engine / inductor decide this needs to be saved for fp8 row-wise + FSDP.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 4, 2025

Looking at the AOT graphs for fp8 kernel, we can see abs_4 corresponds to the buf8 abs(weight) matrix being stored for backward as shown in my previous comment.

Specifically, in the backward graph we can see abs4 is going through amax and other quantization steps here.

@danielvegamyhre
Copy link
Contributor

danielvegamyhre commented Feb 4, 2025

In my repro script, commenting out the abs op from the _save_list for per op AC the peak memory discrepancy is much smaller:

Saving abs op:
bf16: 1409.38 MB
fp8: 1441.50 MB

Not saving abs op
bf16: 1409.38 MB
fp8: 1569.63 MB

@danielvegamyhre
Copy link
Contributor

In my repro script, commenting out the abs op from the _save_list for per op AC the peak memory discrepancy is much smaller:

Saving abs op: bf16: 1409.38 MB fp8: 1441.50 MB

Not saving abs op bf16: 1409.38 MB fp8: 1569.63 MB

Taking this one step further to see if the improvement translates from FFN to the full transformer model, I tried commenting out the abs op from torchtitan per op AC save list and rerunning the original experiment from the PR description and it virtually eliminated the extra peak memory, without negatively impacting TPS:

bf16:

NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh  -
-training.compile
...
[rank0]:2025-02-04 14:06:54,584 - root - INFO - step: 10  loss: 10.1837  memory: 47.65GiB(50.16%)  tps: 5,746  mfu: 33.65%

fp8 row-wise WITH storing abs() op output:

[rank0]:2025-02-04 14:13:52,324 - root - INFO - step: 10  loss:  9.9765  memory: 58.55GiB(61.63%)  tps: 6,379  mfu: 37.35%
[rank0]:2025-02-04 14:14:05,145 - root - INFO - step: 20  loss:  8.3610  memory: 58.55GiB(61.63%)  tps: 6,390  mfu: 37.42%
[rank0]:2025-02-04 14:14:17,975 - root - INFO - step: 30  loss:  7.6860  memory: 58.55GiB(61.63%)  tps: 6,386  mfu: 37.39%

fp8 row-wise WITHOUT storing abs() op output:

[rank0]:2025-02-04 14:19:22,446 - root - INFO - step: 10  loss: 10.1946  memory: 47.77GiB(50.28%)  tps: 6,293  mfu: 36.85%
[rank0]:2025-02-04 14:19:35,190 - root - INFO - step: 20  loss:  8.4865  memory: 47.77GiB(50.28%)  tps: 6,429  mfu: 37.64%
[rank0]:2025-02-04 14:19:47,951 - root - INFO - step: 30  loss:  7.6982  memory: 47.77GiB(50.28%)  tps: 6,420  mfu: 37.60%

I also discussed this with @soulitzer and he believes the RCA seems correct. He said when compile is enabled we can only save max() and not abs(), versus in eager mode saving abs() can make sense since there's no dead code elimination and thus it would need to be recomputed.

I think one potential solution would be to conditionally decide whether to save abs() or not, depending on if compile is enabled or not. Need to look into how feasible that is.

@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 4, 2025

awesome analysis!

I think me adding abs to that list was a mistake, as abs is a pointwise operation. @danielvegamyhre are you up for just deleting abs from the list, and verifying that memory usage does not regress for per-tensor recipes on 8 GPUs? Sorry about that :)

@danielvegamyhre
Copy link
Contributor

awesome analysis!

I think me adding abs to that list was a mistake, as abs is a pointwise operation. @danielvegamyhre are you up for just deleting abs from the list, and verifying that memory usage does not regress for per-tensor recipes on 8 GPUs? Sorry about that :)

Sure, will do

tianyu-l pushed a commit that referenced this pull request Feb 6, 2025
… increase in peak memory usage for float8 training with row-wise scaling (#820)

This fixes a bug in how per op AC interacts with float8 training. The
bug surfaced through unexpected increase in peak memory usage when
integrating row-wise scaling into float8 training in
#808. The full RCA can be in
found in this thread starting
[here](#808 (comment)).

## TL;DR
- I
[tested](#808 (comment))
different model layers and permutations of configurations and found the
minimal repro was a FFN with per op AC + FSDP + compile + float8
row-wise scaling.
- I
[analyzed](#808 (comment))
the triton kernels generated for the compiled FFN forward pass in bf16
vs fp8, tracking peak cumulative memory allocated and what was saved for
backward. I found the fp8 forward kernel was saving a huge buffer/tensor
for backward which the bf16 kernel was not saving.
- I
[tracked](#808 (comment))
how the various buffers were used and determined this huge buffer saved
for backwards was holding `abs(W3)`, where W3 is an **unsharded weight
tensor.**
- I
[tested](#808 (comment))
removing `abs()` from the save list for per op AC, and confirmed it
resolved the memory issue without impacting TPS.


## Benchmarks 

**float8 row-wise WITH storing abs() op output:**
```
[rank0]:2025-02-04 14:13:52,324 - root - INFO - step: 10  loss:  9.9765  memory: 58.55GiB(61.63%)  tps: 6,379  mfu: 37.35%
[rank0]:2025-02-04 14:14:05,145 - root - INFO - step: 20  loss:  8.3610  memory: 58.55GiB(61.63%)  tps: 6,390  mfu: 37.42%
[rank0]:2025-02-04 14:14:17,975 - root - INFO - step: 30  loss:  7.6860  memory: 58.55GiB(61.63%)  tps: 6,386  mfu: 37.39%
```

**float8 row-wise WITHOUT storing abs() op output:**
```
[rank0]:2025-02-04 14:19:22,446 - root - INFO - step: 10  loss: 10.1946  memory: 47.77GiB(50.28%)  tps: 6,293  mfu: 36.85%
[rank0]:2025-02-04 14:19:35,190 - root - INFO - step: 20  loss:  8.4865  memory: 47.77GiB(50.28%)  tps: 6,429  mfu: 37.64%
[rank0]:2025-02-04 14:19:47,951 - root - INFO - step: 30  loss:  7.6982  memory: 47.77GiB(50.28%)  tps: 6,420  mfu: 37.60%
```

I also confirmed this also does not cause a regression in memory usage
or TPS with tensorwise scaling:

**float8 tensorwise WITH storing abs() output:**
```
[rank0]:2025-02-04 15:06:53,589 - root - INFO - step: 10  loss:  9.8912  memory: 47.77GiB(50.28%)  tps: 6,881  mfu: 40.30%
[rank0]:2025-02-04 15:07:05,499 - root - INFO - step: 20  loss:  8.4239  memory: 47.77GiB(50.28%)  tps: 6,879  mfu: 40.28%
[rank0]:2025-02-04 15:07:17,412 - root - INFO - step: 30  loss:  7.6615  memory: 47.77GiB(50.28%)  tps: 6,877  mfu: 40.27%
```

**float8 tensorwise WITHOUT storing abs() output:**

```
[rank0]:2025-02-04 15:08:32,527 - root - INFO - step: 10  loss:  9.9628  memory: 47.77GiB(50.28%)  tps: 6,865  mfu: 40.20%
[rank0]:2025-02-04 15:08:44,450 - root - INFO - step: 20  loss:  8.5451  memory: 47.77GiB(50.28%)  tps: 6,871  mfu: 40.24%
[rank0]:2025-02-04 15:08:56,383 - root - INFO - step: 30  loss:  7.8286  memory: 47.77GiB(50.28%)  tps: 6,865  mfu: 40.20%
```
@vkuzo vkuzo changed the title [not for land yet] example of float8 with rowwise scaling add configuration for float8 with rowwise scaling, via recipe lookup Feb 7, 2025
@vkuzo vkuzo force-pushed the 20250127_float8_rowwise_example branch from 8ffb32c to b08786a Compare February 7, 2025 16:10
@vkuzo vkuzo requested review from tianyu-l and mori360 February 7, 2025 16:11
@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 7, 2025

@tianyu-l , @mori360 , any chance you could share the command to run the linter locally? I see lint fail in CI but it doesn't share what specifically is wrong.

@mori360
Copy link
Contributor

mori360 commented Feb 7, 2025

@tianyu-l , @mori360 , any chance you could share the command to run the linter locally? I see lint fail in CI but it doesn't share what specifically is wrong.

We use pre-commit the run the linter locally. You can find the doc here.

Summary:

This is an example of how to call float8 training with rowwise scaling
from torchao.

TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.

```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20  loss:  8.4931  memory: 47.65GiB(50.16%)  tps: 5,760  mfu: 33.73%

// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
// torchao main branch
step: 40  loss:  7.3818  memory: 66.81GiB(70.33%)  tps: 6,412  mfu: 37.55%
// torchao with pytorch/ao#1629
step: 20  loss:  8.3823  memory: 58.55GiB(61.63%)  tps: 6,424  mfu: 37.62%

// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20  loss:  8.4258  memory: 47.32GiB(49.81%)  tps: 7,186  mfu: 42.08%

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo force-pushed the 20250127_float8_rowwise_example branch from b08786a to fc86c6f Compare February 7, 2025 22:12
@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 7, 2025

thank you, fixed lint

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add more description on the recipes in the PR summary? E.g. what are all_axiswise and lw_axiswise_with_gw_hp, are they both rowwise scaled but different flavor?

could you also update the comments here
https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L145-L147

"--float8.recipe_name",
type=str,
default=None,
help="If specified, creates float8 config from recipe name, choices are `all_axiswise` and `lw_axiswise_with_gw_hp`",
Copy link
Contributor

@tianyu-l tianyu-l Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please:

  1. add a choices arg, similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L292
  2. add more info on the recipes, including references, e.g. if users are interested, where they can learn more

Copy link
Contributor

@lessw2020 lessw2020 Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related - imo, this naming should be made more intuitive for users. I understand you want exact breakout within your code but users won't necessarily grasp it vs calling it in more direct terms.

@@ -42,29 +42,43 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
return
try:
from torchao.float8 import Float8LinearConfig

# we should update this code after torchao exposes this publically
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# we should update this code after torchao exposes this publically
# TODO: we should update this code after torchao exposes this publically

@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 8, 2025

Could you add more description on the recipes in the PR summary? E.g. what are all_axiswise and lw_axiswise_with_gw_hp, are they both rowwise scaled but different flavor?

could you also update the comments here https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py#L145-L147

to be honest, we don't have good writeups for this, and we hope to land something here to unblock us from running experiments to create this information. Is there a good way to land experimental features here (without such writeups), or is the recommendation to use unlanded code and only land ready-for-wide-usage features?

@tianyu-l
Copy link
Contributor

tianyu-l commented Feb 9, 2025

Is there a good way to land experimental features here (without such writeups), or is the recommendation to use unlanded code and only land ready-for-wide-usage features?

In general, the feature should be ready or almost ready (put under the experimental section of configs) to land. For ease of exploration and experimentation, we are creating multiple (level of) extension points:

  1. Introducing a generic ModelHandler interface. #823 as you requested in should we have an extension point for model transforms out of tree? #790. It allows you to explore quantization in general without cloning torchtitan. E.g. you could have dependency on torchtitan in another repo.
  2. Add Dynamic Model Import and ModelSpec Definition #814 allows heavier change, e.g. completely switching the model and training techniques

I'd like to understand better what you need beyond 1 for this change. Is it because:

  • "This is a general way to enable recipe-based Float8 techniques, although the specific recipes are still under experiments. Sooner or later the recipes will be ready, with better naming." This is convincing, IMO. If that's the case, I think the PR itself is not experimental and can be landed -- it's just we need to be clear in the helper message that users in the world are not supposed to understand/use the config/recipes.
  • "Landing this change would enable more stable experiments with latest torchtitan" (otherwise, why not in a fork?)

@vkuzo
Copy link
Contributor Author

vkuzo commented Feb 10, 2025

In general, the feature should be ready or almost ready (put under the experimental section of configs) to land. For ease of exploration and experimentation, we are creating multiple (level of) extension points:

makes sense, thank you!

There are two features in this PR:

  1. float8 with rowwise scaling training recipe. This feature is ready for usage and we are looking for more real world signal.
  2. string name to recipe looking in torchao float8 training. This feature is something we definitely want in torchao, but we haven't finalized the design.

I think a good way to go with this PR is
a. leave it unlanded for now
b. design the proper recipe-by-name API in torchao (ETA: 2025Q1).
c. come back here and update this PR to link to the final version of (b), which will make it much easier to document and set expectations for users.

If we need this landed before (b) for some reason, we can chat and figure it out then.

@tianyu-l
Copy link
Contributor

I think a good way to go with this PR is
a. leave it unlanded for now
b. design the proper recipe-by-name API in torchao (ETA: 2025Q1).
c. come back here and update this PR to link to the final version of (b), which will make it much easier to document and set expectations for users.
If we need this landed before (b) for some reason, we can chat and figure it out then.

That's very thoughtful! Looking forward to the API!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants