-
Notifications
You must be signed in to change notification settings - Fork 320
integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script #2785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2785
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit 3c78e2a with merge base 253d65a ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…cript stack-info: PR: #2785, branch: danielvegamyhre/stack/44
3489e41
to
fc4fe47
Compare
Do we actually need these triton kernels? It seems like scaled_mm is pretty universally, is there a valid case for the small shapes you mentioned previously I guess the other valid argument is that some people aren't using new enough triton to have this fp8 mm support |
No, scaled_mm is definitely better, I just wanted to keep the triton gemms so as a side-project for learning I can try optimizing them to match scaled_mm using additional features like TMA, etc. |
…cript stack-info: PR: #2785, branch: danielvegamyhre/stack/44
fc4fe47
to
8875c4f
Compare
…cript stack-info: PR: #2785, branch: danielvegamyhre/stack/44
8875c4f
to
07a35b6
Compare
…cript stack-info: PR: #2785, branch: danielvegamyhre/stack/44
07a35b6
to
ef6fc50
Compare
…cript stack-info: PR: #2785, branch: danielvegamyhre/stack/44
ef6fc50
to
3a81b75
Compare
…cript stack-info: PR: #2785, branch: danielvegamyhre/stack/44
3a81b75
to
f7a4a59
Compare
@@ -48,7 +48,7 @@ def _scaled_grouped_mm( | |||
""" | |||
# TODO: Remove logging once prototype is more mature. This is currently very useful for development and debugging. | |||
if scaling_type == MoEScalingType.FP8_ROWWISE: | |||
# print("Using fp8 rowwise scaled_grouped_mm") | |||
print("Using fp8 rowwise scaled_grouped_mm") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this shouldn't be commented out, this was mistakenly left in a previous PR. So fixing it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we shouldnt be printing though.. maybe logging but forsure not printing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I had this as logging originally but changed to print when I noticed the logs didn't show up in torchtitan training runs when I was trying to get ScaledGroupedMMTensor to work e2e, and wanted to focus on the task at hand rather than debug whatever torchtitan is doing with the logging module.
I switched back to logging statements in this PR: #2835
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
…cript stack-info: PR: #2785, branch: danielvegamyhre/stack/44
f7a4a59
to
3c78e2a
Compare
Stacked PRs:
integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script
Summary
use_triton
flag to Float8BlockwiseLinear` that defaults to False. When True, use triton for gemms. When False, use torch._scaled_mm for gemms.Benchmarks
Benchmarking an eager linear forward + backward pass with llama4 shapes, using torch._scaled_mm improves perf by ~2x but still slower than bf16. Next step is to look at the trace to see which quantization kernels are slowing things down.