Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions helion/_compiler/ast_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import TYPE_CHECKING
from typing import TypeVar

import torch

from .. import exc
from .source_location import SourceLocation
from .source_location import current_location
Expand Down Expand Up @@ -82,10 +84,29 @@ def __repr__(self) -> str:

def update_type_info(self, type_info: TypeInfo) -> TypeInfo:
if self._type_info is not None and type_info != self._type_info:
prev_rank = self._tensor_rank(self._type_info)
new_rank = self._tensor_rank(type_info)
if (
prev_rank is not None
and new_rank is not None
and prev_rank != new_rank
):
self._type_info = type_info
return self._type_info
type_info = self._type_info.merge(type_info)
self._type_info = type_info
return self._type_info

@staticmethod
def _tensor_rank(type_info: "TypeInfo") -> int | None:
for attr in ["fake_value", "tensor"]:
obj = getattr(type_info, attr, None)
if attr == "tensor" and obj is not None:
obj = getattr(obj, "fake_value", None)
if isinstance(obj, torch.Tensor):
return obj.dim()
return None

def debug_annotations(self) -> list[str]:
result = []
if self._type_info:
Expand Down
137 changes: 136 additions & 1 deletion helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,47 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
if rdim.reduction and rdim.size == size:
return rdim

# Check if size matches any tile dimension for symbolic equality.
# When building expressions that mix sizes derived from tiles (e.g. via
# slicing) with sizes coming directly from tile block vars, we want them
# to share the same SymInt variable whenever they are equal by
# construction. This preserves equality in the shape environment and
# avoids spurious "size mismatch" issues during fake-tensor broadcasting
# and arithmetic in type propagation.
if isinstance(size, torch.SymInt):
block_idx = self.get_block_id(size)
if block_idx is not None and not self.block_sizes[block_idx].reduction:
return self._clone_block_size_as_reduction(block_idx, size)

sym = size._sympy_()
for block_idx, block_info in enumerate(self.block_sizes):
if not block_info.reduction and sym == block_info.symbol():
return self._clone_block_size_as_reduction(block_idx, size)

# Allocate a new reduction dimension
return self._allocate_new_reduction(size)

def _clone_block_size_as_reduction(
self, block_idx: int, size: torch.SymInt | int
) -> BlockSizeInfo:
rdim = self._allocate_new_reduction(size)
rdim.var = self.block_sizes[block_idx].var
return rdim

def _allocate_new_reduction(self, size: torch.SymInt | int) -> BlockSizeInfo:
rdim_idx = self.allocate_block_size(
size,
reduction=True,
source=ReductionLoopBlockSizeSource(
sum([int(bs.reduction) for bs in self.block_sizes])
self._next_reduction_loop_index()
),
hint=next_power_of_2(self.size_hint(size)),
)
return self.block_sizes[rdim_idx]

def _next_reduction_loop_index(self) -> int:
return sum(int(info.reduction) for info in self.block_sizes)

def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt:
with self.shape_env.ignore_fresh_unbacked_symbols():
sym = self.shape_env.create_unbacked_symint()
Expand Down Expand Up @@ -203,6 +233,90 @@ def cached_create_unbacked_symint(
self._symint_cache[key] = result
return result


def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None:
"""Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
tensor._tile_index_block_id = block_id # type: ignore[attr-defined]

def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None:
"""Return the originating ``tile.index`` block id if present."""
return getattr(tensor, "_tile_index_block_id", None)

def get_indexer_output_dims(
self,
indexer_tensor: torch.Tensor,
base_dim_size: int | torch.SymInt | None,
) -> list[int | torch.SymInt]:
"""Map a tensor indexer's shape to the output dimensions for advanced indexing."""

dims = list(indexer_tensor.size())
non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1]

# Multi-dimensional indexer - return full shape
if len(non_broadcast_dims) > 1:
return dims

# Try to find block_id from various sources
block_id = (
self.get_tile_index_tensor_block_id(indexer_tensor)
or (self.get_block_id(base_dim_size) if base_dim_size is not None else None)
or (self.get_block_id(non_broadcast_dims[0]) if non_broadcast_dims else None)
)

if block_id is not None:
return [self.block_sizes[block_id].var]
return [non_broadcast_dims[0]] if non_broadcast_dims else [1]

def tensor_indexer_broadcast_shape(
self, tensors: typing.Sequence[torch.Tensor]
) -> list[int | torch.SymInt] | None:
"""Compute a shared broadcast shape for tensor indexers when needed."""

tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)]
if not tensor_list:
return None

if all(self.get_tile_index_tensor_block_id(t) is not None for t in tensor_list):
return None

shapes = [list(t.size()) for t in tensor_list]
return compute_broadcast_shape_for_tensor_indexers(shapes, self)

def resolve_tile_index_shape(
self, input_tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
) -> tuple[list[int | torch.SymInt], int | None]:
"""Resolve the symbolic shape for tensors derived from ``tile.index``.

Returns a copy of ``output_shape`` where the single non-broadcast
dimension is replaced with the canonical block-symbol and the associated
block_id to register on the new tensor. If the tensor is not a tile
indexer or it introduces more than one non-broadcast dimension, the
original shape and ``None`` are returned.
"""

block_id = self.get_tile_index_tensor_block_id(input_tensor)
if block_id is None:
return list(output_shape), None

resolved = list(output_shape)
non_broadcast = [i for i, s in enumerate(resolved) if self.size_hint(s) != 1]
if len(non_broadcast) <= 1:
if non_broadcast:
resolved[non_broadcast[0]] = self.block_sizes[block_id].var
return resolved, block_id
return resolved, None

def new_index_result(
self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt]
) -> torch.Tensor:
"""Create a new tensor for indexing/view ops while preserving tile index provenance."""

resolved_shape, block_id = self.resolve_tile_index_shape(tensor, output_shape)
result = tensor.new_empty(resolved_shape)
if block_id is not None:
self.register_tile_index_tensor_block_id(result, block_id)
return result

def to_fake(self, obj: object, origin: Origin) -> object:
if isinstance(obj, torch.Tensor):
return self._to_fake_tensor(obj, origin.to_source())
Expand Down Expand Up @@ -283,6 +397,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
self.fake_mode, tensor, shape_env=self.shape_env, source=source
)
self.input_sources[result] = source
if hasattr(tensor, "_tile_index_block_id"):
self.register_tile_index_tensor_block_id(
result, typing.cast(int, getattr(tensor, "_tile_index_block_id"))
)
if isinstance(source, LocalSource):
for i, s in enumerate(result.size()):
if isinstance(s, torch.SymInt) and isinstance(
Expand Down Expand Up @@ -535,3 +653,20 @@ def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:

def _has_unbacked(expr: sympy.Expr) -> bool:
return any(n.name.startswith("u") for n in expr.free_symbols) # pyright: ignore[reportAttributeAccessIssue]


def compute_broadcast_shape_for_tensor_indexers(
shapes: list[list[int | torch.SymInt]],
env: "CompileEnvironment"
) -> list[int | torch.SymInt]:
"""Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting."""
if not shapes:
return []

max_ndim = max(len(s) for s in shapes)
padded = [([1] * (max_ndim - len(s)) + s) for s in shapes]

return [
next((d for d in dims if env.size_hint(d) != 1), 1)
for dims in zip(*padded, strict=True)
]
Loading
Loading