Skip to content

Commit 20af46d

Browse files
Alexandr Guzhvafacebook-github-bot
Alexandr Guzhva
authored andcommitted
AVX2 version of faiss::HNSW::MinimaxHeap::pop_min() (facebookresearch#2874)
Summary: Pull Request resolved: facebookresearch#2874 Differential Revision: D46125506 fbshipit-source-id: 325ccca6eb47eefa567c382ba7db2459d348abf5
1 parent cd7b943 commit 20af46d

File tree

3 files changed

+204
-0
lines changed

3 files changed

+204
-0
lines changed

faiss/impl/HNSW.cpp

+92
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616
#include <faiss/impl/IDSelector.h>
1717
#include <faiss/utils/prefetch.h>
1818

19+
#include <faiss/impl/platform_macros.h>
20+
21+
#ifdef __AVX2__
22+
#include <immintrin.h>
23+
#include <type_traits>
24+
#endif
25+
1926
namespace faiss {
2027

2128
/**************************************************************
@@ -1010,6 +1017,90 @@ void HNSW::MinimaxHeap::clear() {
10101017
nvalid = k = 0;
10111018
}
10121019

1020+
#ifdef __AVX2__
1021+
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1022+
assert(k > 0);
1023+
static_assert(
1024+
std::is_same<storage_idx_t, int32_t>::value,
1025+
"This code expects storage_idx_t to be int32_t");
1026+
1027+
int32_t min_idx = -1;
1028+
float min_dis = std::numeric_limits<float>::max();
1029+
1030+
size_t iii = 0;
1031+
1032+
__m256i min_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1033+
__m256 min_distances = _mm256_set1_ps(std::numeric_limits<float>::max());
1034+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1035+
__m256i offset = _mm256_set1_epi32(8);
1036+
1037+
// The baseline version is available in non-AVX2 branch.
1038+
1039+
// The following loop tracks the rightmost index with the min distance.
1040+
// -1 index values are ignored.
1041+
const int k8 = (k / 8) * 8;
1042+
for (; iii < k8; iii += 8) {
1043+
__m256i indices =
1044+
_mm256_loadu_si256((const __m256i*)(ids.data() + iii));
1045+
__m256 distances = _mm256_loadu_ps(dis.data() + iii);
1046+
1047+
// This mask filters out -1 values among indices.
1048+
__m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
1049+
1050+
__m256i dmask = _mm256_castps_si256(
1051+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
1052+
__m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
1053+
1054+
const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
1055+
_mm256_castsi256_ps(current_indices),
1056+
_mm256_castsi256_ps(min_indices),
1057+
finalmask));
1058+
1059+
const __m256 min_distances_new =
1060+
_mm256_blendv_ps(distances, min_distances, finalmask);
1061+
1062+
min_indices = min_indices_new;
1063+
min_distances = min_distances_new;
1064+
1065+
current_indices = _mm256_add_epi32(current_indices, offset);
1066+
}
1067+
1068+
// Vectorizing is doable, but is not practical
1069+
int32_t vidx8[8];
1070+
float vdis8[8];
1071+
_mm256_storeu_ps(vdis8, min_distances);
1072+
_mm256_storeu_si256((__m256i*)vidx8, min_indices);
1073+
1074+
for (size_t j = 0; j < 8; j++) {
1075+
if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
1076+
min_idx = vidx8[j];
1077+
min_dis = vdis8[j];
1078+
}
1079+
}
1080+
1081+
// process last values. Vectorizing is doable, but is not practical
1082+
for (; iii < k; iii++) {
1083+
if (ids[iii] != -1 && dis[iii] <= min_dis) {
1084+
min_dis = dis[iii];
1085+
min_idx = iii;
1086+
}
1087+
}
1088+
1089+
if (min_idx == -1) {
1090+
return -1;
1091+
}
1092+
1093+
if (vmin_out)
1094+
*vmin_out = min_dis;
1095+
int ret = ids[min_idx];
1096+
ids[min_idx] = -1;
1097+
--nvalid;
1098+
return ret;
1099+
}
1100+
1101+
#else
1102+
1103+
// baseline non-vectorized version
10131104
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
10141105
assert(k > 0);
10151106
// returns min. This is an O(n) operation
@@ -1039,6 +1130,7 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
10391130

10401131
return ret;
10411132
}
1133+
#endif
10421134

10431135
int HNSW::MinimaxHeap::count_below(float thresh) {
10441136
int n_below = 0;

tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(FAISS_TEST_SRC
2828
test_distances_simd.cpp
2929
test_heap.cpp
3030
test_code_distance.cpp
31+
test_hnsw.cpp
3132
)
3233

3334
add_executable(faiss_test ${FAISS_TEST_SRC})

tests/test_hnsw.cpp

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <gtest/gtest.h>
9+
10+
#include <cstddef>
11+
#include <cstdint>
12+
#include <limits>
13+
#include <random>
14+
#include <unordered_set>
15+
#include <vector>
16+
17+
#include <faiss/impl/HNSW.h>
18+
19+
int reference_pop_min(faiss::HNSW::MinimaxHeap& heap, float* vmin_out) {
20+
assert(heap.k > 0);
21+
// returns min. This is an O(n) operation
22+
int i = heap.k - 1;
23+
while (i >= 0) {
24+
if (heap.ids[i] != -1)
25+
break;
26+
i--;
27+
}
28+
if (i == -1)
29+
return -1;
30+
int imin = i;
31+
float vmin = heap.dis[i];
32+
i--;
33+
while (i >= 0) {
34+
if (heap.ids[i] != -1 && heap.dis[i] < vmin) {
35+
vmin = heap.dis[i];
36+
imin = i;
37+
}
38+
i--;
39+
}
40+
if (vmin_out)
41+
*vmin_out = vmin;
42+
int ret = heap.ids[imin];
43+
heap.ids[imin] = -1;
44+
--heap.nvalid;
45+
46+
return ret;
47+
}
48+
49+
void test_popmin(int heap_size, int amount_to_put) {
50+
// create a heap
51+
faiss::HNSW::MinimaxHeap mm_heap(heap_size);
52+
53+
using storage_idx_t = faiss::HNSW::storage_idx_t;
54+
55+
std::default_random_engine rng(123 + heap_size * amount_to_put);
56+
std::uniform_int_distribution<storage_idx_t> u(0, 65536);
57+
std::uniform_real_distribution<float> uf(0, 1);
58+
59+
// generate random unique indices
60+
std::unordered_set<storage_idx_t> indices;
61+
while (indices.size() < amount_to_put) {
62+
const storage_idx_t index = u(rng);
63+
indices.insert(index);
64+
}
65+
66+
// put ones into the heap
67+
for (const auto index : indices) {
68+
mm_heap.push(index, uf(rng));
69+
}
70+
71+
// clone the heap
72+
faiss::HNSW::MinimaxHeap cloned_mm_heap = mm_heap;
73+
74+
// takes ones out one by one
75+
while (mm_heap.size() > 0) {
76+
// compare heaps
77+
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
78+
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
79+
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
80+
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
81+
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
82+
83+
// use the reference pop_min for the cloned heap
84+
float cloned_vmin_dis = std::numeric_limits<float>::quiet_NaN();
85+
storage_idx_t cloned_vmin_idx =
86+
reference_pop_min(cloned_mm_heap, &cloned_vmin_dis);
87+
88+
float vmin_dis = std::numeric_limits<float>::quiet_NaN();
89+
storage_idx_t vmin_idx = mm_heap.pop_min(&vmin_dis);
90+
91+
// compare returns
92+
ASSERT_EQ(vmin_dis, cloned_vmin_dis);
93+
ASSERT_EQ(vmin_idx, cloned_vmin_idx);
94+
}
95+
96+
// compare heaps again
97+
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
98+
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
99+
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
100+
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
101+
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
102+
}
103+
104+
TEST(HNSW, Test_popmin) {
105+
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32};
106+
for (const size_t size : sizes) {
107+
for (size_t amount = size; amount > 0; amount /= 2) {
108+
test_popmin(size, amount);
109+
}
110+
}
111+
}

0 commit comments

Comments
 (0)