Skip to content

Commit 5ac2278

Browse files
vfdev-5pytorchmergebot
authored andcommitted
Optimized vertical flip using memcpy (pytorch#89414)
## Description - Use memcpy for vertical flip - Added bool type support for horizontal flip - channels last input with horizontal flip goes also into cpu_vflip_memcpy and has a speed-up Previous PRs: - pytorch#90013 - pytorch#88989 ## Results ### Horizontal flip - AVX2 (channels last input only) ``` [------------------------------------------------------------------------- Horizontal flip -------------------------------------------------------------------------] | torch (1.14.0a0+giteb3e189) PR | Pillow (9.3.0) | torch (1.14.0a0+gitb0bd5c4) nightly 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------------- channels=3, size=256, dtype=torch.int64, mf=channels_last | 204.813 (+-1.018) | | 308.070 (+-1.573) channels=3, size=520, dtype=torch.int64, mf=channels_last | 844.523 (+-2.302) | | 1226.801 (+-5.069) channels=3, size=712, dtype=torch.int64, mf=channels_last | 2246.512 (+-8.935) | | 2689.692 (+-22.654) channels=1, size=256, dtype=torch.int32, mf=channels_last | 21.024 (+-0.083) | 44.196 (+-0.131) | 22.564 (+-0.066) channels=1, size=520, dtype=torch.int32, mf=channels_last | 71.806 (+-0.150) | 166.653 (+-0.789) | 72.660 (+-0.160) channels=1, size=712, dtype=torch.int32, mf=channels_last | 129.354 (+-0.385) | 306.998 (+-0.819) | 130.094 (+-0.274) channels=3, size=256, dtype=torch.uint8, mf=channels_last | 177.250 (+-0.485) | 44.232 (+-0.465) | 289.201 (+-2.837) channels=3, size=520, dtype=torch.uint8, mf=channels_last | 699.055 (+-1.940) | 166.540 (+-0.903) | 1172.747 (+-3.645) channels=3, size=712, dtype=torch.uint8, mf=channels_last | 1302.968 (+-5.390) | 307.210 (+-0.852) | 2149.396 (+-23.570) channels=1, size=256, dtype=torch.int16, mf=channels_last | 11.943 (+-0.079) | | 12.451 (+-0.033) channels=1, size=520, dtype=torch.int16, mf=channels_last | 39.830 (+-0.093) | | 40.583 (+-0.070) channels=1, size=712, dtype=torch.int16, mf=channels_last | 69.001 (+-0.078) | | 69.590 (+-0.162) channels=3, size=256, dtype=torch.int8, mf=channels_last | 177.378 (+-0.507) | | 283.461 (+-2.957) channels=3, size=520, dtype=torch.int8, mf=channels_last | 698.915 (+-1.840) | | 1061.208 (+-10.449) channels=3, size=712, dtype=torch.int8, mf=channels_last | 1299.365 (+-3.919) | | 1957.424 (+-13.149) channels=3, size=256, dtype=torch.int8, mf=channels_first | 17.955 (+-0.077) | | 89.456 (+-0.285) channels=3, size=520, dtype=torch.int8, mf=channels_first | 56.901 (+-0.081) | | 339.802 (+-0.879) channels=3, size=712, dtype=torch.int8, mf=channels_first | 103.629 (+-0.256) | | 627.845 (+-1.185) channels=1, size=256, dtype=torch.float32, mf=channels_last | 21.179 (+-0.077) | 44.146 (+-0.260) | 22.957 (+-0.138) channels=1, size=520, dtype=torch.float32, mf=channels_last | 71.685 (+-0.155) | 166.666 (+-0.730) | 72.606 (+-0.124) channels=1, size=712, dtype=torch.float32, mf=channels_last | 129.168 (+-0.288) | 307.094 (+-1.571) | 130.156 (+-0.453) channels=1, size=256, dtype=torch.float16, mf=channels_last | 33.049 (+-0.089) | | 33.056 (+-0.477) channels=1, size=520, dtype=torch.float16, mf=channels_last | 116.635 (+-0.299) | | 113.433 (+-0.891) channels=1, size=712, dtype=torch.float16, mf=channels_last | 212.134 (+-0.413) | | 204.394 (+-0.822) channels=3, size=256, dtype=torch.float64, mf=channels_last | 207.214 (+-0.586) | | 302.370 (+-0.670) channels=3, size=520, dtype=torch.float64, mf=channels_last | 846.553 (+-2.301) | | 1223.851 (+-5.280) channels=3, size=712, dtype=torch.float64, mf=channels_last | 2251.687 (+-6.513) | | 2711.557 (+-14.011) channels=1, size=256, dtype=torch.bfloat16, mf=channels_last | 33.237 (+-0.072) | | 33.101 (+-0.070) channels=1, size=520, dtype=torch.bfloat16, mf=channels_last | 113.605 (+-0.337) | | 117.067 (+-0.547) channels=1, size=712, dtype=torch.bfloat16, mf=channels_last | 204.632 (+-0.487) | | 212.590 (+-0.848) channels=1, size=256, dtype=torch.bool, mf=channels_last | 7.950 (+-0.030) | | 37.757 (+-0.080) channels=1, size=520, dtype=torch.bool, mf=channels_last | 23.799 (+-0.080) | | 136.571 (+-0.441) channels=1, size=712, dtype=torch.bool, mf=channels_last | 37.970 (+-0.075) | | 246.894 (+-0.926) channels=1, size=256, dtype=torch.bool, mf=channels_first | 8.009 (+-0.077) | | 37.800 (+-0.100) channels=1, size=520, dtype=torch.bool, mf=channels_first | 23.861 (+-0.099) | | 136.553 (+-0.519) channels=1, size=712, dtype=torch.bool, mf=channels_first | 38.211 (+-0.104) | | 246.939 (+-0.692) Times are in microseconds (us). ``` [Source](https://gist.github.com/vfdev-5/c2ca615b522aeb1c4636dc8d948fec74#file-20221209-100405-pr_vs_nightly-md) - AVX512 (channels last input only) ``` [---------------------------------------------------------------------------- Horizontal flip ----------------------------------------------------------------------------] | torch (1.14.0a0+giteb3e189) PR | Pillow (9.3.0) | torch (1.14.0.dev20221208+cu116) nightly 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------------------- channels=3, size=256, dtype=torch.int64, mf=channels_last | 194.708 (+-9.566) | | 372.067 (+-12.430) channels=3, size=520, dtype=torch.int64, mf=channels_last | 765.151 (+-10.098) | | 1524.231 (+-111.283) channels=3, size=712, dtype=torch.int64, mf=channels_last | 1587.229 (+-88.117) | | 2950.081 (+-92.322) channels=1, size=256, dtype=torch.int32, mf=channels_last | 13.328 (+-0.375) | 49.693 (+-1.193) | 10.323 (+-0.333) channels=1, size=520, dtype=torch.int32, mf=channels_last | 90.580 (+-0.812) | 191.936 (+-4.369) | 92.269 (+-0.980) channels=1, size=712, dtype=torch.int32, mf=channels_last | 163.821 (+-3.174) | 352.053 (+-10.909) | 165.661 (+-4.436) channels=3, size=256, dtype=torch.uint8, mf=channels_last | 206.862 (+-4.417) | 49.336 (+-1.492) | 287.373 (+-7.266) channels=3, size=520, dtype=torch.uint8, mf=channels_last | 829.736 (+-15.857) | 191.489 (+-5.645) | 1166.126 (+-45.667) channels=3, size=712, dtype=torch.uint8, mf=channels_last | 1540.953 (+-28.269) | 352.171 (+-8.784) | 2171.570 (+-82.740) channels=1, size=256, dtype=torch.int16, mf=channels_last | 7.856 (+-0.131) | | 7.943 (+-0.148) channels=1, size=520, dtype=torch.int16, mf=channels_last | 34.750 (+-1.195) | | 36.309 (+-0.716) channels=1, size=712, dtype=torch.int16, mf=channels_last | 85.858 (+-0.729) | | 87.306 (+-0.981) channels=3, size=256, dtype=torch.int8, mf=channels_last | 206.896 (+-5.716) | | 262.551 (+-6.598) channels=3, size=520, dtype=torch.int8, mf=channels_last | 828.212 (+-13.441) | | 1077.916 (+-28.810) channels=3, size=712, dtype=torch.int8, mf=channels_last | 1542.748 (+-31.379) | | 2003.661 (+-71.614) channels=3, size=256, dtype=torch.int8, mf=channels_first | 11.038 (+-0.271) | | 126.867 (+-5.590) channels=3, size=520, dtype=torch.int8, mf=channels_first | 90.190 (+-1.185) | | 501.446 (+-13.498) channels=3, size=712, dtype=torch.int8, mf=channels_first | 165.797 (+-3.016) | | 921.131 (+-20.500) channels=1, size=256, dtype=torch.float32, mf=channels_last | 13.516 (+-0.578) | 49.678 (+-1.966) | 10.360 (+-0.256) channels=1, size=520, dtype=torch.float32, mf=channels_last | 91.195 (+-0.830) | 191.778 (+-4.742) | 91.117 (+-0.855) channels=1, size=712, dtype=torch.float32, mf=channels_last | 168.551 (+-3.352) | 351.585 (+-8.230) | 164.199 (+-3.725) channels=1, size=256, dtype=torch.float16, mf=channels_last | 35.832 (+-0.840) | | 35.087 (+-0.972) channels=1, size=520, dtype=torch.float16, mf=channels_last | 133.624 (+-5.293) | | 131.423 (+-6.002) channels=1, size=712, dtype=torch.float16, mf=channels_last | 240.702 (+-5.213) | | 236.876 (+-7.867) channels=3, size=256, dtype=torch.float64, mf=channels_last | 192.351 (+-6.740) | | 313.999 (+-12.141) channels=3, size=520, dtype=torch.float64, mf=channels_last | 766.553 (+-16.669) | | 1270.797 (+-49.828) channels=3, size=712, dtype=torch.float64, mf=channels_last | 1501.700 (+-69.499) | | 2427.303 (+-126.694) channels=1, size=256, dtype=torch.bfloat16, mf=channels_last | 35.386 (+-0.801) | | 34.539 (+-0.844) channels=1, size=520, dtype=torch.bfloat16, mf=channels_last | 132.369 (+-4.107) | | 130.926 (+-3.597) channels=1, size=712, dtype=torch.bfloat16, mf=channels_last | 237.722 (+-6.680) | | 237.072 (+-5.027) channels=1, size=256, dtype=torch.bool, mf=channels_last | 6.796 (+-0.132) | | 44.727 (+-0.905) channels=1, size=520, dtype=torch.bool, mf=channels_last | 24.827 (+-0.669) | | 166.758 (+-5.141) channels=1, size=712, dtype=torch.bool, mf=channels_last | 42.392 (+-0.980) | | 310.830 (+-6.130) channels=1, size=256, dtype=torch.bool, mf=channels_first | 8.114 (+-0.141) | | 44.776 (+-0.707) channels=1, size=520, dtype=torch.bool, mf=channels_first | 24.787 (+-0.787) | | 167.766 (+-5.004) channels=1, size=712, dtype=torch.bool, mf=channels_first | 42.545 (+-0.636) | | 313.715 (+-7.603) Times are in microseconds (us). ``` [Source](https://gist.github.com/vfdev-5/c2ca615b522aeb1c4636dc8d948fec74#file-20221209-105633-pr_vs_nightly-avx512-md) ### Vertical flip - AVX2 (all tested cases showing speed-up or same perfs) ``` [-------------------------------------------------------------------------- Vertical flip --------------------------------------------------------------------------] | torch (1.14.0a0+giteb3e189) PR | Pillow (9.3.0) | torch (1.14.0a0+gitb0bd5c4) nightly 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------------- channels=3, size=256, dtype=torch.int64, mf=channels_last | 93.125 (+-3.022) | | 101.064 (+-0.436) channels=3, size=520, dtype=torch.int64, mf=channels_last | 412.942 (+-57.066) | | 461.463 (+-2.098) channels=3, size=712, dtype=torch.int64, mf=channels_last | 1533.265 (+-4.071) | | 1829.713 (+-14.311) channels=3, size=256, dtype=torch.int64, mf=channels_first | 101.134 (+-0.924) | | 102.858 (+-0.319) channels=3, size=520, dtype=torch.int64, mf=channels_first | 421.679 (+-1.101) | | 477.413 (+-1.809) channels=3, size=712, dtype=torch.int64, mf=channels_first | 1550.418 (+-3.647) | | 1877.143 (+-6.622) channels=1, size=256, dtype=torch.int32, mf=channels_last | 20.961 (+-0.063) | 19.515 (+-0.302) | 21.980 (+-0.070) channels=1, size=520, dtype=torch.int32, mf=channels_last | 71.199 (+-0.173) | 70.199 (+-0.332) | 95.262 (+-0.109) channels=1, size=712, dtype=torch.int32, mf=channels_last | 128.532 (+-0.318) | 127.325 (+-0.328) | 167.190 (+-0.370) channels=1, size=256, dtype=torch.int32, mf=channels_first | 21.206 (+-0.059) | 19.471 (+-0.128) | 21.469 (+-0.064) channels=1, size=520, dtype=torch.int32, mf=channels_first | 71.284 (+-0.163) | 70.124 (+-0.388) | 94.988 (+-0.239) channels=1, size=712, dtype=torch.int32, mf=channels_first | 129.017 (+-0.286) | 128.088 (+-0.461) | 167.115 (+-1.075) channels=3, size=256, dtype=torch.uint8, mf=channels_last | 16.909 (+-0.057) | 19.570 (+-0.353) | 17.981 (+-0.072) channels=3, size=520, dtype=torch.uint8, mf=channels_last | 55.163 (+-0.138) | 70.218 (+-0.275) | 107.938 (+-0.620) channels=3, size=712, dtype=torch.uint8, mf=channels_last | 98.518 (+-0.121) | 127.737 (+-0.486) | 170.965 (+-0.436) channels=3, size=256, dtype=torch.uint8, mf=channels_first | 18.150 (+-0.084) | 19.758 (+-0.221) | 18.122 (+-0.088) channels=3, size=520, dtype=torch.uint8, mf=channels_first | 56.693 (+-0.200) | 70.278 (+-0.386) | 89.018 (+-0.206) channels=3, size=712, dtype=torch.uint8, mf=channels_first | 100.409 (+-0.235) | 127.772 (+-0.457) | 168.072 (+-0.436) channels=1, size=256, dtype=torch.int16, mf=channels_last | 12.817 (+-0.041) | | 12.818 (+-0.049) channels=1, size=520, dtype=torch.int16, mf=channels_last | 38.359 (+-0.081) | | 63.378 (+-0.165) channels=1, size=712, dtype=torch.int16, mf=channels_last | 68.246 (+-0.090) | | 116.637 (+-0.583) channels=1, size=256, dtype=torch.int16, mf=channels_first | 12.899 (+-0.054) | | 12.649 (+-0.060) channels=1, size=520, dtype=torch.int16, mf=channels_first | 38.404 (+-0.069) | | 63.448 (+-0.108) channels=1, size=712, dtype=torch.int16, mf=channels_first | 68.378 (+-0.104) | | 116.415 (+-0.332) channels=3, size=256, dtype=torch.int8, mf=channels_last | 17.071 (+-0.044) | | 17.792 (+-0.050) channels=3, size=520, dtype=torch.int8, mf=channels_last | 55.163 (+-0.100) | | 108.539 (+-0.466) channels=3, size=712, dtype=torch.int8, mf=channels_last | 98.537 (+-0.091) | | 171.675 (+-0.553) channels=3, size=256, dtype=torch.int8, mf=channels_first | 17.837 (+-0.071) | | 18.355 (+-0.067) channels=3, size=520, dtype=torch.int8, mf=channels_first | 56.051 (+-0.087) | | 88.261 (+-0.129) channels=3, size=712, dtype=torch.int8, mf=channels_first | 100.603 (+-0.245) | | 169.067 (+-0.430) channels=1, size=256, dtype=torch.float32, mf=channels_last | 21.204 (+-0.063) | 19.607 (+-0.140) | 22.202 (+-0.094) channels=1, size=520, dtype=torch.float32, mf=channels_last | 71.356 (+-0.211) | 69.844 (+-0.343) | 94.614 (+-0.167) channels=1, size=712, dtype=torch.float32, mf=channels_last | 129.087 (+-0.290) | 127.065 (+-0.319) | 166.513 (+-0.444) channels=1, size=256, dtype=torch.float32, mf=channels_first | 21.196 (+-0.065) | 19.156 (+-0.132) | 21.516 (+-0.073) channels=1, size=520, dtype=torch.float32, mf=channels_first | 71.422 (+-0.180) | 70.296 (+-0.136) | 94.913 (+-0.095) channels=1, size=712, dtype=torch.float32, mf=channels_first | 129.045 (+-0.312) | 128.023 (+-0.585) | 166.089 (+-0.409) channels=1, size=256, dtype=torch.float16, mf=channels_last | 12.770 (+-0.045) | | 34.853 (+-0.089) channels=1, size=520, dtype=torch.float16, mf=channels_last | 38.363 (+-0.064) | | 131.969 (+-0.577) channels=1, size=712, dtype=torch.float16, mf=channels_last | 67.954 (+-0.107) | | 239.507 (+-0.835) channels=1, size=256, dtype=torch.float16, mf=channels_first | 12.855 (+-0.067) | | 35.124 (+-0.109) channels=1, size=520, dtype=torch.float16, mf=channels_first | 38.725 (+-0.079) | | 131.708 (+-0.586) channels=1, size=712, dtype=torch.float16, mf=channels_first | 68.931 (+-0.086) | | 239.022 (+-0.914) channels=3, size=256, dtype=torch.float64, mf=channels_last | 90.277 (+-0.083) | | 101.512 (+-0.285) channels=3, size=520, dtype=torch.float64, mf=channels_last | 421.277 (+-1.030) | | 471.913 (+-3.654) channels=3, size=712, dtype=torch.float64, mf=channels_last | 1534.394 (+-7.572) | | 1833.262 (+-12.185) channels=3, size=256, dtype=torch.float64, mf=channels_first | 100.809 (+-0.328) | | 103.166 (+-0.335) channels=3, size=520, dtype=torch.float64, mf=channels_first | 425.535 (+-0.926) | | 482.606 (+-1.450) channels=3, size=712, dtype=torch.float64, mf=channels_first | 1550.832 (+-3.547) | | 1859.098 (+-6.517) channels=1, size=256, dtype=torch.bfloat16, mf=channels_last | 12.954 (+-0.051) | | 12.744 (+-0.046) channels=1, size=520, dtype=torch.bfloat16, mf=channels_last | 41.180 (+-0.064) | | 63.362 (+-0.139) channels=1, size=712, dtype=torch.bfloat16, mf=channels_last | 68.136 (+-0.142) | | 117.009 (+-0.292) channels=1, size=256, dtype=torch.bfloat16, mf=channels_first | 13.049 (+-0.052) | | 12.792 (+-0.076) channels=1, size=520, dtype=torch.bfloat16, mf=channels_first | 38.488 (+-0.092) | | 63.451 (+-0.096) channels=1, size=712, dtype=torch.bfloat16, mf=channels_first | 68.103 (+-0.091) | | 116.693 (+-0.290) channels=1, size=256, dtype=torch.bool, mf=channels_last | 7.572 (+-0.029) | | 8.017 (+-0.071) channels=1, size=520, dtype=torch.bool, mf=channels_last | 22.121 (+-0.061) | | 23.614 (+-0.074) channels=1, size=712, dtype=torch.bool, mf=channels_last | 36.896 (+-0.094) | | 39.460 (+-0.084) channels=1, size=256, dtype=torch.bool, mf=channels_first | 7.671 (+-0.028) | | 8.034 (+-0.058) channels=1, size=520, dtype=torch.bool, mf=channels_first | 21.989 (+-0.053) | | 23.645 (+-0.063) channels=1, size=712, dtype=torch.bool, mf=channels_first | 37.252 (+-0.072) | | 39.477 (+-0.100) channels=1, size=256, dtype=torch.complex64, mf=channels_last | 37.129 (+-0.052) | | 37.801 (+-0.101) channels=1, size=520, dtype=torch.complex64, mf=channels_last | 122.646 (+-0.230) | | 139.074 (+-0.467) channels=1, size=712, dtype=torch.complex64, mf=channels_last | 228.946 (+-0.736) | | 257.589 (+-0.545) channels=1, size=256, dtype=torch.complex64, mf=channels_first | 37.088 (+-0.070) | | 37.894 (+-0.078) channels=1, size=520, dtype=torch.complex64, mf=channels_first | 122.695 (+-0.268) | | 138.933 (+-0.336) channels=1, size=712, dtype=torch.complex64, mf=channels_first | 234.655 (+-0.454) | | 255.787 (+-0.530) Times are in microseconds (us). ``` [Source](https://gist.github.com/vfdev-5/c2ca615b522aeb1c4636dc8d948fec74#file-20221209-100440-pr_vs_nightly-md) - AVX512 (all tested cases showing speed-up or same perfs) ``` [---------------------------------------------------------------------------- Vertical flip -----------------------------------------------------------------------------] | torch (1.14.0a0+giteb3e189) PR | Pillow (9.3.0) | torch (1.14.0.dev20221208+cu116) nightly 1 threads: --------------------------------------------------------------------------------------------------------------------------------------------------------------- channels=3, size=256, dtype=torch.int64, mf=channels_last | 122.544 (+-1.962) | | 129.161 (+-1.809) channels=3, size=520, dtype=torch.int64, mf=channels_last | 508.274 (+-4.790) | | 533.872 (+-7.457) channels=3, size=712, dtype=torch.int64, mf=channels_last | 951.176 (+-29.534) | | 1073.603 (+-44.676) channels=3, size=256, dtype=torch.int64, mf=channels_first | 127.872 (+-2.700) | | 127.326 (+-2.666) channels=3, size=520, dtype=torch.int64, mf=channels_first | 518.019 (+-4.157) | | 538.094 (+-6.600) channels=3, size=712, dtype=torch.int64, mf=channels_first | 1002.176 (+-42.545) | | 1033.989 (+-42.137) channels=1, size=256, dtype=torch.int32, mf=channels_last | 10.025 (+-0.135) | 10.054 (+-0.369) | 10.155 (+-0.285) channels=1, size=520, dtype=torch.int32, mf=channels_last | 89.867 (+-0.994) | 88.712 (+-0.622) | 103.029 (+-2.254) channels=1, size=712, dtype=torch.int32, mf=channels_last | 161.787 (+-2.080) | 161.370 (+-1.801) | 182.608 (+-7.031) channels=1, size=256, dtype=torch.int32, mf=channels_first | 10.005 (+-0.277) | 9.965 (+-0.338) | 10.604 (+-0.334) channels=1, size=520, dtype=torch.int32, mf=channels_first | 89.116 (+-0.996) | 88.840 (+-0.608) | 102.103 (+-2.111) channels=1, size=712, dtype=torch.int32, mf=channels_first | 164.328 (+-3.284) | 161.538 (+-2.739) | 181.702 (+-3.770) channels=3, size=256, dtype=torch.uint8, mf=channels_last | 8.853 (+-0.148) | 10.292 (+-0.494) | 8.961 (+-0.190) channels=3, size=520, dtype=torch.uint8, mf=channels_last | 68.368 (+-1.158) | 90.068 (+-1.780) | 81.155 (+-0.945) channels=3, size=712, dtype=torch.uint8, mf=channels_last | 125.458 (+-2.511) | 163.150 (+-2.532) | 147.039 (+-4.264) channels=3, size=256, dtype=torch.uint8, mf=channels_first | 10.409 (+-0.435) | 10.406 (+-0.351) | 10.263 (+-0.252) channels=3, size=520, dtype=torch.uint8, mf=channels_first | 69.077 (+-1.062) | 90.057 (+-0.992) | 79.910 (+-0.884) channels=3, size=712, dtype=torch.uint8, mf=channels_first | 127.286 (+-2.789) | 162.862 (+-2.953) | 142.821 (+-2.119) channels=1, size=256, dtype=torch.int16, mf=channels_last | 7.513 (+-0.143) | | 7.364 (+-0.154) channels=1, size=520, dtype=torch.int16, mf=channels_last | 33.140 (+-0.779) | | 42.141 (+-0.820) channels=1, size=712, dtype=torch.int16, mf=channels_last | 86.235 (+-1.187) | | 104.205 (+-2.205) channels=1, size=256, dtype=torch.int16, mf=channels_first | 7.410 (+-0.162) | | 7.075 (+-0.126) channels=1, size=520, dtype=torch.int16, mf=channels_first | 33.656 (+-0.914) | | 40.991 (+-0.893) channels=1, size=712, dtype=torch.int16, mf=channels_first | 86.087 (+-1.191) | | 105.419 (+-1.801) channels=3, size=256, dtype=torch.int8, mf=channels_last | 8.802 (+-0.196) | | 8.627 (+-0.202) channels=3, size=520, dtype=torch.int8, mf=channels_last | 66.348 (+-0.775) | | 80.631 (+-1.832) channels=3, size=712, dtype=torch.int8, mf=channels_last | 126.275 (+-2.318) | | 144.597 (+-4.242) channels=3, size=256, dtype=torch.int8, mf=channels_first | 10.255 (+-0.383) | | 10.101 (+-0.335) channels=3, size=520, dtype=torch.int8, mf=channels_first | 68.124 (+-0.849) | | 79.286 (+-0.748) channels=3, size=712, dtype=torch.int8, mf=channels_first | 127.118 (+-2.225) | | 142.029 (+-2.507) channels=1, size=256, dtype=torch.float32, mf=channels_last | 9.850 (+-0.453) | 9.299 (+-0.253) | 10.030 (+-0.234) channels=1, size=520, dtype=torch.float32, mf=channels_last | 91.506 (+-1.319) | 90.265 (+-0.824) | 107.570 (+-2.093) channels=1, size=712, dtype=torch.float32, mf=channels_last | 167.820 (+-3.883) | 162.871 (+-2.397) | 180.046 (+-8.952) channels=1, size=256, dtype=torch.float32, mf=channels_first | 10.118 (+-0.359) | 10.433 (+-0.479) | 10.204 (+-0.344) channels=1, size=520, dtype=torch.float32, mf=channels_first | 90.862 (+-1.486) | 90.138 (+-0.969) | 107.011 (+-1.801) channels=1, size=712, dtype=torch.float32, mf=channels_first | 163.931 (+-3.653) | 163.155 (+-2.673) | 186.707 (+-2.248) channels=1, size=256, dtype=torch.float16, mf=channels_last | 7.304 (+-0.134) | | 24.141 (+-0.444) channels=1, size=520, dtype=torch.float16, mf=channels_last | 35.186 (+-0.656) | | 101.523 (+-1.465) channels=1, size=712, dtype=torch.float16, mf=channels_last | 85.707 (+-0.841) | | 192.640 (+-4.942) channels=1, size=256, dtype=torch.float16, mf=channels_first | 7.286 (+-0.142) | | 24.155 (+-0.555) channels=1, size=520, dtype=torch.float16, mf=channels_first | 33.819 (+-1.009) | | 101.620 (+-3.034) channels=1, size=712, dtype=torch.float16, mf=channels_first | 84.811 (+-0.993) | | 192.286 (+-4.707) channels=3, size=256, dtype=torch.float64, mf=channels_last | 126.273 (+-2.519) | | 128.831 (+-1.975) channels=3, size=520, dtype=torch.float64, mf=channels_last | 551.861 (+-4.159) | | 517.343 (+-4.501) channels=3, size=712, dtype=torch.float64, mf=channels_last | 1102.465 (+-66.427) | | 1224.532 (+-55.656) channels=3, size=256, dtype=torch.float64, mf=channels_first | 129.965 (+-2.083) | | 130.709 (+-2.261) channels=3, size=520, dtype=torch.float64, mf=channels_first | 526.332 (+-5.354) | | 515.399 (+-4.320) channels=3, size=712, dtype=torch.float64, mf=channels_first | 1169.215 (+-78.889) | | 1102.536 (+-51.178) channels=1, size=256, dtype=torch.bfloat16, mf=channels_last | 7.478 (+-0.147) | | 7.154 (+-0.162) channels=1, size=520, dtype=torch.bfloat16, mf=channels_last | 33.836 (+-1.022) | | 38.854 (+-0.648) channels=1, size=712, dtype=torch.bfloat16, mf=channels_last | 85.483 (+-0.582) | | 99.190 (+-2.202) channels=1, size=256, dtype=torch.bfloat16, mf=channels_first | 7.416 (+-0.125) | | 7.169 (+-0.121) channels=1, size=520, dtype=torch.bfloat16, mf=channels_first | 34.958 (+-0.717) | | 40.136 (+-0.784) channels=1, size=712, dtype=torch.bfloat16, mf=channels_first | 85.505 (+-1.207) | | 99.793 (+-2.065) channels=1, size=256, dtype=torch.bool, mf=channels_last | 5.856 (+-0.178) | | 5.824 (+-0.118) channels=1, size=520, dtype=torch.bool, mf=channels_last | 12.030 (+-0.330) | | 14.478 (+-0.554) channels=1, size=712, dtype=torch.bool, mf=channels_last | 30.116 (+-0.639) | | 31.163 (+-0.873) channels=1, size=256, dtype=torch.bool, mf=channels_first | 5.804 (+-0.113) | | 5.825 (+-0.102) channels=1, size=520, dtype=torch.bool, mf=channels_first | 12.043 (+-0.363) | | 14.240 (+-0.341) channels=1, size=712, dtype=torch.bool, mf=channels_first | 30.001 (+-1.001) | | 33.199 (+-0.430) channels=1, size=256, dtype=torch.complex64, mf=channels_last | 29.941 (+-0.861) | | 28.229 (+-0.904) channels=1, size=520, dtype=torch.complex64, mf=channels_last | 173.244 (+-2.577) | | 173.173 (+-2.260) channels=1, size=712, dtype=torch.complex64, mf=channels_last | 323.548 (+-3.338) | | 318.318 (+-2.764) channels=1, size=256, dtype=torch.complex64, mf=channels_first | 29.001 (+-1.029) | | 28.565 (+-2.074) channels=1, size=520, dtype=torch.complex64, mf=channels_first | 173.078 (+-1.993) | | 170.664 (+-1.722) channels=1, size=712, dtype=torch.complex64, mf=channels_first | 324.782 (+-3.759) | | 315.745 (+-2.600) Times are in microseconds (us). ``` [Source](https://gist.github.com/vfdev-5/c2ca615b522aeb1c4636dc8d948fec74#file-20221209-105707-pr_vs_nightly-avx512-md) Pull Request resolved: pytorch#89414 Approved by: https://github.com/peterbell10, https://github.com/lezcano, https://github.com/albanD
1 parent 3873575 commit 5ac2278

File tree

1 file changed

+85
-22
lines changed

1 file changed

+85
-22
lines changed

aten/src/ATen/native/cpu/IndexKernel.cpp

Lines changed: 85 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -457,16 +457,19 @@ void masked_select_kernel(TensorIterator& iter, int64_t result_stride) {
457457
});
458458
}
459459

