Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize lexicographical_compare! #4552

Merged
27 changes: 21 additions & 6 deletions benchmarks/src/mismatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ using namespace std;

constexpr int64_t no_pos = -1;

template <class T>
enum class op {
mismatch,
lexi,
};

template <class T, op Op>
void bm(benchmark::State& state) {
vector<T> a(static_cast<size_t>(state.range(0)), T{'.'});
vector<T> b(static_cast<size_t>(state.range(0)), T{'.'});
Expand All @@ -22,15 +27,25 @@ void bm(benchmark::State& state) {
}

for (auto _ : state) {
benchmark::DoNotOptimize(ranges::mismatch(a, b));
if constexpr (Op == op::mismatch) {
benchmark::DoNotOptimize(ranges::mismatch(a, b));
} else if constexpr (Op == op::lexi) {
benchmark::DoNotOptimize(ranges::lexicographical_compare(a, b));
}
}
}

#define COMMON_ARGS Args({8, 3})->Args({24, 22})->Args({105, -1})->Args({4021, 3056})

BENCHMARK(bm<uint8_t>)->COMMON_ARGS;
BENCHMARK(bm<uint16_t>)->COMMON_ARGS;
BENCHMARK(bm<uint32_t>)->COMMON_ARGS;
BENCHMARK(bm<uint64_t>)->COMMON_ARGS;
BENCHMARK(bm<uint8_t, op::mismatch>)->COMMON_ARGS;
BENCHMARK(bm<uint16_t, op::mismatch>)->COMMON_ARGS;
BENCHMARK(bm<uint32_t, op::mismatch>)->COMMON_ARGS;
BENCHMARK(bm<uint64_t, op::mismatch>)->COMMON_ARGS;

BENCHMARK(bm<uint8_t, op::lexi>)->COMMON_ARGS;
BENCHMARK(bm<int8_t, op::lexi>)->COMMON_ARGS;
BENCHMARK(bm<uint16_t, op::lexi>)->COMMON_ARGS;
BENCHMARK(bm<uint32_t, op::lexi>)->COMMON_ARGS;
BENCHMARK(bm<uint64_t, op::lexi>)->COMMON_ARGS;

BENCHMARK_MAIN();
16 changes: 15 additions & 1 deletion stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -10886,8 +10886,22 @@ namespace ranges {
_Num2 = SIZE_MAX;
}

const int _Ans = _STD _Memcmp_count(_First1, _First2, (_STD min)(_Num1, _Num2));
const size_t _Num = (_STD min)(_Num1, _Num2);
#if _USE_STD_VECTOR_ALGORITHMS
const auto _First1_ptr = _STD to_address(_First1);
const auto _First2_ptr = _STD to_address(_First2);
const size_t _Pos = __std_mismatch<sizeof(*_First1_ptr)>(_First1_ptr, _First2_ptr, _Num);
if (_Pos == _Num2) {
return false;
} else if (_Pos == _Num1) {
return true;
} else {
return _STD invoke(_Pred, _First1_ptr[_Pos], _First2_ptr[_Pos]);
}
#else // ^^^ _USE_STD_VECTOR_ALGORITHMS / !_USE_STD_VECTOR_ALGORITHMS vvv
const int _Ans = _STD _Memcmp_count(_First1, _First2, _Num);
return _Memcmp_classification_pred{}(_Ans, 0) || (_Ans == 0 && _Num1 < _Num2);
#endif // ^^^ !_USE_STD_VECTOR_ALGORITHMS ^^^
}
}

