|
17 | 17 |
|
18 | 18 | namespace {
|
19 | 19 |
|
20 |
| -// Computes a horizontal sum over an __m256 register |
21 |
| -inline float horizontal_sum(const __m256 reg) { |
22 |
| - const __m256 h0 = _mm256_hadd_ps(reg, reg); |
23 |
| - const __m256 h1 = _mm256_hadd_ps(h0, h0); |
24 |
| - |
25 |
| - // extract high and low __m128 regs from __m256 |
26 |
| - const __m128 h2 = _mm256_extractf128_ps(h1, 1); |
27 |
| - const __m128 h3 = _mm256_castps256_ps128(h1); |
28 |
| - |
29 |
| - // get a final hsum into all 4 regs |
30 |
| - const __m128 h4 = _mm_add_ss(h2, h3); |
| 20 | +inline float horizontal_sum(const __m128 v) { |
| 21 | + const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2)); |
| 22 | + const __m128 v1 = _mm_add_ps(v, v0); |
| 23 | + __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1)); |
| 24 | + const __m128 v3 = _mm_add_ps(v1, v2); |
| 25 | + return _mm_cvtss_f32(v3); |
| 26 | +} |
31 | 27 |
|
32 |
| - // extract f[0] from __m128 |
33 |
| - const float hsum = _mm_cvtss_f32(h4); |
34 |
| - return hsum; |
| 28 | +// Computes a horizontal sum over an __m256 register |
| 29 | +inline float horizontal_sum(const __m256 v) { |
| 30 | + const __m128 v0 = |
| 31 | + _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1)); |
| 32 | + return horizontal_sum(v0); |
35 | 33 | }
|
36 | 34 |
|
37 | 35 | } // namespace
|
|
0 commit comments