@@ -988,6 +988,130 @@ def kernel(
988988 torch .testing .assert_close (src_result , expected_src )
989989 torch .testing .assert_close (dst_result , expected_dst )
990990
991+ def test_indirect_indexing_2d (self ):
992+ @helion .kernel ()
993+ def test (
994+ col : torch .Tensor , # [M, K] int64
995+ val : torch .Tensor , # [M, K] fp32
996+ B : torch .Tensor , # [K, N] fp32
997+ ) -> torch .Tensor : # [M, N] fp32
998+ M , K = col .shape
999+ _ , N = B .shape
1000+ out_dtype = torch .promote_types (val .dtype , B .dtype )
1001+ C = torch .empty ((M , N ), dtype = out_dtype , device = B .device )
1002+ B_flat = B .reshape (- 1 ) # [K*N]
1003+
1004+ for tile_m , tile_n in hl .tile ([M , N ]):
1005+ # [tile_m, tile_n]
1006+ acc = hl .zeros ([tile_m , tile_n ], dtype = torch .float32 )
1007+
1008+ for tile_k in hl .tile (K ):
1009+ # [tile_m, tile_k]
1010+ cols_2d = col [tile_m , tile_k ]
1011+ # [tile_m, tile_k, tile_n]
1012+ B_slice = hl .load (
1013+ B_flat ,
1014+ [(cols_2d * N )[:, :, None ] + tile_n .index [None , None , :]]
1015+ )
1016+ # [tile_m, tile_k]
1017+ vals_2d = val [tile_m , tile_k ]
1018+ # [tile_m, tile_k, tile_n]
1019+ contrib = vals_2d [:, :, None ] * B_slice
1020+ # [tile_m, tile_n]
1021+ contrib = contrib .sum (dim = 1 )
1022+ # [tile_m, tile_n]
1023+ acc = acc + contrib
1024+
1025+ C [tile_m , tile_n ] = acc .to (out_dtype )
1026+
1027+ return C
1028+
1029+ M , K , N = 32 , 16 , 24
1030+ col = torch .randint (0 , K , (M , K ), device = DEVICE , dtype = torch .int64 )
1031+ val = torch .rand ((M , K ), device = DEVICE , dtype = torch .float32 )
1032+ B = torch .rand ((K , N ), device = DEVICE , dtype = torch .float32 )
1033+
1034+ code , result = code_and_output (
1035+ test ,
1036+ (col , val , B ),
1037+ block_size = [8 , 8 , 4 ],
1038+ )
1039+
1040+ # For each output position (i,j), compute sum over k: val[i,k] * B[col[i,k], j]
1041+ expected = torch .zeros ((M , N ), device = DEVICE , dtype = torch .float32 )
1042+ for i in range (M ):
1043+ for j in range (N ):
1044+ for k in range (K ):
1045+ expected [i , j ] += val [i , k ] * B [col [i , k ], j ]
1046+
1047+ torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-5 )
1048+ self .assertExpectedJournal (code )
1049+
1050+ def test_indirect_indexing_3d (self ):
1051+ @helion .kernel ()
1052+ def test (
1053+ col : torch .Tensor , # [M, N, K] int64 - indices for first dimension of B
1054+ val : torch .Tensor , # [M, N, K] fp32 - values to multiply
1055+ B : torch .Tensor , # [K, P, Q] fp32 - tensor to index into
1056+ ) -> torch .Tensor : # [M, N, P, Q] fp32
1057+ M , N , K = col .shape
1058+ _ , P , Q = B .shape
1059+ out_dtype = torch .promote_types (val .dtype , B .dtype )
1060+ C = torch .empty ((M , N , P , Q ), dtype = out_dtype , device = B .device )
1061+
1062+ for tile_m , tile_n , tile_p , tile_q in hl .tile ([M , N , P , Q ]):
1063+ # [tile_m, tile_n, tile_p, tile_q]
1064+ acc = hl .zeros ([tile_m , tile_n , tile_p , tile_q ], dtype = torch .float32 )
1065+
1066+ for tile_k in hl .tile (K ):
1067+ # [tile_m, tile_n, tile_k]
1068+ cols_3d = col [tile_m , tile_n , tile_k ]
1069+
1070+ # [tile_m, tile_n, tile_k, tile_p, tile_q]
1071+ # Direct indexing into B using gather
1072+ B_slice = B [
1073+ cols_3d [:, :, :, None , None ],
1074+ tile_p .index [None , None , :, None ],
1075+ tile_q .index [None , None , None , :],
1076+ ]
1077+
1078+ # [tile_m, tile_n, tile_k]
1079+ vals_3d = val [tile_m , tile_n , tile_k ]
1080+
1081+ # [tile_m, tile_n, tile_k, tile_p, tile_q]
1082+ contrib = vals_3d [:, :, :, None , None ] * B_slice
1083+
1084+ # [tile_m, tile_n, tile_p, tile_q] - sum over k dimension
1085+ contrib = contrib .sum (dim = 2 )
1086+
1087+ # [tile_m, tile_n, tile_p, tile_q]
1088+ acc = acc + contrib
1089+
1090+ C [tile_m , tile_n , tile_p , tile_q ] = acc .to (out_dtype )
1091+ return C
1092+
1093+ M , N , K , P , Q = 16 , 12 , 8 , 10 , 14
1094+ col = torch .randint (0 , K , (M , N , K ), device = DEVICE , dtype = torch .int64 )
1095+ val = torch .rand ((M , N , K ), device = DEVICE , dtype = torch .float32 )
1096+ B = torch .rand ((K , P , Q ), device = DEVICE , dtype = torch .float32 )
1097+
1098+ code , result = code_and_output (
1099+ test ,
1100+ (col , val , B ),
1101+ block_size = [4 , 4 , 4 , 4 , 4 ], # 5D tiling for M, N, P, Q, K
1102+ )
1103+
1104+ # For each output position (i,j,p,q), compute sum over k: val[i,j,k] * B[col[i,j,k], p, q]
1105+ expected = torch .zeros ((M , N , P , Q ), device = DEVICE , dtype = torch .float32 )
1106+ for i in range (M ):
1107+ for j in range (N ):
1108+ for p in range (P ):
1109+ for q in range (Q ):
1110+ for k in range (K ):
1111+ expected [i , j , p , q ] += val [i , j , k ] * B [col [i , j , k ], p , q ]
1112+
1113+ torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-5 )
1114+ self .assertExpectedJournal (code )
9911115
9921116if __name__ == "__main__" :
9931117 unittest .main ()
0 commit comments