Skip to content

Commit 80659b7

Browse files
Gamrixfacebook-github-bot
authored andcommitted
Hoisting common expressions out of If blocks [retry] (pytorch#65645)
Summary: Pull Request resolved: pytorch#65645 This is a retry of PR: pytorch#59492 Latest Changes: Added more tests, added the getOrCreateDB pattern, updated logic to remove unnecessary checks addressed all comments. Adding code to find common expressions from the two subblocks of an if operation and hoist them before the if block. This also allows Dead Code Elimination to then eliminate some if blocks. Test Plan: python test_jit.py TestIfHoisting Reviewed By: eellison Differential Revision: D33302065 Pulled By: Gamrix fbshipit-source-id: a5a184a480cf07354359aaca344c6e27b687a3c2
1 parent 569aeec commit 80659b7

12 files changed

+419
-105
lines changed

test/jit/test_if_hoisting.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
2+
import torch
3+
from torch.testing import FileCheck
4+
from torch.testing._internal.jit_utils import JitTestCase
5+
6+
if __name__ == "__main__":
7+
raise RuntimeError(
8+
"This test file is not meant to be run directly, use:\n\n"
9+
"\tpython test/test_jit.py TESTNAME\n\n"
10+
"instead."
11+
)
12+
13+
14+
class TestIfHoisting(JitTestCase):
15+
def test_if_hoist_basic(self):
16+
def fn(x: bool, y: int):
17+
if x:
18+
z = y + 3
19+
else:
20+
z = y + 3
21+
return z
22+
23+
24+
fn_script = torch.jit.script(fn)
25+
op_graph = fn_script.graph
26+
self.run_pass("common_expression_hoisting", op_graph)
27+
self.run_pass("dce", op_graph)
28+
FileCheck().check_count("prim::If", 0, exactly=True).run(op_graph)
29+
FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
30+
self.assertEqual(fn(True, 1), fn_script(True, 1))
31+
32+
def test_if_hoist_transposed_expr(self):
33+
"""
34+
Making sure that we can properly eliminate
35+
an expression even if it is not at the start
36+
of a block
37+
"""
38+
def fn(x: bool, y: int):
39+
if x:
40+
a = y + 3
41+
b = y * 2
42+
else:
43+
b = y * 2
44+
a = y + 3
45+
return a, b
46+
47+
fn_script = torch.jit.script(fn)
48+
op_graph = fn_script.graph
49+
self.run_pass("common_expression_hoisting", op_graph)
50+
self.run_pass("dce", op_graph)
51+
52+
FileCheck().check_count("prim::If", 0, exactly=True).run(op_graph)
53+
FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
54+
55+
self.assertEqual(fn(True, 1), fn_script(True, 1))
56+
self.assertEqual(fn(False, 5), fn_script(False, 5))
57+
58+
def test_if_hoist_swapped_expr(self):
59+
"""
60+
Making sure that the if statement
61+
doesn't get fully eliminated here
62+
"""
63+
def fn(x: bool, y: int):
64+
if x:
65+
a = y + 3
66+
b = y * 2
67+
else:
68+
a = y * 2
69+
b = y + 3
70+
return a, b
71+
72+
fn_script = torch.jit.script(fn)
73+
op_graph = fn_script.graph
74+
self.run_pass("common_expression_hoisting", op_graph)
75+
self.run_pass("dce", op_graph)
76+
77+
FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
78+
FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
79+
80+
self.assertEqual(fn(True, 1), fn_script(True, 1))
81+
self.assertEqual(fn(False, 5), fn_script(False, 5))
82+
83+
def test_if_hoist_reused_var(self):
84+
"""
85+
Making sure that cases where the python variable is reused
86+
is handled correctly
87+
"""
88+
def fn(x: bool, y: int):
89+
b = 6
90+
if x:
91+
a = y + 3
92+
a = y * 2
93+
else:
94+
a = y * 2
95+
b = y + 3
96+
return a, b
97+
98+
fn_script = torch.jit.script(fn)
99+
op_graph = fn_script.graph
100+
self.run_pass("common_expression_hoisting", op_graph)
101+
self.run_pass("dce", op_graph)
102+
103+
FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
104+
FileCheck().check_count("aten::add", 1, exactly=True).run(op_graph)
105+
FileCheck().check_count("aten::mul", 1, exactly=True).run(op_graph)
106+
107+
self.assertEqual(fn(True, 1), fn_script(True, 1))
108+
self.assertEqual(fn(False, 5), fn_script(False, 5))
109+
110+
def test_no_hoist(self):
111+
"""
112+
Nothing should happen here, expressions are different
113+
"""
114+
def fn(x: bool, y: int, z: int):
115+
if x:
116+
a = y + 3
117+
else:
118+
a = z + 3
119+
return a
120+
121+
fn_script = torch.jit.script(fn)
122+
op_graph = fn_script.graph
123+
self.run_pass("common_expression_hoisting", op_graph)
124+
self.run_pass("dce", op_graph)
125+
126+
FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
127+
FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
128+
129+
self.assertEqual(fn(True, 1, 3), fn_script(True, 1, 3))
130+
self.assertEqual(fn(False, 5, 10), fn_script(False, 5, 10))
131+
132+
def test_mutate_before(self):
133+
"""
134+
Make sure that if there is a mutation before the common
135+
op, the hoist doesn't happen
136+
"""
137+
def fn(x: bool, y: torch.Tensor):
138+
if x:
139+
y.add_(8)
140+
a = y + 3
141+
else:
142+
a = y + 3
143+
return a
144+
145+
fn_script = torch.jit.script(fn)
146+
op_graph = fn_script.graph
147+
self.run_pass("common_expression_hoisting", op_graph)
148+
self.run_pass("dce", op_graph)
149+
150+
FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
151+
FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
152+
FileCheck().check_count("aten::add_", 1, exactly=True).run(op_graph)
153+
154+
t1 = torch.Tensor([1])
155+
t2 = torch.Tensor([5, 6])
156+
self.assertEqual(fn(True, t1), fn_script(True, t1))
157+
self.assertEqual(fn(False, t2), fn_script(False, t2))
158+
159+
def test_mutate_after(self):
160+
"""
161+
Check that the hoist can happen properly, and
162+
that the output is still correct.
163+
"""
164+
def fn(x: bool, y: torch.Tensor):
165+
if x:
166+
b = 1
167+
a = y + 3
168+
y.add_(8)
169+
else:
170+
b = 2
171+
a = y + 3
172+
c = b + a
173+
return a
174+
175+
fn_script = torch.jit.script(fn)
176+
op_graph = fn_script.graph
177+
self.run_pass("common_expression_hoisting", op_graph)
178+
self.run_pass("dce", op_graph)
179+
180+
FileCheck().check_count("prim::If", 1, exactly=True).run(op_graph)
181+
FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
182+
183+
t1 = torch.Tensor([1])
184+
t2 = torch.Tensor([5, 6])
185+
self.assertEqual(fn(True, t1.clone()), fn_script(True, t1.clone()))
186+
self.assertEqual(fn(False, t2.clone()), fn_script(False, t2.clone()))
187+
188+
def test_multiple_hoists(self):
189+
"""
190+
test that hoists that depend on other hoists are done correctly
191+
"""
192+
def fn(x: bool, y: torch.Tensor):
193+
if x:
194+
a = y + 3
195+
b = a + y
196+
else:
197+
a = y + 3
198+
b = a + y
199+
c = b * 2
200+
return c
201+
202+
fn_script = torch.jit.script(fn)
203+
op_graph = fn_script.graph
204+
self.run_pass("common_expression_hoisting", op_graph)
205+
self.run_pass("dce", op_graph)
206+
207+
FileCheck().check_count("prim::If", 0, exactly=True).run(op_graph)
208+
FileCheck().check_count("aten::add", 2, exactly=True).run(op_graph)
209+
210+
t1 = torch.Tensor([1])
211+
t2 = torch.Tensor([5, 6])
212+
self.assertEqual(fn(True, t1), fn_script(True, t1))
213+
self.assertEqual(fn(False, t2), fn_script(False, t2))

test/quantization/jit/test_quantize_jit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,7 @@ class Res(torch.nn.Module):
12561256
def __init__(self):
12571257
super(Res, self).__init__()
12581258
self.conv = torch.nn.Conv2d(3, 3, 1).float()
1259+
self.conv2 = torch.nn.Conv2d(3, 3, 1).float()
12591260
self.use_skip = True
12601261

12611262
def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
@@ -1264,7 +1265,7 @@ def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
12641265
if self.use_skip:
12651266
return self.conv(x)
12661267
else:
1267-
return self.conv(x)
1268+
return self.conv2(x)
12681269

12691270
class M(torch.nn.Module):
12701271
def __init__(self):

test/test_jit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401
2626
from jit.test_ignore_context_manager import TestIgnoreContextManager # noqa: F401
2727
from jit.test_symbolic_shape_analysis import TestSymbolicShapeAnalysis # noqa: F401
28+
from jit.test_if_hoisting import TestIfHoisting # noqa: F401
2829
from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401
2930
from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401
3031
from jit.test_peephole import TestPeephole # noqa: F401

tools/build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ core_sources_full_mobile_no_backend_interface = [
219219
"torch/csrc/jit/passes/clear_profiling.cpp",
220220
"torch/csrc/jit/passes/clear_undefinedness.cpp",
221221
"torch/csrc/jit/passes/common_subexpression_elimination.cpp",
222+
"torch/csrc/jit/passes/common_expression_hoisting.cpp",
222223
"torch/csrc/jit/passes/concat_opt.cpp",
223224
"torch/csrc/jit/passes/constant_pooling.cpp",
224225
"torch/csrc/jit/passes/constant_propagation.cpp",

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def _jit_pass_inline(Graph) -> None: ...
211211
def _jit_pass_constant_propagation(Graph) -> None: ...
212212
def _jit_pass_propagate_shapes_on_graph(Graph) -> None: ...
213213
def _jit_erase_non_input_shape_information(Graph) -> None: ...
214+
def _jit_pass_common_expression_hoisting(Graph) -> None: ...
214215
def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
215216
def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ...
216217
def _jit_can_fuse_on_cpu() -> _bool: ...

torch/csrc/jit/ir/node_hashing.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
208208

209209
} // anonymous namespace
210210

211+
// Makes a hash that hashes the input Value, the output type
212+
// as well as the node attributes
211213
size_t HashNode::operator()(const Node* k) const {
212214
AT_ASSERT(k != nullptr);
213215
size_t constant_hash = 0;
@@ -235,6 +237,8 @@ size_t HashNode::operator()(const Node* k) const {
235237
constant_hash);
236238
};
237239

240+
// Checks that two nodes have the same inputs, output types
241+
// and node attributes.
238242
bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
239243
if (lhs == nullptr && rhs == nullptr)
240244
return true;
@@ -267,6 +271,16 @@ bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
267271
if (!attributesEqualCSE(lhs, rhs))
268272
return false;
269273

274+
// Check if the blocks contained in a op are the same
275+
if (lhs->blocks().size() != rhs->blocks().size()) {
276+
return false;
277+
}
278+
for (size_t i = 0; i < lhs->blocks().size(); ++i) {
279+
if (lhs->blocks()[i] != rhs->blocks()[i]) {
280+
return false;
281+
}
282+
}
283+
270284
return true;
271285
};
272286

0 commit comments

Comments
 (0)