Skip to content

Commit b8c3a97

Browse files
mengdilinfacebook-github-bot
authored andcommitted
add hnsw unit test for PR 3840
Summary: #3845 Add unit tests for helper search utilities for HNSW. These utility functions live inside an anonymous namespace and each has a reference version gated behind a const bool, I refactored them so the reference version is a flag for the function which defaults to false. If we are concerned about the performance overhead of the extra if branching (whether to use reference version or not) inside these utility functions, I'm happy to lift out the reference versions to their own functions inside the unit test Differential Revision: D62510014
1 parent d85fda7 commit b8c3a97

File tree

3 files changed

+298
-120
lines changed

3 files changed

+298
-120
lines changed

faiss/impl/HNSW.cpp

+114-118
Original file line numberDiff line numberDiff line change
@@ -470,105 +470,6 @@ void search_neighbors_to_add(
470470
vt.advance();
471471
}
472472

473-
/**************************************************************
474-
* Searching subroutines
475-
**************************************************************/
476-
477-
/// greedily update a nearest vector at a given level
478-
HNSWStats greedy_update_nearest(
479-
const HNSW& hnsw,
480-
DistanceComputer& qdis,
481-
int level,
482-
storage_idx_t& nearest,
483-
float& d_nearest) {
484-
// selects a version
485-
const bool reference_version = false;
486-
487-
HNSWStats stats;
488-
489-
for (;;) {
490-
storage_idx_t prev_nearest = nearest;
491-
492-
size_t begin, end;
493-
hnsw.neighbor_range(nearest, level, &begin, &end);
494-
495-
size_t ndis = 0;
496-
497-
// select a version, based on a flag
498-
if (reference_version) {
499-
// a reference version
500-
for (size_t i = begin; i < end; i++) {
501-
storage_idx_t v = hnsw.neighbors[i];
502-
if (v < 0)
503-
break;
504-
ndis += 1;
505-
float dis = qdis(v);
506-
if (dis < d_nearest) {
507-
nearest = v;
508-
d_nearest = dis;
509-
}
510-
}
511-
} else {
512-
// a faster version
513-
514-
// the following version processes 4 neighbors at a time
515-
auto update_with_candidate = [&](const storage_idx_t idx,
516-
const float dis) {
517-
if (dis < d_nearest) {
518-
nearest = idx;
519-
d_nearest = dis;
520-
}
521-
};
522-
523-
int n_buffered = 0;
524-
storage_idx_t buffered_ids[4];
525-
526-
for (size_t j = begin; j < end; j++) {
527-
storage_idx_t v = hnsw.neighbors[j];
528-
if (v < 0)
529-
break;
530-
ndis += 1;
531-
532-
buffered_ids[n_buffered] = v;
533-
n_buffered += 1;
534-
535-
if (n_buffered == 4) {
536-
float dis[4];
537-
qdis.distances_batch_4(
538-
buffered_ids[0],
539-
buffered_ids[1],
540-
buffered_ids[2],
541-
buffered_ids[3],
542-
dis[0],
543-
dis[1],
544-
dis[2],
545-
dis[3]);
546-
547-
for (size_t id4 = 0; id4 < 4; id4++) {
548-
update_with_candidate(buffered_ids[id4], dis[id4]);
549-
}
550-
551-
n_buffered = 0;
552-
}
553-
}
554-
555-
// process leftovers
556-
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
557-
float dis = qdis(buffered_ids[icnt]);
558-
update_with_candidate(buffered_ids[icnt], dis);
559-
}
560-
}
561-
562-
// update stats
563-
stats.ndis += ndis;
564-
stats.nhops += 1;
565-
566-
if (nearest == prev_nearest) {
567-
return stats;
568-
}
569-
}
570-
}
571-
572473
} // namespace
573474

