Skip to content

Commit 5d6abd6

Browse files
Some small improvements.
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
1 parent f2361a4 commit 5d6abd6

File tree

8 files changed

+517
-126
lines changed

8 files changed

+517
-126
lines changed

faiss/IndexHNSW.cpp

+20-9
Original file line numberDiff line numberDiff line change
@@ -275,23 +275,23 @@ void hnsw_search(
275275
FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
276276
efSearch = params->efSearch;
277277
}
278-
size_t n1 = 0, n2 = 0, ndis = 0;
278+
size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
279279

280280
idx_t check_period = InterruptCallback::get_period_hint(
281281
hnsw.max_level * index->d * efSearch);
282282

283283
for (idx_t i0 = 0; i0 < n; i0 += check_period) {
284284
idx_t i1 = std::min(i0 + check_period, n);
285285

286-
#pragma omp parallel
286+
#pragma omp parallel if (i1 - i0 > 1)
287287
{
288288
VisitedTable vt(index->ntotal);
289289
typename BlockResultHandler::SingleResultHandler res(bres);
290290

291291
std::unique_ptr<DistanceComputer> dis(
292292
storage_distance_computer(index->storage));
293293

294-
#pragma omp for reduction(+ : n1, n2, ndis) schedule(guided)
294+
#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
295295
for (idx_t i = i0; i < i1; i++) {
296296
res.begin(i);
297297
dis->set_query(x + i * index->d);
@@ -300,13 +300,14 @@ void hnsw_search(
300300
n1 += stats.n1;
301301
n2 += stats.n2;
302302
ndis += stats.ndis;
303+
nhops += stats.nhops;
303304
res.end();
304305
}
305306
}
306307
InterruptCallback::check();
307308
}
308309

309-
hnsw_stats.combine({n1, n2, ndis});
310+
hnsw_stats.combine({n1, n2, ndis, nhops});
310311
}
311312

312313
} // anonymous namespace
@@ -632,6 +633,10 @@ void IndexHNSW::permute_entries(const idx_t* perm) {
632633
hnsw.permute_entries(perm);
633634
}
634635

636+
DistanceComputer* IndexHNSW::get_distance_computer() const {
637+
return storage->get_distance_computer();
638+
}
639+
635640
/**************************************************************
636641
* IndexHNSWFlat implementation
637642
**************************************************************/
@@ -655,8 +660,13 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
655660

656661
IndexHNSWPQ::IndexHNSWPQ() = default;
657662

