Skip to content

integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script #2785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Aug 17, 2025

Stacked PRs:


integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script

Summary

  • Add use_triton flag to Float8BlockwiseLinear` that defaults to False. When True, use triton for gemms. When False, use torch._scaled_mm for gemms.
  • Compute scales in column major format, for compatibility with torch._scaled_mm.
  • Add some autotuner configs for IO bounds dynamic quantization kernels.
  • Rename triton kernels to be prefixed with "triton_" for clarity
  • Write out reciprocal scales directly from triton quant kernels, rather than writing regular scales then doing 1.0/scales to pass in reciprocals to GEMM for rescaling outputs. This is necessary because in some cases the division op (1.0/scales) changes the memory layout from col major to row major, which triggers assertions on mem layout

Benchmarks

Benchmarking an eager linear forward + backward pass with llama4 shapes, using torch._scaled_mm improves perf by ~2x but still slower than bf16. Next step is to look at the trace to see which quantization kernels are slowing things down.

    M     N     K  out_dtype         bf16_mm_linear_us    fp8_triton_linear_us    fp8_scaled_mm_linear_us
-----  ----  ----  --------------  -------------------  ----------------------  -------------------------
16640  5120  8192  torch.bfloat16              5349.79                 12643.5                    7746.98
16640  8192  5120  torch.bfloat16              6100.62                 12912.6                    7783.31

Copy link

pytorch-bot bot commented Aug 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2785

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 3c78e2a with merge base 253d65a (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Aug 17, 2025
…cript

stack-info: PR: #2785, branch: danielvegamyhre/stack/44
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/44 branch from 3489e41 to fc4fe47 Compare August 17, 2025 16:17
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 17, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft August 18, 2025 15:32
@danielvegamyhre danielvegamyhre changed the title integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script [WIP] integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script Aug 18, 2025
@danielvegamyhre danielvegamyhre changed the title [WIP] integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script [fp8 blockwise] integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script Aug 20, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review August 20, 2025 00:59
@drisspg
Copy link
Contributor

drisspg commented Aug 20, 2025

Do we actually need these triton kernels? It seems like scaled_mm is pretty universally, is there a valid case for the small shapes you mentioned previously

I guess the other valid argument is that some people aren't using new enough triton to have this fp8 mm support

@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Aug 20, 2025

Do we actually need these triton kernels? It seems like scaled_mm is pretty universally, is there a valid case for the small shapes you mentioned previously

I guess the other valid argument is that some people aren't using new enough triton to have this fp8 mm support

No, scaled_mm is definitely better, I just wanted to keep the triton gemms so as a side-project for learning I can try optimizing them to match scaled_mm using additional features like TMA, etc.

@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/43 to main August 20, 2025 17:24
danielvegamyhre added a commit that referenced this pull request Aug 20, 2025
…cript

stack-info: PR: #2785, branch: danielvegamyhre/stack/44
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/44 branch from fc4fe47 to 8875c4f Compare August 20, 2025 17:24
@danielvegamyhre danielvegamyhre changed the title [fp8 blockwise] integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script Aug 20, 2025
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/43 August 20, 2025 17:24
danielvegamyhre added a commit that referenced this pull request Aug 20, 2025
…cript

stack-info: PR: #2785, branch: danielvegamyhre/stack/44
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/44 branch from 8875c4f to 07a35b6 Compare August 20, 2025 17:30
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/43 to main August 20, 2025 17:30
danielvegamyhre added a commit that referenced this pull request Aug 20, 2025
…cript

stack-info: PR: #2785, branch: danielvegamyhre/stack/44
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/44 branch from 07a35b6 to ef6fc50 Compare August 20, 2025 20:35
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 20, 2025
@danielvegamyhre danielvegamyhre changed the title integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script [fp8 blockwise] integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script Aug 20, 2025
danielvegamyhre added a commit that referenced this pull request Aug 20, 2025
…cript

stack-info: PR: #2785, branch: danielvegamyhre/stack/44
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/44 branch from ef6fc50 to 3a81b75 Compare August 20, 2025 21:21
@danielvegamyhre danielvegamyhre changed the title [fp8 blockwise] integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script Aug 20, 2025
@danielvegamyhre danielvegamyhre changed the title integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script [fp8 blockwise] integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script Aug 20, 2025
danielvegamyhre added a commit that referenced this pull request Aug 20, 2025
…cript

stack-info: PR: #2785, branch: danielvegamyhre/stack/44
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/44 branch from 3a81b75 to f7a4a59 Compare August 20, 2025 21:25
@danielvegamyhre danielvegamyhre changed the title [fp8 blockwise] integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script Aug 20, 2025
@@ -48,7 +48,7 @@ def _scaled_grouped_mm(
"""
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging.
if scaling_type == MoEScalingType.FP8_ROWWISE:
# print("Using fp8 rowwise scaled_grouped_mm")
print("Using fp8 rowwise scaled_grouped_mm")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this shouldn't be commented out, this was mistakenly left in a previous PR. So fixing it here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldnt be printing though.. maybe logging but forsure not printing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I had this as logging originally but changed to print when I noticed the logs didn't show up in torchtitan training runs when I was trying to get ScaledGroupedMMTensor to work e2e, and wanted to focus on the task at hand rather than debug whatever torchtitan is doing with the logging module.

I switched back to logging statements in this PR: #2835

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg I landed #2835 and rebased on top of main, so the print statements are logging now.

@danielvegamyhre danielvegamyhre changed the title integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script [fp8 blockwise] integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script Aug 21, 2025
@danielvegamyhre danielvegamyhre changed the title [fp8 blockwise] integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script Aug 21, 2025
…cript

stack-info: PR: #2785, branch: danielvegamyhre/stack/44
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/44 branch from f7a4a59 to 3c78e2a Compare August 22, 2025 21:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants