diff --git a/helion/_compiler/ast_extension.py b/helion/_compiler/ast_extension.py index 2aca45765..171467a32 100644 --- a/helion/_compiler/ast_extension.py +++ b/helion/_compiler/ast_extension.py @@ -8,6 +8,8 @@ from typing import TYPE_CHECKING from typing import TypeVar +import torch + from .. import exc from .source_location import SourceLocation from .source_location import current_location @@ -82,10 +84,29 @@ def __repr__(self) -> str: def update_type_info(self, type_info: TypeInfo) -> TypeInfo: if self._type_info is not None and type_info != self._type_info: + prev_rank = self._tensor_rank(self._type_info) + new_rank = self._tensor_rank(type_info) + if ( + prev_rank is not None + and new_rank is not None + and prev_rank != new_rank + ): + self._type_info = type_info + return self._type_info type_info = self._type_info.merge(type_info) self._type_info = type_info return self._type_info + @staticmethod + def _tensor_rank(type_info: "TypeInfo") -> int | None: + for attr in ["fake_value", "tensor"]: + obj = getattr(type_info, attr, None) + if attr == "tensor" and obj is not None: + obj = getattr(obj, "fake_value", None) + if isinstance(obj, torch.Tensor): + return obj.dim() + return None + def debug_annotations(self) -> list[str]: result = [] if self._type_info: diff --git a/helion/_compiler/compile_environment.py b/helion/_compiler/compile_environment.py index 3e05ad16c..beb4764b9 100644 --- a/helion/_compiler/compile_environment.py +++ b/helion/_compiler/compile_environment.py @@ -142,17 +142,47 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf if rdim.reduction and rdim.size == size: return rdim + # Check if size matches any tile dimension for symbolic equality. + # When building expressions that mix sizes derived from tiles (e.g. via + # slicing) with sizes coming directly from tile block vars, we want them + # to share the same SymInt variable whenever they are equal by + # construction. This preserves equality in the shape environment and + # avoids spurious "size mismatch" issues during fake-tensor broadcasting + # and arithmetic in type propagation. + if isinstance(size, torch.SymInt): + block_idx = self.get_block_id(size) + if block_idx is not None and not self.block_sizes[block_idx].reduction: + return self._clone_block_size_as_reduction(block_idx, size) + + sym = size._sympy_() + for block_idx, block_info in enumerate(self.block_sizes): + if not block_info.reduction and sym == block_info.symbol(): + return self._clone_block_size_as_reduction(block_idx, size) + # Allocate a new reduction dimension + return self._allocate_new_reduction(size) + + def _clone_block_size_as_reduction( + self, block_idx: int, size: torch.SymInt | int + ) -> BlockSizeInfo: + rdim = self._allocate_new_reduction(size) + rdim.var = self.block_sizes[block_idx].var + return rdim + + def _allocate_new_reduction(self, size: torch.SymInt | int) -> BlockSizeInfo: rdim_idx = self.allocate_block_size( size, reduction=True, source=ReductionLoopBlockSizeSource( - sum([int(bs.reduction) for bs in self.block_sizes]) + self._next_reduction_loop_index() ), hint=next_power_of_2(self.size_hint(size)), ) return self.block_sizes[rdim_idx] + def _next_reduction_loop_index(self) -> int: + return sum(int(info.reduction) for info in self.block_sizes) + def create_block_var(self, debug_name: str, hint: int = 64) -> torch.SymInt: with self.shape_env.ignore_fresh_unbacked_symbols(): sym = self.shape_env.create_unbacked_symint() @@ -203,6 +233,90 @@ def cached_create_unbacked_symint( self._symint_cache[key] = result return result + + def register_tile_index_tensor_block_id(self, tensor: torch.Tensor, block_id: int) -> None: + """Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance.""" + tensor._tile_index_block_id = block_id # type: ignore[attr-defined] + + def get_tile_index_tensor_block_id(self, tensor: torch.Tensor) -> int | None: + """Return the originating ``tile.index`` block id if present.""" + return getattr(tensor, "_tile_index_block_id", None) + + def get_indexer_output_dims( + self, + indexer_tensor: torch.Tensor, + base_dim_size: int | torch.SymInt | None, + ) -> list[int | torch.SymInt]: + """Map a tensor indexer's shape to the output dimensions for advanced indexing.""" + + dims = list(indexer_tensor.size()) + non_broadcast_dims = [d for d in dims if self.size_hint(d) != 1] + + # Multi-dimensional indexer - return full shape + if len(non_broadcast_dims) > 1: + return dims + + # Try to find block_id from various sources + block_id = ( + self.get_tile_index_tensor_block_id(indexer_tensor) + or (self.get_block_id(base_dim_size) if base_dim_size is not None else None) + or (self.get_block_id(non_broadcast_dims[0]) if non_broadcast_dims else None) + ) + + if block_id is not None: + return [self.block_sizes[block_id].var] + return [non_broadcast_dims[0]] if non_broadcast_dims else [1] + + def tensor_indexer_broadcast_shape( + self, tensors: typing.Sequence[torch.Tensor] + ) -> list[int | torch.SymInt] | None: + """Compute a shared broadcast shape for tensor indexers when needed.""" + + tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)] + if not tensor_list: + return None + + if all(self.get_tile_index_tensor_block_id(t) is not None for t in tensor_list): + return None + + shapes = [list(t.size()) for t in tensor_list] + return compute_broadcast_shape_for_tensor_indexers(shapes, self) + + def resolve_tile_index_shape( + self, input_tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt] + ) -> tuple[list[int | torch.SymInt], int | None]: + """Resolve the symbolic shape for tensors derived from ``tile.index``. + + Returns a copy of ``output_shape`` where the single non-broadcast + dimension is replaced with the canonical block-symbol and the associated + block_id to register on the new tensor. If the tensor is not a tile + indexer or it introduces more than one non-broadcast dimension, the + original shape and ``None`` are returned. + """ + + block_id = self.get_tile_index_tensor_block_id(input_tensor) + if block_id is None: + return list(output_shape), None + + resolved = list(output_shape) + non_broadcast = [i for i, s in enumerate(resolved) if self.size_hint(s) != 1] + if len(non_broadcast) <= 1: + if non_broadcast: + resolved[non_broadcast[0]] = self.block_sizes[block_id].var + return resolved, block_id + return resolved, None + + def new_index_result( + self, tensor: torch.Tensor, output_shape: typing.Sequence[int | torch.SymInt] + ) -> torch.Tensor: + """Create a new tensor for indexing/view ops while preserving tile index provenance.""" + + resolved_shape, block_id = self.resolve_tile_index_shape(tensor, output_shape) + result = tensor.new_empty(resolved_shape) + if block_id is not None: + self.register_tile_index_tensor_block_id(result, block_id) + return result + def to_fake(self, obj: object, origin: Origin) -> object: if isinstance(obj, torch.Tensor): return self._to_fake_tensor(obj, origin.to_source()) @@ -283,6 +397,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor: self.fake_mode, tensor, shape_env=self.shape_env, source=source ) self.input_sources[result] = source + if hasattr(tensor, "_tile_index_block_id"): + self.register_tile_index_tensor_block_id( + result, typing.cast(int, getattr(tensor, "_tile_index_block_id")) + ) if isinstance(source, LocalSource): for i, s in enumerate(result.size()): if isinstance(s, torch.SymInt) and isinstance( @@ -535,3 +653,20 @@ def _to_sympy(x: int | torch.SymInt) -> sympy.Expr: def _has_unbacked(expr: sympy.Expr) -> bool: return any(n.name.startswith("u") for n in expr.free_symbols) # pyright: ignore[reportAttributeAccessIssue] + + +def compute_broadcast_shape_for_tensor_indexers( + shapes: list[list[int | torch.SymInt]], + env: "CompileEnvironment" +) -> list[int | torch.SymInt]: + """Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.""" + if not shapes: + return [] + + max_ndim = max(len(s) for s in shapes) + padded = [([1] * (max_ndim - len(s)) + s) for s in shapes] + + return [ + next((d for d in dims if env.size_hint(d) != 1), 1) + for dims in zip(*padded, strict=True) + ] diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index f26cc333f..f80fede46 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -4,7 +4,9 @@ import collections import dataclasses from typing import TYPE_CHECKING +from typing import Any from typing import NamedTuple +from typing import cast import sympy import torch @@ -447,6 +449,24 @@ def codegen_store( ) +@dataclasses.dataclass +class _TensorIndexContext: + tensors: list[torch.Tensor] + shapes: list[list[int | torch.SymInt]] + dims_count: list[int] + broadcast_shape: list[int | torch.SymInt] | None + + @property + def shared_shape(self) -> list[int | torch.SymInt]: + return self.broadcast_shape or [] + + @property + def broadcast_width(self) -> int: + if self.broadcast_shape is not None: + return len(self.broadcast_shape) + return max((len(shape) for shape in self.shapes), default=0) + + class SubscriptIndexing(NamedTuple): index_expr: ast.AST mask_expr: ast.AST @@ -456,6 +476,27 @@ def has_mask(self) -> bool: isinstance(self.mask_expr, ast.Constant) and self.mask_expr.value is None ) + @staticmethod + def _tensor_index_context(index: list[object]) -> _TensorIndexContext: + tensors = [cast(torch.Tensor, k) for k in index if isinstance(k, torch.Tensor)] + return SubscriptIndexing._build_tensor_context(tensors) + + @staticmethod + def _build_tensor_context(tensors: list[torch.Tensor]) -> _TensorIndexContext: + env = CompileEnvironment.current() + shapes = [list(t.size()) for t in tensors] + dims_count = [] + for dims in shapes: + non_bcast = [d for d in dims if env.size_hint(d) != 1] + dims_count.append(len(dims) if len(non_bcast) > 1 else 1) + broadcast_shape = env.tensor_indexer_broadcast_shape(tensors) + return _TensorIndexContext( + tensors=tensors, + shapes=shapes, + dims_count=dims_count, + broadcast_shape=broadcast_shape, + ) + @staticmethod def compute_shape( tensor: torch.Tensor, index: list[object] @@ -463,8 +504,13 @@ def compute_shape( assert isinstance(tensor, torch.Tensor) assert isinstance(index, (list, tuple)), index input_size = collections.deque(tensor.size()) - output_size = [] + output_size: list[int | torch.SymInt] = [] env = CompileEnvironment.current() + ctx = SubscriptIndexing._tensor_index_context(index) + + use_broadcast_once = ctx.broadcast_shape is not None + shared_shape = ctx.shared_shape + added_broadcast_shape = False for k in index: if k is None: output_size.append(1) @@ -482,19 +528,21 @@ def compute_shape( output_size.append(1) elif isinstance(k, slice): size = input_size.popleft() - # Handle slices with steps slice_size = compute_slice_size(k, size) - if slice_size != 1: rdim = env.allocate_reduction_dimension(slice_size) output_size.append(rdim.var) else: output_size.append(1) - elif isinstance(k, torch.Tensor) and ( - k.ndim == 1 or (len(index) == 1 and tensor.ndim == 1) - ): - input_size.popleft() - output_size.extend(k.size()) + elif isinstance(k, torch.Tensor): + if use_broadcast_once: + input_size.popleft() + if not added_broadcast_shape: + output_size.extend(shared_shape) + added_broadcast_shape = True + else: + base_dim = input_size.popleft() + output_size.extend(env.get_indexer_output_dims(k, base_dim)) else: raise exc.InvalidIndexingType(k) assert len(input_size) == 0, "invalid subscript" @@ -514,6 +562,11 @@ def create( output_size = SubscriptIndexing.compute_shape(fake_value, index) env = CompileEnvironment.current() dtype = env.triton_index_type() + ctx = SubscriptIndexing._tensor_index_context(index) + broadcast_width = ctx.broadcast_width + first_tensor_start: int | None = None + tensor_seen = 0 + for n, k in enumerate(index): if k is None: output_idx += 1 @@ -573,34 +626,38 @@ def create( else: index_values.append(f"tl.zeros([1], {dtype}){expand}") output_idx += 1 - elif isinstance(k, torch.Tensor) and k.ndim == 1: - expand = tile_strategy.expand_str(output_size, output_idx) - ast_index = state.ast_args[1] - assert isinstance(ast_index, (list, tuple)) - assert len(ast_index) == len(index) - index_var = state.codegen.lift(ast_index[n], prefix="index").id - index_values.append(f"({index_var}){expand}") - if (block_idx := env.get_block_id(output_size[output_idx])) is not None: - if mask := state.codegen.mask_var(block_idx): - mask_values.setdefault(f"({mask}){expand}") - output_idx += 1 - elif ( - isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1 - ): - # TODO(jansel): combine this case with the above - ast_index = state.ast_args[1] - assert isinstance(ast_index, (list, tuple)) - assert len(ast_index) == 1 - index_var = state.codegen.lift(ast_index[0], prefix="index").id - index_values.append(index_var) - output_idx += k.ndim - for n, s in enumerate(output_size): - if (block_idx := env.get_block_id(s)) is not None and ( - mask := state.codegen.mask_var(block_idx) - ): - mask_values.setdefault( - f"({mask}){tile_strategy.expand_str(output_size, n)}" - ) + elif isinstance(k, torch.Tensor): + # Determine this tensor indexer's behavior + k_shape = ctx.shapes[tensor_seen] + dims_count = ctx.dims_count[tensor_seen] + # Right-align within the shared broadcast region when present + right_aligned_offset = max(0, broadcast_width - len(k_shape)) + if first_tensor_start is None: + start_pos = output_idx + right_aligned_offset + first_tensor_start = output_idx + output_idx += broadcast_width + else: + # Subsequent tensor indexers: align to the shared region at offset + if dims_count == 1: + non_one_positions = [ + i for i, d in enumerate(k_shape) if env.size_hint(d) != 1 + ] + rel = non_one_positions[0] if non_one_positions else (len(k_shape) - 1) + start_pos = first_tensor_start + right_aligned_offset + rel + else: + start_pos = first_tensor_start + right_aligned_offset + + # Clamp start_pos to valid range + if output_size: + start_pos = max(0, min(start_pos, len(output_size) - 1)) + else: + start_pos = 0 + + SubscriptIndexing._emit_tensor_indexer( + k, n, k_shape, dims_count, output_size, start_pos, + index, state, tile_strategy, index_values, mask_values, env, + ) + tensor_seen += 1 else: raise exc.InvalidIndexingType(type(k)) assert len(output_size) == output_idx @@ -618,10 +675,94 @@ def create( if extra_mask is not None: mask_values.setdefault("{_extra_mask}") kwargs["_extra_mask"] = extra_mask + return SubscriptIndexing( expr_from_string("+".join(index_expr)), expr_from_string("&".join(mask_values) or "None", **kwargs), ) + + @staticmethod + def _emit_tensor_indexer( + k: torch.Tensor, + n: int, + k_shape: list[int | torch.SymInt], + dims_count: int, + output_size: list[int | torch.SymInt], + start_pos: int, + index: list[object], + state: CodegenState, + tile_strategy: Any, + index_values: list[str], + mask_values: dict[str, None], + env: CompileEnvironment, + ) -> None: + ast_index = state.ast_args[1] + assert isinstance(ast_index, (list, tuple)) + assert len(ast_index) == len(index) + + available = max(0, len(output_size) - start_pos) + width = 1 if dims_count == 1 else min(k.ndim, available) + if width <= 0: + lifted = state.codegen.lift(ast_index[n], prefix="index").id + index_values.append(f"({lifted})") + return + + tile_origin_block_id = env.get_tile_index_tensor_block_id(k) + + if width == 1: + expand = tile_strategy.expand_str(output_size, start_pos) + if tile_origin_block_id is not None: + index_var = state.codegen.index_var(tile_origin_block_id) + index_values.append(f"({index_var}){expand}") + if (mask := state.codegen.mask_var(tile_origin_block_id)) is not None: + mask_values.setdefault(f"({mask}){expand}") + return + + lifted = state.codegen.lift(ast_index[n], prefix="index").id + index_values.append(f"({lifted}){expand}") + output_block_id = env.get_block_id(output_size[start_pos]) + if output_block_id is not None: + if mask := state.codegen.mask_var(output_block_id): + mask_values.setdefault(f"({mask}){expand}") + return + + # Multi-dimensional tensor indexer path + index_var = state.codegen.lift(ast_index[n], prefix="index").id + positions = [start_pos + d for d in range(width)] + bracket = SubscriptIndexing._merge_expand_bracket(tile_strategy, output_size, positions) + index_values.append(f"({index_var}){bracket}") + + for pos in positions: + block_idx = env.get_block_id(output_size[pos]) + if block_idx is not None: + if mask := state.codegen.mask_var(block_idx): + expand = tile_strategy.expand_str(output_size, pos) + mask_values.setdefault(f"({mask}){expand}") + + @staticmethod + def _merge_expand_bracket( + tile_strategy: Any, + output_size: list[int | torch.SymInt], + positions: list[int], + ) -> str: + tokens: list[str] | None = None + for pos in positions: + expand = tile_strategy.expand_str(output_size, pos) + if expand == "": + current = [":"] + else: + assert expand.startswith("[") and expand.endswith("]"), expand + current = expand[1:-1].split(", ") if len(expand) > 2 else [] + if tokens is None: + tokens = current + elif current: + tokens = [ + ":" if (a == ":" or b == ":") else "None" + for a, b in zip(tokens, current, strict=True) + ] + if not tokens or all(t == ":" for t in tokens): + return "" + return f"[{', '.join(tokens)}]" @dataclasses.dataclass diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 0dcbadfcb..ff0152766 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -483,7 +483,7 @@ def _check_block_broadcast_compatibility(self, node: torch.fx.Node) -> None: We right-align shapes and then, per-dimension, verify that there aren't two distinct non-1 symbolic sizes that are not known-equal. This is more - robust than relying solely on block-id provenance and works even if + robust than relying solely on block-id origin tracking and works even if upstream rewrites introduced fresh symbolic expressions. """ env = CompileEnvironment.current() @@ -522,32 +522,25 @@ def is_one(x: int | torch.SymInt) -> bool: # Check each dimension independently for dim in range(max_rank): - # First, see if multiple distinct block-ids appear in this dim - block_ids: set[int] = set() - for s in shapes: - size_i = s[dim] - if is_one(size_i): - continue - block_id = env.get_block_id(size_i) - if block_id is not None: - block_ids.add(block_id) + non_one_sizes = [s[dim] for s in shapes if not is_one(s[dim])] + if len(non_one_sizes) <= 1: + continue + + # Check block_ids first - different tile loops cannot broadcast + block_ids = { + block_id + for sz in non_one_sizes + if (block_id := env.get_block_id(sz)) is not None + } if len(block_ids) >= 2: raise exc.ShapeMismatch( str(shapes[0]), ", ".join(map(str, shapes[1:])), ) - # Otherwise, fall back to strict symbolic inequality among non-1 sizes - exprs: set[object] = set() - for s in shapes: - size_i = s[dim] - if is_one(size_i): - continue - if isinstance(size_i, torch.SymInt): - exprs.add(size_i._sympy_()) - else: - exprs.add(size_i) - if len(exprs) >= 2: + # Check symbolic equality + base = non_one_sizes[0] + if not all(env.known_equal(base, sz) for sz in non_one_sizes[1:]): raise exc.ShapeMismatch( str(shapes[0]), ", ".join(map(str, shapes[1:])), @@ -1008,6 +1001,8 @@ def codegen_expand(ctx: GraphInterpreter, node: torch.fx.Node) -> object: val = node.meta["val"] assert isinstance(val, torch.Tensor) shape = [*val.size()] + # Prepend None-indexing to match target rank for tl.broadcast_to. + # Triton requires input rank to match output rank for broadcast_to. if node.args[0].meta["val"].ndim != len(shape): # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] broadcasting = [":"] * len(shape) for i in range(len(shape) - node.args[0].meta["val"].ndim): # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess] diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index e0e1aca5c..e4448c640 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -38,6 +38,7 @@ from .compile_environment import LoopSpecBlockSizeSource from .compile_environment import warning from .host_function import HostFunction +from .indexing_strategy import SubscriptIndexing from .host_function import SymbolOrigin from .output_header import library_imports from .source_location import current_location @@ -136,11 +137,10 @@ def set(self, name: str, type_info: TypeInfo) -> None: # pyright: ignore[report def merge(self, other: LocalScope | dict[str, TypeInfo]) -> LocalScope: if isinstance(other, LocalScope): other = other.variables + for k, v in other.items(): if k in self.variables: - existing = self.variables[k] - merged = existing.merge(v, var_name=k) - self.variables[k] = merged + self.variables[k] = self.variables[k].merge(v, var_name=k) else: self.variables[k] = v return self @@ -426,8 +426,13 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]: else: keys = [key] inputs_consumed = 0 - output_sizes = [] + output_sizes: list[int | torch.SymInt] = [] env = CompileEnvironment.current() + tensor_indexers = [cast(TensorType, k).fake_value for k in keys if isinstance(k, TensorType)] + ctx = SubscriptIndexing._build_tensor_context(tensor_indexers) + use_broadcast_once = ctx.broadcast_shape is not None + shared_shape: list[int | torch.SymInt] = ctx.shared_shape + added_broadcast_shape = False for k in keys: if isinstance(k, LiteralType): if isinstance(k.value, (int, torch.SymInt)): @@ -459,9 +464,19 @@ def _device_indexing_size(self, key: TypeInfo) -> list[int | torch.SymInt]: elif isinstance(k, TileIndexType): inputs_consumed += 1 output_sizes.append(env.block_sizes[k.block_id].var) - elif isinstance(k, TensorType) and k.fake_value.ndim == 1: - inputs_consumed += 1 - output_sizes.append(k.fake_value.size(0)) + elif isinstance(k, TensorType): + if use_broadcast_once: + inputs_consumed += 1 + if not added_broadcast_shape: + output_sizes.extend(shared_shape) + added_broadcast_shape = True + else: + # tile.index-only case: treat each indexer as contributing its own dim + base_dim_size = self.fake_value.size(inputs_consumed) + inputs_consumed += 1 + output_sizes.extend( + env.get_indexer_output_dims(k.fake_value, base_dim_size) + ) elif k.contains_type(TileIndexType): raise exc.OverpackedTile(k) else: @@ -506,9 +521,11 @@ def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo: raise exc.TypeInferenceError( f"Subscript not supported on {self!s} with key={key!s}" ) from None - return TensorType( - origin, self.fake_value.new_empty(self._device_indexing_size(key)) - ) + new_sizes = self._device_indexing_size(key) + env = CompileEnvironment.current() + new_fake = env.new_index_result(self.fake_value, new_sizes) + + return TensorType(origin, new_fake) def merge(self, other: TypeInfo, var_name: str | None = None) -> TypeInfo: if isinstance(other, TensorType): diff --git a/helion/language/memory_ops.py b/helion/language/memory_ops.py index 5cf28e6df..a3d13f62f 100644 --- a/helion/language/memory_ops.py +++ b/helion/language/memory_ops.py @@ -187,7 +187,10 @@ def _( ) -> torch.Tensor: if isinstance(tensor, torch.Tensor): target_shape = SubscriptIndexing.compute_shape(tensor, index) - return tensor.new_empty(target_shape) + from .._compiler.compile_environment import CompileEnvironment + + env = CompileEnvironment.current() + return env.new_index_result(tensor, target_shape) if isinstance(tensor, tuple): tensor_like, dev_ptrs = tensor assert isinstance(tensor_like, torch.Tensor) @@ -207,6 +210,17 @@ def _(state: CodegenState) -> ast.AST: assert isinstance(extra_mask, (type(None), ast.AST)) if isinstance(tensor, torch.Tensor): + # Fast-path for tile_index(...) being broadcast-only indexed + from ..language import tile_index + tensor_node = state.fx_node.args[0] + if ( + isinstance(tensor_node, torch.fx.Node) + and tensor_node.op == "call_function" + and tensor_node.target == tile_index + and all(idx is None or isinstance(idx, slice) for idx in subscript) + ): + return state.ast_args[0] + return state.device_function.indexing_strategy.codegen_load( state, tensor, [*subscript], extra_mask ) diff --git a/helion/language/tile_ops.py b/helion/language/tile_ops.py index fc3abc1b7..6a6fd92dc 100644 --- a/helion/language/tile_ops.py +++ b/helion/language/tile_ops.py @@ -39,8 +39,11 @@ def arange(length: int, device: torch.device) -> torch.Tensor: def _(tile: torch.SymInt) -> torch.Tensor: assert isinstance(tile, torch.SymInt) env = CompileEnvironment.current() - assert env.get_block_id(tile) is not None - return torch.empty([tile], dtype=env.settings.index_dtype, device=env.device) + block_id = env.get_block_id(tile) + assert block_id is not None + t = torch.empty([tile], dtype=env.settings.index_dtype, device=env.device) + env.register_tile_index_tensor_block_id(t, block_id) + return t @_decorators.codegen(tile_index) diff --git a/helion/language/view_ops.py b/helion/language/view_ops.py index ef17082ad..f66866103 100644 --- a/helion/language/view_ops.py +++ b/helion/language/view_ops.py @@ -78,12 +78,15 @@ def _(tensor: torch.Tensor, index: list[object]) -> torch.Tensor: for val in index: if val is None: output_size.append(1) - elif isinstance(val, slice) and repr(val) == "slice(None, None, None)": + elif isinstance(val, slice) and val == slice(None): output_size.append(input_size.popleft()) else: raise exc.InvalidIndexingType(repr(val)) assert len(input_size) == 0 - return tensor.new_empty(output_size) + from .._compiler.compile_environment import CompileEnvironment + + env = CompileEnvironment.current() + return env.new_index_result(tensor, output_size) @_decorators.codegen(subscript) @@ -92,7 +95,7 @@ def _(state: CodegenState) -> ast.AST: for val in state.proxy_arg(1): # pyright: ignore[reportGeneralTypeIssues] if val is None: output_keys.append("None") - elif isinstance(val, slice) and repr(val) == "slice(None, None, None)": + elif isinstance(val, slice) and val == slice(None): output_keys.append(":") else: raise exc.InvalidIndexingType(repr(val)) diff --git a/test/test_indexing.expected b/test/test_indexing.expected index f87c7dc41..c8e4880ee 100644 --- a/test/test_indexing.expected +++ b/test/test_indexing.expected @@ -185,6 +185,116 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor, _launcher(_helion_broadcast_add_3d, (triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),), x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) return out +--- assertExpectedJournal(TestIndexing.test_indirect_indexing_2d) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_test(col, B_flat, val, C, B_flat_stride_0, C_stride_0, C_stride_1, col_stride_0, col_stride_1, val_stride_0, val_stride_1, M, N, K, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < M + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < N + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32) + for offset_3 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_2): + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_3 < K + acc_copy = acc + acc_copy_0 = acc_copy + cols_2d = tl.load(col + (indices_0[:, None] * col_stride_0 + indices_3[None, :] * col_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + v_0 = cols_2d * N + subscript = v_0[:, :, None] + v_1 = tl.cast(indices_1, tl.int64) + v_2 = subscript + v_1 + B_slice = tl.load(B_flat + v_2 * B_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :], other=0) + vals_2d = tl.load(val + (indices_0[:, None] * val_stride_0 + indices_3[None, :] * val_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + subscript_1 = vals_2d[:, :, None] + v_3 = subscript_1 * B_slice + contrib_1 = tl.cast(tl.sum(v_3, 1), tl.float32) + acc = acc_copy_0 + contrib_1 + tl.store(C + (indices_0[:, None] * C_stride_0 + indices_1[None, :] * C_stride_1), acc, mask_0[:, None] & mask_1[None, :]) + +def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher): + M, K = col.shape + _, N = B.shape + out_dtype = torch.promote_types(val.dtype, B.dtype) + C = torch.empty((M, N), dtype=out_dtype, device=B.device) + B_flat = B.reshape(-1) + _BLOCK_SIZE_0 = 8 + _BLOCK_SIZE_1 = 8 + _BLOCK_SIZE_2 = 4 + _RDIM_SIZE_3 = triton.next_power_of_2(_BLOCK_SIZE_1) + _launcher(_helion_test, (triton.cdiv(M, _BLOCK_SIZE_0) * triton.cdiv(N, _BLOCK_SIZE_1),), col, B_flat, val, C, B_flat.stride(0), C.stride(0), C.stride(1), col.stride(0), col.stride(1), val.stride(0), val.stride(1), M, N, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return C + +--- assertExpectedJournal(TestIndexing.test_indirect_indexing_3d) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_test(col, B, val, C, B_stride_0, B_stride_1, B_stride_2, C_stride_0, C_stride_1, C_stride_2, C_stride_3, col_stride_0, col_stride_1, col_stride_2, val_stride_0, val_stride_1, val_stride_2, M, N, P, Q, K, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr): + num_blocks_0 = tl.cdiv(M, _BLOCK_SIZE_0) + num_blocks_1 = tl.cdiv(N, _BLOCK_SIZE_1) + num_blocks_2 = tl.cdiv(P, _BLOCK_SIZE_2) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1 + pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1) % num_blocks_2 + pid_3 = tl.program_id(0) // (num_blocks_0 * num_blocks_1 * num_blocks_2) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < M + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < N + offset_2 = pid_2 * _BLOCK_SIZE_2 + indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32) + mask_2 = indices_2 < P + offset_3 = pid_3 * _BLOCK_SIZE_3 + indices_3 = (offset_3 + tl.arange(0, _BLOCK_SIZE_3)).to(tl.int32) + mask_3 = indices_3 < Q + acc = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3], 0.0, tl.float32) + for offset_5 in tl.range(0, K.to(tl.int32), _BLOCK_SIZE_4): + indices_5 = offset_5 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32) + mask_4 = indices_5 < K + acc_copy = acc + acc_copy_0 = acc_copy + cols_3d = tl.load(col + (indices_0[:, None, None] * col_stride_0 + indices_1[None, :, None] * col_stride_1 + indices_5[None, None, :] * col_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_4[None, None, :], other=0) + subscript = cols_3d[:, :, :, None, None] + B_slice = tl.load(B + (subscript * B_stride_0 + indices_2[None, None, None, :, None] * B_stride_1 + indices_3[None, None, None, None, :] * B_stride_2), mask_0[:, None, None, None, None] & mask_1[None, :, None, None, None] & mask_4[None, None, :, None, None] & mask_2[None, None, None, :, None] & mask_3[None, None, None, None, :], other=0) + vals_3d = tl.load(val + (indices_0[:, None, None] * val_stride_0 + indices_1[None, :, None] * val_stride_1 + indices_5[None, None, :] * val_stride_2), mask_0[:, None, None] & mask_1[None, :, None] & mask_4[None, None, :], other=0) + subscript_1 = vals_3d[:, :, :, None, None] + v_0 = subscript_1 * B_slice + contrib_1 = tl.cast(tl.sum(v_0, 2), tl.float32) + acc = acc_copy_0 + contrib_1 + tl.store(C + (indices_0[:, None, None, None] * C_stride_0 + indices_1[None, :, None, None] * C_stride_1 + indices_2[None, None, :, None] * C_stride_2 + indices_3[None, None, None, :] * C_stride_3), acc, mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None] & mask_3[None, None, None, :]) + +def test(col: torch.Tensor, val: torch.Tensor, B: torch.Tensor, *, _launcher=_default_launcher): + M, N, K = col.shape + _, P, Q = B.shape + out_dtype = torch.promote_types(val.dtype, B.dtype) + C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device) + _BLOCK_SIZE_0 = 4 + _BLOCK_SIZE_1 = 4 + _BLOCK_SIZE_2 = 4 + _BLOCK_SIZE_3 = 4 + _BLOCK_SIZE_4 = 4 + _RDIM_SIZE_5 = triton.next_power_of_2(_BLOCK_SIZE_2) + _launcher(_helion_test, (triton.cdiv(M, _BLOCK_SIZE_0) * triton.cdiv(N, _BLOCK_SIZE_1) * triton.cdiv(P, _BLOCK_SIZE_2) * triton.cdiv(Q, _BLOCK_SIZE_3),), col, B, val, C, B.stride(0), B.stride(1), B.stride(2), C.stride(0), C.stride(1), C.stride(2), C.stride(3), col.stride(0), col.stride(1), col.stride(2), val.stride(0), val.stride(1), val.stride(2), M, N, P, Q, K, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3) + return C + --- assertExpectedJournal(TestIndexing.test_mask_load) from __future__ import annotations diff --git a/test/test_indexing.py b/test/test_indexing.py index 60522e269..569987961 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -988,6 +988,130 @@ def kernel( torch.testing.assert_close(src_result, expected_src) torch.testing.assert_close(dst_result, expected_dst) + def test_indirect_indexing_2d(self): + @helion.kernel() + def test( + col: torch.Tensor, # [M, K] int64 + val: torch.Tensor, # [M, K] fp32 + B: torch.Tensor, # [K, N] fp32 + ) -> torch.Tensor: # [M, N] fp32 + M, K = col.shape + _, N = B.shape + out_dtype = torch.promote_types(val.dtype, B.dtype) + C = torch.empty((M, N), dtype=out_dtype, device=B.device) + B_flat = B.reshape(-1) # [K*N] + + for tile_m, tile_n in hl.tile([M, N]): + # [tile_m, tile_n] + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + + for tile_k in hl.tile(K): + # [tile_m, tile_k] + cols_2d = col[tile_m, tile_k] + # [tile_m, tile_k, tile_n] + B_slice = hl.load( + B_flat, + [(cols_2d * N)[:, :, None] + tile_n.index[None, None, :]] + ) + # [tile_m, tile_k] + vals_2d = val[tile_m, tile_k] + # [tile_m, tile_k, tile_n] + contrib = vals_2d[:, :, None] * B_slice + # [tile_m, tile_n] + contrib = contrib.sum(dim=1) + # [tile_m, tile_n] + acc = acc + contrib + + C[tile_m, tile_n] = acc.to(out_dtype) + + return C + + M, K, N = 32, 16, 24 + col = torch.randint(0, K, (M, K), device=DEVICE, dtype=torch.int64) + val = torch.rand((M, K), device=DEVICE, dtype=torch.float32) + B = torch.rand((K, N), device=DEVICE, dtype=torch.float32) + + code, result = code_and_output( + test, + (col, val, B), + block_size=[8, 8, 4], + ) + + # For each output position (i,j), compute sum over k: val[i,k] * B[col[i,k], j] + expected = torch.zeros((M, N), device=DEVICE, dtype=torch.float32) + for i in range(M): + for j in range(N): + for k in range(K): + expected[i, j] += val[i, k] * B[col[i, k], j] + + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + self.assertExpectedJournal(code) + + def test_indirect_indexing_3d(self): + @helion.kernel() + def test( + col: torch.Tensor, # [M, N, K] int64 - indices for first dimension of B + val: torch.Tensor, # [M, N, K] fp32 - values to multiply + B: torch.Tensor, # [K, P, Q] fp32 - tensor to index into + ) -> torch.Tensor: # [M, N, P, Q] fp32 + M, N, K = col.shape + _, P, Q = B.shape + out_dtype = torch.promote_types(val.dtype, B.dtype) + C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device) + + for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]): + # [tile_m, tile_n, tile_p, tile_q] + acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32) + + for tile_k in hl.tile(K): + # [tile_m, tile_n, tile_k] + cols_3d = col[tile_m, tile_n, tile_k] + + # [tile_m, tile_n, tile_k, tile_p, tile_q] + # Direct indexing into B using gather + B_slice = B[ + cols_3d[:, :, :, None, None], + tile_p.index[None, None, :, None], + tile_q.index[None, None, None, :], + ] + + # [tile_m, tile_n, tile_k] + vals_3d = val[tile_m, tile_n, tile_k] + + # [tile_m, tile_n, tile_k, tile_p, tile_q] + contrib = vals_3d[:, :, :, None, None] * B_slice + + # [tile_m, tile_n, tile_p, tile_q] - sum over k dimension + contrib = contrib.sum(dim=2) + + # [tile_m, tile_n, tile_p, tile_q] + acc = acc + contrib + + C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype) + return C + + M, N, K, P, Q = 16, 12, 8, 10, 14 + col = torch.randint(0, K, (M, N, K), device=DEVICE, dtype=torch.int64) + val = torch.rand((M, N, K), device=DEVICE, dtype=torch.float32) + B = torch.rand((K, P, Q), device=DEVICE, dtype=torch.float32) + + code, result = code_and_output( + test, + (col, val, B), + block_size=[4, 4, 4, 4, 4], # 5D tiling for M, N, P, Q, K + ) + + # For each output position (i,j,p,q), compute sum over k: val[i,j,k] * B[col[i,j,k], p, q] + expected = torch.zeros((M, N, P, Q), device=DEVICE, dtype=torch.float32) + for i in range(M): + for j in range(N): + for p in range(P): + for q in range(Q): + for k in range(K): + expected[i, j, p, q] += val[i, j, k] * B[col[i, j, k], p, q] + + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) + self.assertExpectedJournal(code) if __name__ == "__main__": unittest.main()