-
Notifications
You must be signed in to change notification settings - Fork 36
[example] flex attention #764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements a custom flex attention kernel using Helion and PyTorch for efficient computation of scaled dot-product attention with support for static and dynamic input shapes.
- Adds a complete flex attention kernel implementation in
examples/flex_attention.py
- Integrates the new example with the benchmarking system via
benchmarks/run.py
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.
File | Description |
---|---|
examples/flex_attention.py | Complete implementation of flex attention kernel with Helion optimization and PyTorch compatibility |
benchmarks/run.py | Adds flex_attention to the benchmark operators mapping |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
|
||
# iterate through partial tiles | ||
|
||
if True: |
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace if True:
with a meaningful condition or comment explaining why this branch is always executed. This makes the code structure unclear.
Copilot uses AI. Check for mistakes.
qk = hl.dot(q_i, k.T, acc=qk) | ||
qk = qk * scale | ||
bcast_qk = qk[None, None, :, :] | ||
score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n) |
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove commented-out code or provide a comment explaining why it's temporarily disabled. Commented code affects readability and maintainability.
score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n) | |
score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n) | |
# The following masking code is temporarily disabled for debugging purposes. | |
# Re-enable if masking is required for your use case. |
Copilot uses AI. Check for mistakes.
|
||
# figure out how many tiles there are here | ||
for tile_n in hl.tile(start_n, end_n, block_size=block_n): | ||
k = key[tile_b.begin, tile_h.begin // num_groups, tile_n, :] |
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing logic tile_h.begin // num_groups
is duplicated across multiple locations. Consider extracting this into a variable for better maintainability.
Copilot uses AI. Check for mistakes.
m_i = m_ij | ||
l_i = l_i * alpha + l_ij | ||
acc = acc * alpha[:, None] | ||
v = value[tile_b.begin, tile_h.begin // num_groups, tile_n, :] |
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing logic tile_h.begin // num_groups
is duplicated across multiple locations. Consider extracting this into a variable for better maintainability.
Copilot uses AI. Check for mistakes.
|
||
# figure out how many tiles there are here | ||
for tile_n in hl.tile(start_n, end_n, block_size=block_n): | ||
k = key[tile_b.begin, tile_h.begin // num_groups, tile_n, :] |
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing logic tile_h.begin // num_groups
is duplicated across multiple locations. Consider extracting this into a variable for better maintainability.
Copilot uses AI. Check for mistakes.
m_i = m_ij | ||
l_i = l_i * alpha + l_ij | ||
acc = acc * alpha[:, None] | ||
v = value[tile_b.begin, tile_h.begin // num_groups, tile_n, :] |
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing logic tile_h.begin // num_groups
is duplicated across multiple locations. Consider extracting this into a variable for better maintainability.
Copilot uses AI. Check for mistakes.
for tile_n in hl.tile(start_n, end_n, block_size=block_n): | ||
k = key[tile_b.begin, tile_h.begin // num_groups, tile_n, :] | ||
bcast_b = (tile_b.begin + hl.arange(tile_b.block_size))[ | ||
:, None, None, None | ||
] | ||
bcast_h = (tile_h.begin + hl.arange(tile_h.block_size))[ | ||
None, :, None, None | ||
] | ||
bcast_m = (tile_m.begin + hl.arange(tile_m.block_size))[ | ||
None, None, :, None | ||
] | ||
bcast_n = (tile_n.begin + hl.arange(tile_n.block_size))[ | ||
None, None, None, : | ||
] | ||
qk = hl.zeros([tile_m, tile_n], dtype=torch.float32) | ||
qk = hl.dot(q_i, k.T, acc=qk) | ||
qk = qk * scale | ||
bcast_qk = qk[None, None, :, :] | ||
score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n) | ||
# mask = block_mask_mask_mod(bcast_b, bcast_h, bcast_m, bcast_n) | ||
# score = torch.where(mask, score, -float("inf")) | ||
qk = score.squeeze(0).squeeze(0) | ||
|
||
m_ij = torch.maximum(m_i, torch.amax(qk, -1)) | ||
qk = qk - m_ij[:, None] | ||
p = torch.exp2(log_2_e * qk) | ||
l_ij = torch.sum(p, -1) | ||
alpha = torch.exp2(m_i - m_ij) | ||
m_i = m_ij | ||
l_i = l_i * alpha + l_ij | ||
acc = acc * alpha[:, None] | ||
v = value[tile_b.begin, tile_h.begin // num_groups, tile_n, :] | ||
p = p.to(v.dtype) | ||
acc = hl.dot(p, v, acc=acc) |
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's significant code duplication between the two attention computation blocks (lines 92-125 and 144-177). Consider extracting this logic into a helper function to reduce duplication.
Copilot uses AI. Check for mistakes.
for tile_n in hl.tile(start_n, end_n, block_size=block_n): | ||
k = key[tile_b.begin, tile_h.begin // num_groups, tile_n, :] | ||
bcast_b = (tile_b.begin + hl.arange(tile_b.block_size))[ | ||
:, None, None, None | ||
] | ||
bcast_h = (tile_h.begin + hl.arange(tile_h.block_size))[ | ||
None, :, None, None | ||
] | ||
bcast_m = (tile_m.begin + hl.arange(tile_m.block_size))[ | ||
None, None, :, None | ||
] | ||
bcast_n = (tile_n.begin + hl.arange(tile_n.block_size))[ | ||
None, None, None, : | ||
] | ||
qk = hl.zeros([tile_m, tile_n], dtype=torch.float32) | ||
qk = hl.dot(q_i, k.T, acc=qk) | ||
qk = qk * scale | ||
bcast_qk = qk[None, None, :, :] | ||
score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n) | ||
mask = block_mask_mask_mod(bcast_b, bcast_h, bcast_m, bcast_n) | ||
score = torch.where(mask, score, -float("inf")) | ||
qk = score.squeeze(0).squeeze(0) | ||
|
||
m_ij = torch.maximum(m_i, torch.amax(qk, -1)) | ||
qk = qk - m_ij[:, None] | ||
p = torch.exp2(log_2_e * qk) | ||
l_ij = torch.sum(p, -1) | ||
alpha = torch.exp2(m_i - m_ij) | ||
m_i = m_ij | ||
l_i = l_i * alpha + l_ij | ||
acc = acc * alpha[:, None] | ||
v = value[tile_b.begin, tile_h.begin // num_groups, tile_n, :] | ||
p = p.to(v.dtype) | ||
acc = hl.dot(p, v, acc=acc) |
Copilot
AI
Oct 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's significant code duplication between the two attention computation blocks (lines 92-125 and 144-177). Consider extracting this logic into a helper function to reduce duplication.
Copilot uses AI. Check for mistakes.
@v0i0 how does the perf look? |
yes curious the same too, and also wonder if we have results from tritonbench accuracy checks |
yeah I believe it's still being worked on - maybe we will wait for regular attn perf first, and then we can update this PR to generate the same pattern for best perf |
20x sounds like more of a perf bug in the kernel. |
No description provided.