@@ -457,16 +457,19 @@ void masked_select_kernel(TensorIterator& iter, int64_t result_stride) {
457
457
});
458
458
}
459
459
460
-
461
460
template <typename scalar_t >
462
461
void cpu_hflip_vec (at::TensorIterator& iter) {
463
462
464
463
auto loop2d = [&](char ** base, const int64_t *strides, int64_t size0, int64_t size1) {
465
464
466
- static constexpr int ntensors = 3 ;
465
+ // Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
466
+ // and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
467
+ // output and input.
468
+ static constexpr int ntensors = 2 ;
469
+ const int64_t *outer_strides = &strides[3 ];
470
+
467
471
std::array<char *, ntensors> data_arr;
468
472
std::copy_n (base, ntensors, data_arr.data ());
469
- const int64_t *outer_strides = &strides[ntensors];
470
473
471
474
using Vec = Vectorized<scalar_t >;
472
475
@@ -514,7 +517,7 @@ void cpu_hflip_vec(at::TensorIterator& iter) {
514
517
}
515
518
516
519
// advance:
517
- for (const auto arg : c10::irange (data_arr. size () )) {
520
+ for (const auto arg : c10::irange (ntensors )) {
518
521
data_arr[arg] += outer_strides[arg];
519
522
}
520
523
}
@@ -525,6 +528,46 @@ void cpu_hflip_vec(at::TensorIterator& iter) {
525
528
iter.cast_outputs ();
526
529
}
527
530
531
+ void cpu_vflip_memcpy (at::TensorIterator& iter) {
532
+ // This is a vertical flip specialization using memcpy to speed-up the runtime
533
+
534
+ auto loop2d = [&](char ** base, const int64_t *strides, int64_t size0, int64_t size1) {
535
+
536
+ // Here ntensors is defined for output and 1 input. But tensor iterator has defined output, input
537
+ // and restrided_input (see aten/src/ATen/native/TensorTransformations.cpp#L64-L66) but we use only
538
+ // output and input.
539
+ static constexpr int ntensors = 2 ;
540
+ const int64_t *outer_strides = &strides[3 ];
541
+
542
+ std::array<char *, ntensors> data_arr;
543
+ std::copy_n (base, ntensors, data_arr.data ());
544
+
545
+ TORCH_INTERNAL_ASSERT (strides[0 ] == strides[1 ]);
546
+ const int64_t stride = strides[0 ];
547
+
548
+ for (const auto j C10_UNUSED : c10::irange (size1)) {
549
+
550
+ char ** C10_RESTRICT data_ = data_arr.data ();
551
+ int64_t n = size0;
552
+
553
+ char * C10_RESTRICT data[ntensors];
554
+ for (const auto arg : c10::irange (ntensors)) {
555
+ data[arg] = data_[arg];
556
+ }
557
+
558
+ memcpy (data[0 ], data[1 ], n * stride);
559
+
560
+ // advance:
561
+ for (const auto arg : c10::irange (data_arr.size ())) {
562
+ data_arr[arg] += outer_strides[arg];
563
+ }
564
+ }
565
+ };
566
+
567
+ int64_t grain_size = at::internal::GRAIN_SIZE;
568
+ iter.for_each (loop2d, grain_size);
569
+ iter.cast_outputs ();
570
+ }
528
571
529
572
void flip_kernel (TensorIterator& iter, const bool quantized) {
530
573
if (quantized) {
@@ -535,27 +578,47 @@ void flip_kernel(TensorIterator& iter, const bool quantized) {
535
578
});
536
579
});
537
580
} else {
538
- // Special case: horizontal flip with vectorization and input is contiguous
539
- // Context: horizontal flip leads to strides[0] < 0 and
540
- // thus is_contiguous condition is not satisfied and non-vectorized code path is taken.
541
581
auto output_strides = iter.strides (0 );
542
582
auto input_strides = iter.strides (1 );
543
- if (iter.ndim () > 0 && output_strides[0 ] < 0 && input_strides[0 ] == iter.element_size (1 )) {
583
+ if (iter.ndim () > 0 && output_strides[0 ] == -iter.element_size (0 ) && input_strides[0 ] == iter.element_size (1 )) {
584
+ // Special case: horizontal flip with vectorization and input is contiguous
585
+ // Context: horizontal flip leads to strides[0] < 0 and
586
+ // thus is_contiguous condition is not satisfied and non-vectorized code path is taken.
544
587
auto iter_dtype = iter.dtype ();
545
- if (iter_dtype == kByte ) {
546
- return cpu_hflip_vec<uint8_t >(iter);
547
- } else if (iter_dtype == kFloat ) {
548
- return cpu_hflip_vec<float >(iter);
549
- } else if (iter_dtype == kInt ) {
550
- return cpu_hflip_vec<int32_t >(iter);
551
- } else if (iter_dtype == kShort ) {
552
- return cpu_hflip_vec<int16_t >(iter);
553
- } else if (iter_dtype == kLong ) {
554
- return cpu_hflip_vec<int64_t >(iter);
555
- } else if (iter_dtype == kDouble ) {
556
- return cpu_hflip_vec<double >(iter);
557
- }
558
- // other dtypes are handled below with cpu_kernel_vec
588
+ // Ignoring half and bfloat16 as cpu_hflip_vec is slower than cpu_kernel_vec
589
+ if (isIntegralType (iter_dtype, true ) || iter_dtype == kDouble || iter_dtype == kFloat ) {
590
+ // Replace AT_DISPATCH_ALL_TYPES_AND by manual if/else due to internal test failures:
591
+ // - "dtype 'Float' not selected for kernel tag hflip_cpu"
592
+ // - "dtype 'Long' not selected for kernel tag hflip_cpu"
593
+ //
594
+ // AT_DISPATCH_ALL_TYPES_AND(kBool,
595
+ // iter_dtype, "hflip_cpu", [&iter] {
596
+ // cpu_hflip_vec<scalar_t>(iter);
597
+ // });
598
+
599
+ if (iter_dtype == kByte ) {
600
+ return cpu_hflip_vec<uint8_t >(iter);
601
+ } else if (iter_dtype == kChar ) {
602
+ return cpu_hflip_vec<int8_t >(iter);
603
+ } else if (iter_dtype == kInt ) {
604
+ return cpu_hflip_vec<int32_t >(iter);
605
+ } else if (iter_dtype == kLong ) {
606
+ return cpu_hflip_vec<int64_t >(iter);
607
+ } else if (iter_dtype == kShort ) {
608
+ return cpu_hflip_vec<int16_t >(iter);
609
+ } else if (iter_dtype == kBool ) {
610
+ return cpu_hflip_vec<bool >(iter);
611
+ } else if (iter_dtype == kFloat ) {
612
+ return cpu_hflip_vec<float >(iter);
613
+ } else if (iter_dtype == kDouble ) {
614
+ return cpu_hflip_vec<double >(iter);
615
+ }
616
+
617
+ }
618
+ // other dtypes (float16, bfloat16, complex) are handled by cpu_kernel_vec (see below)
619
+ } else if (iter.has_contiguous_first_dim ()) {
620
+ // Special case: vertical flip using memcpy (faster than generic cpu_kernel_vec)
621
+ return cpu_vflip_memcpy (iter);
559
622
}
560
623
561
624
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3 (kBool , kHalf , kBFloat16 , iter.dtype (), " flip_cpu" ,
0 commit comments