Skip to content

Commit e481dc2

Browse files
Made count_true a new format callable (#2075)
1 parent f19fea0 commit e481dc2

File tree

7 files changed

+314
-187
lines changed

7 files changed

+314
-187
lines changed

include/eve/detail/abi.hpp

+19
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,25 @@
3333
# endif
3434
#endif
3535

36+
// Assume an expression to be true at compile time
37+
#if defined(EVE_NO_ASSUME)
38+
# define EVE_ASSUME(...)
39+
#else
40+
# if defined(__clang__)
41+
# define EVE_ASSUME(...) do { __builtin_assume(__VA_ARGS__); } while(0)
42+
# elif defined(__GNUC__)
43+
# if __GNUC__ >= 13
44+
# define EVE_ASSUME(...) do { __attribute__((__assume__(__VA_ARGS__))); } while(0)
45+
# else
46+
# define EVE_ASSUME(...) do { if (!bool{__VA_ARGS__}) __builtin_unreachable(); } while(0)
47+
# endif
48+
# elif defined(_MSC_VER)
49+
# define EVE_ASSUME(...) do { __assume(__VA_ARGS__); } while(0)
50+
# else
51+
# define EVE_ASSUME(...)
52+
# endif
53+
#endif
54+
3655
// Captures math related options and translate to proper setup
3756
#if defined(__FAST_MATH__) && !defined(EVE_FAST_MATH)
3857
# define EVE_FAST_MATH

include/eve/module/core/regular/count_true.hpp

+73-63
Original file line numberDiff line numberDiff line change
@@ -11,69 +11,79 @@
1111

1212
namespace eve
1313
{
14-
//================================================================================================
15-
//! @addtogroup core_reduction
16-
//! @{
17-
//! @var count_true
18-
//! @brief Computes the number of non 0 elements
19-
//!
20-
//! @groupheader{Header file}
21-
//!
22-
//! @code
23-
//! #include <eve/module/core.hpp>
24-
//! @endcode
25-
//!
26-
//! @groupheader{Callable Signatures}
27-
//!
28-
//! @code
29-
//! namespace eve
30-
//! {
31-
//! // Regular overloads
32-
//! constexpr auto count_true(logical_value auto x) noexcept; // 1
33-
//! constexpr auto count_true(top_bits auto t) noexcept; // 1
34-
//!
35-
//! // Lanes masking
36-
//! constexpr auto count_true[conditional_expr auto c](/* any of the above overloads */) noexcept; // 2
37-
//! constexpr auto count_true[logical_value auto m](/* any of the above overloads */) noexcept; // 2
38-
//! }
39-
//! @endcode
40-
//!
41-
//! **Parameters**
42-
//!
43-
//! * `x`: [argument](@ref eve::logical_value).
44-
//! * `t`: [top bits](@ref top_bits).
45-
//! * `c`: [Conditional expression](@ref eve::conditional_expr) masking the operation.
46-
//! * `m`: [Logical value](@ref eve::logical_value) masking the operation.
47-
//!
48-
//! **Return value**
49-
//!
50-
//! 1. The value in the element type of `x` of the number of non 0 elements.
51-
//! 2. A masked version which return the number of true retained elements.
52-
//!
53-
//! **Parameters**
54-
//!
55-
//! * `x` : [argument](@ref eve::value).
56-
//!
57-
//! **Return value**
58-
//!
59-
//! The value of the number of non 0 elements
60-
//! is returned.
61-
//!
62-
//! @groupheader{Example}
63-
//!
64-
//! @godbolt{doc/core/count_true.cpp}
65-
//! @groupheader{Semantic Modifiers}
66-
//!
67-
//! * Masked Call
68-
//!
69-
//! The call `eve::$name$[mask](x, ...)` provides a masked
70-
//! version of `count_true which count the non masked non zero element
71-
//!
72-
//================================================================================================
73-
EVE_MAKE_CALLABLE(count_true_, count_true);
74-
//================================================================================================
75-
//! @}
76-
//================================================================================================
14+
template<typename Options>
15+
struct count_true_t : conditional_callable<count_true_t, Options>
16+
{
17+
template<relaxed_logical_value L>
18+
EVE_FORCEINLINE std::ptrdiff_t operator()(L v) const noexcept
19+
{
20+
static_assert(detail::validate_mask_for<decltype(this->options()), L>(),
21+
"[eve::count_true] - Cannot use a relative conditional expression or a simd value to mask a scalar value");
22+
23+
return EVE_DISPATCH_CALL(v);
24+
}
25+
26+
template<logical_simd_value L>
27+
EVE_FORCEINLINE std::ptrdiff_t operator()(top_bits<L> v) const noexcept
28+
{
29+
return EVE_DISPATCH_CALL(v);
30+
}
31+
32+
EVE_CALLABLE_OBJECT(count_true_t, count_true_);
33+
};
34+
35+
//================================================================================================
36+
//! @addtogroup core_reduction
37+
//! @{
38+
//! @var count_true
39+
//! @brief Computes the number of elements of the input which evaluates to `true`.
40+
//!
41+
//! @groupheader{Header file}
42+
//!
43+
//! @code
44+
//! #include <eve/module/core.hpp>
45+
//! @endcode
46+
//!
47+
//! @groupheader{Callable Signatures}
48+
//!
49+
//! @code
50+
//! namespace eve
51+
//! {
52+
//! // Regular overloads
53+
//! template<relaxed_logical_value L>
54+
//! std::ptrdiff_t count_true(L x) noexcept; // 1
55+
//!
56+
//! template<relaxed_logical_value L>
57+
//! std::ptrdiff_t count_true(top_bits<L> t) noexcept; // 1
58+
//!
59+
//! // Lanes masking
60+
//! std::ptrdiff_t count_true[conditional_expr auto c](/* any of the above overloads */) noexcept; // 2
61+
//! std::ptrdiff_t count_true[logical_value auto m](/* any of the above overloads */) noexcept; // 2
62+
//! }
63+
//! @endcode
64+
//!
65+
//! **Parameters**
66+
//!
67+
//! * `x`: [argument](@ref eve::logical_value).
68+
//! * `t`: [top bits](@ref top_bits).
69+
//! * `c`: [Conditional expression](@ref eve::conditional_expr) masking the operation.
70+
//! * `m`: [Logical value](@ref eve::logical_value) masking the operation.
71+
//!
72+
//! **Return value**
73+
//!
74+
//! 1. The number of elements in `x` which evaluates to `true`. Scalar values are treated as one element.
75+
//! 2. The masked version which return the number of non-masked true elements.
76+
//!
77+
//! @groupheader{Example}
78+
//!
79+
//! @godbolt{doc/core/count_true.cpp}
80+
//! @groupheader{Semantic Modifiers}
81+
//!
82+
//================================================================================================
83+
inline constexpr auto count_true = functor<count_true_t>;
84+
//================================================================================================
85+
//! @}
86+
//================================================================================================
7787
}
7888

7989
#include <eve/module/core/regular/impl/count_true.hpp>

include/eve/module/core/regular/impl/count_true.hpp

+85-26
Original file line numberDiff line numberDiff line change
@@ -16,35 +16,94 @@
1616

1717
namespace eve::detail
1818
{
19-
EVE_FORCEINLINE std::ptrdiff_t
20-
count_true_(EVE_SUPPORTS(cpu_), bool v) noexcept
21-
{
22-
return v ? 1 : 0;
23-
}
19+
template<callable_options O>
20+
EVE_FORCEINLINE std::ptrdiff_t count_true_(EVE_REQUIRES(cpu_), O const& opts, bool v) noexcept
21+
{
22+
if constexpr (match_option<condition_key, O, ignore_none_>)
23+
{
24+
return v ? 1 : 0;
25+
}
26+
else
27+
{
28+
return opts[condition_key].mask(as(v)) && v ? 1 : 0;
29+
}
30+
}
2431

25-
template<value T>
26-
EVE_FORCEINLINE std::ptrdiff_t
27-
count_true_(EVE_SUPPORTS(cpu_), logical<T> v) noexcept
28-
{
29-
if constexpr( scalar_value<T> ) return v.value() ? 1 : 0;
30-
else return count_true(eve::top_bits {v});
31-
}
32+
template<callable_options O, value T>
33+
EVE_FORCEINLINE std::ptrdiff_t count_true_(EVE_REQUIRES(cpu_), O const& opts, logical<T> v) noexcept
34+
{
35+
using C = rbr::result::fetch_t<condition_key, O>;
36+
const auto cx = opts[condition_key];
3237

33-
template<simd_value T, relative_conditional_expr C>
34-
EVE_FORCEINLINE std::ptrdiff_t
35-
count_true_(EVE_SUPPORTS(cpu_), C cond, logical<T> v) noexcept
36-
{
37-
return count_true(top_bits {v, cond});
38-
}
38+
constexpr bool relative_nonignore = relative_conditional_expr<C> && !std::same_as<C, ignore_none_>;
3939

40-
template<logical_simd_value Logical>
41-
EVE_FORCEINLINE std::ptrdiff_t
42-
count_true_(EVE_SUPPORTS(cpu_), top_bits<Logical> mmask) noexcept
43-
{
44-
if constexpr( !top_bits<Logical>::is_aggregated )
40+
if constexpr (scalar_value<T>) return count_true[cx](v.value());
41+
else if constexpr (C::is_complete && !C::is_inverted) return 0;
42+
else if constexpr (has_emulated_abi_v<T>)
43+
{
44+
std::ptrdiff_t count = 0;
45+
46+
if constexpr (relative_conditional_expr<C>)
47+
{
48+
const std::ptrdiff_t begin = cx.offset(as(v));
49+
const std::ptrdiff_t end = begin + cx.count(as(v));
50+
constexpr std::ptrdiff_t size = T::size();
51+
52+
EVE_ASSUME((begin >= 0) && (begin <= end) && (end <= size));
53+
54+
for (std::ptrdiff_t i = begin; i < end; ++i)
55+
{
56+
count += v.get(i);
57+
}
58+
}
59+
else
60+
{
61+
auto mask = expand_mask(cx, as(v));
62+
for (std::ptrdiff_t i = 0; i < v.size(); ++i)
63+
{
64+
count += (v.get(i) && mask.get(i));
65+
}
66+
}
67+
68+
return count;
69+
}
70+
else if constexpr (relative_nonignore) return count_true(top_bits{v, cx});
71+
else return count_true[cx](eve::top_bits{v});
72+
}
73+
74+
template<callable_options O, logical_simd_value Logical>
75+
EVE_FORCEINLINE std::ptrdiff_t count_true_(EVE_REQUIRES(cpu_), O const& opts, top_bits<Logical> mmask) noexcept
4576
{
46-
return std::popcount(mmask.as_int()) / top_bits<Logical>::bits_per_element;
77+
if constexpr (match_option<condition_key, O, ignore_none_>)
78+
{
79+
if constexpr (top_bits<Logical>::is_aggregated)
80+
{
81+
return count_true(mmask.storage[0]) + count_true(mmask.storage[1]);
82+
}
83+
else
84+
{
85+
return std::popcount(mmask.as_int()) / top_bits<Logical>::bits_per_element;
86+
}
87+
}
88+
else
89+
{
90+
using C = rbr::result::fetch_t<condition_key, O>;
91+
auto cx = opts[condition_key];
92+
93+
if constexpr (top_bits<Logical>::is_aggregated)
94+
{
95+
auto [cx_l, cx_h] = expand_mask(cx, as<Logical>()).slice();
96+
return count_true[cx_l](mmask.storage[0]) + count_true[cx_h](mmask.storage[1]);
97+
}
98+
else
99+
{
100+
auto vm = mmask.as_int();
101+
102+
if constexpr (relative_conditional_expr<C>) vm &= top_bits<Logical>{cx}.as_int();
103+
else vm &= top_bits{expand_mask(cx, as<Logical>())}.as_int();
104+
105+
return std::popcount(vm) / top_bits<Logical>::bits_per_element;
106+
}
107+
}
47108
}
48-
else { return count_true(mmask.storage[0]) + count_true(mmask.storage[1]); }
49-
}
50109
}

include/eve/module/core/regular/impl/simd/arm/sve/all.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ requires sve_abi<abi_t<T, N>>
2323
if constexpr( C::is_inverted ) return count_true(v) == N::value;
2424
else return true;
2525
}
26-
else return count_true(cond, v) == cond.count(as<wide<T,N>>());
26+
else return count_true[cond](v) == cond.count(as<wide<T,N>>());
2727
}
2828
}

