Skip to content

Commit

Permalink
Implement ranges::nth_element (#1063)
Browse files Browse the repository at this point in the history
  • Loading branch information
CaseyCarter authored Aug 1, 2020
1 parent accaf18 commit e58a945
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 9 deletions.
246 changes: 237 additions & 9 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -7268,22 +7268,22 @@ void inplace_merge(_ExPo&&, _BidIt _First, _BidIt _Mid, _BidIt _Last) noexcept /

// FUNCTION TEMPLATE sort
template <class _BidIt, class _Pr>
_CONSTEXPR20 _BidIt _Insertion_sort_unchecked(_BidIt _First, const _BidIt _Last, _Pr _Pred) {
_CONSTEXPR20 _BidIt _Insertion_sort_unchecked(const _BidIt _First, const _BidIt _Last, _Pr _Pred) {
// insertion sort [_First, _Last)
if (_First != _Last) {
for (_BidIt _Next = _First; ++_Next != _Last;) { // order next element
_BidIt _Next1 = _Next;
_Iter_value_t<_BidIt> _Val = _STD move(*_Next);
for (_BidIt _Mid = _First; ++_Mid != _Last;) { // order next element
_BidIt _Hole = _Mid;
_Iter_value_t<_BidIt> _Val = _STD move(*_Mid);

if (_DEBUG_LT_PRED(_Pred, _Val, *_First)) { // found new earliest element, move to front
_Move_backward_unchecked(_First, _Next, ++_Next1);
_Move_backward_unchecked(_First, _Mid, ++_Hole);
*_First = _STD move(_Val);
} else { // look for insertion point after first
for (_BidIt _First1 = _Next1; _DEBUG_LT_PRED(_Pred, _Val, *--_First1); _Next1 = _First1) {
*_Next1 = _STD move(*_First1); // move hole down
for (_BidIt _Prev = _Hole; _DEBUG_LT_PRED(_Pred, _Val, *--_Prev); _Hole = _Prev) {
*_Hole = _STD move(*_Prev); // move hole down
}

*_Next1 = _STD move(_Val); // insert element in hole
*_Hole = _STD move(_Val); // insert element in hole
}
}
}
Expand Down Expand Up @@ -7347,6 +7347,7 @@ _CONSTEXPR20 pair<_RanIt, _RanIt> _Partition_by_median_guess_unchecked(_RanIt _F
for (;;) { // partition
for (; _Gfirst < _Last; ++_Gfirst) {
if (_DEBUG_LT_PRED(_Pred, *_Pfirst, *_Gfirst)) {
continue;
} else if (_Pred(*_Gfirst, *_Pfirst)) {
break;
} else if (_Plast != _Gfirst) {
Expand All @@ -7359,6 +7360,7 @@ _CONSTEXPR20 pair<_RanIt, _RanIt> _Partition_by_median_guess_unchecked(_RanIt _F

for (; _First < _Glast; --_Glast) {
if (_DEBUG_LT_PRED(_Pred, *_Prev_iter(_Glast), *_Pfirst)) {
continue;
} else if (_Pred(*_Pfirst, *_Prev_iter(_Glast))) {
break;
} else if (--_Pfirst != _Prev_iter(_Glast)) {
Expand Down Expand Up @@ -7444,6 +7446,156 @@ void sort(_ExPo&& _Exec, const _RanIt _First, const _RanIt _Last) noexcept /* te
// order [_First, _Last)
_STD sort(_STD forward<_ExPo>(_Exec), _First, _Last, less{});
}

#ifdef __cpp_lib_concepts
namespace ranges {
// clang-format off
template <bidirectional_iterator _It, class _Pr, class _Pj>
requires sortable<_It, _Pr, _Pj>
constexpr void _Insertion_sort_common(const _It _First, const _It _Last, _Pr _Pred, _Pj _Proj) {
// insertion sort [_First, _Last)

if (_First == _Last) { // empty range is sorted
return;
}

for (auto _Mid = _First; ++_Mid != _Last;) { // order next element
iter_value_t<_It> _Val = _RANGES iter_move(_Mid);
auto _Hole = _Mid;

for (auto _Prev = _Hole;;) {
--_Prev;
if (!_STD invoke(_Pred, _STD invoke(_Proj, _Val), _STD invoke(_Proj, *_Prev))) {
break;
}
*_Hole = _RANGES iter_move(_Prev); // move hole down
if (--_Hole == _First) {
break;
}
}

*_Hole = _STD move(_Val); // insert element in hole
}
}

template <random_access_iterator _It, class _Pr, class _Pj>
requires sortable<_It, _Pr, _Pj>
constexpr void _Med3_unchecked(_It _First, _It _Mid, _It _Last, _Pr _Pred, _Pj _Proj) {
// sort median of three elements to middle
if (_STD invoke(_Pred, _STD invoke(_Proj, *_Mid), _STD invoke(_Proj, *_First))) {
_RANGES iter_swap(_Mid, _First);
}

if (!_STD invoke(_Pred, _STD invoke(_Proj, *_Last), _STD invoke(_Proj, *_Mid))) {
return;
}

// swap middle and last, then test first again
_RANGES iter_swap(_Last, _Mid);

if (_STD invoke(_Pred, _STD invoke(_Proj, *_Mid), _STD invoke(_Proj, *_First))) {
_RANGES iter_swap(_Mid, _First);
}
}

template <random_access_iterator _It, class _Pr, class _Pj>
requires sortable<_It, _Pr, _Pj>
constexpr void _Guess_median_unchecked(_It _First, _It _Mid, _It _Last, _Pr _Pred, _Pj _Proj) {
// sort median element to middle
using _Diff = iter_difference_t<_It>;
const _Diff _Count = _Last - _First;
if (_Count > 40) { // Tukey's ninther
const _Diff _Step = (_Count + 1) >> 3; // +1 can't overflow because range was made inclusive in caller
const _Diff _Two_step = _Step << 1; // note: intentionally discards low-order bit
_Med3_unchecked(_First, _First + _Step, _First + _Two_step, _Pred, _Proj);
_Med3_unchecked(_Mid - _Step, _Mid, _Mid + _Step, _Pred, _Proj);
_Med3_unchecked(_Last - _Two_step, _Last - _Step, _Last, _Pred, _Proj);
_Med3_unchecked(_First + _Step, _Mid, _Last - _Step, _Pred, _Proj);
} else {
_Med3_unchecked(_First, _Mid, _Last, _Pred, _Proj);
}
}

template <random_access_iterator _It, class _Pr, class _Pj>
requires sortable<_It, _Pr, _Pj>
_NODISCARD constexpr subrange<_It> _Partition_by_median_guess_unchecked(
_It _First, _It _Last, _Pr _Pred, _Pj _Proj) {
// Choose a pivot, partition [_First, _Last) into elements less than pivot, elements equal to pivot, and
// elements greater than pivot; return the equal partition as a subrange.

_It _Mid = _First + ((_Last - _First) >> 1); // shift for codegen
_RANGES _Guess_median_unchecked(_First, _Mid, _RANGES prev(_Last), _Pred, _Proj);
_It _Pfirst = _Mid;
_It _Plast = _RANGES next(_Pfirst);

while (_First < _Pfirst
&& !_STD invoke(_Pred, _STD invoke(_Proj, *_RANGES prev(_Pfirst)), _STD invoke(_Proj, *_Pfirst))
&& !_STD invoke(_Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_RANGES prev(_Pfirst)))) {
--_Pfirst;
}

while (_Plast < _Last
&& !_STD invoke(_Pred, _STD invoke(_Proj, *_Plast), _STD invoke(_Proj, *_Pfirst))
&& !_STD invoke(_Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_Plast))) {
++_Plast;
}

_It _Gfirst = _Plast;
_It _Glast = _Pfirst;

for (;;) { // partition
for (; _Gfirst < _Last; ++_Gfirst) {
if (_STD invoke(_Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_Gfirst))) {
continue;
} else if (_STD invoke(_Pred, _STD invoke(_Proj, *_Gfirst), _STD invoke(_Proj, *_Pfirst))) {
break;
} else if (_Plast != _Gfirst) {
_RANGES iter_swap(_Plast, _Gfirst);
++_Plast;
} else {
++_Plast;
}
}

for (; _First < _Glast; --_Glast) {
if (_STD invoke(_Pred, _STD invoke(_Proj, *_RANGES prev(_Glast)), _STD invoke(_Proj, *_Pfirst))) {
continue;
} else if (_STD invoke(
_Pred, _STD invoke(_Proj, *_Pfirst), _STD invoke(_Proj, *_RANGES prev(_Glast)))) {
break;
} else if (--_Pfirst != _RANGES prev(_Glast)) {
_RANGES iter_swap(_Pfirst, _RANGES prev(_Glast));
}
}

if (_Glast == _First && _Gfirst == _Last) {
return {_STD move(_Pfirst), _STD move(_Plast)};
}

if (_Glast == _First) { // no room at bottom, rotate pivot upward
if (_Plast != _Gfirst) {
_RANGES iter_swap(_Pfirst, _Plast);
}

++_Plast;
_RANGES iter_swap(_Pfirst, _Gfirst);
++_Pfirst;
++_Gfirst;
} else if (_Gfirst == _Last) { // no room at top, rotate pivot downward
if (--_Glast != --_Pfirst) {
_RANGES iter_swap(_Glast, _Pfirst);
}

_RANGES iter_swap(_Pfirst, --_Plast);
} else {
_RANGES iter_swap(_Gfirst, --_Glast);
++_Gfirst;
}
}
}
// clang-format on
} // namespace ranges
#endif // __cpp_lib_concepts
#endif // _HAS_CXX17

// FUNCTION TEMPLATE stable_sort
Expand Down Expand Up @@ -7765,7 +7917,7 @@ _CONSTEXPR20 void nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last, _Pr _Pre
if (_UMid.second <= _UNth) {
_UFirst = _UMid.second;
} else if (_UMid.first <= _UNth) {
return; // Nth inside fat pivot, done
return; // _Nth is in the subrange of elements equal to the pivot; done
} else {
_ULast = _UMid.first;
}
Expand Down Expand Up @@ -7793,6 +7945,82 @@ void nth_element(_ExPo&&, _RanIt _First, _RanIt _Nth, _RanIt _Last) noexcept /*
// not parallelized at present, parallelism expected to be feasible in a future release
_STD nth_element(_First, _Nth, _Last);
}

#ifdef __cpp_lib_concepts
namespace ranges {
// VARIABLE ranges::nth_element
class _Nth_element_fn : private _Not_quite_object {
public:
using _Not_quite_object::_Not_quite_object;

// clang-format off
template <random_access_iterator _It, sentinel_for<_It> _Se, class _Pr = ranges::less, class _Pj = identity>
requires sortable<_It, _Pr, _Pj>
constexpr _It operator()(_It _First, _It _Nth, _Se _Last, _Pr _Pred = {}, _Pj _Proj = {}) const {
_Adl_verify_range(_First, _Nth);
_Adl_verify_range(_Nth, _Last);
auto _UNth = _Get_unwrapped(_Nth);
auto _UFinal = _Get_final_iterator_unwrapped<_It>(_UNth, _STD move(_Last));
_Seek_wrapped(_Nth, _UFinal);

_Nth_element_common(_Get_unwrapped(_STD move(_First)), _STD move(_UNth), _STD move(_UFinal),
_Pass_fn(_Pred), _Pass_fn(_Proj));
return _Nth;
}

template <random_access_range _Rng, class _Pr = ranges::less, class _Pj = identity>
requires sortable<iterator_t<_Rng>, _Pr, _Pj>
constexpr borrowed_iterator_t<_Rng> operator()(
_Rng&& _Range, iterator_t<_Rng> _Nth, _Pr _Pred = {}, _Pj _Proj = {}) const {
_Adl_verify_range(_RANGES begin(_Range), _Nth);
_Adl_verify_range(_Nth, _RANGES end(_Range));
auto _UNth = _Get_unwrapped(_Nth);
auto _UFinal = [&] {
if constexpr (common_range<_Rng>) {
return _Uend(_Range);
} else if constexpr (sized_range<_Rng>) {
return _RANGES next(_Ubegin(_Range), _RANGES distance(_Range));
} else {
return _RANGES next(_UNth, _Uend(_Range));
}
}();
_Seek_wrapped(_Nth, _UFinal);

_Nth_element_common(
_Ubegin(_Range), _STD move(_UNth), _STD move(_UFinal), _Pass_fn(_Pred), _Pass_fn(_Proj));
return _Nth;
}
// clang-format on
private:
template <class _It, class _Pr, class _Pj>
static constexpr void _Nth_element_common(_It _First, _It _Nth, _It _Last, _Pr _Pred, _Pj _Proj) {
_STL_INTERNAL_STATIC_ASSERT(random_access_iterator<_It>);
_STL_INTERNAL_STATIC_ASSERT(sortable<_It, _Pr, _Pj>);

if (_Nth == _Last) {
return; // nothing to do
}

while (_ISORT_MAX < _Last - _First) { // divide and conquer, ordering partition containing Nth
subrange<_It> _Mid = _RANGES _Partition_by_median_guess_unchecked(_First, _Last, _Pred, _Proj);

if (_Mid.end() <= _Nth) {
_First = _Mid.end();
} else if (_Mid.begin() <= _Nth) {
return; // _Nth is in the subrange of elements equal to the pivot; done
} else {
_Last = _Mid.begin();
}
}

// sort any remainder
_RANGES _Insertion_sort_common(_STD move(_First), _STD move(_Last), _STD move(_Pred), _STD move(_Proj));
}
};

inline constexpr _Nth_element_fn nth_element{_Not_quite_object::_Construct_tag{}};
} // namespace ranges
#endif // __cpp_lib_concepts
#endif // _HAS_CXX17

// FUNCTION TEMPLATE includes
Expand Down
1 change: 1 addition & 0 deletions tests/std/test.lst
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ tests\P0896R4_ranges_alg_mismatch
tests\P0896R4_ranges_alg_move
tests\P0896R4_ranges_alg_move_backward
tests\P0896R4_ranges_alg_none_of
tests\P0896R4_ranges_alg_nth_element
tests\P0896R4_ranges_alg_partition
tests\P0896R4_ranges_alg_partition_copy
tests\P0896R4_ranges_alg_partition_point
Expand Down
4 changes: 4 additions & 0 deletions tests/std/tests/P0896R4_ranges_alg_nth_element/env.lst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

RUNALL_INCLUDE ..\concepts_matrix.lst
82 changes: 82 additions & 0 deletions tests/std/tests/P0896R4_ranges_alg_nth_element/test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <algorithm>
#include <cassert>
#include <concepts>
#include <ranges>
#include <utility>

#include <range_algorithm_support.hpp>

using namespace std;

// Validate dangling story
STATIC_ASSERT(same_as<decltype(ranges::nth_element(borrowed<false>{}, nullptr_to<int>)), ranges::dangling>);
STATIC_ASSERT(same_as<decltype(ranges::nth_element(borrowed<true>{}, nullptr_to<int>)), int*>);

using P = pair<int, int>;

struct instantiator {
static constexpr int keys[] = {7, 6, 5, 4, 3, 2, 1, 0};

template <ranges::random_access_range R>
static constexpr void call() {
#if !defined(__clang__) && !defined(__EDG__) // TRANSITION, VSO-938163
#pragma warning(suppress : 4127) // conditional expression is constant
if (!ranges::contiguous_range<R> || !is_constant_evaluated())
#endif // TRANSITION, VSO-938163
{
using ranges::nth_element, ranges::all_of, ranges::find, ranges::iterator_t, ranges::less, ranges::none_of,
ranges::size;

P input[size(keys)];
const auto init = [&] {
for (size_t j = 0; j < size(keys); ++j) {
input[j] = P{keys[j], static_cast<int>(10 + j)};
}
};

// Validate range overload
for (int i = 0; i < int{size(keys)}; ++i) {
init();
const R wrapped{input};
const auto nth = wrapped.begin() + i;
const same_as<iterator_t<R>> auto result = nth_element(wrapped, nth, less{}, get_first);
assert(result == wrapped.end());
assert((*nth == P{i, static_cast<int>(10 + (find(keys, i) - keys))}));
if (nth != wrapped.end()) {
assert(all_of(wrapped.begin(), nth, [&](auto&& x) { return get_first(x) <= get_first(*nth); }));
assert(all_of(nth, wrapped.end(), [&](auto&& x) { return get_first(*nth) <= get_first(x); }));
}
}

// Validate iterator overload
for (int i = 0; i < int{size(keys)}; ++i) {
init();
const R wrapped{input};
const auto nth = wrapped.begin() + i;
const same_as<iterator_t<R>> auto result =
nth_element(wrapped.begin(), nth, wrapped.end(), less{}, get_first);
assert(result == wrapped.end());
assert((input[i] == P{i, static_cast<int>(10 + (find(keys, i) - keys))}));
if (nth != wrapped.end()) {
assert(all_of(wrapped.begin(), nth, [&](auto&& x) { return get_first(x) <= get_first(*nth); }));
assert(all_of(nth, wrapped.end(), [&](auto&& x) { return get_first(*nth) <= get_first(x); }));
}
}

{
// Validate empty range
const R range{};
const same_as<iterator_t<R>> auto result = nth_element(range, range.begin(), less{}, get_first);
assert(result == range.end());
}
}
}
};

int main() {
STATIC_ASSERT((test_random<instantiator, P>(), true));
test_random<instantiator, P>();
}

0 comments on commit e58a945

Please sign in to comment.