@@ -142,6 +142,30 @@ def allocate_reduction_dimension(self, size: torch.SymInt | int) -> BlockSizeInf
142
142
if rdim .reduction and rdim .size == size :
143
143
return rdim
144
144
145
+ # Check if size matches any tile dimension for symbolic equality.
146
+ # When building expressions that mix sizes derived from tiles
147
+ # (e.g., via slicing) with sizes coming directly from tile block vars, we
148
+ # want them to share the same SymInt variable whenever they are equal by
149
+ # construction. This preserves equality in the shape environment and avoids
150
+ # spurious "size mismatch" issues during fake-tensor broadcasting and
151
+ # arithmetic in type propagation.
152
+ if isinstance (size , torch .SymInt ):
153
+ size_str = str (size )
154
+ for block_info in self .block_sizes :
155
+ if not block_info .reduction and str (block_info .var ) == size_str :
156
+ # Create reduction dimension with the same var to preserve
157
+ # symbolic equality and ensure all later users see identical
158
+ # symbols (rather than equal-but-distinct SymInts).
159
+ rdim_idx = self .allocate_block_size (
160
+ size ,
161
+ reduction = True ,
162
+ source = ReductionLoopBlockSizeSource (
163
+ reduction_loop = len ([b for b in self .block_sizes if b .reduction ])
164
+ ),
165
+ )
166
+ self .block_sizes [rdim_idx ].var = block_info .var
167
+ return self .block_sizes [rdim_idx ]
168
+
145
169
# Allocate a new reduction dimension
146
170
rdim_idx = self .allocate_block_size (
147
171
size ,
@@ -203,6 +227,91 @@ def cached_create_unbacked_symint(
203
227
self ._symint_cache [key ] = result
204
228
return result
205
229
230
+
231
+ def register_tile_index_tensor_block_id (self , tensor : torch .Tensor , block_id : int ) -> None :
232
+ """Annotate ``tensor`` as originating from ``tile.index`` with ``block_id`` provenance."""
233
+ tensor ._tile_index_block_id = block_id # type: ignore[attr-defined]
234
+
235
+ def get_tile_index_tensor_block_id (self , tensor : torch .Tensor ) -> int | None :
236
+ """Return the originating ``tile.index`` block id if present."""
237
+ return getattr (tensor , "_tile_index_block_id" , None )
238
+
239
+ def get_indexer_output_dims (
240
+ self ,
241
+ indexer_tensor : torch .Tensor ,
242
+ base_dim_size : int | torch .SymInt | None ,
243
+ ) -> list [int | torch .SymInt ]:
244
+ """Map a tensor indexer's shape to the output dimensions for advanced indexing."""
245
+
246
+ dims = list (indexer_tensor .size ())
247
+ non_broadcast_dims = [d for d in dims if self .size_hint (d ) != 1 ]
248
+
249
+ # Multi-dimensional indexer - return full shape
250
+ if len (non_broadcast_dims ) > 1 :
251
+ return dims
252
+
253
+ block_id = self .get_tile_index_tensor_block_id (indexer_tensor )
254
+ if block_id is None and base_dim_size is not None :
255
+ block_id = self .get_block_id (base_dim_size )
256
+ if block_id is None and non_broadcast_dims :
257
+ block_id = self .get_block_id (non_broadcast_dims [0 ])
258
+
259
+ if block_id is not None :
260
+ return [self .block_sizes [block_id ].var ]
261
+ if non_broadcast_dims :
262
+ return [non_broadcast_dims [0 ]]
263
+ return [1 ]
264
+
265
+ def tensor_indexer_broadcast_shape (
266
+ self , tensors : typing .Sequence [torch .Tensor ]
267
+ ) -> list [int | torch .SymInt ] | None :
268
+ """Compute a shared broadcast shape for tensor indexers when needed."""
269
+
270
+ tensor_list = [t for t in tensors if isinstance (t , torch .Tensor )]
271
+ if not tensor_list :
272
+ return None
273
+
274
+ if all (self .get_tile_index_tensor_block_id (t ) is not None for t in tensor_list ):
275
+ return None
276
+
277
+ shapes = [list (t .size ()) for t in tensor_list ]
278
+ return compute_broadcast_shape_for_tensor_indexers (shapes , self )
279
+
280
+ def resolve_tile_index_shape (
281
+ self , input_tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
282
+ ) -> tuple [list [int | torch .SymInt ], int | None ]:
283
+ """Resolve the symbolic shape for tensors derived from ``tile.index``.
284
+
285
+ Returns a copy of ``output_shape`` where the single non-broadcast
286
+ dimension is replaced with the canonical block-symbol and the associated
287
+ block_id to register on the new tensor. If the tensor is not a tile
288
+ indexer or it introduces more than one non-broadcast dimension, the
289
+ original shape and ``None`` are returned.
290
+ """
291
+
292
+ block_id = self .get_tile_index_tensor_block_id (input_tensor )
293
+ if block_id is None :
294
+ return list (output_shape ), None
295
+
296
+ resolved = list (output_shape )
297
+ non_broadcast = [i for i , s in enumerate (resolved ) if self .size_hint (s ) != 1 ]
298
+ if len (non_broadcast ) <= 1 :
299
+ if non_broadcast :
300
+ resolved [non_broadcast [0 ]] = self .block_sizes [block_id ].var
301
+ return resolved , block_id
302
+ return resolved , None
303
+
304
+ def new_index_result (
305
+ self , tensor : torch .Tensor , output_shape : typing .Sequence [int | torch .SymInt ]
306
+ ) -> torch .Tensor :
307
+ """Create a new tensor for indexing/view ops while preserving tile index provenance."""
308
+
309
+ resolved_shape , block_id = self .resolve_tile_index_shape (tensor , output_shape )
310
+ result = tensor .new_empty (resolved_shape )
311
+ if block_id is not None :
312
+ self .register_tile_index_tensor_block_id (result , block_id )
313
+ return result
314
+
206
315
def to_fake (self , obj : object , origin : Origin ) -> object :
207
316
if isinstance (obj , torch .Tensor ):
208
317
return self ._to_fake_tensor (obj , origin .to_source ())
@@ -283,6 +392,10 @@ def _to_fake_tensor(self, tensor: torch.Tensor, source: Source) -> torch.Tensor:
283
392
self .fake_mode , tensor , shape_env = self .shape_env , source = source
284
393
)
285
394
self .input_sources [result ] = source
395
+ if hasattr (tensor , "_tile_index_block_id" ):
396
+ self .register_tile_index_tensor_block_id (
397
+ result , typing .cast (int , getattr (tensor , "_tile_index_block_id" ))
398
+ )
286
399
if isinstance (source , LocalSource ):
287
400
for i , s in enumerate (result .size ()):
288
401
if isinstance (s , torch .SymInt ) and isinstance (
@@ -357,9 +470,9 @@ def current() -> CompileEnvironment:
357
470
@staticmethod
358
471
def has_current () -> bool :
359
472
try :
360
- CompileEnvironment .current ()
361
- return True
362
- except NoCurrentEnvironment :
473
+ CompileEnvironment .current ()
474
+ return True
475
+ except NoCurrentEnvironment :
363
476
return False
364
477
365
478
def get_block_id (self , size : int | torch .SymInt | sympy .Expr ) -> int | None :
@@ -535,3 +648,35 @@ def _to_sympy(x: int | torch.SymInt) -> sympy.Expr:
535
648
536
649
def _has_unbacked (expr : sympy .Expr ) -> bool :
537
650
return any (n .name .startswith ("u" ) for n in expr .free_symbols ) # pyright: ignore[reportAttributeAccessIssue]
651
+
652
+
653
+ def compute_broadcast_shape_for_tensor_indexers (
654
+ shapes : list [list [int | torch .SymInt ]],
655
+ env : "CompileEnvironment"
656
+ ) -> list [int | torch .SymInt ]:
657
+ """
658
+ Compute broadcast shape for multiple tensor indexers using right-aligned broadcasting.
659
+
660
+ Args:
661
+ shapes: List of shapes from each tensor indexer
662
+ env: CompileEnvironment for size_hint and known_equal checks
663
+
664
+ Returns:
665
+ Broadcast shape as list of dimensions
666
+ """
667
+ if not shapes :
668
+ return []
669
+
670
+ max_ndim = max (len (s ) for s in shapes )
671
+ padded = [([1 ] * (max_ndim - len (s )) + s ) for s in shapes ]
672
+ broadcast_shape : list [int | torch .SymInt ] = []
673
+
674
+ for dims_at_pos in zip (* padded , strict = True ):
675
+ chosen : int | torch .SymInt | None = None
676
+ for d in dims_at_pos :
677
+ if env .size_hint (d ) != 1 :
678
+ if chosen is None or env .known_equal (chosen , d ):
679
+ chosen = d
680
+ broadcast_shape .append (chosen if chosen is not None else 1 )
681
+
682
+ return broadcast_shape
0 commit comments