include/eve/module/core/regular/impl/simd/arm/sve/count_true.hpp

+10-18
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,15 @@
1313

1414
namespace eve::detail
1515
{
16-
template<scalar_value T, typename N, relative_conditional_expr C>
17-
EVE_FORCEINLINE std::ptrdiff_t
18-
count_true_(EVE_SUPPORTS(cpu_), C cond, logical<wide<T,N>> v) noexcept
19-
requires sve_abi<abi_t<T, N>>
20-
{
21-
auto const m = cond.mask(as<wide<T,N>>{});
22-
if constexpr(sizeof(T) == 1) return svcntp_b8(m,v);
23-
else if constexpr(sizeof(T) == 2) return svcntp_b16(m,v);
24-
else if constexpr(sizeof(T) == 4) return svcntp_b32(m,v);
25-
else if constexpr(sizeof(T) == 8) return svcntp_b64(m,v);
26-
}
16+
template<callable_options O, arithmetic_scalar_value T, typename N>
17+
EVE_FORCEINLINE std::ptrdiff_t count_true_(EVE_REQUIRES(sve_), O const& opts, logical<wide<T,N>> v) noexcept
18+
requires sve_abi<abi_t<T, N>>
19+
{
20+
auto const m = expand_mask(opts[condition_key], as<wide<T, N>>{});
2721

28-
template<scalar_value T, typename N>
29-
EVE_FORCEINLINE std::ptrdiff_t
30-
count_true_(EVE_SUPPORTS(sve_), logical<wide<T,N>> v) noexcept
31-
requires sve_abi<abi_t<T, N>>
32-
{
33-
return count_true[ignore_none](v);
34-
}
22+
if constexpr (sizeof(T) == 1) return svcntp_b8(m, v);
23+
else if constexpr (sizeof(T) == 2) return svcntp_b16(m, v);
24+
else if constexpr (sizeof(T) == 4) return svcntp_b32(m, v);
25+
else if constexpr (sizeof(T) == 8) return svcntp_b64(m, v);
26+
}
3527
}

