Skip to content

Commit d3a1a6d

Browse files
committed
fix
1 parent e695ab4 commit d3a1a6d

File tree

9 files changed

+302
-48
lines changed

9 files changed

+302
-48
lines changed

helion/_compat.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import contextlib
44
import functools
55
import re
6+
from typing import TYPE_CHECKING
67
from typing import Any
78
from typing import Callable
89
from typing import cast
@@ -16,6 +17,9 @@
1617
import triton.language as tl
1718
import triton.runtime.jit as triton_jit
1819

20+
if TYPE_CHECKING:
21+
from collections.abc import Generator
22+
1923
NativeSpecializeImpl = Callable[
2024
[type[BaseBackend], object, bool, bool, bool], tuple[object, ...]
2125
]
@@ -306,3 +310,26 @@ def supports_amd_cdna_tunables() -> bool:
306310
return match is not None and int(match.group(1), 16) >= 0x908
307311
except Exception:
308312
return False
313+
314+
315+
@contextlib.contextmanager
316+
def patch_fake_tensor_ctor() -> Generator[None, None, None]:
317+
"""Context manager that patches FakeTensor.__new__ for the following purpose:
318+
- Add _tile_index_block_id attribute with None as initial value.
319+
This ensures all FakeTensors have a _tile_index_block_id attribute,
320+
which is used to track which block a tile.index tensor originated from.
321+
"""
322+
from torch._subclasses.fake_tensor import FakeTensor
323+
324+
original_new = FakeTensor.__new__
325+
326+
def patched_new(*args: Any, **kwargs: Any) -> FakeTensor: # noqa: ANN401
327+
result = original_new(*args, **kwargs)
328+
result._tile_index_block_id = None # type: ignore[attr-defined]
329+
return result
330+
331+
FakeTensor.__new__ = staticmethod(patched_new) # type: ignore[method-assign]
332+
try:
333+
yield
334+
finally:
335+
FakeTensor.__new__ = original_new # type: ignore[method-assign]

helion/_compiler/compile_environment.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.fx.experimental.symbolic_shapes import ShapeEnv
2121

