Skip to content

Commit af5d0bf

Browse files
KsenijaSfacebook-github-bot
authored andcommitted
[ONNX] Add pass that fuses Conv and BatchNormalization (pytorch#40547)
Summary: Add pass that fuses Conv and Batchnormalization nodes into one node Conv. This pass is only applied in inference mode (training is None or TrainingMode.Eval). Since this pass needs access to param_dict it is written outside peephole file where these kind of passes (fusing multiple nodes into one) is usually placed. This PR also adds wrapper skipIfNoEmbed to skip debug_embed_params test: Pass that fuses Conv and Batchnorm changes the params of resnet model and parameters of onnx and pytorch model won't match. Since parameters are not matching, debug_embed_params test for test_resnet will fail and that is expected, therefore debug_embed_params test for test_resnet should be skipped. Pull Request resolved: pytorch#40547 Reviewed By: gchanan Differential Revision: D22631687 Pulled By: bzinodev fbshipit-source-id: fe45812400398a32541e797f727fd8697eb6d8c0
1 parent ad7133d commit af5d0bf

File tree

9 files changed

+346
-7
lines changed

9 files changed

+346
-7
lines changed

aten/src/ATen/core/interned_strings.h

+2
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ namespace c10 {
239239
_(onnx, LogSoftmax) \
240240
_(onnx, ReduceL1) \
241241
_(onnx, ReduceL2) \
242+
_(onnx, Conv) \
243+
_(onnx, BatchNormalization) \
242244
FORALL_ATTR_BASE_SYMBOLS(_) \
243245
_(attr, Subgraph) \
244246
_(attr, ReverseSubgraph) \

test/onnx/test_pytorch_onnx_caffe2.py

+8
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ def wrapper(self):
5454
return func(self)
5555
return wrapper
5656

57+
def skipIfNoEmbed(func):
58+
def wrapper(self):
59+
if not self.embed_params:
60+
raise unittest.SkipTest("Skip debug embed_params test")
61+
return func(self)
62+
return wrapper
63+
5764
# def import_model(proto, input, workspace=None, use_gpu=True):
5865
# model_def = onnx.ModelProto.FromString(proto)
5966
# onnx.checker.check_model(model_def)
@@ -504,6 +511,7 @@ def test_inception(self):
504511
self.run_model_test(inception_v3(), train=False, batch_size=BATCH_SIZE,
505512
state_dict=state_dict, input=x)
506513

514+
@skipIfNoEmbed
507515
def test_resnet(self):
508516
state_dict = model_zoo.load_url(model_urls['resnet50'], progress=False)
509517
self.run_model_test(resnet50(), train=False, batch_size=BATCH_SIZE,

test/onnx/test_pytorch_onnx_onnxruntime.py

+45
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,51 @@ def test_r2plus1d_18_video(self):
300300
x = torch.randn(1, 3, 4, 112, 112, requires_grad=True)
301301
self.run_test(model, (x,), rtol=1e-3, atol=1e-5)
302302

303+
def test_fuse_conv_bn1d(self):
304+
class Fuse(torch.nn.Module):
305+
def __init__(self):
306+
super(Fuse, self).__init__()
307+
self.conv = torch.nn.Conv1d(16, 33, 3, stride=2)
308+
self.bn = torch.nn.BatchNorm1d(33)
309+
310+
def forward(self, x):
311+
out = self.conv(x)
312+
return self.bn(out)
313+
314+
model = Fuse()
315+
x = torch.randn(20, 16, 50, requires_grad=True)
316+
self.run_test(model, (x,))
317+
318+
def test_fuse_conv_bn2d(self):
319+
class Fuse(torch.nn.Module):
320+
def __init__(self):
321+
super(Fuse, self).__init__()
322+
self.conv = torch.nn.Conv2d(3, 2, kernel_size=1, stride=2, padding=3, bias=False)
323+
self.bn = torch.nn.BatchNorm2d(2)
324+
325+
def forward(self, x):
326+
out = self.conv(x)
327+
return self.bn(out)
328+
329+
model = Fuse()
330+
x = torch.randn(2, 3, 2, 2, requires_grad=True)
331+
self.run_test(model, (x,))
332+
333+
def test_fuse_conv_bn3d(self):
334+
class Fuse(torch.nn.Module):
335+
def __init__(self):
336+
super(Fuse, self).__init__()
337+
self.conv = torch.nn.Conv3d(3, 2, (3, 5, 2), stride=(2, 1, 1), padding=(3, 2, 0), bias=False)
338+
self.bn = torch.nn.BatchNorm3d(2)
339+
340+
def forward(self, x):
341+
out = self.conv(x)
342+
return self.bn(out)
343+
344+
model = Fuse()
345+
x = torch.randn(2, 3, 10, 50, 100, requires_grad=True)
346+
self.run_test(model, (x,), rtol=1e-3, atol=1e-6)
347+
303348
def test_reshape_constant_fold(self):
304349
class Reshape(torch.nn.Module):
305350
def __init__(self, ):

test/onnx/test_utility_funs.py

+106-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
import torch
55
import torch.onnx
6-
from torch.onnx import utils, OperatorExportTypes
6+
from torch.onnx import utils, OperatorExportTypes, TrainingMode
77
from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
88
import torch.utils.cpp_extension
99
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
1010

11+
import torchvision
12+
1113
import onnx
1214
import onnxruntime # noqa
1315

@@ -675,6 +677,109 @@ def forward(self, x):
675677

676678
np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
677679

680+
def test_fuse_conv_bn(self):
681+
class Fuse(torch.nn.Module):
682+
def __init__(self):
683+
super(Fuse, self).__init__()
684+
self.conv = torch.nn.Conv2d(3, 2, kernel_size=1, stride=2, padding=3, bias=True)
685+
self.bn = torch.nn.BatchNorm2d(2)
686+
687+
def forward(self, x):
688+
out = self.conv(x)
689+
return self.bn(out)
690+
691+
x = torch.randn(2, 3, 2, 2, requires_grad=True)
692+
graph, _, __ = utils._model_to_graph(Fuse(), (x, ),
693+
do_constant_folding=True,
694+
training=TrainingMode.EVAL)
695+
for node in graph.nodes():
696+
assert node.kind() != "onnx::BatchNormalization"
697+
assert node.kind() == "onnx::Conv"
698+
699+
assert len(list(graph.nodes())) == 1
700+
701+
def test_fuse_resnet18(self):
702+
model = torchvision.models.resnet18(pretrained=True)
703+
x = torch.randn(2, 3, 224, 224, requires_grad=True)
704+
graph, _, __ = utils._model_to_graph(model, (x, ),
705+
do_constant_folding=True)
706+
707+
for node in graph.nodes():
708+
assert node.kind() != "onnx::BatchNormalization"
709+
710+
def test_conv_bn(self):
711+
class MyModule(torch.nn.Module):
712+
def __init__(self):
713+
super(MyModule, self).__init__()
714+
self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True)
715+
self.bn = torch.nn.BatchNorm2d(16, affine=True)
716+
717+
def forward(self, x):
718+
x = self.conv(x)
719+
bn = self.bn(x)
720+
return bn
721+
722+
model = MyModule()
723+
x = torch.randn(10, 3, 128, 128)
724+
725+
f = io.BytesIO()
726+
torch.onnx.export(model, (x,), f,
727+
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
728+
ort_sess = onnxruntime.InferenceSession(f.getvalue())
729+
ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()}
730+
ort_outs1 = ort_sess.run(None, ort_inputs)
731+
732+
f = io.BytesIO()
733+
torch.onnx.export(model, (x,), f,
734+
opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL)
735+
ort_sess = onnxruntime.InferenceSession(f.getvalue())
736+
ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()}
737+
ort_outs2 = ort_sess.run(None, ort_inputs)
738+
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in zip(ort_outs1, ort_outs2)]
739+
740+
def test_multiple_conv_bn(self):
741+
class MyModule(torch.nn.Module):
742+
def __init__(self):
743+
super(MyModule, self).__init__()
744+
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
745+
self.conv2 = torch.nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=0, bias=False)
746+
self.conv3 = torch.nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, bias=False)
747+
self.bn = torch.nn.BatchNorm2d(64)
748+
self.bn2 = torch.nn.BatchNorm2d(2)
749+
self.relu = torch.nn.ReLU(inplace=True)
750+
self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
751+
752+
753+
def forward(self, x):
754+
x = self.conv1(x)
755+
x = self.bn(x)
756+
x = self.relu(x)
757+
x = self.maxpool(x)
758+
x = self.conv2(x)
759+
x = self.bn2(x)
760+
x = self.relu(x)
761+
x = self.conv3(x)
762+
x = self.bn2(x)
763+
x = self.relu(x)
764+
return x
765+
766+
model = MyModule()
767+
x = torch.randn(2, 3, 224, 224)
768+
769+
f = io.BytesIO()
770+
torch.onnx.export(model, (x,), f,
771+
opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING)
772+
ort_sess = onnxruntime.InferenceSession(f.getvalue())
773+
ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()}
774+
ort_outs1 = ort_sess.run(None, ort_inputs)
775+
f = io.BytesIO()
776+
torch.onnx.export(model, (x,), f,
777+
opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL)
778+
ort_sess = onnxruntime.InferenceSession(f.getvalue())
779+
ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()}
780+
ort_outs2 = ort_sess.run(None, ort_inputs)
781+
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in zip(ort_outs1, ort_outs2)]
782+
678783

