Skip to content

Commit 7fdba45

Browse files
Mikhail Zolotukhinfacebook-github-bot
authored andcommitted
[TensorExpr] IRSimplifier: sort terms in polynomials, terms, minterms, maxterms. (pytorch#63197)
Summary: Pull Request resolved: pytorch#63197 This solves non-determinism from using hash values in sort methods. Changes in tests are mostly mechanical. Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D30292776 Pulled By: ZolotukhinM fbshipit-source-id: 74f57b53c3afc9d4be45715fd74781271373e055
1 parent 8bdd542 commit 7fdba45

File tree

8 files changed

+288
-329
lines changed

8 files changed

+288
-329
lines changed

test/cpp/tensorexpr/test_cuda.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,10 +1575,10 @@ TEST(Cuda, MaskMultiDim_CUDA) {
15751575
const std::string& verification_pattern =
15761576
R"IR(
15771577
# CHECK-NOT: if (
1578-
# CHECK: C[100 * blockIdx.x + threadIdx.x] =
1578+
# CHECK: C[threadIdx.x + 100 * blockIdx.x] =
15791579
# CHECK: __syncthreads();
15801580
# CHECK: if (threadIdx.x<50
1581-
# CHECK: D[50 * blockIdx.x + threadIdx.x] =)IR";
1581+
# CHECK: D[threadIdx.x + 50 * blockIdx.x] =)IR";
15821582

15831583
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
15841584

@@ -1705,10 +1705,10 @@ TEST(Cuda, MaskMultiDimSymbolic_CUDA) {
17051705
const std::string& verification_pattern =
17061706
R"IR(
17071707
# CHECK: if (threadIdx.x<A_SIZE
1708-
# CHECK: C[threadIdx.x + A_SIZE * blockIdx.x] =
1708+
# CHECK: C[A_SIZE * blockIdx.x + threadIdx.x] =
17091709
# CHECK: __syncthreads();
17101710
# CHECK: if (threadIdx.x<B_SIZE
1711-
# CHECK: D[threadIdx.x + B_SIZE * blockIdx.x] =)IR";
1711+
# CHECK: D[B_SIZE * blockIdx.x + threadIdx.x] =)IR";
17121712

17131713
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
17141714

@@ -1852,10 +1852,10 @@ TEST(Cuda, MaskCompoundInnerLoop_CUDA) {
18521852
const std::string& verification_pattern =
18531853
R"IR(
18541854
# CHECK-NOT: if (
1855-
# CHECK: c[100 * blockIdx.x + threadIdx.x] =
1855+
# CHECK: c[threadIdx.x + 100 * blockIdx.x] =
18561856
# CHECK: __syncthreads();
18571857
# CHECK: if (threadIdx.x<50
1858-
# CHECK: d[50 * blockIdx.x + threadIdx.x] =)IR";
1858+
# CHECK: d[threadIdx.x + 50 * blockIdx.x] =)IR";
18591859

18601860
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
18611861

@@ -1991,10 +1991,10 @@ TEST(Cuda, MaskInnerLoopOneBlock_CUDA) {
19911991
R"IR(
19921992
# CHECK: for (int i = 0; i < 10
19931993
# CHECK-NOT: if (
1994-
# CHECK: c[100 * i + threadIdx.x] =
1994+
# CHECK: c[threadIdx.x + 100 * i] =
19951995
# CHECK: __syncthreads();
19961996
# CHECK: if (threadIdx.x<50
1997-
# CHECK: d[50 * i + threadIdx.x] =)IR";
1997+
# CHECK: d[threadIdx.x + 50 * i] =)IR";
19981998

19991999
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
20002000

@@ -2119,7 +2119,7 @@ TEST(Cuda, MaskMultiDimMultiAxis_CUDA) {
21192119
const std::string& verification_pattern =
21202120
R"IR(
21212121
# CHECK: if (threadIdx.y<1
2122-
# CHECK: C[30 * blockIdx.x + threadIdx.x] =
2122+
# CHECK: C[threadIdx.x + 30 * blockIdx.x] =
21232123
# CHECK: __syncthreads();
21242124
# CHECK: if (threadIdx.x<1
21252125
# CHECK: D[threadIdx.y + 15 * blockIdx.x] =)IR";
@@ -2250,7 +2250,7 @@ TEST(Cuda, MaskMultiDimMultiLevel_CUDA) {
22502250
const std::string& verification_pattern =
22512251
R"IR(
22522252
# CHECK-NOT: if (
2253-
# CHECK: C[30 * blockIdx.x + threadIdx.x] =
2253+
# CHECK: C[threadIdx.x + 30 * blockIdx.x] =
22542254
# CHECK: __syncthreads();
22552255
# CHECK: if (blockIdx.x<5
22562256
# CHECK: if (threadIdx.x<15

test/cpp/tensorexpr/test_loopnest.cpp

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ void checkIR(StmtPtr s, const std::string& pattern) {
2929
torch::jit::testing::FileCheck().run(pattern, oss.str());
3030
}
3131

32+
void checkExprIR(ExprPtr e, const std::string& pattern) {
33+
std::string prefixed_pattern = "# CHECK: " + pattern + "\n";
34+
std::ostringstream oss;
35+
oss << *e << "\n";
36+
torch::jit::testing::FileCheck().run(prefixed_pattern, oss.str());
37+
}
38+
39+
void checkExprIR(const ExprHandle& e, const std::string& pattern) {
40+
checkExprIR(e.node(), pattern);
41+
}
42+
3243
TEST(LoopNest, ExprSimple01) {
3344
KernelScope kernel_scope;
3445
Tensor* tensor = Compute(
@@ -1305,7 +1316,7 @@ TEST(LoopNest, ScheduleInlineRandomUnrelated) {
13051316
# CHECK: for (int m2 = 0; m2 < 4; m2++)
13061317
# CHECK: for (int n2 = 0; n2 < 5; n2++)
13071318
# CHECK: for (int k2 = 0; k2 < 6; k2++)
1308-
# CHECK: y[m2, n2, k2] = ((n2 * m2) * k2 + (rand())) + (rand());)IR");
1319+
# CHECK: y[m2, n2, k2] = ((k2 * m2) * n2 + (rand())) + (rand());)IR");
13091320
}
13101321

13111322
// Make sure we generate the right number of random values == the dimensionality
@@ -1710,11 +1721,11 @@ TEST(LoopNest, ScheduleInlineOutputTensors) {
17101721
# CHECK: for (int m1 = 0; m1 < 4; m1++)
17111722
# CHECK: for (int n1 = 0; n1 < 5; n1++)
17121723
# CHECK: for (int k1 = 0; k1 < 6; k1++)
1713-
# CHECK: x[m1, n1, k1] = (n1 * m1) * k1;
1724+
# CHECK: x[m1, n1, k1] = (k1 * m1) * n1;
17141725
# CHECK: for (int m2 = 0; m2 < 4; m2++)
17151726
# CHECK: for (int n2 = 0; n2 < 5; n2++)
17161727
# CHECK: for (int k2 = 0; k2 < 6; k2++)
1717-
# CHECK: y[m2, n2, k2] = (n2 * m2) * k2 + m2;)IR");
1728+
# CHECK: y[m2, n2, k2] = (k2 * m2) * n2 + m2;)IR");
17181729
}
17191730

17201731
TEST(LoopNest, ScheduleFuserStyle) {
@@ -2130,7 +2141,7 @@ TEST(LoopNest, Reduce2dComputeAt) {
21302141
# CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = int(0);
21312142
# CHECK: for (int r = 0; r < 2; r++) {
21322143
# CHECK: for (int s = 0; s < 2; s++) {
2133-
# CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = (cons[(0 + cy * (1 * W)) + cx * 1]) + (temp[(0 + r * (1 * (W + 1))) + (s + cx) * 1]);
2144+
# CHECK: cons[(0 + cy * (1 * W)) + cx * 1] = (cons[(0 + cy * (1 * W)) + cx * 1]) + (temp[(0 + r * (1 * (W + 1))) + (cx + s) * 1]);
21342145
# CHECK: }
21352146
# CHECK: }
21362147
# CHECK: }
@@ -3225,7 +3236,7 @@ TEST(LoopNest, NormalizeStartVariable) {
32253236
{Store::make(a_buf, {x}, Load::make(kInt, b_buf, {x})),
32263237
Store::make(b_buf, {x}, x * 2)});
32273238
auto for_stmt = For::make(x, y, 100, for_body);
3228-
Block::make({for_stmt});
3239+
auto parent_block = Block::make({for_stmt});
32293240

32303241
LoopNest::normalize(for_stmt);
32313242

@@ -3235,8 +3246,8 @@ TEST(LoopNest, NormalizeStartVariable) {
32353246
const std::string& expected_ir =
32363247
R"IR(
32373248
# CHECK: for (int x = 0; x < 100 - y; x++) {
3238-
# CHECK: A[y + x] = B[y + x];
3239-
# CHECK: B[y + x] = 2 * (y + x);
3249+
# CHECK: A[x + y] = B[x + y];
3250+
# CHECK: B[x + y] = 2 * (x + y);
32403251
)IR";
32413252
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
32423253
}
@@ -3304,7 +3315,7 @@ TEST(LoopNest, NormalizeOnNestedInnerLoop) {
33043315
R"IR(
33053316
# CHECK: for (int x = 50; x < 100; x++) {
33063317
# CHECK: for (int y = 0; y < 90; y++) {
3307-
# CHECK: A[x] = (((B[y + 10]) + 2 * y) + (A[x])) + 20;
3318+
# CHECK: A[x] = (((A[x]) + (B[y + 10])) + 2 * y) + 20;
33083319
)IR";
33093320
torch::jit::testing::FileCheck().run(expected_ir, oss.str());
33103321
}
@@ -3327,7 +3338,7 @@ TEST(LoopNest, NormalizeAndSplitWithTail) {
33273338
BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kInt);
33283339
VarHandle x("x", kInt);
33293340
auto for_stmt = For::make(x, 5, 10, Store::make(a_buf, {x}, x * 2));
3330-
Block::make({for_stmt});
3341+
auto parent_block = Block::make({for_stmt});
33313342

33323343
LoopNest::normalize(for_stmt);
33333344

@@ -3373,7 +3384,7 @@ TEST(LoopNest, FlattenSimpleLoopNest2D) {
33733384
auto for_body = Block::make({Store::make(a_buf, {i, j}, i * j)});
33743385
auto inner_for = For::make(j, 0, 5, for_body);
33753386
auto outer_for = For::make(i, 0, 10, inner_for);
3376-
Block::make({outer_for});
3387+
auto parent_block = Block::make({outer_for});
33773388

33783389
std::vector<ForPtr> loops = {outer_for, inner_for};
33793390
ForPtr flattened = nullptr;
@@ -3420,7 +3431,7 @@ TEST(LoopNest, FlattenSimpleLoopNest3D) {
34203431
auto for1 = For::make(k, 0, 7, for_body);
34213432
auto for2 = For::make(j, 0, 5, for1);
34223433
auto for3 = For::make(i, 0, 10, for2);
3423-
Block::make({for3});
3434+
auto parent_block = Block::make({for3});
34243435

34253436
std::vector<ForPtr> loops = {for3, for2, for1};
34263437
ForPtr flattened = nullptr;
@@ -3463,7 +3474,7 @@ TEST(LoopNest, FlattenLoopNestAfterNormalize) {
34633474
auto for_body = Block::make({Store::make(a_buf, {i - 2, j - 3}, i * j)});
34643475
auto inner_for = For::make(j, 3, 15, for_body);
34653476
auto outer_for = For::make(i, 2, 10, inner_for);
3466-
Block::make({outer_for});
3477+
auto parent_block = Block::make({outer_for});
34673478

34683479
std::vector<ForPtr> loops = {outer_for, inner_for};
34693480
ForPtr flattened = nullptr;
@@ -3712,7 +3723,7 @@ TEST(LoopNest, CacheReadsSimple) {
37123723
#CHECK: A_local[j_1] = A[
37133724
#CHECK: }
37143725
#CHECK: for (int j_2
3715-
#CHECK: B[10 * i_1 + j_2] = A_local[j_2];
3726+
#CHECK: B[j_2 + 10 * i_1] = A_local[j_2];
37163727
#CHECK: }
37173728
#CHECK: }
37183729
#CHECK: for (int i_2
@@ -3769,7 +3780,7 @@ TEST(LoopNest, CacheReadsOuter) {
37693780
checkIR(result, R"IR(
37703781
#CHECK: Allocate(A_local); // dtype=int, dims=[21, 11]
37713782
#CHECK: A_local[j_1 + 11 * i_1] =
3772-
#CHECK: B[10 * i_2 + j_2] = (A_local[(j_2 + 11 * i_2) + 12]) + (A_local[j_2 + 11 * i_2]);
3783+
#CHECK: B[j_2 + 10 * i_2] = (A_local[j_2 + 11 * i_2]) + (A_local[(j_2 + 11 * i_2) + 12]);
37733784
)IR");
37743785

37753786
std::vector<int> b_data(200, 0);
@@ -3816,7 +3827,7 @@ TEST(LoopNest, CacheReadsInternal) {
38163827
checkIR(result, R"IR(
38173828
#CHECK: Allocate(A_local); // dtype=int, dims=[2, 11]
38183829
#CHECK: A_local[j_1 + 11 * i_2] =
3819-
#CHECK: B[10 * i_1 + j_2] = (A_local[j_2 + 12]) + (A_local[j_2]);
3830+
#CHECK: B[j_2 + 10 * i_1] = (A_local[j_2 + 12]) + (A_local[j_2]);
38203831
)IR");
38213832

38223833
std::vector<int> b_data(200, 0);
@@ -3863,8 +3874,8 @@ TEST(LoopNest, CacheReadsInner) {
38633874

38643875
checkIR(result, R"IR(
38653876
#CHECK: Allocate(A_local); // dtype=int, dims=[5, 2]
3866-
#CHECK: A_local[2 * i_2 + j_2] =
3867-
#CHECK: B[10 * i_1 + j_1] = (A_local[1]) + (A_local[8]);
3877+
#CHECK: A_local[j_2 + 2 * i_2] =
3878+
#CHECK: B[j_1 + 10 * i_1] = (A_local[1]) + (A_local[8]);
38683879
)IR");
38693880

38703881
std::vector<int> b_data(200, 0);
@@ -3914,7 +3925,7 @@ TEST(LoopNest, CacheWritesSimple) {
39143925
#CHECK: for (int j = 0; j < 64
39153926
#CHECK: A_local[j] = i * j;
39163927
#CHECK: for (int j_1 = 0; j_1 < 64
3917-
#CHECK: A[64 * i + j_1] = A_local[
3928+
#CHECK: A[j_1 + 64 * i] = A_local[
39183929
#CHECK: Free(A_local);
39193930
#CHECK-NOT: A_local
39203931
)IR");

test/cpp/tensorexpr/test_reductions.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,8 +1578,8 @@ TEST(Reductions, ReductionCacheBodyAccess) {
15781578
#CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12]
15791579
#CHECK: for (int j = 0; j < 32; j++) {
15801580
#CHECK: for (int k = 0; k < 12; k++) {
1581-
#CHECK: scale_local[k + 12 * j] = scale[(k + 384 * l1) + 12 * j];
1582-
#CHECK: sum[l1] = (sum[l1]) + (scale_local[12 * n1_1 + m1_1]);
1581+
#CHECK: scale_local[k + 12 * j] = scale[(k + 12 * j) + 384 * l1];
1582+
#CHECK: sum[l1] = (sum[l1]) + (scale_local[m1_1 + 12 * n1_1]);
15831583
#CHECK: scale_1[l] = (b[l]) * (sum[l]);
15841584
#CHECK: Free(scale_local);
15851585
)IR";
@@ -1667,7 +1667,7 @@ TEST(Reductions, ReductionSplitCacheConsumerAccess) {
16671667
const std::string& expected_ir =
16681668
R"IR(
16691669
#CHECK: Allocate(sum_local); // dtype=float, dims=[4]
1670-
#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((12 * n1_1 + 384 * l1_inner) + m1_1) + 1536 * l1_outer]);
1670+
#CHECK: sum[l1_inner + 4 * l1_outer] = (sum[l1_inner + 4 * l1_outer]) + (scale[((m1_1 + 12 * n1_1) + 1536 * l1_outer) + 384 * l1_inner]);
16711671
#CHECK: for (int i = 0; i < 4
16721672
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
16731673
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);
@@ -1716,7 +1716,7 @@ TEST(Reductions, ReductionReorderCacheConsumerAccess) {
17161716
const std::string& expected_ir =
17171717
R"IR(
17181718
#CHECK: Allocate(sum_local); // dtype=float, dims=[4]
1719-
#CHECK: sum[l1] = (sum[l1]) + (scale[(12 * n1_1 + m1_1) + 384 * l1]);
1719+
#CHECK: sum[l1] = (sum[l1]) + (scale[(m1_1 + 12 * n1_1) + 384 * l1]);
17201720
#CHECK: for (int i = 0; i < 4
17211721
#CHECK: sum_local[i] = sum[i + 4 * l_outer];
17221722
#CHECK: scale_1[l_inner + 4 * l_outer] = (b[l_inner + 4 * l_outer]) * (sum_local[l_inner]);

0 commit comments

Comments
 (0)