Skip to content

Commit b317f03

Browse files
committed
wip
1 parent 0361edf commit b317f03

File tree

3 files changed

+99
-0
lines changed

3 files changed

+99
-0
lines changed

helion/_compiler/indexing_strategy.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sympy
1010
import torch
1111
from torch._inductor.utils import triton_type
12+
from torch._prims_common import compute_required_storage_length
1213

1314
from .. import exc
1415
from .._compat import get_tensor_descriptor_fn_name
@@ -519,6 +520,27 @@ def compute_shape(
519520
assert len(input_size) == 0, "invalid subscript"
520521
return output_size
521522

523+
@staticmethod
524+
def _needs_int64(fake_value: torch.Tensor) -> bool:
525+
storage_offset = fake_value.storage_offset()
526+
try:
527+
required = compute_required_storage_length(
528+
fake_value.shape,
529+
fake_value.stride(),
530+
storage_offset,
531+
)
532+
except Exception: # pragma: no cover - defensive fallback
533+
return False
534+
535+
if not isinstance(required, int):
536+
return False
537+
538+
if abs(storage_offset) > torch.iinfo(torch.int32).max:
539+
return True
540+
541+
max_offset = required - 1
542+
return max_offset > torch.iinfo(torch.int32).max
543+
522544
@staticmethod
523545
def create(
524546
state: CodegenState,
@@ -533,6 +555,8 @@ def create(
533555
output_size = SubscriptIndexing.compute_shape(fake_value, index)
534556
env = CompileEnvironment.current()
535557
dtype = env.triton_index_type()
558+
if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value):
559+
raise exc.IndexOffsetOutOfRangeForInt32(env.settings.index_dtype)
536560

537561
def _is_size_one(size: int | torch.SymInt) -> bool:
538562
return env.known_equal(size, 1)

helion/exc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ class InvalidIndexingType(BaseError):
107107
message = "Expected tile/int/None/tensor/etc in tensor[...], got {0!s}."
108108

109109

110+
class IndexOffsetOutOfRangeForInt32(BaseError):
111+
message = (
112+
"Tensor indexing offsets exceed the int32 range, but the kernel index_dtype is {0}. "
113+
"Use @helion.kernel(index_dtype=torch.int64) to enable larger offsets."
114+
)
115+
116+
110117
class DataDependentOutputShapeNotSupported(BaseError):
111118
message = (
112119
"{op_desc} is not supported in Helion device loops because it produces "

test/test_indexing.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from helion._testing import RefEagerTestBase
1212
from helion._testing import TestCase
1313
from helion._testing import code_and_output
14+
from helion._testing import skipIfLowVRAM
1415
from helion._testing import skipIfNormalMode
1516
from helion._testing import skipIfRefEager
1617
from helion._testing import skipIfRocm
@@ -241,6 +242,73 @@ def test_block_size_access(x: torch.Tensor) -> torch.Tensor:
241242
expected = torch.full_like(x, 1, dtype=torch.int32)
242243
torch.testing.assert_close(result, expected)
243244

245+
@skipIfLowVRAM("Test allocates ~15GB across multiple CUDA tensors")
246+
def test_int32_offset_out_of_range_error(self):
247+
repro_config = helion.Config(
248+
block_sizes=[32, 32],
249+
flatten_loops=[False],
250+
indexing="pointer",
251+
l2_groupings=[1],
252+
loop_orders=[[0, 1]],
253+
num_stages=3,
254+
num_warps=4,
255+
pid_type="flat",
256+
range_flattens=[None],
257+
range_multi_buffers=[None],
258+
range_num_stages=[],
259+
range_unroll_factors=[0],
260+
range_warp_specializes=[],
261+
)
262+
263+
def make_kernel(*, index_dtype: torch.dtype):
264+
kwargs = dict(config=repro_config, static_shapes=False)
265+
kwargs["index_dtype"] = index_dtype
266+
decorator = helion.kernel(**kwargs)
267+
268+
@decorator
269+
def repro_bf16_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
270+
x, y = torch.broadcast_tensors(x, y)
271+
out = torch.empty(
272+
x.shape,
273+
dtype=torch.promote_types(x.dtype, y.dtype),
274+
device=x.device,
275+
)
276+
for tile in hl.tile(out.size()):
277+
out[tile] = x[tile] + y[tile]
278+
return out
279+
280+
return repro_bf16_add
281+
282+
def run_case(shape, *, index_dtype, expect_int64=False, expect_error=False):
283+
kernel = make_kernel(index_dtype=index_dtype)
284+
x = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16)
285+
y = torch.randn(*shape, device=DEVICE, dtype=torch.bfloat16)
286+
torch.cuda.synchronize()
287+
if expect_error:
288+
with self.assertRaisesRegex(
289+
helion.exc.IndexOffsetOutOfRangeForInt32,
290+
f"index_dtype is {index_dtype}",
291+
):
292+
code_and_output(kernel, (x, y))
293+
torch.cuda.synchronize()
294+
return
295+
296+
code, out = code_and_output(kernel, (x, y))
297+
torch.cuda.synchronize()
298+
checker = self.assertIn if expect_int64 else self.assertNotIn
299+
checker("tl.int64", code)
300+
torch.cuda.synchronize()
301+
ref_out = torch.add(x, y)
302+
torch.cuda.synchronize()
303+
torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=1e-2)
304+
305+
small_shape = (128, 128)
306+
large_shape = (51200, 51200)
307+
308+
run_case(small_shape, index_dtype=torch.int32)
309+
run_case(large_shape, index_dtype=torch.int32, expect_error=True)
310+
run_case(large_shape, index_dtype=torch.int64, expect_int64=True)
311+
244312
def test_assign_int(self):
245313
@helion.kernel
246314
def fn(x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)