Skip to content

Conversation

v0i0
Copy link
Contributor

@v0i0 v0i0 commented Oct 1, 2025

No description provided.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 1, 2025
@v0i0 v0i0 requested a review from yf225 October 6, 2025 23:31
@choijon5 choijon5 requested a review from Copilot October 6, 2025 23:31
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a custom flex attention kernel using Helion and PyTorch for efficient computation of scaled dot-product attention with support for static and dynamic input shapes.

  • Adds a complete flex attention kernel implementation in examples/flex_attention.py
  • Integrates the new example with the benchmarking system via benchmarks/run.py

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.

File Description
examples/flex_attention.py Complete implementation of flex attention kernel with Helion optimization and PyTorch compatibility
benchmarks/run.py Adds flex_attention to the benchmark operators mapping

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.


# iterate through partial tiles

if True:
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace if True: with a meaningful condition or comment explaining why this branch is always executed. This makes the code structure unclear.

Copilot uses AI. Check for mistakes.

qk = hl.dot(q_i, k.T, acc=qk)
qk = qk * scale
bcast_qk = qk[None, None, :, :]
score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n)
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented-out code or provide a comment explaining why it's temporarily disabled. Commented code affects readability and maintainability.

Suggested change
score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n)
score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n)
# The following masking code is temporarily disabled for debugging purposes.
# Re-enable if masking is required for your use case.

Copilot uses AI. Check for mistakes.


# figure out how many tiles there are here
for tile_n in hl.tile(start_n, end_n, block_size=block_n):
k = key[tile_b.begin, tile_h.begin // num_groups, tile_n, :]
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indexing logic tile_h.begin // num_groups is duplicated across multiple locations. Consider extracting this into a variable for better maintainability.

Copilot uses AI. Check for mistakes.

m_i = m_ij
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
v = value[tile_b.begin, tile_h.begin // num_groups, tile_n, :]
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indexing logic tile_h.begin // num_groups is duplicated across multiple locations. Consider extracting this into a variable for better maintainability.

Copilot uses AI. Check for mistakes.


# figure out how many tiles there are here
for tile_n in hl.tile(start_n, end_n, block_size=block_n):
k = key[tile_b.begin, tile_h.begin // num_groups, tile_n, :]
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indexing logic tile_h.begin // num_groups is duplicated across multiple locations. Consider extracting this into a variable for better maintainability.

Copilot uses AI. Check for mistakes.

m_i = m_ij
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
v = value[tile_b.begin, tile_h.begin // num_groups, tile_n, :]
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indexing logic tile_h.begin // num_groups is duplicated across multiple locations. Consider extracting this into a variable for better maintainability.

Copilot uses AI. Check for mistakes.

Comment on lines +92 to +125
for tile_n in hl.tile(start_n, end_n, block_size=block_n):
k = key[tile_b.begin, tile_h.begin // num_groups, tile_n, :]
bcast_b = (tile_b.begin + hl.arange(tile_b.block_size))[
:, None, None, None
]
bcast_h = (tile_h.begin + hl.arange(tile_h.block_size))[
None, :, None, None
]
bcast_m = (tile_m.begin + hl.arange(tile_m.block_size))[
None, None, :, None
]
bcast_n = (tile_n.begin + hl.arange(tile_n.block_size))[
None, None, None, :
]
qk = hl.zeros([tile_m, tile_n], dtype=torch.float32)
qk = hl.dot(q_i, k.T, acc=qk)
qk = qk * scale
bcast_qk = qk[None, None, :, :]
score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n)
# mask = block_mask_mask_mod(bcast_b, bcast_h, bcast_m, bcast_n)
# score = torch.where(mask, score, -float("inf"))
qk = score.squeeze(0).squeeze(0)

m_ij = torch.maximum(m_i, torch.amax(qk, -1))
qk = qk - m_ij[:, None]
p = torch.exp2(log_2_e * qk)
l_ij = torch.sum(p, -1)
alpha = torch.exp2(m_i - m_ij)
m_i = m_ij
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
v = value[tile_b.begin, tile_h.begin // num_groups, tile_n, :]
p = p.to(v.dtype)
acc = hl.dot(p, v, acc=acc)
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's significant code duplication between the two attention computation blocks (lines 92-125 and 144-177). Consider extracting this logic into a helper function to reduce duplication.

Copilot uses AI. Check for mistakes.

Comment on lines +144 to +177
for tile_n in hl.tile(start_n, end_n, block_size=block_n):
k = key[tile_b.begin, tile_h.begin // num_groups, tile_n, :]
bcast_b = (tile_b.begin + hl.arange(tile_b.block_size))[
:, None, None, None
]
bcast_h = (tile_h.begin + hl.arange(tile_h.block_size))[
None, :, None, None
]
bcast_m = (tile_m.begin + hl.arange(tile_m.block_size))[
None, None, :, None
]
bcast_n = (tile_n.begin + hl.arange(tile_n.block_size))[
None, None, None, :
]
qk = hl.zeros([tile_m, tile_n], dtype=torch.float32)
qk = hl.dot(q_i, k.T, acc=qk)
qk = qk * scale
bcast_qk = qk[None, None, :, :]
score = score_mod(bcast_qk, bcast_b, bcast_h, bcast_m, bcast_n)
mask = block_mask_mask_mod(bcast_b, bcast_h, bcast_m, bcast_n)
score = torch.where(mask, score, -float("inf"))
qk = score.squeeze(0).squeeze(0)

m_ij = torch.maximum(m_i, torch.amax(qk, -1))
qk = qk - m_ij[:, None]
p = torch.exp2(log_2_e * qk)
l_ij = torch.sum(p, -1)
alpha = torch.exp2(m_i - m_ij)
m_i = m_ij
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
v = value[tile_b.begin, tile_h.begin // num_groups, tile_n, :]
p = p.to(v.dtype)
acc = hl.dot(p, v, acc=acc)
Copy link

Copilot AI Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's significant code duplication between the two attention computation blocks (lines 92-125 and 144-177). Consider extracting this logic into a helper function to reduce duplication.

Copilot uses AI. Check for mistakes.

@oulgen
Copy link
Contributor

oulgen commented Oct 7, 2025

@v0i0 how does the perf look?

@yf225
Copy link
Contributor

yf225 commented Oct 7, 2025

@v0i0 how does the perf look?

yes curious the same too, and also wonder if we have results from tritonbench accuracy checks

@v0i0
Copy link
Contributor Author

v0i0 commented Oct 7, 2025

@oulgen @yf225 perf is currently absolutely atrocious, like 20x slowdown iirc. i wonder if triton hates the way the broadcast is done, or something else. do you think whatever you're doing for regular attn perf could apply here / what is that perf like?

@yf225
Copy link
Contributor

yf225 commented Oct 7, 2025

do you think whatever you're doing for regular attn perf could apply here / what is that perf like

yeah I believe it's still being worked on - maybe we will wait for regular attn perf first, and then we can update this PR to generate the same pattern for best perf

@jansel
Copy link
Contributor

jansel commented Oct 7, 2025

20x sounds like more of a perf bug in the kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants