Skip to content

Commit 2ad0bf3

Browse files
committed
test
fix test
1 parent bbc2be4 commit 2ad0bf3

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed

test/test_indexing.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,130 @@ def kernel(
988988
torch.testing.assert_close(src_result, expected_src)
989989
torch.testing.assert_close(dst_result, expected_dst)
990990

991+
def test_indirect_indexing_2d(self):
992+
@helion.kernel()
993+
def test(
994+
col: torch.Tensor, # [M, K] int64
995+
val: torch.Tensor, # [M, K] fp32
996+
B: torch.Tensor, # [K, N] fp32
997+
) -> torch.Tensor: # [M, N] fp32
998+
M, K = col.shape
999+
_, N = B.shape
1000+
out_dtype = torch.promote_types(val.dtype, B.dtype)
1001+
C = torch.empty((M, N), dtype=out_dtype, device=B.device)
1002+
B_flat = B.reshape(-1) # [K*N]
1003+
1004+
for tile_m, tile_n in hl.tile([M, N]):
1005+
# [tile_m, tile_n]
1006+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
1007+
1008+
for tile_k in hl.tile(K):
1009+
# [tile_m, tile_k]
1010+
cols_2d = col[tile_m, tile_k]
1011+
# [tile_m, tile_k, tile_n]
1012+
B_slice = hl.load(
1013+
B_flat,
1014+
[(cols_2d * N)[:, :, None] + tile_n.index[None, None, :]]
1015+
)
1016+
# [tile_m, tile_k]
1017+
vals_2d = val[tile_m, tile_k]
1018+
# [tile_m, tile_k, tile_n]
1019+
contrib = vals_2d[:, :, None] * B_slice
1020+
# [tile_m, tile_n]
1021+
contrib = contrib.sum(dim=1)
1022+
# [tile_m, tile_n]
1023+
acc = acc + contrib
1024+
1025+
C[tile_m, tile_n] = acc.to(out_dtype)
1026+
1027+
return C
1028+
1029+
M, K, N = 32, 16, 24
1030+
col = torch.randint(0, K, (M, K), device=DEVICE, dtype=torch.int64)
1031+
val = torch.rand((M, K), device=DEVICE, dtype=torch.float32)
1032+
B = torch.rand((K, N), device=DEVICE, dtype=torch.float32)
1033+
1034+
code, result = code_and_output(
1035+
test,
1036+
(col, val, B),
1037+
block_size=[8, 8, 4],
1038+
)
1039+
1040+
# For each output position (i,j), compute sum over k: val[i,k] * B[col[i,k], j]
1041+
expected = torch.zeros((M, N), device=DEVICE, dtype=torch.float32)
1042+
for i in range(M):
1043+
for j in range(N):
1044+
for k in range(K):
1045+
expected[i, j] += val[i, k] * B[col[i, k], j]
1046+
1047+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
1048+
self.assertExpectedJournal(code)
1049+
1050+
def test_indirect_indexing_3d(self):
1051+
@helion.kernel()
1052+
def test(
1053+
col: torch.Tensor, # [M, N, K] int64 - indices for first dimension of B
1054+
val: torch.Tensor, # [M, N, K] fp32 - values to multiply
1055+
B: torch.Tensor, # [K, P, Q] fp32 - tensor to index into
1056+
) -> torch.Tensor: # [M, N, P, Q] fp32
1057+
M, N, K = col.shape
1058+
_, P, Q = B.shape
1059+
out_dtype = torch.promote_types(val.dtype, B.dtype)
1060+
C = torch.empty((M, N, P, Q), dtype=out_dtype, device=B.device)
1061+
1062+
for tile_m, tile_n, tile_p, tile_q in hl.tile([M, N, P, Q]):
1063+
# [tile_m, tile_n, tile_p, tile_q]
1064+
acc = hl.zeros([tile_m, tile_n, tile_p, tile_q], dtype=torch.float32)
1065+
1066+
for tile_k in hl.tile(K):
1067+
# [tile_m, tile_n, tile_k]
1068+
cols_3d = col[tile_m, tile_n, tile_k]
1069+
1070+
# [tile_m, tile_n, tile_k, tile_p, tile_q]
1071+
# Direct indexing into B using gather
1072+
B_slice = B[
1073+
cols_3d[:, :, :, None, None],
1074+
tile_p.index[None, None, :, None],
1075+
tile_q.index[None, None, None, :],
1076+
]
1077+
1078+
# [tile_m, tile_n, tile_k]
1079+
vals_3d = val[tile_m, tile_n, tile_k]
1080+
1081+
# [tile_m, tile_n, tile_k, tile_p, tile_q]
1082+
contrib = vals_3d[:, :, :, None, None] * B_slice
1083+
1084+
# [tile_m, tile_n, tile_p, tile_q] - sum over k dimension
1085+
contrib = contrib.sum(dim=2)
1086+
1087+
# [tile_m, tile_n, tile_p, tile_q]
1088+
acc = acc + contrib
1089+
1090+
C[tile_m, tile_n, tile_p, tile_q] = acc.to(out_dtype)
1091+
return C
1092+
1093+
M, N, K, P, Q = 16, 12, 8, 10, 14
1094+
col = torch.randint(0, K, (M, N, K), device=DEVICE, dtype=torch.int64)
1095+
val = torch.rand((M, N, K), device=DEVICE, dtype=torch.float32)
1096+
B = torch.rand((K, P, Q), device=DEVICE, dtype=torch.float32)
1097+
1098+
code, result = code_and_output(
1099+
test,
1100+
(col, val, B),
1101+
block_size=[4, 4, 4, 4, 4], # 5D tiling for M, N, P, Q, K
1102+
)
1103+
1104+
# For each output position (i,j,p,q), compute sum over k: val[i,j,k] * B[col[i,j,k], p, q]
1105+
expected = torch.zeros((M, N, P, Q), device=DEVICE, dtype=torch.float32)
1106+
for i in range(M):
1107+
for j in range(N):
1108+
for p in range(P):
1109+
for q in range(Q):
1110+
for k in range(K):
1111+
expected[i, j, p, q] += val[i, j, k] * B[col[i, j, k], p, q]
1112+
1113+
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
1114+
self.assertExpectedJournal(code)
9911115

9921116
if __name__ == "__main__":
9931117
unittest.main()

0 commit comments

Comments
 (0)