Skip to content

Commit 500fd65

Browse files
BowenBaopytorchmergebot
authored andcommitted
[ONNX] Create common ExportTestCase base class (pytorch#88145)
Refactor out a common base class `ExportTestCase`, for common things in `setUp`. Pull Request resolved: pytorch#88145 Approved by: https://github.com/justinchuby, https://github.com/abock, https://github.com/AllenTiTaiWang
1 parent 20ae19a commit 500fd65

16 files changed

+76
-55
lines changed

test/onnx/onnx_test_common.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33
from __future__ import annotations
44

55
import os
6-
import random
76
from typing import Any, Mapping, Type
87

9-
import numpy as np
108
import onnxruntime
9+
import pytorch_test_common
1110

1211
import torch
1312
from torch.onnx import _constants, verification
14-
from torch.testing._internal import common_utils
1513

1614
onnx_model_dir = os.path.join(
1715
os.path.dirname(os.path.realpath(__file__)),
@@ -54,21 +52,15 @@ def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any])
5452
return f"{cls.__name__}_{suffix}"
5553

5654

57-
def set_rng_seed(seed):
58-
torch.manual_seed(seed)
59-
random.seed(seed)
60-
np.random.seed(seed)
61-
62-
63-
class _TestONNXRuntime(common_utils.TestCase):
55+
class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
6456
opset_version = _constants.ONNX_DEFAULT_OPSET
6557
keep_initializers_as_inputs = True # For IR version 3 type export.
6658
is_script = False
6759
check_shape = True
6860
check_dtype = True
6961

7062
def setUp(self):
71-
set_rng_seed(0)
63+
super().setUp()
7264
onnxruntime.set_seed(0)
7365
if torch.cuda.is_available():
7466
torch.cuda.manual_seed_all(0)

test/onnx/pytorch_test_common.py

+26
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@
22

33
import functools
44
import os
5+
import random
56
import sys
67
import unittest
78
from typing import Optional
89

10+
import numpy as np
11+
912
import torch
1013
from torch.autograd import function
14+
from torch.onnx._internal import diagnostics
15+
from torch.testing._internal import common_utils
1116

1217
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
1318
sys.path.insert(-1, pytorch_test_dir)
@@ -188,3 +193,24 @@ def wrapper(self, *args, **kwargs):
188193

189194
def flatten(x):
190195
return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x))
196+
197+
198+
def set_rng_seed(seed):
199+
torch.manual_seed(seed)
200+
random.seed(seed)
201+
np.random.seed(seed)
202+
203+
204+
class ExportTestCase(common_utils.TestCase):
205+
"""Test case for ONNX export.
206+
207+
Any test case that tests functionalities under torch.onnx should inherit from this class.
208+
"""
209+
210+
def setUp(self):
211+
super().setUp()
212+
# TODO(#88264): Flaky test failures after changing seed.
213+
set_rng_seed(0)
214+
if torch.cuda.is_available():
215+
torch.cuda.manual_seed_all(0)
216+
diagnostics.engine.clear()

test/onnx/test_autograd_funs.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Owner(s): ["module: onnx"]
22

3-
import unittest
3+
import pytorch_test_common
44

55
import torch
6-
76
from onnx_test_common import run_model_test
87
from torch.onnx import OperatorExportTypes
98
from torch.onnx._globals import GLOBALS
109
from torch.onnx.utils import _model_to_graph
10+
from torch.testing._internal import common_utils
1111

1212

13-
class TestAutogradFuns(unittest.TestCase):
13+
class TestAutogradFuns(pytorch_test_common.ExportTestCase):
1414
opset_version = GLOBALS.export_onnx_opset_version
1515
keep_initializers_as_inputs = False
1616
onnx_shape_inference = True
@@ -209,4 +209,4 @@ def forward(self, input):
209209

210210

211211
if __name__ == "__main__":
212-
unittest.main()
212+
common_utils.run_tests()

test/onnx/test_custom_ops.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import numpy as np
55
import onnx
66
import onnx_test_common
7+
import pytorch_test_common
78
import torch
89
import torch.utils.cpp_extension
910
from test_pytorch_onnx_caffe2 import do_export
1011
from torch.onnx import symbolic_helper
1112
from torch.testing._internal import common_utils
1213

1314

