Commit 596b58e
Add support for transposed grouped convolution in torch to linalg lowering (#4056)
The conversion of the convolutiong torch operation to linalg currently
works for grouped convolution (number of groups > 1) and transposed
convolution, but the conversion failed when both are used at the same
time. This change set correct this.
The core of the changes are in the Linear.cpp. In transposed grouped
convolution, the output filters is the one divided by the groups in the
weights, not the input channel (see the "Variables" section in both
links below for details). This was one of the fixes.
https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
The other issue was that the weights expansion had to happen before the
Channel/Filter dimension permutation/flip. This is because the expansion
deals with adjacent dimensions, but in the final weights tensor the
group and the input channel are not going to be adjacent. Once the
dimensions are flipped, the expansion operation can't generate the
expected dimension format. See the comment in the code for details.
@rsuderman
@vivekkhandelwal1
@zjgarvey
@penguin-wwy
@ubfx
@sahas3
@dixinzhou
@rafaelubalmw
---------
Co-authored-by: Ivan Garcia <[email protected]>1 parent 96da98b commit 596b58e
File tree
4 files changed
+185
-31
lines changed- lib/Conversion/TorchToLinalg
- projects/pt1
- e2e_testing
- python/torch_mlir_e2e_test/test_suite
- test/Conversion/TorchToLinalg
4 files changed
+185
-31
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
955 | 955 | | |
956 | 956 | | |
957 | 957 | | |
| 958 | + | |
| 959 | + | |
| 960 | + | |
| 961 | + | |
| 962 | + | |
| 963 | + | |
| 964 | + | |
| 965 | + | |
| 966 | + | |
| 967 | + | |
| 968 | + | |
| 969 | + | |
| 970 | + | |
| 971 | + | |
| 972 | + | |
| 973 | + | |
| 974 | + | |
| 975 | + | |
| 976 | + | |
| 977 | + | |
| 978 | + | |
| 979 | + | |
| 980 | + | |
| 981 | + | |
| 982 | + | |
| 983 | + | |
| 984 | + | |
| 985 | + | |
| 986 | + | |
| 987 | + | |
| 988 | + | |
| 989 | + | |
| 990 | + | |
| 991 | + | |
| 992 | + | |
| 993 | + | |
| 994 | + | |
| 995 | + | |
| 996 | + | |
| 997 | + | |
| 998 | + | |
| 999 | + | |
| 1000 | + | |
| 1001 | + | |
| 1002 | + | |
958 | 1003 | | |
| 1004 | + | |
| 1005 | + | |
| 1006 | + | |
959 | 1007 | | |
960 | 1008 | | |
961 | 1009 | | |
| |||
965 | 1013 | | |
966 | 1014 | | |
967 | 1015 | | |
968 | | - | |
969 | | - | |
| 1016 | + | |
| 1017 | + | |
| 1018 | + | |
| 1019 | + | |
| 1020 | + | |
| 1021 | + | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
| 1026 | + | |
| 1027 | + | |
| 1028 | + | |
970 | 1029 | | |
971 | 1030 | | |
972 | 1031 | | |
973 | | - | |
| 1032 | + | |
974 | 1033 | | |
975 | | - | |
| 1034 | + | |
976 | 1035 | | |
977 | 1036 | | |
978 | 1037 | | |
979 | 1038 | | |
980 | 1039 | | |
981 | 1040 | | |
982 | | - | |
| 1041 | + | |
983 | 1042 | | |
984 | | - | |
985 | | - | |
986 | | - | |
| 1043 | + | |
| 1044 | + | |
| 1045 | + | |
| 1046 | + | |
| 1047 | + | |
| 1048 | + | |
| 1049 | + | |
987 | 1050 | | |
988 | 1051 | | |
989 | 1052 | | |
| |||
1373 | 1436 | | |
1374 | 1437 | | |
1375 | 1438 | | |
1376 | | - | |
1377 | | - | |
1378 | | - | |
1379 | | - | |
1380 | | - | |
1381 | | - | |
1382 | | - | |
1383 | | - | |
1384 | | - | |
1385 | | - | |
1386 | | - | |
1387 | | - | |
1388 | | - | |
1389 | | - | |
1390 | | - | |
1391 | | - | |
1392 | | - | |
1393 | | - | |
1394 | | - | |
1395 | 1439 | | |
1396 | | - | |
| 1440 | + | |
| 1441 | + | |
| 1442 | + | |
| 1443 | + | |
1397 | 1444 | | |
1398 | 1445 | | |
1399 | 1446 | | |
1400 | 1447 | | |
1401 | 1448 | | |
1402 | 1449 | | |
1403 | 1450 | | |
1404 | | - | |
| 1451 | + | |
1405 | 1452 | | |
1406 | 1453 | | |
1407 | 1454 | | |
1408 | 1455 | | |
1409 | 1456 | | |
1410 | 1457 | | |
1411 | | - | |
1412 | | - | |
| 1458 | + | |
1413 | 1459 | | |
1414 | 1460 | | |
1415 | 1461 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3537 | 3537 | | |
3538 | 3538 | | |
3539 | 3539 | | |
| 3540 | + | |
3540 | 3541 | | |
3541 | 3542 | | |
3542 | 3543 | | |
| |||
4113 | 4114 | | |
4114 | 4115 | | |
4115 | 4116 | | |
| 4117 | + | |
4116 | 4118 | | |
4117 | 4119 | | |
4118 | 4120 | | |
| |||
Lines changed: 32 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1725 | 1725 | | |
1726 | 1726 | | |
1727 | 1727 | | |
| 1728 | + | |
| 1729 | + | |
| 1730 | + | |
| 1731 | + | |
| 1732 | + | |
| 1733 | + | |
| 1734 | + | |
| 1735 | + | |
| 1736 | + | |
| 1737 | + | |
| 1738 | + | |
| 1739 | + | |
| 1740 | + | |
| 1741 | + | |
| 1742 | + | |
| 1743 | + | |
| 1744 | + | |
| 1745 | + | |
| 1746 | + | |
| 1747 | + | |
| 1748 | + | |
| 1749 | + | |
| 1750 | + | |
| 1751 | + | |
| 1752 | + | |
| 1753 | + | |
| 1754 | + | |
| 1755 | + | |
| 1756 | + | |
| 1757 | + | |
| 1758 | + | |
| 1759 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
76 | 76 | | |
77 | 77 | | |
78 | 78 | | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
0 commit comments