diff --git a/src/xss-common-argsort.h b/src/xss-common-argsort.h index 4f5c30d..93d46bc 100644 --- a/src/xss-common-argsort.h +++ b/src/xss-common-argsort.h @@ -584,6 +584,14 @@ X86_SIMD_SORT_INLINE void argselect_(type_t *arr, arr, arg, pos, pivot_index, right, max_iters - 1); } +template +X86_SIMD_SORT_FINLINE bool is_sorted(T *arr, arrsize_t arrsize, bool descending) +{ + auto comp = descending ? Comparator::STDSortComparator + : Comparator::STDSortComparator; + return std::is_sorted(arr, arr + arrsize, comp); +} + /* argsort methods for 32-bit and 64-bit dtypes */ template @@ -600,11 +608,12 @@ X86_SIMD_SORT_INLINE void xss_argsort(T *arr, using vectype = typename std::conditional, full_vector>::type; - using argtype = typename std::conditional, full_vector>::type; + static_assert(is_valid_vector_type_key_value(), + "Invalid type for argsort!"); if (arrsize > 1) { /* simdargsort does not work for float/double arrays with nan */ @@ -620,9 +629,7 @@ X86_SIMD_SORT_INLINE void xss_argsort(T *arr, UNUSED(hasnan); /* early exit for already sorted arrays: float/double with nan never reach here*/ - auto comp = descending ? Comparator::STDSortComparator - : Comparator::STDSortComparator; - if (std::is_sorted(arr, arr + arrsize, comp)) { return; } + if (is_sorted(arr, arrsize, descending)) { return; } #ifdef XSS_COMPILE_OPENMP @@ -708,11 +715,12 @@ X86_SIMD_SORT_INLINE void xss_argselect(T *arr, using vectype = typename std::conditional, full_vector>::type; - using argtype = typename std::conditional, full_vector>::type; + static_assert(is_valid_vector_type_key_value(), + "Invalid type for argselect!"); if (arrsize > 1) { if constexpr (xss::fp::is_floating_point_v) { diff --git a/src/xss-common-includes.h b/src/xss-common-includes.h index a7c34c1..46bfb18 100644 --- a/src/xss-common-includes.h +++ b/src/xss-common-includes.h @@ -92,19 +92,27 @@ constexpr bool always_false = false; typedef size_t arrsize_t; -template -struct zmm_vector; +enum class simd_type : int { INVALID, AVX2, AVX512 }; template -struct ymm_vector; +struct zmm_vector { + static constexpr simd_type vec_type = simd_type::INVALID; +}; template -struct avx2_vector; +struct ymm_vector { + static constexpr simd_type vec_type = simd_type::INVALID; +}; template -struct avx2_half_vector; +struct avx2_vector { + static constexpr simd_type vec_type = simd_type::INVALID; +}; -enum class simd_type : int { AVX2, AVX512 }; +template +struct avx2_half_vector { + static constexpr simd_type vec_type = simd_type::INVALID; +}; template X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b); @@ -113,4 +121,29 @@ struct float16 { uint16_t val; }; +template +constexpr bool is_valid_vector_type() +{ + return vtype::vec_type != simd_type::INVALID; +} + +template +constexpr bool is_valid_vector_type_32_or_64_bit() +{ + if constexpr (is_valid_vector_type()) { + constexpr int type_size = sizeof(typename vtype::type_t); + return type_size == 4 || type_size == 8; + } + else { + return false; + } +} + +template +constexpr bool is_valid_vector_type_key_value() +{ + return is_valid_vector_type_32_or_64_bit() + && is_valid_vector_type_32_or_64_bit(); +} + #endif // XSS_COMMON_INCLUDES diff --git a/src/xss-common-keyvaluesort.hpp b/src/xss-common-keyvaluesort.hpp index 1a15de7..775d439 100644 --- a/src/xss-common-keyvaluesort.hpp +++ b/src/xss-common-keyvaluesort.hpp @@ -580,6 +580,8 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv( && sizeof(T2) == sizeof(int32_t), half_vector, full_vector>::type; + static_assert(is_valid_vector_type_key_value(), + "Invalid type for keyvalue_qsort!"); // Exit early if no work would be done if (arrsize <= 1) return; @@ -677,6 +679,8 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys, && sizeof(T2) == sizeof(int32_t), half_vector, full_vector>::type; + static_assert(is_valid_vector_type_key_value(), + "Invalid type for keyvalue_select!"); // Exit early if no work would be done if (arrsize <= 1) return; @@ -732,6 +736,19 @@ X86_SIMD_SORT_INLINE void xss_partial_sort_kv(T1 *keys, bool hasnan, bool descending) { + using keytype = + typename std::conditional, + full_vector>::type; + using valtype = + typename std::conditional, + full_vector>::type; + static_assert(is_valid_vector_type_key_value(), + "Invalid type for keyvalue_partial_sort!"); + if (k == 0) return; xss_select_kv( keys, indexes, k - 1, arrsize, hasnan, descending); diff --git a/src/xss-common-qsort.h b/src/xss-common-qsort.h index cf4a34a..f43124f 100644 --- a/src/xss-common-qsort.h +++ b/src/xss-common-qsort.h @@ -652,6 +652,7 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr, template X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan) { + static_assert(is_valid_vector_type(), "Invalid type for qsort!"); using comparator = typename std::conditional, @@ -716,6 +717,7 @@ template X86_SIMD_SORT_INLINE void xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { + static_assert(is_valid_vector_type(), "Invalid type for qselect!"); using comparator = typename std::conditional, @@ -758,6 +760,8 @@ template X86_SIMD_SORT_INLINE void xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan) { + static_assert(is_valid_vector_type(), + "Invalid type for partial_qsort!"); if (k == 0) return; xss_qselect(arr, k - 1, arrsize, hasnan); xss_qsort(arr, k - 1, hasnan);