Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion helion/_compiler/generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,16 @@ def has_mask(self) -> bool:
)


def generate_ast(func: HostFunction, config: Config) -> ast.AST:
def emit_main_def() -> ast.stmt:
return statement_from_string("""
if __name__ == "__main__":
call()
""")


def generate_ast(
func: HostFunction, config: Config, emit_repro_caller: bool
) -> ast.AST:
with func:
codegen = GenerateAST(func, config)
with codegen.device_function:
Expand All @@ -386,12 +395,20 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
codegen.host_dead_code_elimination()
host_def = func.codegen_function_def(codegen.host_statements)

call_def = []
main_def = []
if emit_repro_caller:
call_def = [func.codegen_call_function()]
main_def = [emit_main_def()]

result = ast.Module(
[
*func.codegen_imports(),
*codegen.device_function.codegen_helper_functions(),
*kernel_def,
host_def,
*call_def,
*main_def,
],
[],
)
Expand Down
51 changes: 50 additions & 1 deletion helion/_compiler/host_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import sympy
import torch
from torch._inductor.codegen.wrapper import pexpr
from torch.utils._sympy.symbol import SymT
from torch.utils._sympy.symbol import symbol_is_type

from .. import exc
from . import ast_extension
Expand Down Expand Up @@ -98,14 +100,17 @@ def __init__(
self.args: ast.arguments = root.args
self.body: list[ast.stmt] = root.body

self.params = inspect.signature(fn).bind(*fake_args)
self.params.apply_defaults()

HostFunction.validate_ast(root)

from .device_ir import lower_to_device_ir
from .static_loop_unroller import unroll_static_loops
from .type_propagation import propagate_types

unroll_static_loops(self)
propagate_types(self, fake_args)
propagate_types(self)
env.finalize_config_spec()
self.device_ir = lower_to_device_ir(self)

Expand Down Expand Up @@ -246,6 +251,50 @@ def codegen_function_def(self, statements: list[ast.AST]) -> ast.FunctionDef:
type_params=None,
)

def codegen_call_function(self) -> ast.FunctionDef:
def stringify(arg: object) -> str:
if isinstance(arg, (list, tuple)):
parts = [stringify(a) for a in arg]
return f"({','.join(parts)},)"
if isinstance(arg, str):
return f'"{arg}"'
if isinstance(arg, torch.SymInt):
return str(CompileEnvironment.current().size_hint(arg))
if isinstance(arg, torch.SymFloat):
if symbol_is_type(arg.node.expr, SymT.UNBACKED_FLOAT):
return "1.1"
return str(arg.node._hint)
if isinstance(arg, torch.SymBool):
if not arg.node._hint:
return "False"
return str(arg.node._hint)
return str(arg)

inits = []
for name, arg in self.params.arguments.items():
if isinstance(arg, torch.Tensor):
rhs = f"rand_strided(size={stringify(arg.size())}, stride={stringify(arg.stride())}, dtype={arg.dtype}, device='{arg.device}')"
else:
rhs = stringify(arg)
inits.append(statement_from_string(f"{name} = {rhs}"))

call_args = self.params.arguments.keys()
statements = [
statement_from_string("from torch._dynamo.testing import rand_strided"),
*inits,
statement_from_string(f"{self.name}({', '.join(call_args)})"),
]
return ast_extension.create(
ast.FunctionDef,
name="call",
args=[],
body=statements,
decorator_list=[],
type_comment=None,
returns=None,
type_params=None,
)