14-
class TestCustomOps(common_utils.TestCase):
15+
class TestCustomOps(pytorch_test_common.ExportTestCase):
1516
def test_custom_add(self):
1617
op_source = """
1718
#include <torch/script.h>
@@ -56,7 +57,7 @@ def symbolic_custom_add(g, self, other):
5657
np.testing.assert_array_equal(caffe2_out[0], model(x, y).cpu().numpy())
5758

5859

59-
class TestCustomAutogradFunction(common_utils.TestCase):
60+
class TestCustomAutogradFunction(pytorch_test_common.ExportTestCase):
6061
opset_version = 9
6162
keep_initializers_as_inputs = False
6263
onnx_shape_inference = True
@@ -130,7 +131,7 @@ def symbolic_pythonop(ctx: torch.onnx.SymbolicContext, g, *args, **kwargs):
130131
onnx_test_common.run_model_test(self, model, input_args=(x,))
131132

132133

133-
class TestExportAsContribOps(common_utils.TestCase):
134+
class TestExportAsContribOps(pytorch_test_common.ExportTestCase):
134135
opset_version = 14
135136
keep_initializers_as_inputs = False
136137
onnx_shape_inference = True

test/onnx/test_export_modes.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
# Make the helper files in test/ importable
1616
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
1717
sys.path.append(pytorch_test_dir)
18+
import pytorch_test_common
19+
1820
from torch.testing._internal import common_utils
1921

2022

2123
# Smoke tests for export methods
22-
class TestExportModes(common_utils.TestCase):
24+
class TestExportModes(pytorch_test_common.ExportTestCase):
2325
class MyModel(nn.Module):
2426
def __init__(self):
2527
super(TestExportModes.MyModel, self).__init__()

test/onnx/test_models.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
import unittest
44

5-
import torch
5+
import pytorch_test_common
66

7+
import torch
78
from model_defs.dcgan import _netD, _netG, bsz, imgsz, nz, weights_init
89
from model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2
910
from model_defs.mnist import MNIST
@@ -44,7 +45,7 @@ def toC(x):
4445
BATCH_SIZE = 2
4546

4647

47-
class TestModels(common_utils.TestCase):
48+
class TestModels(pytorch_test_common.ExportTestCase):
4849
opset_version = 9 # Caffe2 doesn't support the default.
4950
keep_initializers_as_inputs = False
5051

test/onnx/test_models_onnxruntime.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import onnx_test_common
99
import parameterized
1010
import PIL
11+
import pytorch_test_common
1112
import test_models
1213

1314
import torch
@@ -64,7 +65,7 @@ def exportTest(
6465

6566
TestModels = type(
6667
"TestModels",
67-
(common_utils.TestCase,),
68+
(pytorch_test_common.ExportTestCase,),
6869
dict(
6970
test_models.TestModels.__dict__,
7071
is_script_test_enabled=False,
@@ -77,7 +78,7 @@ def exportTest(
7778
# model tests for scripting with new JIT APIs and shape inference
7879
TestModels_new_jit_API = type(
7980
"TestModels_new_jit_API",
80-
(common_utils.TestCase,),
81+
(pytorch_test_common.ExportTestCase,),
8182
dict(
8283
TestModels.__dict__,
8384
exportTest=exportTest,

test/onnx/test_onnx_opset.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import itertools
55

66
import onnx
7+
import pytorch_test_common
78

89
import torch
910
import torch.onnx
@@ -70,7 +71,7 @@ def check_onnx_opsets_operator(
7071
check_onnx_opset_operator(model, ops[opset_version], opset_version)
7172

7273

73-
class TestONNXOpset(common_utils.TestCase):
74+
class TestONNXOpset(pytorch_test_common.ExportTestCase):
7475
def test_opset_fallback(self):
7576
class MyModule(Module):
7677
def forward(self, x):

test/onnx/test_operators.py

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232
from torch.autograd import Function, Variable
3333
from torch.nn import functional, Module
34+
from torch.onnx._internal import diagnostics
3435
from torch.onnx.symbolic_helper import (
3536
_get_tensor_dim_size,
3637
_get_tensor_sizes,
@@ -71,6 +72,10 @@ def forward(self, *args):
7172

7273

7374
class TestOperators(common_utils.TestCase):
75+
def setUp(self):
76+
super().setUp()
77+
diagnostics.engine.clear()
78+
7479
def assertONNX(self, f, args, params=None, **kwargs):
7580
if params is None:
7681
params = ()

test/onnx/test_pytorch_helper.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import unittest
55

66
import numpy as np
7+
import pytorch_test_common
78

89
import torch.nn.init as init
910
import torch.onnx
@@ -15,7 +16,7 @@
1516
from torch.testing._internal.common_utils import skipIfNoLapack
1617

1718

18-
class TestCaffe2Backend(common_utils.TestCase):
19+
class TestCaffe2Backend(pytorch_test_common.ExportTestCase):
1920
@skipIfNoLapack
2021
@unittest.skip("test broken because Lapack was always missing.")
2122
def test_helper(self):

test/onnx/test_pytorch_jit_onnx.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: onnx"]
22
import onnxruntime
3+
import pytorch_test_common
34

45
import torch
56
from pytorch_test_common import skipIfNoCuda
@@ -171,7 +172,7 @@ def MakeTestCase(opset_version: int) -> type:
171172
name = f"TestJITIRToONNX_opset{opset_version}"
172173
return type(
173174
str(name),
174-
(common_utils.TestCase,),
175+
(pytorch_test_common.ExportTestCase,),
175176
dict(_TestJITIRToONNX.__dict__, opset_version=opset_version),
176177
)
177178

test/onnx/test_pytorch_onnx_caffe2.py

+10-17
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import model_defs.word_language_model as word_language_model
1313
import numpy as np
1414
import onnx
15+
import pytorch_test_common
1516
import torch.onnx
1617
import torch.onnx.operators
1718
import torch.utils.model_zoo as model_zoo
@@ -129,18 +130,10 @@ def do_export(model, inputs, *args, **kwargs):
129130
}
130131

131132

132-
class TestCaffe2Backend_opset9(common_utils.TestCase):
133+
class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase):
133134
opset_version = 9
134135
embed_params = False
135136

136-
def setUp(self):
137-
# the following should ideally be super().setUp(), https://github.com/pytorch/pytorch/issues/79630
138-
common_utils.TestCase.setUp(self)
139-
torch.manual_seed(0)
140-
if torch.cuda.is_available():
141-
torch.cuda.manual_seed_all(0)
142-
np.random.seed(seed=0)
143-
144137
def convert_cuda(self, model, input):
145138
cuda_model = model.cuda()
146139
# input might be nested - we want to move everything to GPU
@@ -3198,52 +3191,52 @@ def setup_rnn_tests():
31983191
# to embed_params=True
31993192
TestCaffe2BackendEmbed_opset9 = type(
32003193
"TestCaffe2BackendEmbed_opset9",
3201-
(common_utils.TestCase,),
3194+
(pytorch_test_common.ExportTestCase,),
32023195
dict(TestCaffe2Backend_opset9.__dict__, embed_params=True),
32033196
)
32043197

32053198
# opset 7 tests
32063199
TestCaffe2Backend_opset7 = type(
32073200
"TestCaffe2Backend_opset7",
3208-
(common_utils.TestCase,),
3201+
(pytorch_test_common.ExportTestCase,),
32093202
dict(TestCaffe2Backend_opset9.__dict__, opset_version=7),
32103203
)
32113204
TestCaffe2BackendEmbed_opset7 = type(
32123205
"TestCaffe2BackendEmbed_opset7",
3213-
(common_utils.TestCase,),
3206+
(pytorch_test_common.ExportTestCase,),
32143207
dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=7),
32153208
)
32163209

32173210
# opset 8 tests
32183211
TestCaffe2Backend_opset8 = type(
32193212
"TestCaffe2Backend_opset8",
3220-
(common_utils.TestCase,),
3213+
(pytorch_test_common.ExportTestCase,),
32213214
dict(TestCaffe2Backend_opset9.__dict__, opset_version=8),
32223215
)
32233216
TestCaffe2BackendEmbed_opset8 = type(
32243217
"TestCaffe2BackendEmbed_opset8",
3225-
(common_utils.TestCase,),
3218+
(pytorch_test_common.ExportTestCase,),
32263219
dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=8),
32273220
)
32283221

32293222
# opset 10 tests
32303223
TestCaffe2Backend_opset10 = type(
32313224
"TestCaffe2Backend_opset10",
3232-
(common_utils.TestCase,),
3225+
(pytorch_test_common.ExportTestCase,),
32333226
dict(TestCaffe2Backend_opset9.__dict__, opset_version=10),
32343227
)
32353228

32363229
TestCaffe2BackendEmbed_opset10 = type(
32373230
"TestCaffe2BackendEmbed_opset10",
3238-
(common_utils.TestCase,),
3231+
(pytorch_test_common.ExportTestCase,),
32393232
dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=10),
32403233
)
32413234

32423235
# add the same test suite as above, but switch embed_params=False
32433236
# to embed_params=True
32443237
TestCaffe2BackendEmbed_opset9_new_jit_API = type(
32453238
"TestCaffe2BackendEmbed_opset9_new_jit_API",
3246-
(common_utils.TestCase,),
3239+
(pytorch_test_common.ExportTestCase,),
32473240
dict(TestCaffe2Backend_opset9.__dict__, embed_params=True),
32483241
)
32493242

test/onnx/test_pytorch_onnx_caffe2_quantized.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
import numpy as np
88
import onnx
9+
import pytorch_test_common
910
import torch.ao.nn.quantized as nnq
1011
import torch.nn as nn
1112
import torch.onnx
1213
from torch.testing._internal import common_utils
1314

1415

15-
class TestQuantizedOps(common_utils.TestCase):
16+
class TestQuantizedOps(pytorch_test_common.ExportTestCase):
1617
def generic_test(
1718
self, model, sample_inputs, input_names=None, decimal=3, relaxed_check=False
1819
):

test/onnx/test_pytorch_onnx_no_runtime.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
1212

1313
import numpy as np
14-
1514
import onnx
1615
import onnx.numpy_helper
16+
import pytorch_test_common
1717

1818
import torch
1919
import torch.nn.functional as F
@@ -74,7 +74,7 @@ def export_to_onnx(
7474
return onnx_model
7575

7676

77-
class TestONNXExport(common_utils.TestCase):
77+
class TestONNXExport(pytorch_test_common.ExportTestCase):
7878
def test_fuse_addmm(self):
7979
class AddmmModel(torch.nn.Module):
8080
def forward(self, x):

0 commit comments

Comments
 (0)