From 5abbb93f6e8eaa8ae6b8e12fa43fd954847f1502 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Fri, 19 Sep 2025 01:02:15 +0800 Subject: [PATCH] Roll back to EU copy for AB buffer in Sort kernel Signed-off-by: Feng Yuan --- src/ATen/native/xpu/sycl/SortingKernels.h | 29 +++++++++++++++-------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/ATen/native/xpu/sycl/SortingKernels.h b/src/ATen/native/xpu/sycl/SortingKernels.h index bd5972a997..aad93d9eb6 100644 --- a/src/ATen/native/xpu/sycl/SortingKernels.h +++ b/src/ATen/native/xpu/sycl/SortingKernels.h @@ -316,6 +316,13 @@ void segmented_radix_sort_pairs_downsweep_kernel( // ======================= large sort ======================= +template +struct ABBufferCopyFunctor { + scalar_t operator()(scalar_t x) const { + return x; + } +}; + template < typename key_t, typename value_t, @@ -409,18 +416,20 @@ void segmented_radix_sort_pairs_kernel( auto input_calc = TrivialOffsetCalculator<2>(); at::detail::Array data; if (keys_out) { - auto q = at::xpu::getCurrentSYCLQueue(); - q.memcpy( - (void*)keys_out, - (void*)keys_temp, - sizeof(key_t) * num_segments * num_elements); + data[0] = (char*)keys_out; + data[1] = (char*)keys_temp; + auto fn = ABBufferCopyFunctor(); + auto vec_size = memory::can_vectorize_up_to(data); + launch_vectorized_kernel( + num_segments * num_elements, fn, data, input_calc, vec_size); } if (values_out) { - auto q = at::xpu::getCurrentSYCLQueue(); - q.memcpy( - (void*)values_out, - (void*)values_temp, - sizeof(value_t) * num_segments * num_elements); + data[0] = (char*)values_out; + data[1] = (char*)values_temp; + auto fn = ABBufferCopyFunctor(); + auto vec_size = memory::can_vectorize_up_to(data); + launch_vectorized_kernel( + num_segments * num_elements, fn, data, input_calc, vec_size); } } }