-
Notifications
You must be signed in to change notification settings - Fork 36
Open
Labels
Description
Raised by @v0i0 - this would be helpful for consolidating common logic for attention variants, e.g. in #764.
Example:
from __future__ import annotations
import torch
import helion
import helion.language as hl
from helion.language import Tile
# TODO: we could add some decorator here to specifically say that "this is a Helion device loop"
# e.g. `@helion.device_loop()`
def inner_device_loop(tile: Tile, x_chunk: torch.Tensor, y_chunk: torch.Tensor) -> torch.Tensor:
"""Device helper that performs its own hl.tile iteration."""
tmp = torch.empty_like(x_chunk)
# Second-level device loop: iterate over the elements owned by ``tile``
for local_tile in hl.tile(tile.block_size, block_size=32):
tmp[local_tile] = x_chunk[local_tile] + y_chunk[local_tile]
return tmp
@helion.kernel()
def nested_device_loops(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Outer kernel that delegates a chunk of work to ``inner_device_loop``."""
assert x.shape == y.shape
out = torch.empty_like(x)
# First-level device loop tiles the full iteration space.
for tile in hl.tile(x.numel(), block_size=128):
x_chunk = x[tile]
y_chunk = y[tile]
# Call into a helper that contains another device loop.
out[tile] = inner_device_loop(tile, x_chunk, y_chunk)
return out
def main() -> None:
if not torch.cuda.is_available():
raise RuntimeError("This example expects a CUDA-capable device.")
size = 1 << 12
x = torch.randn(size, device="cuda", dtype=torch.float32)
y = torch.randn(size, device="cuda", dtype=torch.float32)
out = nested_device_loops(x, y)
torch.testing.assert_close(out, x + y)
if __name__ == "__main__":
main()