460-
461460
template <typename scalar_t>
462461
void cpu_hflip_vec(at::TensorIterator& iter) {
463462

464463
auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) {
465464

466-
static constexpr int ntensors = 3;
465+
// Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
466+
// and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
467+
// output and input.
468+
static constexpr int ntensors = 2;
469+
const int64_t *outer_strides = &strides[3];
470+
467471
std::array<char*, ntensors> data_arr;
468472
std::copy_n(base, ntensors, data_arr.data());
469-
const int64_t *outer_strides = &strides[ntensors];
470473

471474
using Vec = Vectorized<scalar_t>;
472475

@@ -514,7 +517,7 @@ void cpu_hflip_vec(at::TensorIterator& iter) {
514517
}
515518

516519
// advance:
517-
for (const auto arg : c10::irange(data_arr.size())) {
520+
for (const auto arg : c10::irange(ntensors)) {
518521
data_arr[arg] += outer_strides[arg];
519522
}
520523
}
@@ -525,6 +528,46 @@ void cpu_hflip_vec(at::TensorIterator& iter) {
525528
iter.cast_outputs();
526529
}
527530

531+
void cpu_vflip_memcpy(at::TensorIterator& iter) {
532+
// This is a vertical flip specialization using memcpy to speed-up the runtime
533+
534+
auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) {
535+
536+
// Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
537+
// and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
538+
// output and input.
539+
static constexpr int ntensors = 2;
540+
const int64_t *outer_strides = &strides[3];
541+
542+
std::array<char*, ntensors> data_arr;
543+
std::copy_n(base, ntensors, data_arr.data());
544+
545+
TORCH_INTERNAL_ASSERT(strides[0] == strides[1]);
546+
const int64_t stride = strides[0];
547+
548+
for (const auto j C10_UNUSED : c10::irange(size1)) {
549+
550+
char** C10_RESTRICT data_ = data_arr.data();
551+
int64_t n = size0;
552+
553+
char* C10_RESTRICT data[ntensors];
554+
for (const auto arg : c10::irange(ntensors)) {
555+
data[arg] = data_[arg];
556+
}
557+
558+
memcpy(data[0], data[1], n * stride);
559+
560+
// advance:
561+
for (const auto arg : c10::irange(data_arr.size())) {
562+
data_arr[arg] += outer_strides[arg];
563+
}
564+
}
565+
};
566+
567+
int64_t grain_size = at::internal::GRAIN_SIZE;
568+
iter.for_each(loop2d, grain_size);
569+
iter.cast_outputs();
570+
}
528571

