|
12 | 12 | import model_defs.word_language_model as word_language_model
|
13 | 13 | import numpy as np
|
14 | 14 | import onnx
|
| 15 | +import pytorch_test_common |
15 | 16 | import torch.onnx
|
16 | 17 | import torch.onnx.operators
|
17 | 18 | import torch.utils.model_zoo as model_zoo
|
@@ -129,18 +130,10 @@ def do_export(model, inputs, *args, **kwargs):
|
129 | 130 | }
|
130 | 131 |
|
131 | 132 |
|
132 |
| -class TestCaffe2Backend_opset9(common_utils.TestCase): |
| 133 | +class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase): |
133 | 134 | opset_version = 9
|
134 | 135 | embed_params = False
|
135 | 136 |
|
136 |
| - def setUp(self): |
137 |
| - # the following should ideally be super().setUp(), https://github.com/pytorch/pytorch/issues/79630 |
138 |
| - common_utils.TestCase.setUp(self) |
139 |
| - torch.manual_seed(0) |
140 |
| - if torch.cuda.is_available(): |
141 |
| - torch.cuda.manual_seed_all(0) |
142 |
| - np.random.seed(seed=0) |
143 |
| - |
144 | 137 | def convert_cuda(self, model, input):
|
145 | 138 | cuda_model = model.cuda()
|
146 | 139 | # input might be nested - we want to move everything to GPU
|
@@ -3198,52 +3191,52 @@ def setup_rnn_tests():
|
3198 | 3191 | # to embed_params=True
|
3199 | 3192 | TestCaffe2BackendEmbed_opset9 = type(
|
3200 | 3193 | "TestCaffe2BackendEmbed_opset9",
|
3201 |
| - (common_utils.TestCase,), |
| 3194 | + (pytorch_test_common.ExportTestCase,), |
3202 | 3195 | dict(TestCaffe2Backend_opset9.__dict__, embed_params=True),
|
3203 | 3196 | )
|
3204 | 3197 |
|
3205 | 3198 | # opset 7 tests
|
3206 | 3199 | TestCaffe2Backend_opset7 = type(
|
3207 | 3200 | "TestCaffe2Backend_opset7",
|
3208 |
| - (common_utils.TestCase,), |
| 3201 | + (pytorch_test_common.ExportTestCase,), |
3209 | 3202 | dict(TestCaffe2Backend_opset9.__dict__, opset_version=7),
|
3210 | 3203 | )
|
3211 | 3204 | TestCaffe2BackendEmbed_opset7 = type(
|
3212 | 3205 | "TestCaffe2BackendEmbed_opset7",
|
3213 |
| - (common_utils.TestCase,), |
| 3206 | + (pytorch_test_common.ExportTestCase,), |
3214 | 3207 | dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=7),
|
3215 | 3208 | )
|
3216 | 3209 |
|
3217 | 3210 | # opset 8 tests
|
3218 | 3211 | TestCaffe2Backend_opset8 = type(
|
3219 | 3212 | "TestCaffe2Backend_opset8",
|
3220 |
| - (common_utils.TestCase,), |
| 3213 | + (pytorch_test_common.ExportTestCase,), |
3221 | 3214 | dict(TestCaffe2Backend_opset9.__dict__, opset_version=8),
|
3222 | 3215 | )
|
3223 | 3216 | TestCaffe2BackendEmbed_opset8 = type(
|
3224 | 3217 | "TestCaffe2BackendEmbed_opset8",
|
3225 |
| - (common_utils.TestCase,), |
| 3218 | + (pytorch_test_common.ExportTestCase,), |
3226 | 3219 | dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=8),
|
3227 | 3220 | )
|
3228 | 3221 |
|
3229 | 3222 | # opset 10 tests
|
3230 | 3223 | TestCaffe2Backend_opset10 = type(
|
3231 | 3224 | "TestCaffe2Backend_opset10",
|
3232 |
| - (common_utils.TestCase,), |
| 3225 | + (pytorch_test_common.ExportTestCase,), |
3233 | 3226 | dict(TestCaffe2Backend_opset9.__dict__, opset_version=10),
|
3234 | 3227 | )
|
3235 | 3228 |
|
3236 | 3229 | TestCaffe2BackendEmbed_opset10 = type(
|
3237 | 3230 | "TestCaffe2BackendEmbed_opset10",
|
3238 |
| - (common_utils.TestCase,), |
| 3231 | + (pytorch_test_common.ExportTestCase,), |
3239 | 3232 | dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=10),
|
3240 | 3233 | )
|
3241 | 3234 |
|
3242 | 3235 | # add the same test suite as above, but switch embed_params=False
|
3243 | 3236 | # to embed_params=True
|
3244 | 3237 | TestCaffe2BackendEmbed_opset9_new_jit_API = type(
|
3245 | 3238 | "TestCaffe2BackendEmbed_opset9_new_jit_API",
|
3246 |
| - (common_utils.TestCase,), |
| 3239 | + (pytorch_test_common.ExportTestCase,), |
3247 | 3240 | dict(TestCaffe2Backend_opset9.__dict__, embed_params=True),
|
3248 | 3241 | )
|
3249 | 3242 |
|
|
0 commit comments