2222
from .. import exc
23+
from .._compat import patch_fake_tensor_ctor
2324
from ..language.constexpr import ConstExpr
2425
from .loop_dependency_checker import LoopDependencyChecker
2526
from .source_location import SourceLocation
@@ -272,6 +273,67 @@ def cached_create_unbacked_symint(
272273
self._symint_cache[key] = result
273274
return result
274275

276+
def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
277+
"""Return the originating ``tile.index`` block id if present."""
278+
return tensor._tile_index_block_id # type: ignore[attr-defined]
279+
280+
def should_broadcast_tensor_indexers(
281+
self, tensors: typing.Sequence[torch.Tensor]
282+
) -> bool:
283+
"""Check whether tensor indexers need broadcasting."""
284+
if not tensors:
285+
return False
286+
# tile.index tensors don't need broadcasting
287+
if all(self.get_tile_index_tensor_block_id(t) for t in tensors):
288+
return False
289+
# Single 1D tensor doesn't need broadcast handling
290+
return not (len(tensors) == 1 and tensors[0].ndim == 1)
291+
292+
def tensor_indexer_broadcast_shape(
293+
self, tensors: typing.Sequence[torch.Tensor]
294+
) -> list[int | torch.SymInt]:
295+
"""Compute broadcast shape for tensor indexers."""
296+
shapes = [list(t.size()) for t in tensors]
297+
if all(len(s) == 1 for s in shapes) and len(shapes) > 1: # Cartesian
298+
return [s[0] for s in shapes]
299+
max_ndim = max(len(s) for s in shapes)
300+
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
301+
return [
302+
next((d for d in dims if self.size_hint(d) != 1), 1)
303+
for dims in zip(*padded, strict=True)
304+
]
305+
306+
def tensor_indexer_dims(
307+
self, indexer_tensor: torch.Tensor
308+
) -> list[int | torch.SymInt]:
309+
"""Return dims contributed by a tensor indexer (non-broadcast case)."""
310+
non_trivial = [d for d in indexer_tensor.size() if self.size_hint(d) != 1]
311+
bid = self.get_tile_index_tensor_block_id(indexer_tensor) or (
312+
self.get_block_id(non_trivial[0]) if non_trivial else None
313+
)
314+
if bid:
315+
return [self.block_sizes[bid].var]
316+
return non_trivial or [1] # type: ignore[return-value]
317+
318+
def new_index_result(
319+
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
320+
) -> torch.Tensor:
321+
"""Create tensor for indexing ops, preserving tile index provenance."""
322+
shape = list(output_shape)
323+
non_trivial = [i for i, s in enumerate(shape) if self.size_hint(s) != 1]
324+
if len(non_trivial) > 1:
325+
return tensor.new_empty(shape)
326+
bid = self.get_tile_index_tensor_block_id(tensor)
327+
if non_trivial:
328+
if bid is None:
329+
bid = self.get_block_id(shape[non_trivial[0]])
330+
if bid:
331+
shape[non_trivial[0]] = self.block_sizes[bid].var
332+
result = tensor.new_empty(shape)
333+
if bid:
334+
result._tile_index_block_id = bid # type: ignore[attr-defined]
335+
return result
336+
275337
def to_fake(self, obj: object, origin: Origin) -> object:
276338
if obj is None:
277339
return None
@@ -418,6 +480,8 @@ def sympy_debug(self, expr: sympy.Expr) -> str:
418480

419481
def __enter__(self) -> Self:
420482
assert getattr(tls, "env", None) is None, "CompileEnvironment already active"
483+
self.fake_tensor_ctor_patch_ctx = patch_fake_tensor_ctor()
484+
self.fake_tensor_ctor_patch_ctx.__enter__()
421485
self.fake_mode.__enter__()
422486
tls.env = self
423487
self.loop_dependency_checker = LoopDependencyChecker()
@@ -431,6 +495,7 @@ def __exit__(
431495
) -> None:
432496
tls.env = None
433497
self.fake_mode.__exit__(exc_type, exc_value, traceback)
498+
self.fake_tensor_ctor_patch_ctx.__exit__(exc_type, exc_value, traceback)
434499

435500
@staticmethod
436501
def current() -> CompileEnvironment:

helion/_compiler/indexing_strategy.py

Lines changed: 129 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,10 @@ def compute_shape(
575575
input_size = collections.deque(tensor.size())
576576
output_size = []
577577
env = CompileEnvironment.current()
578+
579+
tensor_indexers = [k for k in index if isinstance(k, torch.Tensor)]
580+
should_broadcast = env.should_broadcast_tensor_indexers(tensor_indexers)
581+
578582
k_index = 0
579583
for k in index:
580584
if k is None:
@@ -617,11 +621,14 @@ def compute_shape(
617621
else:
618622
output_size.append(1)
619623
k_index += 1
620-
elif isinstance(k, torch.Tensor) and (
621-
k.ndim == 1 or (len(index) == 1 and tensor.ndim == 1)
622-
):
623-
input_size.popleft()
624-
output_size.extend(k.size())
624+
elif isinstance(k, torch.Tensor):
625+
base_dim = input_size.popleft()
626+
if not should_broadcast:
627+
output_size.extend(env.tensor_indexer_dims(k))
628+
elif k is tensor_indexers[0]:
629+
output_size.extend(
630+
env.tensor_indexer_broadcast_shape(tensor_indexers)
631+
)
625632
k_index += 1
626633
else:
627634
raise exc.InvalidIndexingType(k)
@@ -667,13 +674,99 @@ def create(
667674
output_size = SubscriptIndexing.compute_shape(fake_value, index, state)
668675
env = CompileEnvironment.current()
669676
dtype = env.triton_index_type()
677+
tensor_indexers = [k for k in index if isinstance(k, torch.Tensor)]
678+
should_broadcast = env.should_broadcast_tensor_indexers(tensor_indexers)
679+
broadcast_dims = 0
680+
if should_broadcast:
681+
broadcast_dims = len(env.tensor_indexer_broadcast_shape(tensor_indexers))
682+
is_cartesian = (
683+
broadcast_dims >= 2
684+
and len(tensor_indexers) == broadcast_dims
685+
and all(
686+
t.ndim == 1
687+
or sum(1 for d in t.size() if env.size_hint(d) != 1) <= 1
688+
for t in tensor_indexers
689+
)
690+
)
670691
if dtype == "tl.int32" and SubscriptIndexing._needs_int64(fake_value):
671692
raise exc.IndexOffsetOutOfRangeForInt32(env.index_dtype)
672693

673694
def _is_size_one(size: int | torch.SymInt) -> bool:
674695
return env.known_equal(size, 1)
675696

676697
k_index = 0
698+
699+
def tensor_index_source_and_mask(
700+
index_elem: torch.Tensor, index_var: str, pos: int
701+
) -> tuple[str, int | None]:
702+
tile_id = env.get_tile_index_tensor_block_id(index_elem)
703+
src = state.codegen.index_var(tile_id) if tile_id else index_var
704+
mask_id = tile_id or (
705+
env.get_block_id(output_size[pos]) if pos < len(output_size) else None
706+
)
707+
return src, mask_id
708+
709+
def handle_broadcast_tensor(
710+
position: int,
711+
index_elem: torch.Tensor,
712+
index_var: str,
713+
cur_output_idx: int,
714+
) -> tuple[str, dict[str, None]]:
715+
"""Handle tensor index with broadcast shape (cartesian or general)."""
716+
assert broadcast_dims > 0
717+
tensor_idx = next(
718+
i for i, t in enumerate(tensor_indexers) if t is index_elem
719+
)
720+
first_tensor_out_idx = (
721+
cur_output_idx if tensor_idx == 0 else cur_output_idx - broadcast_dims
722+
)
723+
non_trivial_output_positions: list[int] = []
724+
if is_cartesian:
725+
pos = first_tensor_out_idx + tensor_idx
726+
single_output_dim = True
727+
else:
728+
# Find position(s) where this tensor contributes non-trivial dims
729+
offset = max(0, broadcast_dims - index_elem.ndim)
730+
non_trivial_output_positions = [
731+
first_tensor_out_idx + offset + i
732+
for i in range(index_elem.ndim)
733+
if env.size_hint(index_elem.size(i)) != 1
734+
]
735+
pos = non_trivial_output_positions[0]
736+
single_output_dim = len(non_trivial_output_positions) <= 1
737+
738+
new_masks: dict[str, None] = {}
739+
if single_output_dim:
740+
src, _ = tensor_index_source_and_mask(index_elem, index_var, pos)
741+
expand = (
742+
tile_strategy.expand_str(output_size, pos)
743+
if index_elem.ndim == 1
744+
else ""
745+
)
746+
idx_val = f"({src}){expand}"
747+
else:
748+
# Multi-dim tensor with multiple non-trivial dims
749+
idx_val = f"({index_var})"
750+
if tensor_idx == 0:
751+
for p in non_trivial_output_positions:
752+
if (
753+
p < len(output_size)
754+
and (bid := env.get_block_id(output_size[p]))
755+
and (mv := state.codegen.mask_var(bid))
756+
and not _is_size_one(fake_value.size(len(index_values)))
757+
):
758+
new_masks.setdefault(
759+
f"({mv}){tile_strategy.expand_str(output_size, p)}"
760+
)
761+
# Padded iota mask
762+
if (
763+
orig_len := _get_padded_iota_original_length(state, position)
764+
) is not None:
765+
new_masks.setdefault(
766+
f"(({index_var} < {orig_len}){tile_strategy.expand_str(output_size, first_tensor_out_idx + tensor_idx)})"
767+
)
768+
return idx_val, new_masks
769+
677770
for n, k in enumerate(index):
678771
if k is None:
679772
output_idx += 1
@@ -752,40 +845,41 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
752845
index_values.append(f"tl.zeros([1], {dtype}){expand}")
753846
output_idx += 1
754847
k_index += 1
755-
elif isinstance(k, torch.Tensor) and k.ndim == 1:
756-
expand = tile_strategy.expand_str(output_size, output_idx)
848+
elif isinstance(k, torch.Tensor):
757849
ast_index = state.ast_args[1]
758850
assert isinstance(ast_index, (list, tuple))
759-
assert len(ast_index) == len(index)
760851
index_var = state.codegen.lift(ast_index[n], prefix="index").id
761-
index_values.append(f"({index_var}){expand}")
762-
if (block_idx := env.get_block_id(output_size[output_idx])) is not None:
763-
if mask := state.codegen.mask_var(block_idx):
764-
mask_values.setdefault(f"({mask}){expand}")
765-
# Check if this index comes from a padded hl.arange and generate mask
766-
if (
767-
original_length := _get_padded_iota_original_length(state, n)
768-
) is not None:
769-
mask_values.setdefault(f"({index_var} < {original_length}){expand}")
770-
output_idx += 1
771-
k_index += 1
772-
elif (
773-
isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1
774-
):
775-
# TODO(jansel): combine this case with the above
776-
ast_index = state.ast_args[1]
777-
assert isinstance(ast_index, (list, tuple))
778-
assert len(ast_index) == 1
779-
index_var = state.codegen.lift(ast_index[0], prefix="index").id
780-
index_values.append(index_var)
781-
output_idx += k.ndim
782-
for n, s in enumerate(output_size):
783-
if (block_idx := env.get_block_id(s)) is not None and (
784-
mask := state.codegen.mask_var(block_idx)
852+
853+
# Use broadcast handling for: multiple tensors, or single tensor with ndim > 1
854+
if should_broadcast:
855+
idx_val, new_masks = handle_broadcast_tensor(
856+
n, k, index_var, output_idx
857+
)
858+
index_values.append(idx_val)
859+
mask_values.update(new_masks)
860+
if k is tensor_indexers[0]:
861+
output_idx += broadcast_dims
862+
k_index += 1
863+
continue
864+
865+
index_source, mask_block_id = tensor_index_source_and_mask(
866+
k, index_var, output_idx
867+
)
868+
869+
expand = (
870+
tile_strategy.expand_str(output_size, output_idx)
871+
if k.ndim < len(output_size)
872+
else ""
873+
)
874+
index_values.append(f"({index_source}){expand}")
875+
if mask_block_id is not None:
876+
mask_var = state.codegen.mask_var(mask_block_id)
877+
if mask_var and not _is_size_one(
878+
fake_value.size(len(index_values) - 1)
785879
):
786-
mask_values.setdefault(
787-
f"({mask}){tile_strategy.expand_str(output_size, n)}"
788-
)
880+
mask_values.setdefault(f"({mask_var}){expand}")
881+
882+
output_idx += k.ndim
789883
k_index += 1
790884
else:
791885
raise exc.InvalidIndexingType(type(k))

0 commit comments

Comments
 (0)