Expand Down
50 changes: 42 additions & 8 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -5609,8 +5609,14 @@ namespace ranges {
#endif // _HAS_CXX20

template <class _Elem1, class _Elem2>
_INLINE_VAR constexpr bool _Lex_compare_memcmp_classify_elements = conjunction_v<_Is_character_or_bool<_Elem1>,
_Is_character_or_bool<_Elem2>, is_unsigned<_Elem1>, is_unsigned<_Elem2>>;
constexpr bool _Lex_compare_memcmp_classify_elements =
#if _USE_STD_VECTOR_ALGORITHMS
is_integral_v<_Elem1> && is_integral_v<_Elem2> && sizeof(_Elem1) == sizeof(_Elem2)
&& is_unsigned_v<_Elem1> == is_unsigned_v<_Elem2>;
#else // ^^^ _USE_STD_VECTOR_ALGORITHMS / !_USE_STD_VECTOR_ALGORITHMS vvv
conjunction_v<_Is_character_or_bool<_Elem1>, _Is_character_or_bool<_Elem2>, is_unsigned<_Elem1>,
is_unsigned<_Elem2>>;
#endif // ^^^ !_USE_STD_VECTOR_ALGORITHMS ^^^

#ifdef __cpp_lib_byte
template <>
Expand Down Expand Up @@ -5682,10 +5688,24 @@ _NODISCARD _CONSTEXPR20 bool lexicographical_compare(
if (!_STD is_constant_evaluated())
#endif // _HAS_CXX20
{
const auto _Num1 = static_cast<size_t>(_ULast1 - _UFirst1);
const auto _Num2 = static_cast<size_t>(_ULast2 - _UFirst2);
const int _Ans = _STD _Memcmp_count(_UFirst1, _UFirst2, (_STD min)(_Num1, _Num2));
const auto _Num1 = static_cast<size_t>(_ULast1 - _UFirst1);
const auto _Num2 = static_cast<size_t>(_ULast2 - _UFirst2);
const size_t _Num = (_STD min)(_Num1, _Num2);
#if _USE_STD_VECTOR_ALGORITHMS
const auto _First1_ptr = _STD _To_address(_UFirst1);
const auto _First2_ptr = _STD _To_address(_UFirst2);
const size_t _Pos = __std_mismatch<sizeof(*_First1_ptr)>(_First1_ptr, _First2_ptr, _Num);
if (_Pos == _Num2) {
return false;
} else if (_Pos == _Num1) {
return true;
} else {
return _Pred(_First1_ptr[_Pos], _First2_ptr[_Pos]);
}
#else // ^^^ _USE_STD_VECTOR_ALGORITHMS / !_USE_STD_VECTOR_ALGORITHMS vvv
const int _Ans = _STD _Memcmp_count(_UFirst1, _UFirst2, _Num);
return _Memcmp_pred{}(_Ans, 0) || (_Ans == 0 && _Num1 < _Num2);
#endif // ^^^ !_USE_STD_VECTOR_ALGORITHMS ^^^
}
}

Expand Down Expand Up @@ -5782,14 +5802,28 @@ _NODISCARD constexpr auto lexicographical_compare_three_way(const _InIt1 _First1
using _Memcmp_pred = _Lex_compare_three_way_memcmp_classify<decltype(_UFirst1), decltype(_UFirst2), _Cmp>;
if constexpr (!is_void_v<_Memcmp_pred>) {
if (!_STD is_constant_evaluated()) {
const auto _Num1 = static_cast<size_t>(_ULast1 - _UFirst1);
const auto _Num2 = static_cast<size_t>(_ULast2 - _UFirst2);
const int _Ans = _STD _Memcmp_count(_UFirst1, _UFirst2, (_STD min)(_Num1, _Num2));
const auto _Num1 = static_cast<size_t>(_ULast1 - _UFirst1);
const auto _Num2 = static_cast<size_t>(_ULast2 - _UFirst2);
const size_t _Num = (_STD min)(_Num1, _Num2);
#if _USE_STD_VECTOR_ALGORITHMS
const auto _First1_ptr = _STD to_address(_UFirst1);
const auto _First2_ptr = _STD to_address(_UFirst2);
const size_t _Pos = __std_mismatch<sizeof(*_First1_ptr)>(_First1_ptr, _First2_ptr, _Num);
if (_Pos == _Num1) {
return _Pos == _Num2 ? strong_ordering::equal : strong_ordering::less;
} else if (_Pos == _Num2) {
return strong_ordering::greater;
} else {
return _Comp(_First1_ptr[_Pos], _First2_ptr[_Pos]);
}
#else // ^^^ _USE_STD_VECTOR_ALGORITHMS / !_USE_STD_VECTOR_ALGORITHMS vvv
const int _Ans = _STD _Memcmp_count(_UFirst1, _UFirst2, _Num);
if (_Ans == 0) {
return _Num1 <=> _Num2;
} else {
return _Memcmp_pred{}(_Ans, 0);
}
#endif // ^^^ !_USE_STD_VECTOR_ALGORITHMS ^^^
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

RUNALL_INCLUDE ..\char8_t_matrix.lst
RUNALL_CROSSLIST
* PM_CL="" # Test memcmp and manual vectorization
* PM_CL="/D_USE_STD_VECTOR_ALGORITHMS=0" # Test memcmp only
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,10 @@ void test_lex_compare_memcmp_classify_for_types() {
test_lex_compare_memcmp_classify_for_pred<expected_less, Type1, Type2, less<>>();
test_lex_compare_memcmp_classify_for_pred<expected_greater, Type1, Type2, greater<>>();

test_lex_compare_memcmp_classify_for_pred<void, Type1, Type2, less<int>>();
test_lex_compare_memcmp_classify_for_pred<void, Type1, Type2, greater<int>>();
using bigger_type = conditional_t<sizeof(Type1) == sizeof(int), long long, int>;

test_lex_compare_memcmp_classify_for_pred<void, Type1, Type2, less<bigger_type>>();
test_lex_compare_memcmp_classify_for_pred<void, Type1, Type2, greater<bigger_type>>();

test_lex_compare_memcmp_classify_for_pred<void, Type1, Type2, less<volatile Type1>>();
test_lex_compare_memcmp_classify_for_pred<void, Type1, Type2, greater<volatile Type1>>();
Expand Down Expand Up @@ -176,13 +178,15 @@ void test_lex_compare_memcmp_classify_for_types() {
#endif // _HAS_CXX20
}

constexpr bool vec_alg = _USE_STD_VECTOR_ALGORITHMS;

template <bool Expected, class Type1, class Type2>
void test_lex_compare_memcmp_classify_for_1byte_integrals() {
test_lex_compare_memcmp_classify_for_types<Expected, Type1, Type2>();

test_lex_compare_memcmp_classify_for_opaque_preds<is_unsigned_v<char>, Type1, Type2, char>();
test_lex_compare_memcmp_classify_for_opaque_preds<is_unsigned_v<char> || vec_alg, Type1, Type2, char>();
test_lex_compare_memcmp_classify_for_opaque_preds<true, Type1, Type2, unsigned char>();
test_lex_compare_memcmp_classify_for_opaque_preds<false, Type1, Type2, signed char>();
test_lex_compare_memcmp_classify_for_opaque_preds<vec_alg, Type1, Type2, signed char>();
#ifdef __cpp_lib_char8_t
test_lex_compare_memcmp_classify_for_opaque_preds<true, Type1, Type2, char8_t>();
#endif // __cpp_lib_char8_t
Expand Down Expand Up @@ -225,13 +229,13 @@ bool operator<(const user_struct&, const user_struct&) {

void lex_compare_memcmp_classify_test_cases() {
// Allow unsigned 1 byte integrals
test_lex_compare_memcmp_classify_for_1byte_integrals<is_unsigned_v<char>, char, char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<is_unsigned_v<char> || vec_alg, char, char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<is_unsigned_v<char>, unsigned char, char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<is_unsigned_v<char>, char, unsigned char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<true, unsigned char, unsigned char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<false, signed char, signed char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<false, char, signed char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<false, signed char, char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<vec_alg, signed char, signed char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<vec_alg && is_signed_v<char>, char, signed char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<vec_alg && is_signed_v<char>, signed char, char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<false, unsigned char, signed char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<false, signed char, unsigned char>();
#ifdef __cpp_lib_char8_t
Expand All @@ -252,8 +256,8 @@ void lex_compare_memcmp_classify_test_cases() {
test_lex_compare_memcmp_classify_for_1byte_integrals<true, bool, char8_t>();
test_lex_compare_memcmp_classify_for_1byte_integrals<true, char8_t, bool>();
#endif // __cpp_lib_char8_t
test_lex_compare_memcmp_classify_for_1byte_integrals<false, char, bool>();
test_lex_compare_memcmp_classify_for_1byte_integrals<false, bool, char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<vec_alg && is_unsigned_v<char>, char, bool>();
test_lex_compare_memcmp_classify_for_1byte_integrals<vec_alg && is_unsigned_v<char>, bool, char>();
test_lex_compare_memcmp_classify_for_1byte_integrals<false, signed char, bool>();
test_lex_compare_memcmp_classify_for_1byte_integrals<false, bool, signed char>();

Expand Down Expand Up @@ -281,10 +285,10 @@ void lex_compare_memcmp_classify_test_cases() {
// Don't allow bigger integrals
test_lex_compare_memcmp_classify_for_types<false, unsigned char, int>();
test_lex_compare_memcmp_classify_for_types<false, int, unsigned char>();
test_lex_compare_memcmp_classify_for_types<false, int, int>();
test_lex_compare_memcmp_classify_for_types<false, unsigned int, unsigned int>();
test_lex_compare_memcmp_classify_for_types<false, short, short>();
test_lex_compare_memcmp_classify_for_types<false, unsigned short, unsigned short>();
test_lex_compare_memcmp_classify_for_types<vec_alg, int, int>();
test_lex_compare_memcmp_classify_for_types<vec_alg, unsigned int, unsigned int>();
test_lex_compare_memcmp_classify_for_types<vec_alg, short, short>();
test_lex_compare_memcmp_classify_for_types<vec_alg, unsigned short, unsigned short>();

// Don't allow pointers
test_lex_compare_memcmp_classify_for_types<false, int*, int*>();
Expand All @@ -298,9 +302,11 @@ void lex_compare_memcmp_classify_test_cases() {
test_lex_compare_memcmp_classify_for_pred<less<int>, char8_t, char8_t, _Char_traits_lt<char_traits<char8_t>>>();
#endif // __cpp_lib_char8_t

test_lex_compare_memcmp_classify_for_pred<void, wchar_t, wchar_t, _Char_traits_lt<char_traits<wchar_t>>>();
test_lex_compare_memcmp_classify_for_pred<void, char16_t, char16_t, _Char_traits_lt<char_traits<char16_t>>>();
test_lex_compare_memcmp_classify_for_pred<void, char32_t, char32_t, _Char_traits_lt<char_traits<char32_t>>>();
using vless = conditional_t<vec_alg, less<int>, void>;

test_lex_compare_memcmp_classify_for_pred<vless, wchar_t, wchar_t, _Char_traits_lt<char_traits<wchar_t>>>();
test_lex_compare_memcmp_classify_for_pred<vless, char16_t, char16_t, _Char_traits_lt<char_traits<char16_t>>>();
test_lex_compare_memcmp_classify_for_pred<vless, char32_t, char32_t, _Char_traits_lt<char_traits<char32_t>>>();

// Test different containers
#if _HAS_CXX20
Expand Down
Loading