Skip to content

Try to get better type errors for the static sorting functions #202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/xss-common-argsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,14 @@ X86_SIMD_SORT_INLINE void argselect_(type_t *arr,
arr, arg, pos, pivot_index, right, max_iters - 1);
}

template <typename T, typename vtype>
X86_SIMD_SORT_FINLINE bool is_sorted(T *arr, arrsize_t arrsize, bool descending)
{
auto comp = descending ? Comparator<vtype, true>::STDSortComparator
: Comparator<vtype, false>::STDSortComparator;
return std::is_sorted(arr, arr + arrsize, comp);
}

/* argsort methods for 32-bit and 64-bit dtypes */
template <typename T,
template <typename...>
Expand All @@ -600,11 +608,12 @@ X86_SIMD_SORT_INLINE void xss_argsort(T *arr,
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
half_vector<T>,
full_vector<T>>::type;

using argtype =
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
half_vector<arrsize_t>,
full_vector<arrsize_t>>::type;
static_assert(is_valid_vector_type_key_value<vectype, argtype>(),
"Invalid type for argsort!");

if (arrsize > 1) {
/* simdargsort does not work for float/double arrays with nan */
Expand All @@ -620,9 +629,7 @@ X86_SIMD_SORT_INLINE void xss_argsort(T *arr,
UNUSED(hasnan);

/* early exit for already sorted arrays: float/double with nan never reach here*/
auto comp = descending ? Comparator<vectype, true>::STDSortComparator
: Comparator<vectype, false>::STDSortComparator;
if (std::is_sorted(arr, arr + arrsize, comp)) { return; }
if (is_sorted<T, vectype>(arr, arrsize, descending)) { return; }

#ifdef XSS_COMPILE_OPENMP

Expand Down Expand Up @@ -708,11 +715,12 @@ X86_SIMD_SORT_INLINE void xss_argselect(T *arr,
using vectype = typename std::conditional<sizeof(T) == sizeof(int32_t),
half_vector<T>,
full_vector<T>>::type;

using argtype =
typename std::conditional<sizeof(arrsize_t) == sizeof(int32_t),
half_vector<arrsize_t>,
full_vector<arrsize_t>>::type;
static_assert(is_valid_vector_type_key_value<vectype, argtype>(),
"Invalid type for argselect!");

if (arrsize > 1) {
if constexpr (xss::fp::is_floating_point_v<T>) {
Expand Down
45 changes: 39 additions & 6 deletions src/xss-common-includes.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,27 @@ constexpr bool always_false = false;

typedef size_t arrsize_t;

template <typename type>
struct zmm_vector;
enum class simd_type : int { INVALID, AVX2, AVX512 };

template <typename type>
struct ymm_vector;
struct zmm_vector {
static constexpr simd_type vec_type = simd_type::INVALID;
};

template <typename type>
struct avx2_vector;
struct ymm_vector {
static constexpr simd_type vec_type = simd_type::INVALID;
};

template <typename type>
struct avx2_half_vector;
struct avx2_vector {
static constexpr simd_type vec_type = simd_type::INVALID;
};

enum class simd_type : int { AVX2, AVX512 };
template <typename type>
struct avx2_half_vector {
static constexpr simd_type vec_type = simd_type::INVALID;
};

template <typename vtype, typename T = typename vtype::type_t>
X86_SIMD_SORT_INLINE bool comparison_func(const T &a, const T &b);
Expand All @@ -113,4 +121,29 @@ struct float16 {
uint16_t val;
};

template <typename vtype>
constexpr bool is_valid_vector_type()
{
return vtype::vec_type != simd_type::INVALID;
}

template <typename vtype>
constexpr bool is_valid_vector_type_32_or_64_bit()
{
if constexpr (is_valid_vector_type<vtype>()) {
constexpr int type_size = sizeof(typename vtype::type_t);
return type_size == 4 || type_size == 8;
}
else {
return false;
}
}

template <typename vtype1, typename vtype2>
constexpr bool is_valid_vector_type_key_value()
{
return is_valid_vector_type_32_or_64_bit<vtype1>()
&& is_valid_vector_type_32_or_64_bit<vtype2>();
}

#endif // XSS_COMMON_INCLUDES
17 changes: 17 additions & 0 deletions src/xss-common-keyvaluesort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,8 @@ X86_SIMD_SORT_INLINE void xss_qsort_kv(
&& sizeof(T2) == sizeof(int32_t),
half_vector<T2>,
full_vector<T2>>::type;
static_assert(is_valid_vector_type_key_value<keytype, valtype>(),
"Invalid type for keyvalue_qsort!");

// Exit early if no work would be done
if (arrsize <= 1) return;
Expand Down Expand Up @@ -677,6 +679,8 @@ X86_SIMD_SORT_INLINE void xss_select_kv(T1 *keys,
&& sizeof(T2) == sizeof(int32_t),
half_vector<T2>,
full_vector<T2>>::type;
static_assert(is_valid_vector_type_key_value<keytype, valtype>(),
"Invalid type for keyvalue_select!");

// Exit early if no work would be done
if (arrsize <= 1) return;
Expand Down Expand Up @@ -732,6 +736,19 @@ X86_SIMD_SORT_INLINE void xss_partial_sort_kv(T1 *keys,
bool hasnan,
bool descending)
{
using keytype =
typename std::conditional<sizeof(T1) != sizeof(T2)
&& sizeof(T1) == sizeof(int32_t),
half_vector<T1>,
full_vector<T1>>::type;
using valtype =
typename std::conditional<sizeof(T1) != sizeof(T2)
&& sizeof(T2) == sizeof(int32_t),
half_vector<T2>,
full_vector<T2>>::type;
static_assert(is_valid_vector_type_key_value<keytype, valtype>(),
"Invalid type for keyvalue_partial_sort!");

if (k == 0) return;
xss_select_kv<T1, T2, full_vector, half_vector>(
keys, indexes, k - 1, arrsize, hasnan, descending);
Expand Down
4 changes: 4 additions & 0 deletions src/xss-common-qsort.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ X86_SIMD_SORT_INLINE void qselect_(type_t *arr,
template <typename vtype, typename T, bool descending = false>
X86_SIMD_SORT_INLINE void xss_qsort(T *arr, arrsize_t arrsize, bool hasnan)
{
static_assert(is_valid_vector_type<vtype>(), "Invalid type for qsort!");
using comparator =
typename std::conditional<descending,
Comparator<vtype, true>,
Expand Down Expand Up @@ -716,6 +717,7 @@ template <typename vtype, typename T, bool descending = false>
X86_SIMD_SORT_INLINE void
xss_qselect(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
{
static_assert(is_valid_vector_type<vtype>(), "Invalid type for qselect!");
using comparator =
typename std::conditional<descending,
Comparator<vtype, true>,
Expand Down Expand Up @@ -758,6 +760,8 @@ template <typename vtype, typename T, bool descending = false>
X86_SIMD_SORT_INLINE void
xss_partial_qsort(T *arr, arrsize_t k, arrsize_t arrsize, bool hasnan)
{
static_assert(is_valid_vector_type<vtype>(),
"Invalid type for partial_qsort!");
if (k == 0) return;
xss_qselect<vtype, T, descending>(arr, k - 1, arrsize, hasnan);
xss_qsort<vtype, T, descending>(arr, k - 1, hasnan);
Expand Down