Skip to content

Commit d1b920b

Browse files
harouwufacebook-github-bot
authored andcommitted
keep net type info when generating model complete net (pytorch#11032)
Summary: keep net type info when generating model complete net. This will keep the performance optimization option Pull Request resolved: pytorch#11032 Reviewed By: wat3rBro Differential Revision: D9564125 Pulled By: harouwu fbshipit-source-id: c6546af9b1d4ff5eddf6124e24a5da1b8baf47df
1 parent 56bdd87 commit d1b920b

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

caffe2/python/model_helper.py

+2
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,8 @@ def GetCompleteNet(self):
465465
for op in new_net.Proto().op:
466466
op.debug_info = op.debug_info + "/param_init_net"
467467
new_net.AppendNet(self.net)
468+
# keep the execution optimization
469+
new_net.Proto().type = self.net.Proto().type
468470
return new_net
469471

470472
def ConstructInitTrainNetfromNet(self, net):

caffe2/python/model_helper_test.py

+20
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,26 @@
88

99

1010
class ModelHelperTest(unittest.TestCase):
11+
def test_get_complete_net_type(self):
12+
model = model_helper.ModelHelper("test_orig")
13+
brew.conv(
14+
model,
15+
"input",
16+
"conv",
17+
dim_in=3,
18+
dim_out=16,
19+
weight_init=("MSRAFill", {}),
20+
kernel=3,
21+
stride=1,
22+
pad=0,
23+
)
24+
model.net.Proto().type = "async_scheduling"
25+
net = model.GetCompleteNet()
26+
model2 = model_helper.ModelHelper("test_new")
27+
model2.ConstructInitTrainNetfromNet(net)
28+
self.assertTrue(model2.net.Proto().type, "async_scheduling")
29+
self.assertTrue(model2.param_init_net.Proto().type, "async_scheduling")
30+
1131
def test_get_complete_net(self):
1232
model = model_helper.ModelHelper("test_orig")
1333
conv = brew.conv(

0 commit comments

Comments
 (0)