574475
/// Finds neighbors and builds links with them, starting from an entry
@@ -644,7 +545,8 @@ void HNSW::add_with_locks(
644545
float d_nearest = ptdis(nearest);
645546

646547
for (; level > pt_level; level--) {
647-
greedy_update_nearest(*this, ptdis, level, nearest, d_nearest);
548+
hnsw_utils::greedy_update_nearest(
549+
*this, ptdis, level, nearest, d_nearest);
648550
}
649551

650552
for (; level >= 0; level--) {
@@ -667,16 +569,16 @@ void HNSW::add_with_locks(
667569
}
668570
}
669571

572+
namespace hnsw_utils {
573+
670574
/**************************************************************
671575
* Searching
672576
**************************************************************/
673577

674-
namespace {
675578
using MinimaxHeap = HNSW::MinimaxHeap;
676579
using Node = HNSW::Node;
677580
using C = HNSW::C;
678581
/** Do a BFS on the candidates list */
679-
680582
int search_from_candidates(
681583
const HNSW& hnsw,
682584
DistanceComputer& qdis,
@@ -685,11 +587,9 @@ int search_from_candidates(
685587
VisitedTable& vt,
686588
HNSWStats& stats,
687589
int level,
688-
int nres_in = 0,
689-
const SearchParametersHNSW* params = nullptr) {
690-
// selects a version
691-
const bool reference_version = false;
692-
590+
int nres_in,
591+
const SearchParametersHNSW* params,
592+
const bool reference_version) {
693593
int nres = nres_in;
694594
int ndis = 0;
695595

@@ -851,10 +751,8 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
851751
DistanceComputer& qdis,
852752
int ef,
853753
VisitedTable* vt,
854-
HNSWStats& stats) {
855-
// selects a version
856-
const bool reference_version = false;
857-
754+
HNSWStats& stats,
755+
const bool reference_version) {
858756
int ndis = 0;
859757
std::priority_queue<Node> top_candidates;
860758
std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
@@ -984,6 +882,104 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
984882
return top_candidates;
985883
}
986884

885+
/// greedily update a nearest vector at a given level
886+
HNSWStats greedy_update_nearest(
887+
const HNSW& hnsw,
888+
DistanceComputer& qdis,
889+
int level,
890+
storage_idx_t& nearest,
891+
float& d_nearest,
892+
const bool reference_version) {
893+
HNSWStats stats;
894+
895+
for (;;) {
896+
storage_idx_t prev_nearest = nearest;
897+
898+
size_t begin, end;
899+
hnsw.neighbor_range(nearest, level, &begin, &end);
900+
901+
size_t ndis = 0;
902+
903+
// select a version, based on a flag
904+
if (reference_version) {
905+
// a reference version
906+
for (size_t i = begin; i < end; i++) {
907+
storage_idx_t v = hnsw.neighbors[i];
908+
if (v < 0)
909+
break;
910+
ndis += 1;
911+
float dis = qdis(v);
912+
if (dis < d_nearest) {
913+
nearest = v;
914+
d_nearest = dis;
915+
}
916+
}
917+
} else {
918+
// a faster version
919+
920+
// the following version processes 4 neighbors at a time
921+
auto update_with_candidate = [&](const storage_idx_t idx,
922+
const float dis) {
923+
if (dis < d_nearest) {
924+
nearest = idx;
925+
d_nearest = dis;
926+
}
927+
};
928+
929+
int n_buffered = 0;
930+
storage_idx_t buffered_ids[4];
931+
932+
for (size_t j = begin; j < end; j++) {
933+
storage_idx_t v = hnsw.neighbors[j];
934+
if (v < 0)
935+
break;
936+
ndis += 1;
937+
938+
buffered_ids[n_buffered] = v;
939+
n_buffered += 1;
940+
941+
if (n_buffered == 4) {
942+
float dis[4];
943+
qdis.distances_batch_4(
944+
buffered_ids[0],
945+
buffered_ids[1],
946+
buffered_ids[2],
947+
buffered_ids[3],
948+
dis[0],
949+
dis[1],
950+
dis[2],
951+
dis[3]);
952+
953+
for (size_t id4 = 0; id4 < 4; id4++) {
954+
update_with_candidate(buffered_ids[id4], dis[id4]);
955+
}
956+
957+
n_buffered = 0;
958+
}
959+
}
960+
961+
// process leftovers
962+
for (size_t icnt = 0; icnt < n_buffered; icnt++) {
963+
float dis = qdis(buffered_ids[icnt]);
964+
update_with_candidate(buffered_ids[icnt], dis);
965+
}
966+
}
967+
968+
// update stats
969+
stats.ndis += ndis;
970+
stats.nhops += 1;
971+
972+
if (nearest == prev_nearest) {
973+
return stats;
974+
}
975+
}
976+
}
977+
} // namespace hnsw_utils
978+
namespace {
979+
using MinimaxHeap = HNSW::MinimaxHeap;
980+
using Node = HNSW::Node;
981+
using C = HNSW::C;
982+
987983
// just used as a lower bound for the minmaxheap, but it is set for heap search
988984
int extract_k_from_ResultHandler(ResultHandler<C>& res) {
989985
using RH = HeapBlockResultHandler<C>;
@@ -993,7 +989,7 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
993989
return 1;
994990
}
995991

996-
} // anonymous namespace
992+
} // namespace
997993

998994
HNSWStats HNSW::search(
999995
DistanceComputer& qdis,
@@ -1014,8 +1010,8 @@ HNSWStats HNSW::search(
10141010
float d_nearest = qdis(nearest);
10151011

10161012
for (int level = max_level; level >= 1; level--) {
1017-
HNSWStats local_stats =
1018-
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
1013+
HNSWStats local_stats = hnsw_utils::greedy_update_nearest(
1014+
*this, qdis, level, nearest, d_nearest);
10191015
stats.combine(local_stats);
10201016
}
10211017

@@ -1025,11 +1021,11 @@ HNSWStats HNSW::search(
10251021

10261022
candidates.push(nearest, d_nearest);
10271023

1028-
search_from_candidates(
1024+
hnsw_utils::search_from_candidates(
10291025
*this, qdis, res, candidates, vt, stats, 0, 0, params);
10301026
} else {
10311027
std::priority_queue<Node> top_candidates =
1032-
search_from_candidate_unbounded(
1028+
hnsw_utils::search_from_candidate_unbounded(
10331029
*this, Node(d_nearest, nearest), qdis, ef, &vt, stats);
10341030

10351031
while (top_candidates.size() > k) {
@@ -1081,7 +1077,7 @@ void HNSW::search_level_0(
10811077

10821078
candidates.push(cj, nearest_d[j]);
10831079

1084-
nres = search_from_candidates(
1080+
nres = hnsw_utils::search_from_candidates(
10851081
hnsw,
10861082
qdis,
10871083
res,
@@ -1106,7 +1102,7 @@ void HNSW::search_level_0(
11061102
candidates.push(cj, nearest_d[j]);
11071103
}
11081104

1109-
search_from_candidates(
1105+
hnsw_utils::search_from_candidates(
11101106
hnsw, qdis, res, candidates, vt, search_stats, 0, 0, params);
11111107
}
11121108
}

faiss/impl/HNSW.h

+29-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ struct SearchParametersHNSW : SearchParameters {
4747
int efSearch = 16;
4848
bool check_relative_distance = true;
4949
bool bounded_queue = true;
50-
5150
~SearchParametersHNSW() {}
5251
};
5352

@@ -255,4 +254,33 @@ struct HNSWStats {
255254
// global var that collects them all
256255
FAISS_API extern HNSWStats hnsw_stats;
257256

257+
namespace hnsw_utils {
258+
int search_from_candidates(
259+
const HNSW& hnsw,
260+
DistanceComputer& qdis,
261+
ResultHandler<HNSW::C>& res,
262+
HNSW::MinimaxHeap& candidates,
263+
VisitedTable& vt,
264+
HNSWStats& stats,
265+
int level,
266+
int nres_in = 0,
267+
const SearchParametersHNSW* params = nullptr,
268+
const bool reference_version = false);
269+
270+
HNSWStats greedy_update_nearest(
271+
const HNSW& hnsw,
272+
DistanceComputer& qdis,
273+
int level,
274+
HNSW::storage_idx_t& nearest,
275+
float& d_nearest,
276+
const bool reference_version = false);
277+
std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
278+
const HNSW& hnsw,
279+
const HNSW::Node& node,
280+
DistanceComputer& qdis,
281+
int ef,
282+
VisitedTable* vt,
283+
HNSWStats& stats,
284+
const bool reference_version = false);
285+
} // namespace hnsw_utils
258286
} // namespace faiss

0 commit comments

Comments
 (0)