|
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | import torch.onnx
|
6 |
| -from torch.onnx import utils, OperatorExportTypes |
| 6 | +from torch.onnx import utils, OperatorExportTypes, TrainingMode |
7 | 7 | from torch.onnx.symbolic_helper import _set_opset_version, _set_operator_export_type
|
8 | 8 | import torch.utils.cpp_extension
|
9 | 9 | from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
|
10 | 10 |
|
| 11 | +import torchvision |
| 12 | + |
11 | 13 | import onnx
|
12 | 14 | import onnxruntime # noqa
|
13 | 15 |
|
@@ -675,6 +677,109 @@ def forward(self, x):
|
675 | 677 |
|
676 | 678 | np.testing.assert_allclose(ratio_pytorch, ratio_ort, rtol=0.01, atol=0.01)
|
677 | 679 |
|
| 680 | + def test_fuse_conv_bn(self): |
| 681 | + class Fuse(torch.nn.Module): |
| 682 | + def __init__(self): |
| 683 | + super(Fuse, self).__init__() |
| 684 | + self.conv = torch.nn.Conv2d(3, 2, kernel_size=1, stride=2, padding=3, bias=True) |
| 685 | + self.bn = torch.nn.BatchNorm2d(2) |
| 686 | + |
| 687 | + def forward(self, x): |
| 688 | + out = self.conv(x) |
| 689 | + return self.bn(out) |
| 690 | + |
| 691 | + x = torch.randn(2, 3, 2, 2, requires_grad=True) |
| 692 | + graph, _, __ = utils._model_to_graph(Fuse(), (x, ), |
| 693 | + do_constant_folding=True, |
| 694 | + training=TrainingMode.EVAL) |
| 695 | + for node in graph.nodes(): |
| 696 | + assert node.kind() != "onnx::BatchNormalization" |
| 697 | + assert node.kind() == "onnx::Conv" |
| 698 | + |
| 699 | + assert len(list(graph.nodes())) == 1 |
| 700 | + |
| 701 | + def test_fuse_resnet18(self): |
| 702 | + model = torchvision.models.resnet18(pretrained=True) |
| 703 | + x = torch.randn(2, 3, 224, 224, requires_grad=True) |
| 704 | + graph, _, __ = utils._model_to_graph(model, (x, ), |
| 705 | + do_constant_folding=True) |
| 706 | + |
| 707 | + for node in graph.nodes(): |
| 708 | + assert node.kind() != "onnx::BatchNormalization" |
| 709 | + |
| 710 | + def test_conv_bn(self): |
| 711 | + class MyModule(torch.nn.Module): |
| 712 | + def __init__(self): |
| 713 | + super(MyModule, self).__init__() |
| 714 | + self.conv = torch.nn.Conv2d(3, 16, kernel_size=1, stride=2, padding=3, bias=True) |
| 715 | + self.bn = torch.nn.BatchNorm2d(16, affine=True) |
| 716 | + |
| 717 | + def forward(self, x): |
| 718 | + x = self.conv(x) |
| 719 | + bn = self.bn(x) |
| 720 | + return bn |
| 721 | + |
| 722 | + model = MyModule() |
| 723 | + x = torch.randn(10, 3, 128, 128) |
| 724 | + |
| 725 | + f = io.BytesIO() |
| 726 | + torch.onnx.export(model, (x,), f, |
| 727 | + opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING) |
| 728 | + ort_sess = onnxruntime.InferenceSession(f.getvalue()) |
| 729 | + ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()} |
| 730 | + ort_outs1 = ort_sess.run(None, ort_inputs) |
| 731 | + |
| 732 | + f = io.BytesIO() |
| 733 | + torch.onnx.export(model, (x,), f, |
| 734 | + opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL) |
| 735 | + ort_sess = onnxruntime.InferenceSession(f.getvalue()) |
| 736 | + ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()} |
| 737 | + ort_outs2 = ort_sess.run(None, ort_inputs) |
| 738 | + [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in zip(ort_outs1, ort_outs2)] |
| 739 | + |
| 740 | + def test_multiple_conv_bn(self): |
| 741 | + class MyModule(torch.nn.Module): |
| 742 | + def __init__(self): |
| 743 | + super(MyModule, self).__init__() |
| 744 | + self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) |
| 745 | + self.conv2 = torch.nn.Conv2d(64, 2, kernel_size=1, stride=1, padding=0, bias=False) |
| 746 | + self.conv3 = torch.nn.Conv2d(2, 2, kernel_size=3, stride=1, padding=1, bias=False) |
| 747 | + self.bn = torch.nn.BatchNorm2d(64) |
| 748 | + self.bn2 = torch.nn.BatchNorm2d(2) |
| 749 | + self.relu = torch.nn.ReLU(inplace=True) |
| 750 | + self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
| 751 | + |
| 752 | + |
| 753 | + def forward(self, x): |
| 754 | + x = self.conv1(x) |
| 755 | + x = self.bn(x) |
| 756 | + x = self.relu(x) |
| 757 | + x = self.maxpool(x) |
| 758 | + x = self.conv2(x) |
| 759 | + x = self.bn2(x) |
| 760 | + x = self.relu(x) |
| 761 | + x = self.conv3(x) |
| 762 | + x = self.bn2(x) |
| 763 | + x = self.relu(x) |
| 764 | + return x |
| 765 | + |
| 766 | + model = MyModule() |
| 767 | + x = torch.randn(2, 3, 224, 224) |
| 768 | + |
| 769 | + f = io.BytesIO() |
| 770 | + torch.onnx.export(model, (x,), f, |
| 771 | + opset_version=self.opset_version, training=torch.onnx.TrainingMode.TRAINING) |
| 772 | + ort_sess = onnxruntime.InferenceSession(f.getvalue()) |
| 773 | + ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()} |
| 774 | + ort_outs1 = ort_sess.run(None, ort_inputs) |
| 775 | + f = io.BytesIO() |
| 776 | + torch.onnx.export(model, (x,), f, |
| 777 | + opset_version=self.opset_version, training=torch.onnx.TrainingMode.EVAL) |
| 778 | + ort_sess = onnxruntime.InferenceSession(f.getvalue()) |
| 779 | + ort_inputs = {ort_sess.get_inputs()[0].name: x.cpu().numpy()} |
| 780 | + ort_outs2 = ort_sess.run(None, ort_inputs) |
| 781 | + [np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in zip(ort_outs1, ort_outs2)] |
| 782 | + |
678 | 783 |
|
679 | 784 | # opset 10 tests
|
680 | 785 | TestUtilityFuns_opset10 = type(str("TestUtilityFuns_opset10"),
|
|
0 commit comments