1
- #include < gtest/gtest.h>
2
-
3
1
#include < torch/csrc/jit/ir/ir.h>
4
2
#include < torch/csrc/jit/ir/irparser.h>
5
3
#include < torch/csrc/jit/passes/constant_pooling.h>
6
4
#include < torch/csrc/jit/passes/constant_propagation.h>
7
5
#include < torch/csrc/jit/testing/file_check.h>
6
+ #include " test/cpp/jit/test_base.h"
8
7
9
8
#include < sstream>
10
9
#include < string>
11
10
12
11
namespace torch {
13
12
namespace jit {
14
13
15
- TEST (ConstantPoolingTest, Int) {
16
- auto graph = std::make_shared<Graph>();
17
- parseIR (
18
- R"IR(
14
+ void testConstantPooling () {
15
+ {
16
+ auto graph = std::make_shared<Graph>();
17
+ parseIR (
18
+ R"IR(
19
19
graph():
20
20
%8 : int = prim::Constant[value=1]()
21
21
%10 : int = prim::Constant[value=1]()
22
22
return (%8, %10)
23
23
)IR" ,
24
- &*graph);
25
- ConstantPooling (graph);
26
- testing::FileCheck ()
27
- .check_count (" prim::Constant" , 1 , /* exactly*/ true )
28
- ->run (*graph);
29
- }
30
-
31
- TEST (ConstantPoolingTest, PoolingAcrossBlocks) {
32
- auto graph = std::make_shared<Graph>();
33
- parseIR (
34
- R"IR(
24
+ &*graph);
25
+ ConstantPooling (graph);
26
+ testing::FileCheck ()
27
+ .check_count (" prim::Constant" , 1 , /* exactly*/ true )
28
+ ->run (*graph);
29
+ }
30
+ {
31
+ auto graph = std::make_shared<Graph>();
32
+ parseIR (
33
+ R"IR(
35
34
graph(%cond : Tensor):
36
35
%a : str = prim::Constant[value="bcd"]()
37
36
%3 : bool = aten::Bool(%cond)
@@ -45,18 +44,17 @@ graph(%cond : Tensor):
45
44
%7 : (str, str) = prim::TupleConstruct(%a, %b)
46
45
return (%7)
47
46
)IR" ,
48
- &*graph);
49
- ConstantPooling (graph);
50
- testing::FileCheck ()
51
- .check_count (" prim::Constant[value=\" abc\" ]" , 1 , /* exactly*/ true )
52
- ->check_count (" prim::Constant[value=\" bcd\" ]" , 1 , /* exactly*/ true )
53
- ->run (*graph);
54
- }
55
-
56
- TEST (ConstantPoolingTest, PoolingDifferentDevices) {
57
- auto graph = std::make_shared<Graph>();
58
- parseIR (
59
- R"IR(
47
+ &*graph);
48
+ ConstantPooling (graph);
49
+ testing::FileCheck ()
50
+ .check_count (" prim::Constant[value=\" abc\" ]" , 1 , /* exactly*/ true )
51
+ ->check_count (" prim::Constant[value=\" bcd\" ]" , 1 , /* exactly*/ true )
52
+ ->run (*graph);
53
+ }
54
+ {
55
+ auto graph = std::make_shared<Graph>();
56
+ parseIR (
57
+ R"IR(
60
58
graph():
61
59
%2 : int = prim::Constant[value=2]()
62
60
%1 : int = prim::Constant[value=1]()
@@ -72,21 +70,22 @@ graph():
72
70
prim::Print(%x, %y, %z)
73
71
return (%1)
74
72
)IR" ,
75
- &*graph);
76
- // three tensors created - two different devices among the three
77
- // don't have good support for parsing tensor constants
78
- ConstantPropagation (graph);
79
- ConstantPooling (graph);
80
- testing::FileCheck ()
81
- .check_count (
82
- " Float(2:1, requires_grad=0, device=cpu) = prim::Constant" ,
83
- 1 ,
84
- /* exactly*/ true )
85
- ->check_count (
86
- " Long(2:1, requires_grad=0, device=cpu) = prim::Constant" ,
87
- 1 ,
88
- /* exactly*/ true )
89
- ->run (*graph);
73
+ &*graph);
74
+ // three tensors created - two different devices among the three
75
+ // don't have good support for parsing tensor constants
76
+ ConstantPropagation (graph);
77
+ ConstantPooling (graph);
78
+ testing::FileCheck ()
79
+ .check_count (
80
+ " Float(2:1, requires_grad=0, device=cpu) = prim::Constant" ,
81
+ 1 ,
82
+ /* exactly*/ true )
83
+ ->check_count (
84
+ " Long(2:1, requires_grad=0, device=cpu) = prim::Constant" ,
85
+ 1 ,
86
+ /* exactly*/ true )
87
+ ->run (*graph);
88
+ }
90
89
}
91
90
} // namespace jit
92
91
} // namespace torch
0 commit comments