Skip to content

Commit e8b7575

Browse files
Alexandr Guzhvafacebook-github-bot
Alexandr Guzhva
authored andcommitted
AVX2 version of faiss::HNSW::MinimaxHeap::pop_min() (#2874)
Summary: Pull Request resolved: #2874 Reviewed By: mdouze Differential Revision: D46125506 fbshipit-source-id: 4099e5c95bfb168b2097a42f5308c4bea1f72ca8
1 parent 6800ebe commit e8b7575

File tree

3 files changed

+295
-3
lines changed

3 files changed

+295
-3
lines changed

faiss/impl/HNSW.cpp

+102-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@
1616
#include <faiss/impl/IDSelector.h>
1717
#include <faiss/utils/prefetch.h>
1818

19+
#include <faiss/impl/platform_macros.h>
20+
21+
#ifdef __AVX2__
22+
#include <immintrin.h>
23+
24+
#include <limits>
25+
#include <type_traits>
26+
#endif
27+
1928
namespace faiss {
2029

2130
/**************************************************************
@@ -1010,17 +1019,105 @@ void HNSW::MinimaxHeap::clear() {
10101019
nvalid = k = 0;
10111020
}
10121021

1022+
#ifdef __AVX2__
1023+
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1024+
assert(k > 0);
1025+
static_assert(
1026+
std::is_same<storage_idx_t, int32_t>::value,
1027+
"This code expects storage_idx_t to be int32_t");
1028+
1029+
int32_t min_idx = -1;
1030+
float min_dis = std::numeric_limits<float>::infinity();
1031+
1032+
size_t iii = 0;
1033+
1034+
__m256i min_indices = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
1035+
__m256 min_distances =
1036+
_mm256_set1_ps(std::numeric_limits<float>::infinity());
1037+
__m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
1038+
__m256i offset = _mm256_set1_epi32(8);
1039+
1040+
// The baseline version is available in non-AVX2 branch.
1041+
1042+
// The following loop tracks the rightmost index with the min distance.
1043+
// -1 index values are ignored.
1044+
const int k8 = (k / 8) * 8;
1045+
for (; iii < k8; iii += 8) {
1046+
__m256i indices =
1047+
_mm256_loadu_si256((const __m256i*)(ids.data() + iii));
1048+
__m256 distances = _mm256_loadu_ps(dis.data() + iii);
1049+
1050+
// This mask filters out -1 values among indices.
1051+
__m256i m1mask = _mm256_cmpgt_epi32(_mm256_setzero_si256(), indices);
1052+
1053+
__m256i dmask = _mm256_castps_si256(
1054+
_mm256_cmp_ps(min_distances, distances, _CMP_LT_OS));
1055+
__m256 finalmask = _mm256_castsi256_ps(_mm256_or_si256(m1mask, dmask));
1056+
1057+
const __m256i min_indices_new = _mm256_castps_si256(_mm256_blendv_ps(
1058+
_mm256_castsi256_ps(current_indices),
1059+
_mm256_castsi256_ps(min_indices),
1060+
finalmask));
1061+
1062+
const __m256 min_distances_new =
1063+
_mm256_blendv_ps(distances, min_distances, finalmask);
1064+
1065+
min_indices = min_indices_new;
1066+
min_distances = min_distances_new;
1067+
1068+
current_indices = _mm256_add_epi32(current_indices, offset);
1069+
}
1070+
1071+
// Vectorizing is doable, but is not practical
1072+
int32_t vidx8[8];
1073+
float vdis8[8];
1074+
_mm256_storeu_ps(vdis8, min_distances);
1075+
_mm256_storeu_si256((__m256i*)vidx8, min_indices);
1076+
1077+
for (size_t j = 0; j < 8; j++) {
1078+
if (min_dis > vdis8[j] || (min_dis == vdis8[j] && min_idx < vidx8[j])) {
1079+
min_idx = vidx8[j];
1080+
min_dis = vdis8[j];
1081+
}
1082+
}
1083+
1084+
// process last values. Vectorizing is doable, but is not practical
1085+
for (; iii < k; iii++) {
1086+
if (ids[iii] != -1 && dis[iii] <= min_dis) {
1087+
min_dis = dis[iii];
1088+
min_idx = iii;
1089+
}
1090+
}
1091+
1092+
if (min_idx == -1) {
1093+
return -1;
1094+
}
1095+
1096+
if (vmin_out) {
1097+
*vmin_out = min_dis;
1098+
}
1099+
int ret = ids[min_idx];
1100+
ids[min_idx] = -1;
1101+
--nvalid;
1102+
return ret;
1103+
}
1104+
1105+
#else
1106+
1107+
// baseline non-vectorized version
10131108
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
10141109
assert(k > 0);
10151110
// returns min. This is an O(n) operation
10161111
int i = k - 1;
10171112
while (i >= 0) {
1018-
if (ids[i] != -1)
1113+
if (ids[i] != -1) {
10191114
break;
1115+
}
10201116
i--;
10211117
}
1022-
if (i == -1)
1118+
if (i == -1) {
10231119
return -1;
1120+
}
10241121
int imin = i;
10251122
float vmin = dis[i];
10261123
i--;
@@ -1031,14 +1128,16 @@ int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
10311128
}
10321129
i--;
10331130
}
1034-
if (vmin_out)
1131+
if (vmin_out) {
10351132
*vmin_out = vmin;
1133+
}
10361134
int ret = ids[imin];
10371135
ids[imin] = -1;
10381136
--nvalid;
10391137

10401138
return ret;
10411139
}
1140+
#endif
10421141

10431142
int HNSW::MinimaxHeap::count_below(float thresh) {
10441143
int n_below = 0;

tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ set(FAISS_TEST_SRC
2828
test_distances_simd.cpp
2929
test_heap.cpp
3030
test_code_distance.cpp
31+
test_hnsw.cpp
3132
)
3233

3334
add_executable(faiss_test ${FAISS_TEST_SRC})

tests/test_hnsw.cpp

+192
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <gtest/gtest.h>
9+
10+
#include <cstddef>
11+
#include <cstdint>
12+
#include <limits>
13+
#include <random>
14+
#include <unordered_set>
15+
#include <vector>
16+
17+
#include <faiss/impl/HNSW.h>
18+
19+
int reference_pop_min(faiss::HNSW::MinimaxHeap& heap, float* vmin_out) {
20+
assert(heap.k > 0);
21+
// returns min. This is an O(n) operation
22+
int i = heap.k - 1;
23+
while (i >= 0) {
24+
if (heap.ids[i] != -1)
25+
break;
26+
i--;
27+
}
28+
if (i == -1)
29+
return -1;
30+
int imin = i;
31+
float vmin = heap.dis[i];
32+
i--;
33+
while (i >= 0) {
34+
if (heap.ids[i] != -1 && heap.dis[i] < vmin) {
35+
vmin = heap.dis[i];
36+
imin = i;
37+
}
38+
i--;
39+
}
40+
if (vmin_out)
41+
*vmin_out = vmin;
42+
int ret = heap.ids[imin];
43+
heap.ids[imin] = -1;
44+
--heap.nvalid;
45+
46+
return ret;
47+
}
48+
49+
void test_popmin(int heap_size, int amount_to_put) {
50+
// create a heap
51+
faiss::HNSW::MinimaxHeap mm_heap(heap_size);
52+
53+
using storage_idx_t = faiss::HNSW::storage_idx_t;
54+
55+
std::default_random_engine rng(123 + heap_size * amount_to_put);
56+
std::uniform_int_distribution<storage_idx_t> u(0, 65536);
57+
std::uniform_real_distribution<float> uf(0, 1);
58+
59+
// generate random unique indices
60+
std::unordered_set<storage_idx_t> indices;
61+
while (indices.size() < amount_to_put) {
62+
const storage_idx_t index = u(rng);
63+
indices.insert(index);
64+
}
65+
66+
// put ones into the heap
67+
for (const auto index : indices) {
68+
float distance = uf(rng);
69+
if (distance >= 0.7f) {
70+
// add infinity values from time to time
71+
distance = std::numeric_limits<float>::infinity();
72+
}
73+
mm_heap.push(index, distance);
74+
}
75+
76+
// clone the heap
77+
faiss::HNSW::MinimaxHeap cloned_mm_heap = mm_heap;
78+
79+
// takes ones out one by one
80+
while (mm_heap.size() > 0) {
81+
// compare heaps
82+
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
83+
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
84+
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
85+
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
86+
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
87+
88+
// use the reference pop_min for the cloned heap
89+
float cloned_vmin_dis = std::numeric_limits<float>::quiet_NaN();
90+
storage_idx_t cloned_vmin_idx =
91+
reference_pop_min(cloned_mm_heap, &cloned_vmin_dis);
92+
93+
float vmin_dis = std::numeric_limits<float>::quiet_NaN();
94+
storage_idx_t vmin_idx = mm_heap.pop_min(&vmin_dis);
95+
96+
// compare returns
97+
ASSERT_EQ(vmin_dis, cloned_vmin_dis);
98+
ASSERT_EQ(vmin_idx, cloned_vmin_idx);
99+
}
100+
101+
// compare heaps again
102+
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
103+
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
104+
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
105+
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
106+
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
107+
}
108+
109+
void test_popmin_identical_distances(
110+
int heap_size,
111+
int amount_to_put,
112+
const float distance) {
113+
// create a heap
114+
faiss::HNSW::MinimaxHeap mm_heap(heap_size);
115+
116+
using storage_idx_t = faiss::HNSW::storage_idx_t;
117+
118+
std::default_random_engine rng(123 + heap_size * amount_to_put);
119+
std::uniform_int_distribution<storage_idx_t> u(0, 65536);
120+
121+
// generate random unique indices
122+
std::unordered_set<storage_idx_t> indices;
123+
while (indices.size() < amount_to_put) {
124+
const storage_idx_t index = u(rng);
125+
indices.insert(index);
126+
}
127+
128+
// put ones into the heap
129+
for (const auto index : indices) {
130+
mm_heap.push(index, distance);
131+
}
132+
133+
// clone the heap
134+
faiss::HNSW::MinimaxHeap cloned_mm_heap = mm_heap;
135+
136+
// takes ones out one by one
137+
while (mm_heap.size() > 0) {
138+
// compare heaps
139+
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
140+
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
141+
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
142+
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
143+
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
144+
145+
// use the reference pop_min for the cloned heap
146+
float cloned_vmin_dis = std::numeric_limits<float>::quiet_NaN();
147+
storage_idx_t cloned_vmin_idx =
148+
reference_pop_min(cloned_mm_heap, &cloned_vmin_dis);
149+
150+
float vmin_dis = std::numeric_limits<float>::quiet_NaN();
151+
storage_idx_t vmin_idx = mm_heap.pop_min(&vmin_dis);
152+
153+
// compare returns
154+
ASSERT_EQ(vmin_dis, cloned_vmin_dis);
155+
ASSERT_EQ(vmin_idx, cloned_vmin_idx);
156+
}
157+
158+
// compare heaps again
159+
ASSERT_EQ(mm_heap.n, cloned_mm_heap.n);
160+
ASSERT_EQ(mm_heap.k, cloned_mm_heap.k);
161+
ASSERT_EQ(mm_heap.nvalid, cloned_mm_heap.nvalid);
162+
ASSERT_EQ(mm_heap.ids, cloned_mm_heap.ids);
163+
ASSERT_EQ(mm_heap.dis, cloned_mm_heap.dis);
164+
}
165+
166+
TEST(HNSW, Test_popmin) {
167+
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32, 64, 128};
168+
for (const size_t size : sizes) {
169+
for (size_t amount = size; amount > 0; amount /= 2) {
170+
test_popmin(size, amount);
171+
}
172+
}
173+
}
174+
175+
TEST(HNSW, Test_popmin_identical_distances) {
176+
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32};
177+
for (const size_t size : sizes) {
178+
for (size_t amount = size; amount > 0; amount /= 2) {
179+
test_popmin_identical_distances(size, amount, 1.0f);
180+
}
181+
}
182+
}
183+
184+
TEST(HNSW, Test_popmin_infinite_distances) {
185+
std::vector<size_t> sizes = {1, 2, 3, 4, 5, 7, 9, 11, 16, 27, 32};
186+
for (const size_t size : sizes) {
187+
for (size_t amount = size; amount > 0; amount /= 2) {
188+
test_popmin_identical_distances(
189+
size, amount, std::numeric_limits<float>::infinity());
190+
}
191+
}
192+
}

0 commit comments

Comments
 (0)