Skip to content

Commit 50646f1

Browse files
committed
initial version
1 parent 5c71db4 commit 50646f1

File tree

8 files changed

+447
-78
lines changed

8 files changed

+447
-78
lines changed

helion/_compiler/ast_extension.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import TYPE_CHECKING
99
from typing import TypeVar
1010

11+
import torch
12+
1113
from .. import exc
1214
from .source_location import SourceLocation
1315
from .source_location import current_location
@@ -82,10 +84,31 @@ def __repr__(self) -> str:
8284

8385
def update_type_info(self, type_info: TypeInfo) -> TypeInfo:
8486
if self._type_info is not None and type_info != self._type_info:
87+
prev_rank = self._tensor_rank(self._type_info)
88+
new_rank = self._tensor_rank(type_info)
89+
if (
90+
prev_rank is not None
91+
and new_rank is not None
92+
and prev_rank != new_rank
93+
):
94+
self._type_info = type_info
95+
return self._type_info
8596
type_info = self._type_info.merge(type_info)
8697
self._type_info = type_info
8798
return self._type_info
8899

100+
@staticmethod
101+
def _tensor_rank(type_info: "TypeInfo") -> int | None:
102+
fake_value = getattr(type_info, "fake_value", None)
103+
if isinstance(fake_value, torch.Tensor):
104+
return fake_value.dim()
105+
tensor = getattr(type_info, "tensor", None)
106+
if tensor is not None:
107+
fake_value = getattr(tensor, "fake_value", None)
108+
if isinstance(fake_value, torch.Tensor):
109+
return fake_value.dim()
110+
return None
111+
89112
def debug_annotations(self) -> list[str]:
90113
result = []
91114
if self._type_info:

helion/_compiler/compile_environment.py

Lines changed: 148 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,30 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
142142
if rdim.reduction and rdim.size == size:
143143
return rdim
144144

