@@ -575,6 +575,10 @@ def compute_shape(
575575 input_size = collections .deque (tensor .size ())
576576 output_size = []
577577 env = CompileEnvironment .current ()
578+
579+ tensor_indexers = [k for k in index if isinstance (k , torch .Tensor )]
580+ should_broadcast = env .should_broadcast_tensor_indexers (tensor_indexers )
581+
578582 k_index = 0
579583 for k in index :
580584 if k is None :
@@ -617,11 +621,14 @@ def compute_shape(
617621 else :
618622 output_size .append (1 )
619623 k_index += 1
620- elif isinstance (k , torch .Tensor ) and (
621- k .ndim == 1 or (len (index ) == 1 and tensor .ndim == 1 )
622- ):
623- input_size .popleft ()
624- output_size .extend (k .size ())
624+ elif isinstance (k , torch .Tensor ):
625+ base_dim = input_size .popleft ()
626+ if not should_broadcast :
627+ output_size .extend (env .tensor_indexer_dims (k ))
628+ elif k is tensor_indexers [0 ]:
629+ output_size .extend (
630+ env .tensor_indexer_broadcast_shape (tensor_indexers )
631+ )
625632 k_index += 1
626633 else :
627634 raise exc .InvalidIndexingType (k )
@@ -667,13 +674,99 @@ def create(
667674 output_size = SubscriptIndexing .compute_shape (fake_value , index , state )
668675 env = CompileEnvironment .current ()
669676 dtype = env .triton_index_type ()
677+ tensor_indexers = [k for k in index if isinstance (k , torch .Tensor )]
678+ should_broadcast = env .should_broadcast_tensor_indexers (tensor_indexers )
679+ broadcast_dims = 0
680+ if should_broadcast :
681+ broadcast_dims = len (env .tensor_indexer_broadcast_shape (tensor_indexers ))
682+ is_cartesian = (
683+ broadcast_dims >= 2
684+ and len (tensor_indexers ) == broadcast_dims
685+ and all (
686+ t .ndim == 1
687+ or sum (1 for d in t .size () if env .size_hint (d ) != 1 ) <= 1
688+ for t in tensor_indexers
689+ )
690+ )
670691 if dtype == "tl.int32" and SubscriptIndexing ._needs_int64 (fake_value ):
671692 raise exc .IndexOffsetOutOfRangeForInt32 (env .index_dtype )
672693
673694 def _is_size_one (size : int | torch .SymInt ) -> bool :
674695 return env .known_equal (size , 1 )
675696
676697 k_index = 0
698+
699+ def tensor_index_source_and_mask (
700+ index_elem : torch .Tensor , index_var : str , pos : int
701+ ) -> tuple [str , int | None ]:
702+ tile_id = env .get_tile_index_tensor_block_id (index_elem )
703+ src = state .codegen .index_var (tile_id ) if tile_id else index_var
704+ mask_id = tile_id or (
705+ env .get_block_id (output_size [pos ]) if pos < len (output_size ) else None
706+ )
707+ return src , mask_id
708+
709+ def handle_broadcast_tensor (
710+ position : int ,
711+ index_elem : torch .Tensor ,
712+ index_var : str ,
713+ cur_output_idx : int ,
714+ ) -> tuple [str , dict [str , None ]]:
715+ """Handle tensor index with broadcast shape (cartesian or general)."""
716+ assert broadcast_dims > 0
717+ tensor_idx = next (
718+ i for i , t in enumerate (tensor_indexers ) if t is index_elem
719+ )
720+ first_tensor_out_idx = (
721+ cur_output_idx if tensor_idx == 0 else cur_output_idx - broadcast_dims
722+ )
723+ non_trivial_output_positions : list [int ] = []
724+ if is_cartesian :
725+ pos = first_tensor_out_idx + tensor_idx
726+ single_output_dim = True
727+ else :
728+ # Find position(s) where this tensor contributes non-trivial dims
729+ offset = max (0 , broadcast_dims - index_elem .ndim )
730+ non_trivial_output_positions = [
731+ first_tensor_out_idx + offset + i
732+ for i in range (index_elem .ndim )
733+ if env .size_hint (index_elem .size (i )) != 1
734+ ]
735+ pos = non_trivial_output_positions [0 ]
736+ single_output_dim = len (non_trivial_output_positions ) <= 1
737+
738+ new_masks : dict [str , None ] = {}
739+ if single_output_dim :
740+ src , _ = tensor_index_source_and_mask (index_elem , index_var , pos )
741+ expand = (
742+ tile_strategy .expand_str (output_size , pos )
743+ if index_elem .ndim == 1
744+ else ""
745+ )
746+ idx_val = f"({ src } ){ expand } "
747+ else :
748+ # Multi-dim tensor with multiple non-trivial dims
749+ idx_val = f"({ index_var } )"
750+ if tensor_idx == 0 :
751+ for p in non_trivial_output_positions :
752+ if (
753+ p < len (output_size )
754+ and (bid := env .get_block_id (output_size [p ]))
755+ and (mv := state .codegen .mask_var (bid ))
756+ and not _is_size_one (fake_value .size (len (index_values )))
757+ ):
758+ new_masks .setdefault (
759+ f"({ mv } ){ tile_strategy .expand_str (output_size , p )} "
760+ )
761+ # Padded iota mask
762+ if (
763+ orig_len := _get_padded_iota_original_length (state , position )
764+ ) is not None :
765+ new_masks .setdefault (
766+ f"(({ index_var } < { orig_len } ){ tile_strategy .expand_str (output_size , first_tensor_out_idx + tensor_idx )} )"
767+ )
768+ return idx_val , new_masks
769+
677770 for n , k in enumerate (index ):
678771 if k is None :
679772 output_idx += 1
@@ -752,40 +845,41 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
752845 index_values .append (f"tl.zeros([1], { dtype } ){ expand } " )
753846 output_idx += 1
754847 k_index += 1
755- elif isinstance (k , torch .Tensor ) and k .ndim == 1 :
756- expand = tile_strategy .expand_str (output_size , output_idx )
848+ elif isinstance (k , torch .Tensor ):
757849 ast_index = state .ast_args [1 ]
758850 assert isinstance (ast_index , (list , tuple ))
759- assert len (ast_index ) == len (index )
760851 index_var = state .codegen .lift (ast_index [n ], prefix = "index" ).id
761- index_values .append (f"({ index_var } ){ expand } " )
762- if (block_idx := env .get_block_id (output_size [output_idx ])) is not None :
763- if mask := state .codegen .mask_var (block_idx ):
764- mask_values .setdefault (f"({ mask } ){ expand } " )
765- # Check if this index comes from a padded hl.arange and generate mask
766- if (
767- original_length := _get_padded_iota_original_length (state , n )
768- ) is not None :
769- mask_values .setdefault (f"({ index_var } < { original_length } ){ expand } " )
770- output_idx += 1
771- k_index += 1
772- elif (
773- isinstance (k , torch .Tensor ) and len (index ) == 1 and fake_value .ndim == 1
774- ):
775- # TODO(jansel): combine this case with the above
776- ast_index = state .ast_args [1 ]
777- assert isinstance (ast_index , (list , tuple ))
778- assert len (ast_index ) == 1
779- index_var = state .codegen .lift (ast_index [0 ], prefix = "index" ).id
780- index_values .append (index_var )
781- output_idx += k .ndim
782- for n , s in enumerate (output_size ):
783- if (block_idx := env .get_block_id (s )) is not None and (
784- mask := state .codegen .mask_var (block_idx )
852+
853+ # Use broadcast handling for: multiple tensors, or single tensor with ndim > 1
854+ if should_broadcast :
855+ idx_val , new_masks = handle_broadcast_tensor (
856+ n , k , index_var , output_idx
857+ )
858+ index_values .append (idx_val )
859+ mask_values .update (new_masks )
860+ if k is tensor_indexers [0 ]:
861+ output_idx += broadcast_dims
862+ k_index += 1
863+ continue
864+
865+ index_source , mask_block_id = tensor_index_source_and_mask (
866+ k , index_var , output_idx
867+ )
868+
869+ expand = (
870+ tile_strategy .expand_str (output_size , output_idx )
871+ if k .ndim < len (output_size )
872+ else ""
873+ )
874+ index_values .append (f"({ index_source } ){ expand } " )
875+ if mask_block_id is not None :
876+ mask_var = state .codegen .mask_var (mask_block_id )
877+ if mask_var and not _is_size_one (
878+ fake_value .size (len (index_values ) - 1 )
785879 ):
786- mask_values .setdefault (
787- f"( { mask } ) { tile_strategy . expand_str ( output_size , n ) } "
788- )
880+ mask_values .setdefault (f"( { mask_var } ) { expand } " )
881+
882+ output_idx += k . ndim
789883 k_index += 1
790884 else :
791885 raise exc .InvalidIndexingType (type (k ))
0 commit comments