529572
void flip_kernel(TensorIterator& iter, const bool quantized) {
530573
if (quantized) {
@@ -535,27 +578,47 @@ void flip_kernel(TensorIterator& iter, const bool quantized) {
535578
});
536579
});
537580
} else {
538-
// Special case: horizontal flip with vectorization and input is contiguous
539-
// Context: horizontal flip leads to strides[0] < 0 and
540-
// thus is_contiguous condition is not satisfied and non-vectorized code path is taken.
541581
auto output_strides = iter.strides(0);
542582
auto input_strides = iter.strides(1);
543-
if (iter.ndim() > 0 && output_strides[0] < 0 && input_strides[0] == iter.element_size(1)) {
583+
if (iter.ndim() > 0 && output_strides[0] == -iter.element_size(0) && input_strides[0] == iter.element_size(1)) {
584+
// Special case: horizontal flip with vectorization and input is contiguous
585+
// Context: horizontal flip leads to strides[0] < 0 and
586+
// thus is_contiguous condition is not satisfied and non-vectorized code path is taken.
544587
auto iter_dtype = iter.dtype();
545-
if (iter_dtype == kByte) {
546-
return cpu_hflip_vec<uint8_t>(iter);
547-
} else if (iter_dtype == kFloat) {
548-
return cpu_hflip_vec<float>(iter);
549-
} else if (iter_dtype == kInt) {
550-
return cpu_hflip_vec<int32_t>(iter);
551-
} else if (iter_dtype == kShort) {
552-
return cpu_hflip_vec<int16_t>(iter);
553-
} else if (iter_dtype == kLong) {
554-
return cpu_hflip_vec<int64_t>(iter);
555-
} else if (iter_dtype == kDouble) {
556-
return cpu_hflip_vec<double>(iter);
557-
}
558-
// other dtypes are handled below with cpu_kernel_vec
588+
// Ignoring half and bfloat16 as cpu_hflip_vec is slower than cpu_kernel_vec
589+
if (isIntegralType(iter_dtype, true) || iter_dtype == kDouble || iter_dtype == kFloat) {
590+
// Replace AT_DISPATCH_ALL_TYPES_AND by manual if/else due to internal test failures:
591+
// - "dtype 'Float' not selected for kernel tag hflip_cpu"
592+
// - "dtype 'Long' not selected for kernel tag hflip_cpu"
593+
//
594+
// AT_DISPATCH_ALL_TYPES_AND(kBool,
595+
// iter_dtype, "hflip_cpu", [&iter] {
596+
// cpu_hflip_vec<scalar_t>(iter);
597+
// });
598+
599+
if (iter_dtype == kByte) {
600+
return cpu_hflip_vec<uint8_t>(iter);
601+
} else if (iter_dtype == kChar) {
602+
return cpu_hflip_vec<int8_t>(iter);
603+
} else if (iter_dtype == kInt) {
604+
return cpu_hflip_vec<int32_t>(iter);
605+
} else if (iter_dtype == kLong) {
606+
return cpu_hflip_vec<int64_t>(iter);
607+
} else if (iter_dtype == kShort) {
608+
return cpu_hflip_vec<int16_t>(iter);
609+
} else if (iter_dtype == kBool) {
610+
return cpu_hflip_vec<bool>(iter);
611+
} else if (iter_dtype == kFloat) {
612+
return cpu_hflip_vec<float>(iter);
613+
} else if (iter_dtype == kDouble) {
614+
return cpu_hflip_vec<double>(iter);
615+
}
616+
617+
}
618+
// other dtypes (float16, bfloat16, complex) are handled by cpu_kernel_vec (see below)
619+
} else if (iter.has_contiguous_first_dim()) {
620+
// Special case: vertical flip using memcpy (faster than generic cpu_kernel_vec)
621+
return cpu_vflip_memcpy(iter);
559622
}
560623

561624
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(), "flip_cpu",

0 commit comments

Comments
 (0)