145+
# Check if size matches any tile dimension for symbolic equality.
146+
# When building expressions that mix sizes derived from tiles
147+
# (e.g., via slicing) with sizes coming directly from tile block vars, we
148+
# want them to share the same SymInt variable whenever they are equal by
149+
# construction. This preserves equality in the shape environment and avoids
150+
# spurious "size mismatch" issues during fake-tensor broadcasting and
151+
# arithmetic in type propagation.
152+
if isinstance(size, torch.SymInt):
153+
size_str = str(size)
154+
for block_info in self.block_sizes:
155+
if not block_info.reduction and str(block_info.var) == size_str:
156+
# Create reduction dimension with the same var to preserve
157+
# symbolic equality and ensure all later users see identical
158+
# symbols (rather than equal-but-distinct SymInts).
159+
rdim_idx = self.allocate_block_size(
160+
size,
161+
reduction=True,
162+
source=ReductionLoopBlockSizeSource(
163+
reduction_loop=len([b for b in self.block_sizes if b.reduction])
164+
),
165+
)
166+
self.block_sizes[rdim_idx].var = block_info.var
167+
return self.block_sizes[rdim_idx]
168+
145169
# Allocate a new reduction dimension
146170
rdim_idx = self.allocate_block_size(
147171
size,
@@ -203,6 +227,91 @@ def cached_create_unbacked_symint(
203227
self._symint_cache[key] = result
204228
return result
205229

230+
231+
def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None:
232+
"""Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
233+
tensor._tile_index_block_id = block_id # type: ignore[attr-defined]
234+
235+
def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
236+
"""Return the originating ``tile.index`` block id if present."""
237+
return getattr(tensor, "_tile_index_block_id", None)
238+
239+
def get_indexer_output_dims(
240+
self,
241+
indexer_tensor: torch.Tensor,
242+
base_dim_size: int | torch.SymInt | None,
243+
) -> list[int | torch.SymInt]:
244+
"""Map a tensor indexer's shape to the output dimensions for advanced indexing."""
245+
246+
dims = list(indexer_tensor.size())
247+
non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1]
248+
249+
# Multi-dimensional indexer - return full shape
250+
if len(non_broadcast_dims) > 1:
251+
return dims
252+
253+
block_id = self.get_tile_index_tensor_block_id(indexer_tensor)
254+
if block_id is None and base_dim_size is not None:
255+
block_id = self.get_block_id(base_dim_size)
256+
if block_id is None and non_broadcast_dims:
257+
block_id = self.get_block_id(non_broadcast_dims[0])
258+
259+
if block_id is not None:
260+
return [self.block_sizes[block_id].var]
261+
if non_broadcast_dims:
262+
return [non_broadcast_dims[0]]
263+
return [1]
264+
265+
def tensor_indexer_broadcast_shape(
266+
self, tensors: typing.Sequence[torch.Tensor]
267+
) -> list[int | torch.SymInt] | None:
268+
"""Compute a shared broadcast shape for tensor indexers when needed."""
269+
270+
tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)]
271+
if not tensor_list:
272+
return None
273+
274+
if all(self.get_tile_index_tensor_block_id(t) is not None for t in tensor_list):
275+
return None
276+
277+
shapes = [list(t.size()) for t in tensor_list]
278+
return compute_broadcast_shape_for_tensor_indexers(shapes, self)
279+
280+
def resolve_tile_index_shape(
281+
self, input_tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
282+
) -> tuple[list[int | torch.SymInt], int | None]:
283+
"""Resolve the symbolic shape for tensors derived from ``tile.index``.
284+
285+
Returns a copy of ``output_shape`` where the single non-broadcast
286+
dimension is replaced with the canonical block-symbol and the associated
287+
block_id to register on the new tensor. If the tensor is not a tile
288+
indexer or it introduces more than one non-broadcast dimension, the
289+
original shape and ``None`` are returned.
290+
"""
291+
292+
block_id = self.get_tile_index_tensor_block_id(input_tensor)
293+
if block_id is None:
294+
return list(output_shape), None
295+
296+
resolved = list(output_shape)
297+
non_broadcast = [i for i, s in enumerate(resolved) if self.size_hint(s) != 1]
298+
if len(non_broadcast) <= 1:
299+
if non_broadcast:
300+
resolved[non_broadcast[0]] = self.block_sizes[block_id].var
301+
return resolved, block_id
302+
return resolved, None
303+
304+
def new_index_result(
305+
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
306+
) -> torch.Tensor:
307+
"""Create a new tensor for indexing/view ops while preserving tile index provenance."""
308+
309+
resolved_shape, block_id = self.resolve_tile_index_shape(tensor, output_shape)
310+
result = tensor.new_empty(resolved_shape)
311+
if block_id is not None:
312+
self.register_tile_index_tensor_block_id(result, block_id)
313+
return result
314+
206315
def to_fake(self, obj: object, origin: Origin) -> object:
207316
if isinstance(obj, torch.Tensor):
208317
return self._to_fake_tensor(obj, origin.to_source())
@@ -283,6 +392,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
283392
self.fake_mode, tensor, shape_env=self.shape_env, source=source
284393
)
285394
self.input_sources[result] = source
395+
if hasattr(tensor, "_tile_index_block_id"):
396+
self.register_tile_index_tensor_block_id(
397+
result, typing.cast(int, getattr(tensor, "_tile_index_block_id"))
398+
)
286399
if isinstance(source, LocalSource):
287400
for i, s in enumerate(result.size()):
288401
if isinstance(s, torch.SymInt) and isinstance(
@@ -357,9 +470,9 @@ def current() -> CompileEnvironment:
357470
@staticmethod
358471
def has_current() -> bool:
359472
try:
360-
CompileEnvironment.current()
361-
return True
362-
except NoCurrentEnvironment:
473+
CompileEnvironment.current()
474+
return True
475+
except NoCurrentEnvironment:
363476
return False
364477

365478
def get_block_id(self, size: int | torch.SymInt | sympy.Expr) -> int | None:
@@ -535,3 +648,35 @@ def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:
535648

536649
def _has_unbacked(expr: sympy.Expr) -> bool:
537650
return any(n.name.startswith("u") for n in expr.free_symbols) # pyright: ignore[reportAttributeAccessIssue]
651+
652+
653+
def compute_broadcast_shape_for_tensor_indexers(
654+
shapes: list[list[int | torch.SymInt]],
655+
env: "CompileEnvironment"
656+
) -> list[int | torch.SymInt]:
657+
"""
658+
Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
659+
660+
Args:
661+
shapes: List of shapes from each tensor indexer
662+
env: CompileEnvironment for size_hint and known_equal checks
663+
664+
Returns:
665+
Broadcast shape as list of dimensions
666+
"""
667+
if not shapes:
668+
return []
669+
670+
max_ndim = max(len(s) for s in shapes)
671+
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]
672+
broadcast_shape: list[int | torch.SymInt] = []
673+
674+
for dims_at_pos in zip(*padded, strict=True):
675+
chosen: int | torch.SymInt | None = None
676+
for d in dims_at_pos:
677+
if env.size_hint(d) != 1:
678+
if chosen is None or env.known_equal(chosen, d):
679+
chosen = d
680+
broadcast_shape.append(chosen if chosen is not None else 1)
681+
682+
return broadcast_shape

0 commit comments

Comments
 (0)