-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Open
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
FP8 groupwise/blockwise GEMM support via cublas + CUDA 12.9 was added in #158037
However, it seems inductor still needs some updates to handle the new scale shapes that are allowed:
In eager mode, my bench script runs successfully.
With compile, I get an outdated error about the scale shapes, indicating only rowwise and tensorwise are supported:
E0820 17:59:02.818000 532144 site-packages/torch/_subclasses/fake_tensor.py:2757] [1/0] RuntimeError: Invalid scaling configuration. For tensorwise scaling, both scales should be scalar. For rowwise scaling, scale_a should be (16640, 1), scale_b should be (1, 5120). Got scale_a.size()=(16640, 64) and scale_b.size()=(64, 40)
Versions
This is with the latest torch.nightly.
Repro
- Clone https://github.com/pytorch/torchao
- Check out this PR: [fp8 blockwise] wrap triton quantization kernels in custom ops for torch.compile compatibility ao#2829
- Verify eager mode works
python benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py
- Repro compile failure:
python benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py --compile
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @alexsamardzic
Versions
pytorch-triton 3.4.0+gitf7888497 pypi_0 pypi
torch 2.9.0.dev20250820+cu129 pypi_0 pypi
Metadata
Metadata
Assignees
Labels
module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module