Skip to content

Commit e9aa689

Browse files
suofacebook-github-bot
authored andcommitted
Revert D23802296: gtest-ify JIT tests, through the letter c
Test Plan: revert-hammer Differential Revision: D23802296 (pytorch@d2b0450) Original commit changeset: 20c9798a414e fbshipit-source-id: a28d56039ca404fe94ed7572f1febd1673e3e788
1 parent 89c570e commit e9aa689

11 files changed

+294
-282
lines changed

test/cpp/jit/test_autodiff.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
#include <gtest/gtest.h>
2-
1+
#include "test/cpp/jit/test_base.h"
32
#include "test/cpp/jit/test_utils.h"
43
#include "torch/csrc/jit/frontend/tracer.h"
54
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
@@ -84,7 +83,7 @@ variable_list grad(
8483
fmap(inputs, get_edge));
8584
}
8685

87-
TEST(AutodiffTest, ADFormulas) {
86+
void testADFormulas() {
8887
const auto cast = [](const Variable& v) {
8988
return static_cast<at::Tensor>(v);
9089
};
@@ -175,7 +174,7 @@ TEST(AutodiffTest, ADFormulas) {
175174
}
176175
}
177176

178-
TEST(AutodiffTest, Differentiate) {
177+
void testDifferentiate() {
179178
// Note: can't use IRParser for this test due to issue #23989
180179
auto graph = std::make_shared<Graph>();
181180
std::vector<int64_t> sizes{2, 3, 4};
@@ -230,7 +229,7 @@ TEST(AutodiffTest, Differentiate) {
230229
->run(*grad_spec.df);
231230
}
232231

233-
TEST(AutodiffTest, DifferentiateWithRequiresGrad) {
232+
void testDifferentiateWithRequiresGrad() {
234233
const auto graph_string = R"IR(
235234
graph(%0 : Tensor,
236235
%1 : Tensor):

test/cpp/jit/test_class_import.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#include <gtest/gtest.h>
1+
#include <test/cpp/jit/test_base.h>
2+
#include <test/cpp/jit/test_utils.h>
23

34
#include <ATen/core/qualified_name.h>
4-
#include <test/cpp/jit/test_utils.h>
55
#include <torch/csrc/jit/frontend/resolver.h>
66
#include <torch/csrc/jit/serialization/import_source.h>
77
#include <torch/torch.h>
@@ -45,7 +45,7 @@ static void import_libs(
4545
si.loadType(QualifiedName(class_name));
4646
}
4747

48-
TEST(ClassImportTest, Basic) {
48+
void testClassImport() {
4949
auto cu1 = std::make_shared<CompilationUnit>();
5050
auto cu2 = std::make_shared<CompilationUnit>();
5151
std::vector<at::IValue> constantTable;
@@ -80,7 +80,7 @@ TEST(ClassImportTest, Basic) {
8080
ASSERT_FALSE(c);
8181
}
8282

83-
TEST(ClassImportTest, ScriptObject) {
83+
void testScriptObject() {
8484
Module m1("m1");
8585
Module m2("m2");
8686
std::vector<at::IValue> constantTable;
@@ -114,7 +114,7 @@ def __init__(self, x):
114114
return x
115115
)JIT";
116116

117-
TEST(ClassImportTest, ClassDerive) {
117+
void testClassDerive() {
118118
auto cu = std::make_shared<CompilationUnit>();
119119
auto cls = ClassType::create("foo.bar", cu);
120120
const auto self = SimpleSelf(cls);
@@ -142,7 +142,7 @@ class FooBar1234(Module):
142142
return (self.f).top()
143143
)JIT";
144144

145-
TEST(ClassImportTest, CustomClass) {
145+
void testSaveLoadTorchbind() {
146146
auto cu1 = std::make_shared<CompilationUnit>();
147147
std::vector<at::IValue> constantTable;
148148
// Import different versions of FooTest into two namespaces.

test/cpp/jit/test_class_parser.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#include <gtest/gtest.h>
2-
31
#include <test/cpp/jit/test_base.h>
42
#include <torch/csrc/jit/frontend/parser.h>
53
#include <torch/csrc/jit/frontend/resolver.h>
@@ -17,7 +15,7 @@ const auto testSource = R"JIT(
1715
an_attribute : Tensor
1816
)JIT";
1917

20-
TEST(ClassParserTest, Basic) {
18+
void testClassParser() {
2119
Parser p(std::make_shared<Source>(testSource));
2220
std::vector<Def> definitions;
2321
std::vector<Resolver> resolvers;

test/cpp/jit/test_cleanup_passes.cpp

+19-18
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
#include <gtest/gtest.h>
2-
31
#include <torch/csrc/jit/frontend/ir_emitter.h>
42
#include <torch/csrc/jit/ir/ir.h>
53
#include <torch/csrc/jit/ir/irparser.h>
64
#include <torch/csrc/jit/testing/file_check.h>
5+
#include "test/cpp/jit/test_base.h"
76

87
namespace torch {
98
namespace jit {
109

11-
TEST(CleanupPassTest, Basic) {
10+
void testCleanUpPasses() {
1211
// Tests stability of clean up passes when dealing with constant pooling
1312
// and constant propagation.
14-
auto graph = std::make_shared<Graph>();
15-
parseIR(
16-
R"IR(
13+
{
14+
auto graph = std::make_shared<Graph>();
15+
parseIR(
16+
R"IR(
1717
graph(%cond.1 : Tensor,
1818
%suffix.1 : str):
1919
%3 : bool = aten::Bool(%cond.1) # o.py:6:7
@@ -31,19 +31,20 @@ graph(%cond.1 : Tensor,
3131
-> (%12)
3232
return (%25)
3333
)IR",
34-
&*graph);
35-
runCleanupPasses(graph);
36-
testing::FileCheck()
37-
.check_count(
38-
"prim::Constant[value=\"same string with a twist\"]",
39-
1,
40-
/*exactly=*/true)
41-
->run(*graph);
34+
&*graph);
35+
runCleanupPasses(graph);
36+
testing::FileCheck()
37+
.check_count(
38+
"prim::Constant[value=\"same string with a twist\"]",
39+
1,
40+
/*exactly=*/true)
41+
->run(*graph);
4242

43-
auto graph_after_pass_once = graph->toString();
44-
runCleanupPasses(graph);
45-
auto graph_after_pass_twice = graph->toString();
46-
ASSERT_EQ(graph_after_pass_once, graph_after_pass_twice);
43+
auto graph_after_pass_once = graph->toString();
44+
runCleanupPasses(graph);
45+
auto graph_after_pass_twice = graph->toString();
46+
ASSERT_EQ(graph_after_pass_once, graph_after_pass_twice);
47+
}
4748
}
4849
} // namespace jit
4950
} // namespace torch

test/cpp/jit/test_code_template.cpp

+26-24
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
#include <gtest/gtest.h>
1+
#include "test/cpp/jit/test_base.h"
2+
#include "test/cpp/jit/test_utils.h"
23

3-
#include <test/cpp/jit/test_utils.h>
44
#include "torch/csrc/jit/frontend/code_template.h"
55

66
namespace torch {
@@ -33,29 +33,31 @@ static const auto ct_expect = R"(
3333
int notest(int a)
3434
)";
3535

36-
TEST(TestCodeTemplate, Copying) {
37-
TemplateEnv e;
38-
e.s("hi", "foo");
39-
e.v("what", {"is", "this"});
40-
TemplateEnv c(e);
41-
c.s("hi", "foo2");
42-
ASSERT_EQ(e.s("hi"), "foo");
43-
ASSERT_EQ(c.s("hi"), "foo2");
44-
ASSERT_EQ(e.v("what")[0], "is");
45-
}
36+
void testCodeTemplate() {
37+
{
38+
TemplateEnv e;
39+
e.s("hi", "foo");
40+
e.v("what", {"is", "this"});
41+
TemplateEnv c(e);
42+
c.s("hi", "foo2");
43+
ASSERT_EQ(e.s("hi"), "foo");
44+
ASSERT_EQ(c.s("hi"), "foo2");
45+
ASSERT_EQ(e.v("what")[0], "is");
46+
}
4647

47-
TEST(TestCodeTemplate, Formatting) {
48-
TemplateEnv e;
49-
e.v("args", {"hi", "8"});
50-
e.v("bar", {"what\non many\nlines...", "7"});
51-
e.s("a", "3");
52-
e.s("b", "4");
53-
e.v("stuff", {"things...", "others"});
54-
e.v("empty", {});
55-
auto s = ct.format(e);
56-
// std::cout << "'" << s << "'\n";
57-
// std::cout << "'" << ct_expect << "'\n";
58-
ASSERT_EQ(s, ct_expect);
48+
{
49+
TemplateEnv e;
50+
e.v("args", {"hi", "8"});
51+
e.v("bar", {"what\non many\nlines...", "7"});
52+
e.s("a", "3");
53+
e.s("b", "4");
54+
e.v("stuff", {"things...", "others"});
55+
e.v("empty", {});
56+
auto s = ct.format(e);
57+
// std::cout << "'" << s << "'\n";
58+
// std::cout << "'" << ct_expect << "'\n";
59+
ASSERT_EQ(s, ct_expect);
60+
}
5961
}
6062

6163
} // namespace jit
+43-44
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,36 @@
1-
#include <gtest/gtest.h>
2-
31
#include <torch/csrc/jit/ir/ir.h>
42
#include <torch/csrc/jit/ir/irparser.h>
53
#include <torch/csrc/jit/passes/constant_pooling.h>
64
#include <torch/csrc/jit/passes/constant_propagation.h>
75
#include <torch/csrc/jit/testing/file_check.h>
6+
#include "test/cpp/jit/test_base.h"
87

98
#include <sstream>
109
#include <string>
1110

1211
namespace torch {
1312
namespace jit {
1413

15-
TEST(ConstantPoolingTest, Int) {
16-
auto graph = std::make_shared<Graph>();
17-
parseIR(
18-
R"IR(
14+
void testConstantPooling() {
15+
{
16+
auto graph = std::make_shared<Graph>();
17+
parseIR(
18+
R"IR(
1919
graph():
2020
%8 : int = prim::Constant[value=1]()
2121
%10 : int = prim::Constant[value=1]()
2222
return (%8, %10)
2323
)IR",
24-
&*graph);
25-
ConstantPooling(graph);
26-
testing::FileCheck()
27-
.check_count("prim::Constant", 1, /*exactly*/ true)
28-
->run(*graph);
29-
}
30-
31-
TEST(ConstantPoolingTest, PoolingAcrossBlocks) {
32-
auto graph = std::make_shared<Graph>();
33-
parseIR(
34-
R"IR(
24+
&*graph);
25+
ConstantPooling(graph);
26+
testing::FileCheck()
27+
.check_count("prim::Constant", 1, /*exactly*/ true)
28+
->run(*graph);
29+
}
30+
{
31+
auto graph = std::make_shared<Graph>();
32+
parseIR(
33+
R"IR(
3534
graph(%cond : Tensor):
3635
%a : str = prim::Constant[value="bcd"]()
3736
%3 : bool = aten::Bool(%cond)
@@ -45,18 +44,17 @@ graph(%cond : Tensor):
4544
%7 : (str, str) = prim::TupleConstruct(%a, %b)
4645
return (%7)
4746
)IR",
48-
&*graph);
49-
ConstantPooling(graph);
50-
testing::FileCheck()
51-
.check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
52-
->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
53-
->run(*graph);
54-
}
55-
56-
TEST(ConstantPoolingTest, PoolingDifferentDevices) {
57-
auto graph = std::make_shared<Graph>();
58-
parseIR(
59-
R"IR(
47+
&*graph);
48+
ConstantPooling(graph);
49+
testing::FileCheck()
50+
.check_count("prim::Constant[value=\"abc\"]", 1, /*exactly*/ true)
51+
->check_count("prim::Constant[value=\"bcd\"]", 1, /*exactly*/ true)
52+
->run(*graph);
53+
}
54+
{
55+
auto graph = std::make_shared<Graph>();
56+
parseIR(
57+
R"IR(
6058
graph():
6159
%2 : int = prim::Constant[value=2]()
6260
%1 : int = prim::Constant[value=1]()
@@ -72,21 +70,22 @@ graph():
7270
prim::Print(%x, %y, %z)
7371
return (%1)
7472
)IR",
75-
&*graph);
76-
// three tensors created - two different devices among the three
77-
// don't have good support for parsing tensor constants
78-
ConstantPropagation(graph);
79-
ConstantPooling(graph);
80-
testing::FileCheck()
81-
.check_count(
82-
"Float(2:1, requires_grad=0, device=cpu) = prim::Constant",
83-
1,
84-
/*exactly*/ true)
85-
->check_count(
86-
"Long(2:1, requires_grad=0, device=cpu) = prim::Constant",
87-
1,
88-
/*exactly*/ true)
89-
->run(*graph);
73+
&*graph);
74+
// three tensors created - two different devices among the three
75+
// don't have good support for parsing tensor constants
76+
ConstantPropagation(graph);
77+
ConstantPooling(graph);
78+
testing::FileCheck()
79+
.check_count(
80+
"Float(2:1, requires_grad=0, device=cpu) = prim::Constant",
81+
1,
82+
/*exactly*/ true)
83+
->check_count(
84+
"Long(2:1, requires_grad=0, device=cpu) = prim::Constant",
85+
1,
86+
/*exactly*/ true)
87+
->run(*graph);
88+
}
9089
}
9190
} // namespace jit
9291
} // namespace torch

test/cpp/jit/test_create_autodiff_subgraphs.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
#include <gtest/gtest.h>
2-
1+
#include "test/cpp/jit/test_base.h"
32
#include "test/cpp/jit/test_utils.h"
43

54
#include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
65

76
namespace torch {
87
namespace jit {
98

10-
TEST(CreateAutodiffSubgraphsTest, Basic) {
9+
void testCreateAutodiffSubgraphs() {
1110
auto graph = build_lstm();
1211
CreateAutodiffSubgraphs(graph, /*threshold=*/2);
1312
// all of the ops are within the DifferentiableGraph

test/cpp/jit/test_custom_class.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#include <gtest/gtest.h>
2-
31
#include <torch/custom_class.h>
42
#include <torch/script.h>
53

@@ -320,7 +318,7 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
320318

321319
} // namespace
322320

323-
TEST(CustomClassTest, TorchbindIValueAPI) {
321+
void testTorchbindIValueAPI() {
324322
script::Module m("m");
325323

326324
// test make_custom_class API

0 commit comments

Comments
 (0)