include/eve/module/core/regular/impl/simd/riscv/count_true.hpp

+20-28
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,28 @@
1313

1414
namespace eve::detail
1515
{
16-
template<scalar_value T, typename N, relative_conditional_expr C>
17-
EVE_FORCEINLINE std::ptrdiff_t
18-
count_true_(EVE_SUPPORTS(cpu_), C cond, logical<wide<T, N>> v) noexcept
19-
requires rvv_abi<abi_t<T, N>>
20-
{
21-
if constexpr( C::is_complete )
16+
template<callable_options O, scalar_value T, typename N>
17+
EVE_FORCEINLINE std::ptrdiff_t count_true_(EVE_REQUIRES(rvv_), O const& opts, logical<wide<T, N>> v) noexcept
18+
requires rvv_abi<abi_t<T, N>>
2219
{
23-
if constexpr( !C::is_inverted ) return 0;
24-
else return __riscv_vcpop(v, N::value);
20+
using C = rbr::result::fetch_t<condition_key, O>;
21+
22+
if constexpr (C::is_complete)
23+
{
24+
if constexpr (!C::is_inverted) return 0;
25+
else return __riscv_vcpop(v, N::value);
26+
}
27+
else
28+
{
29+
const auto m = expand_mask(opts[condition_key], as<wide<T, N>>{});
30+
return __riscv_vcpop(m, v, N::value);
31+
}
2532
}
26-
else
33+
34+
template<callable_options O, scalar_value T, typename N>
35+
EVE_FORCEINLINE std::ptrdiff_t count_true_(EVE_REQUIRES(rvv_), O const& opts, top_bits<logical<wide<T, N>>> v) noexcept
36+
requires rvv_abi<abi_t<T, N>>
2737
{
28-
auto const m = cond.mask(as<wide<T, N>> {});
29-
return __riscv_vcpop(m, v, N::value);
38+
return count_true.behavior(current_api, opts, v.storage);
3039
}
3140
}
32-
33-
template<scalar_value T, typename N>
34-
EVE_FORCEINLINE std::ptrdiff_t
35-
count_true_(EVE_SUPPORTS(rvv_), logical<wide<T, N>> v) noexcept
36-
requires rvv_abi<abi_t<T, N>>
37-
{
38-
return count_true[ignore_none](v);
39-
}
40-
41-
template<scalar_value T, typename N>
42-
EVE_FORCEINLINE std::ptrdiff_t
43-
count_true_(EVE_SUPPORTS(rvv_), top_bits<logical<wide<T, N>>> v) noexcept
44-
requires rvv_abi<abi_t<T, N>>
45-
{
46-
return count_true[ignore_none](v.storage);
47-
}
48-
}

0 commit comments

Comments
 (0)