658-
IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits)
659-
: IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) {
663+
IndexHNSWPQ::IndexHNSWPQ(
664+
int d,
665+
int pq_m,
666+
int M,
667+
int pq_nbits,
668+
MetricType metric)
669+
: IndexHNSW(new IndexPQ(d, pq_m, pq_nbits, metric), M) {
660670
own_fields = true;
661671
is_trained = false;
662672
}
@@ -782,7 +792,7 @@ void IndexHNSW2Level::search(
782792
IndexHNSW::search(n, x, k, distances, labels);
783793

784794
} else { // "mixed" search
785-
size_t n1 = 0, n2 = 0, ndis = 0;
795+
size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
786796

787797
const IndexIVFPQ* index_ivfpq =
788798
dynamic_cast<const IndexIVFPQ*>(storage);
@@ -814,7 +824,7 @@ void IndexHNSW2Level::search(
814824
int candidates_size = hnsw.upper_beam;
815825
MinimaxHeap candidates(candidates_size);
816826

817-
#pragma omp for reduction(+ : n1, n2, ndis)
827+
#pragma omp for reduction(+ : n1, n2, ndis, nhops)
818828
for (idx_t i = 0; i < n; i++) {
819829
idx_t* idxi = labels + i * k;
820830
float* simi = distances + i * k;
@@ -860,6 +870,7 @@ void IndexHNSW2Level::search(
860870
n1 += search_stats.n1;
861871
n2 += search_stats.n2;
862872
ndis += search_stats.ndis;
873+
nhops += search_stats.nhops;
863874

864875
vt.advance();
865876
vt.advance();
@@ -868,7 +879,7 @@ void IndexHNSW2Level::search(
868879
}
869880
}
870881

871-
hnsw_stats.combine({n1, n2, ndis});
882+
hnsw_stats.combine({n1, n2, ndis, nhops});
872883
}
873884
}
874885

faiss/IndexHNSW.h

+9-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct IndexHNSW;
2727
struct IndexHNSW : Index {
2828
typedef HNSW::storage_idx_t storage_idx_t;
2929

30-
// the link strcuture
30+
// the link structure
3131
HNSW hnsw;
3232

3333
// the sequential storage
@@ -111,6 +111,8 @@ struct IndexHNSW : Index {
111111
void link_singletons();
112112

113113
void permute_entries(const idx_t* perm);
114+
115+
DistanceComputer* get_distance_computer() const override;
114116
};
115117

116118
/** Flat index topped with with a HNSW structure to access elements
@@ -127,7 +129,12 @@ struct IndexHNSWFlat : IndexHNSW {
127129
*/
128130
struct IndexHNSWPQ : IndexHNSW {
129131
IndexHNSWPQ();
130-
IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8);
132+
IndexHNSWPQ(
133+
int d,
134+
int pq_m,
135+
int M,
136+
int pq_nbits = 8,
137+
MetricType metric = METRIC_L2);
131138
void train(idx_t n, const float* x) override;
132139
};
133140

faiss/IndexRefine.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ template <class C>
6868
static void reorder_2_heaps(
6969
idx_t n,
7070
idx_t k,
71-
idx_t* labels,
72-
float* distances,
71+
idx_t* __restrict labels,
72+
float* __restrict distances,
7373
idx_t k_base,
74-
const idx_t* base_labels,
75-
const float* base_distances) {
76-
#pragma omp parallel for
74+
const idx_t* __restrict base_labels,
75+
const float* __restrict base_distances) {
76+
#pragma omp parallel for if (n > 1)
7777
for (idx_t i = 0; i < n; i++) {
7878
idx_t* idxo = labels + i * k;
7979
float* diso = distances + i * k;

faiss/impl/HNSW.cpp

+125-5
Original file line numberDiff line numberDiff line change
@@ -409,18 +409,22 @@ void search_neighbors_to_add(
409409
**************************************************************/
410410

411411
/// greedily update a nearest vector at a given level
412-
void greedy_update_nearest(
412+
HNSWStats greedy_update_nearest(
413413
const HNSW& hnsw,
414414
DistanceComputer& qdis,
415415
int level,
416416
storage_idx_t& nearest,
417417
float& d_nearest) {
418+
HNSWStats stats;
419+
418420
for (;;) {
419421
storage_idx_t prev_nearest = nearest;
420422

421423
size_t begin, end;
422424
hnsw.neighbor_range(nearest, level, &begin, &end);
423-
for (size_t i = begin; i < end; i++) {
425+
426+
size_t ndis = 0;
427+
for (size_t i = begin; i < end; i++, ndis++) {
424428
storage_idx_t v = hnsw.neighbors[i];
425429
if (v < 0)
426430
break;
@@ -430,8 +434,13 @@ void greedy_update_nearest(
430434
d_nearest = dis;
431435
}
432436
}
437+
438+
// update stats
439+
stats.ndis += ndis;
440+
stats.nhops += 1;
441+
433442
if (nearest == prev_nearest) {
434-
return;
443+
return stats;
435444
}
436445
}
437446
}
@@ -641,6 +650,7 @@ int search_from_candidates(
641650
if (dis < threshold) {
642651
if (res.add_result(dis, idx)) {
643652
threshold = res.threshold;
653+
nres += 1;
644654
}
645655
}
646656
}
@@ -692,6 +702,7 @@ int search_from_candidates(
692702
stats.n2++;
693703
}
694704
stats.ndis += ndis;
705+
stats.nhops += nstep;
695706
}
696707

697708
return nres;
@@ -814,6 +825,8 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
814825
float dis = qdis(saved_j[icnt]);
815826
add_to_heap(saved_j[icnt], dis);
816827
}
828+
829+
stats.nhops += 1;
817830
}
818831

819832
++stats.n1;
@@ -853,7 +866,9 @@ HNSWStats HNSW::search(
853866
float d_nearest = qdis(nearest);
854867

855868
for (int level = max_level; level >= 1; level--) {
856-
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
869+
HNSWStats local_stats = greedy_update_nearest(
870+
*this, qdis, level, nearest, d_nearest);
871+
stats.combine(local_stats);
857872
}
858873

859874
int ef = std::max(params ? params->efSearch : efSearch, k);
@@ -916,11 +931,23 @@ HNSWStats HNSW::search(
916931
if (level == 0) {
917932
nres = search_from_candidates(
918933
*this, qdis, res, candidates, vt, stats, 0);
934+
nres = std::min(nres, candidates_size);
919935
} else {
936+
const auto nres_prev = nres;
937+
920938
resh.begin(0);
921939
nres = search_from_candidates(
922940
*this, qdis, resh, candidates, vt, stats, level);
941+
nres = std::min(nres, candidates_size);
923942
resh.end();
943+
944+
// if the search on a particular level produces no improvements,
945+
// then we need to repopulate candidates.
946+
// search_from_candidates() will always damage candidates
947+
// by doing 1 pop_min().
948+
if (nres == 0) {
949+
nres = nres_prev;
950+
}
924951
}
925952
vt.advance();
926953
}
@@ -970,6 +997,7 @@ void HNSW::search_level_0(
970997
0,
971998
nres,
972999
params);
1000+
nres = std::min(nres, candidates_size);
9731001
}
9741002
} else if (search_type == 2) {
9751003
int candidates_size = std::max(efSearch, int(k));
@@ -1051,7 +1079,99 @@ void HNSW::MinimaxHeap::clear() {
10511079
nvalid = k = 0;
10521080
}
10531081

1054-
#ifdef __AVX2__
1082+
#ifdef __AVX512F__
1083+
1084+
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
1085+
assert(k > 0);
1086+
static_assert(
1087+
std::is_same<storage_idx_t, int32_t>::value,
1088+
"This code expects storage_idx_t to be int32_t");
1089+
1090+
int32_t min_idx = -1;
1091+
float min_dis = std::numeric_limits<float>::infinity();
1092+
1093+
__m512i min_indices = _mm512_set1_epi32(-1);
1094+
__m512 min_distances =
1095+
_mm512_set1_ps(std::numeric_limits<float>::infinity());
1096+
__m512i current_indices = _mm512_setr_epi32(
1097+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
1098+
__m512i offset = _mm512_set1_epi32(16);
1099+
1100+
// The following loop tracks the rightmost index with the min distance.
1101+
// -1 index values are ignored.
1102+
const int k16 = (k / 16) * 16;
1103+
for (size_t iii = 0; iii < k16; iii += 16) {
1104+
__m512i indices =
1105+
_mm512_loadu_si512((const __m512i*)(ids.data() + iii));
1106+
__m512 distances = _mm512_loadu_ps(dis.data() + iii);
1107+
1108+
// This mask filters out -1 values among indices.
1109+
__mmask16 m1mask =
1110+
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1111+
1112+
__mmask16 dmask =
1113+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1114+
__mmask16 finalmask = m1mask | dmask;
1115+
1116+
const __m512i min_indices_new = _mm512_mask_blend_epi32(
1117+
finalmask, current_indices, min_indices);
1118+
const __m512 min_distances_new =
1119+
_mm512_mask_blend_ps(finalmask, distances, min_distances);
1120+
1121+
min_indices = min_indices_new;
1122+
min_distances = min_distances_new;
1123+
1124+
current_indices = _mm512_add_epi32(current_indices, offset);
1125+
}
1126+
1127+
// leftovers
1128+
if (k16 != k) {
1129+
const __mmask16 kmask = (1 << (k - k16)) - 1;
1130+
1131+
__m512i indices = _mm512_mask_loadu_epi32(
1132+
_mm512_set1_epi32(-1), kmask, ids.data() + k16);
1133+
__m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16);
1134+
1135+
// This mask filters out -1 values among indices.
1136+
__mmask16 m1mask =
1137+
_mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices);
1138+
1139+
__mmask16 dmask =
1140+
_mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS);
1141+
__mmask16 finalmask = m1mask | dmask;
1142+
1143+
const __m512i min_indices_new = _mm512_mask_blend_epi32(
1144+
finalmask, current_indices, min_indices);
1145+
const __m512 min_distances_new =
1146+
_mm512_mask_blend_ps(finalmask, distances, min_distances);
1147+
1148+
min_indices = min_indices_new;
1149+
min_distances = min_distances_new;
1150+
}
1151+
1152+
// grab min distance
1153+
min_dis = _mm512_reduce_min_ps(min_distances);
1154+
// blend
1155+
__mmask16 mindmask =
1156+
_mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis));
1157+
// pick the max one
1158+
min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices);
1159+
1160+
if (min_idx == -1) {
1161+
return -1;
1162+
}
1163+
1164+
if (vmin_out) {
1165+
*vmin_out = min_dis;
1166+
}
1167+
int ret = ids[min_idx];
1168+
ids[min_idx] = -1;
1169+
--nvalid;
1170+
return ret;
1171+
}
1172+
1173+
#elif __AVX2__
1174+
10551175
int HNSW::MinimaxHeap::pop_min(float* vmin_out) {
10561176
assert(k > 0);
10571177
static_assert(

faiss/impl/HNSW.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -234,20 +234,23 @@ struct HNSW {
234234
};
235235

236236
struct HNSWStats {
237-
size_t n1 = 0; /// numbner of vectors searched
237+
size_t n1 = 0; /// number of vectors searched
238238
size_t n2 =
239-
0; /// number of queries for which the candidate list is exhasted
240-
size_t ndis = 0; /// number of distances computed
239+
0; /// number of queries for which the candidate list is exhausted
240+
size_t ndis = 0; /// number of distances computed
241+
size_t nhops = 0; /// number of hops aka number of edges traversed
241242

242243
void reset() {
243244
n1 = n2 = 0;
244245
ndis = 0;
246+
nhops = 0;
245247
}
246248

247249
void combine(const HNSWStats& other) {
248250
n1 += other.n1;
249251
n2 += other.n2;
250252
ndis += other.ndis;
253+
nhops += other.nhops;
251254
}
252255
};
253256

0 commit comments

Comments
 (0)