def codegen_imports(self) -> list[ast.stmt]:
return [
statement_from_string(line.codegen())
Expand Down
7 changes: 2 additions & 5 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import contextlib
import dataclasses
import functools
import inspect
import re
import types
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -2264,14 +2263,12 @@ def _to_proxy(arg: TypeInfo) -> object:
raise exc.TracedArgNotSupported(arg) from None


def propagate_types(func: HostFunction, fake_args: list[object]) -> None:
def propagate_types(func: HostFunction) -> None:
# Lock needed since patch.object(torch.SymInt.__index__, ...) is not thread safe
with compile_lock, func, enable_python_dispatcher():
global_scope = GlobalScope(function=func)
local_scope = LocalScope(parent=global_scope)
params = inspect.signature(func.fn).bind(*fake_args)
params.apply_defaults()
for name, value in params.arguments.items():
for name, value in func.params.arguments.items():
# TODO(jansel): handle specializations/constexpr
type_info = TypeInfo.from_example(
value,
Expand Down
11 changes: 10 additions & 1 deletion helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@


DEVICE = torch.device("cuda")
EXAMPLES_DIR: Path = Path(__file__).parent.parent / "examples"
PROJECT_ROOT: Path = Path(__file__).parent.parent
EXAMPLES_DIR: Path = PROJECT_ROOT / "examples"


def skipIfRefEager(reason: str) -> Callable[[Callable], Callable]:
Expand Down Expand Up @@ -518,6 +519,13 @@ def normalize_tensor_descriptors(code: str) -> str:
get_tensor_descriptor_fn_name(), "tl.make_tensor_descriptor"
)

@staticmethod
def normalize_device_name(code: str) -> str:
"""
convert device='cuda:0' etc to device=DEVICE
"""
return re.sub(r"device\s*=\s*['\"][^'\"]+['\"]", "device=DEVICE", code)

def lookup(self, test_id: str, value: str) -> tuple[str, str]:
test_id = self.normalize_id(test_id)
if self._current_id != test_id:
Expand All @@ -533,6 +541,7 @@ def lookup(self, test_id: str, value: str) -> tuple[str, str]:
expected = ""

value = self.normalize_tensor_descriptors(value)
value = self.normalize_device_name(value)
value = value.strip()
if value != expected and os.environ.get("EXPECTTEST_ACCEPT", "0") not in {
"0",
Expand Down
14 changes: 10 additions & 4 deletions helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,15 @@ def configs(self) -> list[Config]:
"""
return self.kernel.configs

def to_triton_code(self, config: ConfigLike | None = None) -> str:
def to_triton_code(
self, config: ConfigLike | None = None, emit_repro_caller: bool = False
) -> str:
"""
Generate Triton code for the kernel based on the given configuration.

Args:
config: The configuration to use for code generation.
emit_repro_caller: Emits a main function to call the triton kernel with example inputs.

Returns:
str: The generated Triton code as a string.
Expand All @@ -394,7 +397,7 @@ def to_triton_code(self, config: ConfigLike | None = None) -> str:
if not isinstance(config, Config):
config = Config(**config) # pyright: ignore[reportArgumentType]
self.env.config_spec.normalize(config)
root = generate_ast(self.host_function, config)
root = generate_ast(self.host_function, config, emit_repro_caller)
return get_needed_imports(root) + unparse(root)

def compile_config(
Expand All @@ -418,13 +421,16 @@ def compile_config(
)
if (rv := self._compile_cache.get(config)) is not None:
return rv
triton_code = self.to_triton_code(config)
triton_code = self.to_triton_code(
config, emit_repro_caller=self.settings.print_output_code
)
module = PyCodeCache.load(triton_code)
if allow_print:
log.info("Output code: \n%s", triton_code)
log.info("Output code written to: %s", module.__file__)
log.debug("Debug string: \n%s", LazyString(lambda: self._debug_str()))
if self.settings.print_output_code:
print(triton_code, file=sys.stderr)
module = PyCodeCache.load(triton_code)
rv = getattr(module, self.kernel.name)
self._compile_cache[config] = rv
return rv
Expand Down
169 changes: 169 additions & 0 deletions test/test_misc.expected
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,175 @@ def fn(x: torch.Tensor, *, _launcher=_default_launcher):
_launcher(_helion_fn, (triton.cdiv(m, _BLOCK_SIZE_1),), x, out, out.stride(0), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestMisc.test_triton_repro)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_add(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
load = tl.load(x + indices_0[:, None] * x_stride_0, mask_0[:, None], other=0)
load_1 = tl.load(x + indices_0[:, None] * x_stride_0, mask_0[:, None], other=0)
v_0 = load + load_1
tl.store(out + indices_0[:, None] * out_stride_0, v_0, mask_0[:, None])

def add(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
"""
Add two tensors element-wise with broadcasting support.

Args:
x: First input tensor
y: Second input tensor

Returns:
A new tensor containing the element-wise sum of x and y
"""
x, y = torch.broadcast_tensors(x, y)
out = torch.empty(x.shape, dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
_BLOCK_SIZE_0 = 16
_launcher(_helion_add, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * 1,), x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

def call():
from torch._dynamo.testing import rand_strided
x = rand_strided(size=torch.Size([16, 1]), stride=(1, 1), dtype=torch.float32, device=DEVICE)
y = rand_strided(size=torch.Size([16, 1]), stride=(1, 1), dtype=torch.float32, device=DEVICE)
add(x, y)
if __name__ == '__main__':
call()

--- assertExpectedJournal(TestMisc.test_triton_repro_add)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_add(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
num_blocks_0 = tl.cdiv(x_size_0, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
load = tl.load(x + indices_0[:, None] * x_stride_0, mask_0[:, None], other=0)
load_1 = tl.load(x + indices_0[:, None] * x_stride_0, mask_0[:, None], other=0)
v_0 = load + load_1
tl.store(out + indices_0[:, None] * out_stride_0, v_0, mask_0[:, None])

def add(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
"""
Add two tensors element-wise with broadcasting support.

Args:
x: First input tensor
y: Second input tensor

Returns:
A new tensor containing the element-wise sum of x and y
"""
x, y = torch.broadcast_tensors(x, y)
out = torch.empty(x.shape, dtype=torch.promote_types(x.dtype, y.dtype), device=x.device)
_BLOCK_SIZE_0 = 16
_launcher(_helion_add, (triton.cdiv(x.size(0), _BLOCK_SIZE_0) * 1,), x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

def call():
from torch._dynamo.testing import rand_strided
x = rand_strided(size=(16, 1), stride=(1, 1), dtype=torch.float32, device=DEVICE)
y = rand_strided(size=(16, 1), stride=(1, 1), dtype=torch.float32, device=DEVICE)
add(x, y)
if __name__ == '__main__':
call()

--- assertExpectedJournal(TestMisc.test_triton_repro_custom_static_shapes_False)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kernel(t, out, t_size_0, out_stride_0, t_stride_0, b, i, f, _BLOCK_SIZE_0: tl.constexpr):
num_blocks_0 = tl.cdiv(t_size_0, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < t_size_0
_and = b and True
if _and:
load = tl.load(t + indices_0[:, None] * t_stride_0, mask_0[:, None], other=0)
v_0 = i.to(tl.float32)
v_1 = load + v_0
v_2 = v_1 + f
tl.store(out + indices_0[:, None] * out_stride_0, v_2, mask_0[:, None])

def kernel(t: torch.Tensor, i: int, s: str, b: bool, f: float, *, _launcher=_default_launcher):
out = torch.empty_like(t)
_BLOCK_SIZE_0 = 16
_launcher(_helion_kernel, (triton.cdiv(t.size(0), _BLOCK_SIZE_0) * 1,), t, out, t.size(0), out.stride(0), t.stride(0), b, i, f, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

def call():
from torch._dynamo.testing import rand_strided
t = rand_strided(size=(16, 1), stride=(1, 1), dtype=torch.float32, device=DEVICE)
i = 8192
s = 'foo'
b = False
f = 1.1
kernel(t, i, s, b, f)
if __name__ == '__main__':
call()

--- assertExpectedJournal(TestMisc.test_triton_repro_custom_static_shapes_True)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_kernel(t, out, b, i, f, _BLOCK_SIZE_0: tl.constexpr):
num_blocks_0 = tl.cdiv(16, _BLOCK_SIZE_0)
pid_0 = tl.program_id(0) % num_blocks_0
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
_and = b and True
if _and:
load = tl.load(t + indices_0[:, None] * 1, None)
v_0 = i.to(tl.float32)
v_1 = load + v_0
v_2 = v_1 + f
tl.store(out + indices_0[:, None] * 1, v_2, None)

def kernel(t: torch.Tensor, i: int, s: str, b: bool, f: float, *, _launcher=_default_launcher):
out = torch.empty_like(t)
_BLOCK_SIZE_0 = 16
_launcher(_helion_kernel, (triton.cdiv(16, _BLOCK_SIZE_0) * 1,), t, out, b, i, f, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

def call():
from torch._dynamo.testing import rand_strided
t = rand_strided(size=(16, 1), stride=(1, 1), dtype=torch.float32, device=DEVICE)
i = 8192
s = 'foo'
b = False
f = 1.1
kernel(t, i, s, b, f)
if __name__ == '__main__':
call()

--- assertExpectedJournal(TestMisc.test_tuple_literal_subscript)
from __future__ import annotations

Expand Down
Loading
Loading