Skip to content

Commit 6cfe555

Browse files
Thiago Crepaldipytorchmergebot
Thiago Crepaldi
authored andcommitted
[ONNX] Apply Common Subexpression Elimination pass to ONNX export (pytorch#85665)
## Summary Exporting graphs with Autocast may fail due to a limitation on JIT tracer. By disabling Autocast cache, tracer works, but there can be performance hit when there is reuse of weights in convolution, for example By applying CSE, such performance loss can be reverted. ps: As a comment at pytorch#84092 mentioned, disabling Autocast cache is an acceptable workaround and used throughout PyTorch code. Fixes pytorch#84092 ## Examples of before and after CSE being applied: ### Example: eliminating `%17` and reusing `%16` instead ```python # BEFORE graph(%0 : Float(requires_grad=0, device=cpu)): %3 : Scalar = aten::ScalarImplicit(%0), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %13 : int = prim::Constant[value=3](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %14 : int = prim::Constant[value=4](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %15 : int[] = prim::ListConstruct(%13, %14), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %16 : NoneType = prim::Constant(), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %17 : NoneType = prim::Constant(), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %18 : Device = prim::Constant[value="cpu"](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %19 : bool = prim::Constant[value=0](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %20 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::full(%15, %3, %16, %17, %18, %19), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 return (%20) AFTER graph(%0 : Float(requires_grad=0, device=cpu)): %3 : Scalar = aten::ScalarImplicit(%0), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %13 : int = prim::Constant[value=3](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %14 : int = prim::Constant[value=4](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %15 : int[] = prim::ListConstruct(%13, %14), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %16 : NoneType = prim::Constant(), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: %18 : Device = prim::Constant[value="cpu"](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %19 : bool = prim::Constant[value=0](), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 %20 : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu) = aten::full(%15, %3, %16, %16, %18, %19), scope: test_onnx_opset.TestONNXOpset.test_full.<locals>.MyModule:: # /home/thiagofc/dev/github/pytorch/test/onnx/test_onnx_opset.py:347:0 return (%20) ``` Pull Request resolved: pytorch#85665 Approved by: https://github.com/ngimel, https://github.com/AllenTiTaiWang, https://github.com/BowenBao
1 parent c719ec9 commit 6cfe555

6 files changed

+51
-44
lines changed

test/onnx/expect/TestOperators.test_baddbmm.expect

+12-12
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ graph {
55
node {
66
input: "onnx::MatMul_1"
77
input: "onnx::MatMul_2"
8-
output: "onnx::Mul_5"
8+
output: "onnx::Mul_4"
99
name: "MatMul_0"
1010
op_type: "MatMul"
1111
}
1212
node {
13-
output: "onnx::Mul_11"
13+
output: "onnx::Mul_10"
1414
name: "Constant_1"
1515
op_type: "Constant"
1616
attribute {
@@ -23,14 +23,14 @@ graph {
2323
}
2424
}
2525
node {
26-
input: "onnx::Mul_5"
27-
input: "onnx::Mul_11"
28-
output: "onnx::Add_7"
26+
input: "onnx::Mul_4"
27+
input: "onnx::Mul_10"
28+
output: "onnx::Add_6"
2929
name: "Mul_2"
3030
op_type: "Mul"
3131
}
3232
node {
33-
output: "onnx::Mul_12"
33+
output: "onnx::Mul_11"
3434
name: "Constant_3"
3535
op_type: "Constant"
3636
attribute {
@@ -44,15 +44,15 @@ graph {
4444
}
4545
node {
4646
input: "onnx::Mul_0"
47-
input: "onnx::Mul_12"
48-
output: "onnx::Add_9"
47+
input: "onnx::Mul_11"
48+
output: "onnx::Add_8"
4949
name: "Mul_4"
5050
op_type: "Mul"
5151
}
5252
node {
53-
input: "onnx::Add_7"
54-
input: "onnx::Add_9"
55-
output: "10"
53+
input: "onnx::Add_6"
54+
input: "onnx::Add_8"
55+
output: "9"
5656
name: "Add_5"
5757
op_type: "Add"
5858
}
@@ -115,7 +115,7 @@ graph {
115115
}
116116
}
117117
output {
118-
name: "10"
118+
name: "9"
119119
type {
120120
tensor_type {
121121
elem_type: 1

test/onnx/expect/TestOperators.test_narrow.expect

+6-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ producer_name: "pytorch"
33
producer_version: "CURRENT_VERSION"
44
graph {
55
node {
6-
output: "onnx::Slice_14"
6+
output: "onnx::Slice_13"
77
name: "Constant_0"
88
op_type: "Constant"
99
attribute {
@@ -17,7 +17,7 @@ graph {
1717
}
1818
}
1919
node {
20-
output: "onnx::Slice_15"
20+
output: "onnx::Slice_14"
2121
name: "Constant_1"
2222
op_type: "Constant"
2323
attribute {
@@ -31,7 +31,7 @@ graph {
3131
}
3232
}
3333
node {
34-
output: "onnx::Slice_16"
34+
output: "onnx::Slice_15"
3535
name: "Constant_2"
3636
op_type: "Constant"
3737
attribute {
@@ -46,10 +46,10 @@ graph {
4646
}
4747
node {
4848
input: "onnx::Slice_0"
49+
input: "onnx::Slice_13"
4950
input: "onnx::Slice_14"
5051
input: "onnx::Slice_15"
51-
input: "onnx::Slice_16"
52-
output: "12"
52+
output: "11"
5353
name: "Slice_3"
5454
op_type: "Slice"
5555
}
@@ -71,7 +71,7 @@ graph {
7171
}
7272
}
7373
output {
74-
name: "12"
74+
name: "11"
7575
type {
7676
tensor_type {
7777
elem_type: 1

test/onnx/expect/TestOperators.test_shape_value_map.expect

+14-14
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ graph {
5555
op_type: "Unsqueeze"
5656
}
5757
node {
58-
output: "onnx::Concat_26"
58+
output: "onnx::Concat_25"
5959
name: "Constant_5"
6060
op_type: "Constant"
6161
attribute {
@@ -69,7 +69,7 @@ graph {
6969
}
7070
}
7171
node {
72-
output: "onnx::Concat_27"
72+
output: "onnx::Concat_26"
7373
name: "Constant_6"
7474
op_type: "Constant"
7575
attribute {
@@ -83,7 +83,7 @@ graph {
8383
}
8484
}
8585
node {
86-
output: "onnx::Concat_28"
86+
output: "onnx::Concat_27"
8787
name: "Constant_7"
8888
op_type: "Constant"
8989
attribute {
@@ -98,9 +98,9 @@ graph {
9898
}
9999
node {
100100
input: "onnx::Concat_8"
101+
input: "onnx::Concat_25"
101102
input: "onnx::Concat_26"
102103
input: "onnx::Concat_27"
103-
input: "onnx::Concat_28"
104104
output: "onnx::Reshape_15"
105105
name: "Concat_8"
106106
op_type: "Concat"
@@ -148,7 +148,7 @@ graph {
148148
}
149149
}
150150
node {
151-
output: "onnx::Unsqueeze_20"
151+
output: "onnx::Unsqueeze_19"
152152
name: "Constant_12"
153153
op_type: "Constant"
154154
attribute {
@@ -163,13 +163,13 @@ graph {
163163
}
164164
node {
165165
input: "onnx::Unsqueeze_3"
166-
input: "onnx::Unsqueeze_20"
167-
output: "onnx::Concat_21"
166+
input: "onnx::Unsqueeze_19"
167+
output: "onnx::Concat_20"
168168
name: "Unsqueeze_13"
169169
op_type: "Unsqueeze"
170170
}
171171
node {
172-
output: "onnx::Concat_29"
172+
output: "onnx::Concat_28"
173173
name: "Constant_14"
174174
op_type: "Constant"
175175
attribute {
@@ -183,9 +183,9 @@ graph {
183183
}
184184
}
185185
node {
186-
input: "onnx::Concat_21"
187-
input: "onnx::Concat_29"
188-
output: "onnx::Reshape_24"
186+
input: "onnx::Concat_20"
187+
input: "onnx::Concat_28"
188+
output: "onnx::Reshape_23"
189189
name: "Concat_15"
190190
op_type: "Concat"
191191
attribute {
@@ -196,8 +196,8 @@ graph {
196196
}
197197
node {
198198
input: "onnx::Reshape_18"
199-
input: "onnx::Reshape_24"
200-
output: "25"
199+
input: "onnx::Reshape_23"
200+
output: "24"
201201
name: "Reshape_16"
202202
op_type: "Reshape"
203203
attribute {
@@ -230,7 +230,7 @@ graph {
230230
}
231231
}
232232
output {
233-
name: "25"
233+
name: "24"
234234
type {
235235
tensor_type {
236236
elem_type: 1

test/onnx/test_operators.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# Owner(s): ["module: onnx"]
22

3+
"""
4+
Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
5+
--no-onnx: no onnx python dependency
6+
--produce-onnx-test-data: generate onnx test data
7+
--accept: accept onnx updates and overwrite models
8+
"""
39
import glob
410
import inspect
511
import io
@@ -8,6 +14,9 @@
814
import shutil
915
import tempfile
1016

17+
# Full diff for expect files
18+
import unittest
19+
1120
import torch
1221
import torch.nn as nn
1322
import torch.nn.functional as F
@@ -30,15 +39,6 @@
3039
from torch.testing._internal import common_utils
3140
from torch.testing._internal.common_utils import skipIfCaffe2, skipIfNoLapack
3241

33-
"""Usage: python test/onnx/test_operators.py [--no-onnx] [--produce-onnx-test-data]
34-
--no-onnx: no onnx python dependence
35-
--produce-onnx-test-data: generate onnx test data
36-
--accept: accept onnx updates and overwrite models
37-
"""
38-
39-
# Full diff for expect files
40-
import unittest
41-
4242
unittest.TestCase.maxDiff = None
4343

4444
_onnx_test = False # flag to produce onnx test cases.

torch/_C/__init__.pyi.in

+1
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,7 @@ def _jit_set_inline_everything_mode(enabled: _bool) -> None: ...
989989
def _jit_get_logging_option() -> str: ...
990990
def _jit_set_logging_option(option: str) -> None: ...
991991
def _jit_set_logging_stream(stream_name: str) -> None: ...
992+
def _jit_pass_cse(Graph) -> _bool: ...
992993
def _jit_pass_dce(Graph) -> None: ...
993994
def _jit_pass_lint(Graph) -> None: ...
994995

torch/onnx/utils.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,12 @@ def _optimize_graph(
573573
_C._jit_pass_dce(graph)
574574
_C._jit_pass_lint(graph)
575575

576+
# CSE should improve perf when Autocast is used with disabled cache
577+
# Autocast is disabled due to a limitation on tracer as described at https://github.com/pytorch/pytorch/issues/84092
578+
# Must run before _C._jit_pass_erase_number_types to prevent type substitution
579+
if _C._jit_pass_cse(graph):
580+
_C._jit_pass_onnx_lint(graph)
581+
576582
_C._jit_pass_canonicalize_graph_fuser_ops(graph)
577583
_C._jit_pass_lint(graph)
578584
_C._jit_pass_peephole(graph, True)
@@ -632,6 +638,7 @@ def _optimize_graph(
632638
dynamic_axes = {} if dynamic_axes is None else dynamic_axes
633639
_C._jit_pass_onnx_set_dynamic_input_shape(graph, dynamic_axes, input_names)
634640
_C._jit_pass_onnx_lint(graph)
641+
635642
graph = _C._jit_pass_onnx(graph, operator_export_type)
636643
_C._jit_pass_onnx_lint(graph)
637644
_C._jit_pass_lint(graph)
@@ -851,11 +858,10 @@ def _trace_and_get_graph_from_model(model, args):
851858
orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
852859

853860
# Disable Autocast cache because it replaces kernel's weight and bias
854-
# to be replaced by (undesired) constants
861+
# by (undesired) constants.
862+
# No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665
855863
# TODO: https://github.com/pytorch/pytorch/issues/84092
856864
prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
857-
# When weights are not reused, there is no perf impact
858-
# ONNX runtimes can also apply CSE optimization to compensate the lack of cache here
859865
torch.set_autocast_cache_enabled(False)
860866
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
861867
model, args, strict=False, _force_outplace=False, _return_inputs_states=True

0 commit comments

Comments
 (0)