@@ -29,6 +29,17 @@ void checkIR(StmtPtr s, const std::string& pattern) {
29
29
torch::jit::testing::FileCheck ().run (pattern, oss.str ());
30
30
}
31
31
32
+ void checkExprIR (ExprPtr e, const std::string& pattern) {
33
+ std::string prefixed_pattern = " # CHECK: " + pattern + " \n " ;
34
+ std::ostringstream oss;
35
+ oss << *e << " \n " ;
36
+ torch::jit::testing::FileCheck ().run (prefixed_pattern, oss.str ());
37
+ }
38
+
39
+ void checkExprIR (const ExprHandle& e, const std::string& pattern) {
40
+ checkExprIR (e.node (), pattern);
41
+ }
42
+
32
43
TEST (LoopNest, ExprSimple01) {
33
44
KernelScope kernel_scope;
34
45
Tensor* tensor = Compute (
@@ -1305,7 +1316,7 @@ TEST(LoopNest, ScheduleInlineRandomUnrelated) {
1305
1316
# CHECK: for (int m2 = 0; m2 < 4; m2++)
1306
1317
# CHECK: for (int n2 = 0; n2 < 5; n2++)
1307
1318
# CHECK: for (int k2 = 0; k2 < 6; k2++)
1308
- # CHECK: y[m2, n2, k2] = ((n2 * m2) * k2 + (rand())) + (rand());)IR" );
1319
+ # CHECK: y[m2, n2, k2] = ((k2 * m2) * n2 + (rand())) + (rand());)IR" );
1309
1320
}
1310
1321
1311
1322
// Make sure we generate the right number of random values == the dimensionality
@@ -1710,11 +1721,11 @@ TEST(LoopNest, ScheduleInlineOutputTensors) {
1710
1721
# CHECK: for (int m1 = 0; m1 < 4; m1++)
1711
1722
# CHECK: for (int n1 = 0; n1 < 5; n1++)
1712
1723
# CHECK: for (int k1 = 0; k1 < 6; k1++)
1713
- # CHECK: x[m1, n1, k1] = (n1 * m1) * k1 ;
1724
+ # CHECK: x[m1, n1, k1] = (k1 * m1) * n1 ;
1714
1725
# CHECK: for (int m2 = 0; m2 < 4; m2++)
1715
1726
# CHECK: for (int n2 = 0; n2 < 5; n2++)
1716
1727
# CHECK: for (int k2 = 0; k2 < 6; k2++)
1717
- # CHECK: y[m2, n2, k2] = (n2 * m2) * k2 + m2;)IR" );
1728
+ # CHECK: y[m2, n2, k2] = (k2 * m2) * n2 + m2;)IR" );
1718
1729
}
1719
1730
1720
1731
TEST (LoopNest, ScheduleFuserStyle) {
@@ -2130,7 +2141,7 @@ TEST(LoopNest, Reduce2dComputeAt) {
2130
2141
# CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = int(0);
2131
2142
# CHECK: for (int r = 0; r < 2; r++) {
2132
2143
# CHECK: for (int s = 0; s < 2; s++) {
2133
- # CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = (cons[(0 + cy * (1 * W)) + cx * 1]) + (temp[(0 + r * (1 * (W + 1))) + (s + cx ) * 1]);
2144
+ # CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = (cons[(0 + cy * (1 * W)) + cx * 1]) + (temp[(0 + r * (1 * (W + 1))) + (cx + s ) * 1]);
2134
2145
# CHECK: }
2135
2146
# CHECK: }
2136
2147
# CHECK: }
@@ -3225,7 +3236,7 @@ TEST(LoopNest, NormalizeStartVariable) {
3225
3236
{Store::make (a_buf, {x}, Load::make (kInt , b_buf, {x})),
3226
3237
Store::make (b_buf, {x}, x * 2 )});
3227
3238
auto for_stmt = For::make (x, y, 100 , for_body);
3228
- Block::make ({for_stmt});
3239
+ auto parent_block = Block::make ({for_stmt});
3229
3240
3230
3241
LoopNest::normalize (for_stmt);
3231
3242
@@ -3235,8 +3246,8 @@ TEST(LoopNest, NormalizeStartVariable) {
3235
3246
const std::string& expected_ir =
3236
3247
R"IR(
3237
3248
# CHECK: for (int x = 0; x < 100 - y; x++) {
3238
- # CHECK: A[y + x ] = B[y + x ];
3239
- # CHECK: B[y + x ] = 2 * (y + x );
3249
+ # CHECK: A[x + y ] = B[x + y ];
3250
+ # CHECK: B[x + y ] = 2 * (x + y );
3240
3251
)IR" ;
3241
3252
torch::jit::testing::FileCheck ().run (expected_ir, oss.str ());
3242
3253
}
@@ -3304,7 +3315,7 @@ TEST(LoopNest, NormalizeOnNestedInnerLoop) {
3304
3315
R"IR(
3305
3316
# CHECK: for (int x = 50; x < 100; x++) {
3306
3317
# CHECK: for (int y = 0; y < 90; y++) {
3307
- # CHECK: A[x] = (((B[y + 10]) + 2 * y) + (A[x]) ) + 20;
3318
+ # CHECK: A[x] = (((A[x]) + ( B[y + 10])) + 2 * y) + 20;
3308
3319
)IR" ;
3309
3320
torch::jit::testing::FileCheck ().run (expected_ir, oss.str ());
3310
3321
}
@@ -3327,7 +3338,7 @@ TEST(LoopNest, NormalizeAndSplitWithTail) {
3327
3338
BufHandle a_buf (" A" , {ExprHandle (kTotalSize )}, kInt );
3328
3339
VarHandle x (" x" , kInt );
3329
3340
auto for_stmt = For::make (x, 5 , 10 , Store::make (a_buf, {x}, x * 2 ));
3330
- Block::make ({for_stmt});
3341
+ auto parent_block = Block::make ({for_stmt});
3331
3342
3332
3343
LoopNest::normalize (for_stmt);
3333
3344
@@ -3373,7 +3384,7 @@ TEST(LoopNest, FlattenSimpleLoopNest2D) {
3373
3384
auto for_body = Block::make ({Store::make (a_buf, {i, j}, i * j)});
3374
3385
auto inner_for = For::make (j, 0 , 5 , for_body);
3375
3386
auto outer_for = For::make (i, 0 , 10 , inner_for);
3376
- Block::make ({outer_for});
3387
+ auto parent_block = Block::make ({outer_for});
3377
3388
3378
3389
std::vector<ForPtr> loops = {outer_for, inner_for};
3379
3390
ForPtr flattened = nullptr ;
@@ -3420,7 +3431,7 @@ TEST(LoopNest, FlattenSimpleLoopNest3D) {
3420
3431
auto for1 = For::make (k, 0 , 7 , for_body);
3421
3432
auto for2 = For::make (j, 0 , 5 , for1);
3422
3433
auto for3 = For::make (i, 0 , 10 , for2);
3423
- Block::make ({for3});
3434
+ auto parent_block = Block::make ({for3});
3424
3435
3425
3436
std::vector<ForPtr> loops = {for3, for2, for1};
3426
3437
ForPtr flattened = nullptr ;
@@ -3463,7 +3474,7 @@ TEST(LoopNest, FlattenLoopNestAfterNormalize) {
3463
3474
auto for_body = Block::make ({Store::make (a_buf, {i - 2 , j - 3 }, i * j)});
3464
3475
auto inner_for = For::make (j, 3 , 15 , for_body);
3465
3476
auto outer_for = For::make (i, 2 , 10 , inner_for);
3466
- Block::make ({outer_for});
3477
+ auto parent_block = Block::make ({outer_for});
3467
3478
3468
3479
std::vector<ForPtr> loops = {outer_for, inner_for};
3469
3480
ForPtr flattened = nullptr ;
@@ -3712,7 +3723,7 @@ TEST(LoopNest, CacheReadsSimple) {
3712
3723
#CHECK: A_local[j_1] = A[
3713
3724
#CHECK: }
3714
3725
#CHECK: for (int j_2
3715
- #CHECK: B[10 * i_1 + j_2 ] = A_local[j_2];
3726
+ #CHECK: B[j_2 + 10 * i_1] = A_local[j_2];
3716
3727
#CHECK: }
3717
3728
#CHECK: }
3718
3729
#CHECK: for (int i_2
@@ -3769,7 +3780,7 @@ TEST(LoopNest, CacheReadsOuter) {
3769
3780
checkIR (result, R"IR(
3770
3781
#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11]
3771
3782
#CHECK: A_local[j_1 + 11 * i_1] =
3772
- #CHECK: B[10 * i_2 + j_2 ] = (A_local[( j_2 + 11 * i_2) + 12 ]) + (A_local[j_2 + 11 * i_2]);
3783
+ #CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[( j_2 + 11 * i_2) + 12 ]);
3773
3784
)IR" );
3774
3785
3775
3786
std::vector<int > b_data (200 , 0 );
@@ -3816,7 +3827,7 @@ TEST(LoopNest, CacheReadsInternal) {
3816
3827
checkIR (result, R"IR(
3817
3828
#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11]
3818
3829
#CHECK: A_local[j_1 + 11 * i_2] =
3819
- #CHECK: B[10 * i_1 + j_2 ] = (A_local[j_2 + 12]) + (A_local[j_2]);
3830
+ #CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]);
3820
3831
)IR" );
3821
3832
3822
3833
std::vector<int > b_data (200 , 0 );
@@ -3863,8 +3874,8 @@ TEST(LoopNest, CacheReadsInner) {
3863
3874
3864
3875
checkIR (result, R"IR(
3865
3876
#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2]
3866
- #CHECK: A_local[2 * i_2 + j_2 ] =
3867
- #CHECK: B[10 * i_1 + j_1 ] = (A_local[1]) + (A_local[8]);
3877
+ #CHECK: A_local[j_2 + 2 * i_2] =
3878
+ #CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]);
3868
3879
)IR" );
3869
3880
3870
3881
std::vector<int > b_data (200 , 0 );
@@ -3914,7 +3925,7 @@ TEST(LoopNest, CacheWritesSimple) {
3914
3925
#CHECK: for (int j = 0; j < 64
3915
3926
#CHECK: A_local[j] = i * j;
3916
3927
#CHECK: for (int j_1 = 0; j_1 < 64
3917
- #CHECK: A[64 * i + j_1 ] = A_local[
3928
+ #CHECK: A[j_1 + 64 * i] = A_local[
3918
3929
#CHECK: Free(A_local);
3919
3930
#CHECK-NOT: A_local
3920
3931
)IR" );
0 commit comments