-
Notifications
You must be signed in to change notification settings - Fork 62
Re-implement FlashAttention with new Xe atoms #547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I will break up this large commit into self-contained smaller commits after review is complete. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this here? This isn't flash attention specific, is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's not. These started as some simple helpers to make copying to/from SLM easier for the epilogue. We could move them, maybe to include/cute/algorithm/cute.hpp
, though they should be made more sophisticated (use smaller/larger block sizes as appropriate, automatic fallback to scatter/gather, etc.).
// No diagnostics/error will be issued by the compiler if it is not. | ||
template <typename T> | ||
CUTE_HOST_DEVICE void | ||
set_wi_value(T &x, int i, T val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't you take i as compile time value to make this safer? The usage is on line 137 where the input comes from the unrolled loop index. If you replace the loop with for_each you have a compile time constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is an option -- I did it this way since compile-time unrolling of the loop is IMO harder to use and harder to read.
I opened a compiler ticket for the lack of diagnostics, and they have a patch under review now to address it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. As long as we have diagnostic that's fine. Current solution won't compile for O0. Not sure whether it matters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point on -O0. How important is it to support -O0 operation? Does the rest of CUTLASS work OK under -O0? (I know SYCL in general has had some functionality issues at -O0.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the SYCL issues with -O0 have been resolved. I'm not aware of a good reason to compile with -O0. Even for debug O1/Og tends to be better. I think other parts of CUTLASS are correct with O0. But IGC used to crash for larger kernels compiled at O0. Haven't tried it with more recent IGC versions.
This PR updates FlashAttention to the new copy/MMA atoms.
Changes:
Current status: prefill/decode examples almost all working, similar/better performance to old examples.
Known issues:
Additional features (causal masking, variable sequence lengths, etc.) to be added later.
Reminder: the new atoms require a very recent driver due to necessary IGC fixes/enhancements. Recommended version: ci-comp_igc-30613.