679784
# opset 10 tests
680785
TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),

tools/build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ libtorch_python_core_sources = [
482482
"torch/csrc/jit/python/init.cpp",
483483
"torch/csrc/jit/passes/onnx.cpp",
484484
"torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp",
485+
"torch/csrc/jit/passes/onnx/eval_peephole.cpp",
485486
"torch/csrc/jit/passes/onnx/constant_fold.cpp",
486487
"torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.cpp",
487488
"torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
2+
#include <torch/csrc/jit/passes/onnx/helper.h>
3+
#include <torch/torch.h>
4+
5+
#include <c10/util/Optional.h>
6+
#include <algorithm>
7+
8+
namespace torch {
9+
namespace jit {
10+
11+
namespace onnx {
12+
using namespace ::c10::onnx;
13+
}
14+
15+
std::vector<at::Tensor> getValues(
16+
Node* node,
17+
const ValueToParamPairMap& valsToParamsMap) {
18+
size_t numInputs = node->inputs().size();
19+
std::vector<at::Tensor> inputTensorValues;
20+
inputTensorValues.reserve(numInputs);
21+
for (auto val : node->inputs()) {
22+
if (val->node()->kind() == prim::Param) {
23+
auto itr = valsToParamsMap.find(val);
24+
if (itr == valsToParamsMap.end()) {
25+
continue;
26+
}
27+
inputTensorValues.push_back(itr->second.second.toTensor());
28+
} else if (val->node()->kind() == onnx::Constant) {
29+
inputTensorValues.push_back(val->node()->t(attr::value));
30+
} else {
31+
continue;
32+
}
33+
}
34+
return inputTensorValues;
35+
}
36+
37+
// This pass fuses Conv and BatchNorm into Conv node
38+
// Conv and BatchNorm can be fused only if inputs for Batchnorm node:
39+
// scale, bias, mean and var are all tensors of same shape (C) and
40+
// if the size of the first dimension (dim 0) is the same between Conv
41+
// input weight and Batchnorm input scale
42+
static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) {
43+
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
44+
for (auto* child_block : it->blocks()) {
45+
fuseConvBatchNorm(child_block, valsToParamsMap);
46+
}
47+
if (it->kind() == onnx::Conv) {
48+
if (it->output()->uses().size() != 1) {
49+
continue;
50+
}
51+
auto bnNode = it->output()->uses()[0].user;
52+
if (bnNode->kind() != onnx::BatchNormalization) {
53+
continue;
54+
}
55+
auto origconvNode = *it;
56+
auto epsilon = bnNode->f(attr::epsilon);
57+
auto w_conv_value = getValues(origconvNode, valsToParamsMap);
58+
if (w_conv_value.size() < 1 ||
59+
(origconvNode->inputs().size() == 3 && w_conv_value.size() != 2)) {
60+
continue;
61+
}
62+
63+
auto bn_value = getValues(bnNode, valsToParamsMap);
64+
if (bn_value.size() != 4) {
65+
continue;
66+
}
67+
68+
auto bn_scale = bn_value[0].clone();
69+
auto bn_B = bn_value[1].clone();
70+
auto bn_mean = bn_value[2].clone();
71+
auto bn_var = bn_value[3].clone();
72+
auto w_conv = w_conv_value[0].clone();
73+
at::Tensor b_conv;
74+
75+
if (!bn_scale.is_floating_point() || !bn_B.is_floating_point() ||
76+
!bn_mean.is_floating_point() || !bn_var.is_floating_point() ||
77+
!w_conv.is_floating_point() || bn_scale.dim() != 1 ||
78+
bn_B.dim() != 1 || bn_mean.dim() != 1 || bn_var.dim() != 1 ||
79+
!(bn_scale.size(0) == bn_B.size(0)) ||
80+
!(bn_B.size(0) == bn_mean.size(0)) ||
81+
!(bn_mean.size(0) == bn_var.size(0)) || !(w_conv.dim() > 2) ||
82+
!(w_conv.size(0) == bn_scale.size(0))) {
83+
continue;
84+
}
85+
86+
bn_var = bn_var.add(epsilon);
87+
bn_var = bn_var.sqrt();
88+
bn_scale = bn_scale.div(bn_var);
89+
90+
// Calculate weight
91+
for (size_t i = 0; i < w_conv.size(0); i++) {
92+
w_conv[i] = w_conv[i].mul(bn_scale[i]);
93+
}
94+
95+
// Calculate bias
96+
if (origconvNode->inputs().size() == 3) {
97+
b_conv = w_conv_value[1].clone();
98+
b_conv = b_conv.sub(bn_mean);
99+
b_conv = b_conv.mul(bn_scale);
100+
b_conv = b_conv.add(bn_B);
101+
} else {
102+
bn_mean = bn_mean.mul(bn_scale);
103+
bn_B = bn_B.sub(bn_mean);
104+
b_conv = bn_B;
105+
}
106+
107+
Node* convNode =
108+
b->owningGraph()->create(onnx::Conv, bnNode->outputs().size());
109+
for (size_t i = 0; i < convNode->outputs().size(); ++i) {
110+
convNode->outputs()[i]->copyMetadata(bnNode->outputs()[i]);
111+
}
112+
113+
convNode->copyAttributes(*origconvNode);
114+
convNode->insertBefore(bnNode);
115+
convNode->addInput(origconvNode->inputs().at(0));
116+
117+
auto conv_W = b->owningGraph()->addInput();
118+
valsToParamsMap.insert(
119+
{conv_W, std::make_pair(conv_W->debugName(), w_conv)});
120+
conv_W->inferTypeFrom(w_conv);
121+
convNode->addInput(conv_W);
122+
123+
auto conv_B = b->addInput();
124+
valsToParamsMap.insert(
125+
{conv_B, std::make_pair(conv_B->debugName(), b_conv)});
126+
conv_B->inferTypeFrom(b_conv);
127+
convNode->addInput(conv_B);
128+
129+
bnNode->replaceAllUsesWith(convNode);
130+
bnNode->removeAllInputs();
131+
it->removeAllInputs();
132+
bnNode->destroy();
133+
it.destroyCurrent();
134+
}
135+
}
136+
}
137+
138+
void buildParamsMapFromValueToParamsMap(
139+
const ValueToParamPairMap& valsToParamsMap,
140+
ParamMap& paramsDict) {
141+
paramsDict.clear();
142+
for (const auto& nameTensorParamPair : valsToParamsMap) {
143+
paramsDict.insert(nameTensorParamPair.second);
144+
}
145+
}
146+
147+
void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) {
148+
auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
149+
fuseConvBatchNorm(b, valsToParamsMap);
150+
buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
151+
return;
152+
}
153+
154+
} // namespace jit
155+
} // namespace torch
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include <torch/csrc/jit/ir/ir.h>
4+
5+
namespace torch {
6+
namespace jit {
7+
8+
void EvalPeepholeONNX(Block* b, std::map<std::string, IValue>& paramDict);
9+
10+
} // namespace jit
11+
12+
} // namespace torch

0 commit comments

Comments
 (0)