@@ -470,105 +470,6 @@ void search_neighbors_to_add(
470
470
vt.advance ();
471
471
}
472
472
473
- /* *************************************************************
474
- * Searching subroutines
475
- **************************************************************/
476
-
477
- // / greedily update a nearest vector at a given level
478
- HNSWStats greedy_update_nearest (
479
- const HNSW& hnsw,
480
- DistanceComputer& qdis,
481
- int level,
482
- storage_idx_t & nearest,
483
- float & d_nearest) {
484
- // selects a version
485
- const bool reference_version = false ;
486
-
487
- HNSWStats stats;
488
-
489
- for (;;) {
490
- storage_idx_t prev_nearest = nearest;
491
-
492
- size_t begin, end;
493
- hnsw.neighbor_range (nearest, level, &begin, &end);
494
-
495
- size_t ndis = 0 ;
496
-
497
- // select a version, based on a flag
498
- if (reference_version) {
499
- // a reference version
500
- for (size_t i = begin; i < end; i++) {
501
- storage_idx_t v = hnsw.neighbors [i];
502
- if (v < 0 )
503
- break ;
504
- ndis += 1 ;
505
- float dis = qdis (v);
506
- if (dis < d_nearest) {
507
- nearest = v;
508
- d_nearest = dis;
509
- }
510
- }
511
- } else {
512
- // a faster version
513
-
514
- // the following version processes 4 neighbors at a time
515
- auto update_with_candidate = [&](const storage_idx_t idx,
516
- const float dis) {
517
- if (dis < d_nearest) {
518
- nearest = idx;
519
- d_nearest = dis;
520
- }
521
- };
522
-
523
- int n_buffered = 0 ;
524
- storage_idx_t buffered_ids[4 ];
525
-
526
- for (size_t j = begin; j < end; j++) {
527
- storage_idx_t v = hnsw.neighbors [j];
528
- if (v < 0 )
529
- break ;
530
- ndis += 1 ;
531
-
532
- buffered_ids[n_buffered] = v;
533
- n_buffered += 1 ;
534
-
535
- if (n_buffered == 4 ) {
536
- float dis[4 ];
537
- qdis.distances_batch_4 (
538
- buffered_ids[0 ],
539
- buffered_ids[1 ],
540
- buffered_ids[2 ],
541
- buffered_ids[3 ],
542
- dis[0 ],
543
- dis[1 ],
544
- dis[2 ],
545
- dis[3 ]);
546
-
547
- for (size_t id4 = 0 ; id4 < 4 ; id4++) {
548
- update_with_candidate (buffered_ids[id4], dis[id4]);
549
- }
550
-
551
- n_buffered = 0 ;
552
- }
553
- }
554
-
555
- // process leftovers
556
- for (size_t icnt = 0 ; icnt < n_buffered; icnt++) {
557
- float dis = qdis (buffered_ids[icnt]);
558
- update_with_candidate (buffered_ids[icnt], dis);
559
- }
560
- }
561
-
562
- // update stats
563
- stats.ndis += ndis;
564
- stats.nhops += 1 ;
565
-
566
- if (nearest == prev_nearest) {
567
- return stats;
568
- }
569
- }
570
- }
571
-
572
473
} // namespace
573
474
574
475
// / Finds neighbors and builds links with them, starting from an entry
@@ -644,7 +545,8 @@ void HNSW::add_with_locks(
644
545
float d_nearest = ptdis (nearest);
645
546
646
547
for (; level > pt_level; level--) {
647
- greedy_update_nearest (*this , ptdis, level, nearest, d_nearest);
548
+ hnsw_utils::greedy_update_nearest (
549
+ *this , ptdis, level, nearest, d_nearest);
648
550
}
649
551
650
552
for (; level >= 0 ; level--) {
@@ -667,16 +569,16 @@ void HNSW::add_with_locks(
667
569
}
668
570
}
669
571
572
+ namespace hnsw_utils {
573
+
670
574
/* *************************************************************
671
575
* Searching
672
576
**************************************************************/
673
577
674
- namespace {
675
578
using MinimaxHeap = HNSW::MinimaxHeap;
676
579
using Node = HNSW::Node;
677
580
using C = HNSW::C;
678
581
/* * Do a BFS on the candidates list */
679
-
680
582
int search_from_candidates (
681
583
const HNSW& hnsw,
682
584
DistanceComputer& qdis,
@@ -685,11 +587,9 @@ int search_from_candidates(
685
587
VisitedTable& vt,
686
588
HNSWStats& stats,
687
589
int level,
688
- int nres_in = 0 ,
689
- const SearchParametersHNSW* params = nullptr ) {
690
- // selects a version
691
- const bool reference_version = false ;
692
-
590
+ int nres_in,
591
+ const SearchParametersHNSW* params,
592
+ const bool reference_version) {
693
593
int nres = nres_in;
694
594
int ndis = 0 ;
695
595
@@ -851,10 +751,8 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
851
751
DistanceComputer& qdis,
852
752
int ef,
853
753
VisitedTable* vt,
854
- HNSWStats& stats) {
855
- // selects a version
856
- const bool reference_version = false ;
857
-
754
+ HNSWStats& stats,
755
+ const bool reference_version) {
858
756
int ndis = 0 ;
859
757
std::priority_queue<Node> top_candidates;
860
758
std::priority_queue<Node, std::vector<Node>, std::greater<Node>> candidates;
@@ -984,6 +882,104 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
984
882
return top_candidates;
985
883
}
986
884
885
+ // / greedily update a nearest vector at a given level
886
+ HNSWStats greedy_update_nearest (
887
+ const HNSW& hnsw,
888
+ DistanceComputer& qdis,
889
+ int level,
890
+ storage_idx_t & nearest,
891
+ float & d_nearest,
892
+ const bool reference_version) {
893
+ HNSWStats stats;
894
+
895
+ for (;;) {
896
+ storage_idx_t prev_nearest = nearest;
897
+
898
+ size_t begin, end;
899
+ hnsw.neighbor_range (nearest, level, &begin, &end);
900
+
901
+ size_t ndis = 0 ;
902
+
903
+ // select a version, based on a flag
904
+ if (reference_version) {
905
+ // a reference version
906
+ for (size_t i = begin; i < end; i++) {
907
+ storage_idx_t v = hnsw.neighbors [i];
908
+ if (v < 0 )
909
+ break ;
910
+ ndis += 1 ;
911
+ float dis = qdis (v);
912
+ if (dis < d_nearest) {
913
+ nearest = v;
914
+ d_nearest = dis;
915
+ }
916
+ }
917
+ } else {
918
+ // a faster version
919
+
920
+ // the following version processes 4 neighbors at a time
921
+ auto update_with_candidate = [&](const storage_idx_t idx,
922
+ const float dis) {
923
+ if (dis < d_nearest) {
924
+ nearest = idx;
925
+ d_nearest = dis;
926
+ }
927
+ };
928
+
929
+ int n_buffered = 0 ;
930
+ storage_idx_t buffered_ids[4 ];
931
+
932
+ for (size_t j = begin; j < end; j++) {
933
+ storage_idx_t v = hnsw.neighbors [j];
934
+ if (v < 0 )
935
+ break ;
936
+ ndis += 1 ;
937
+
938
+ buffered_ids[n_buffered] = v;
939
+ n_buffered += 1 ;
940
+
941
+ if (n_buffered == 4 ) {
942
+ float dis[4 ];
943
+ qdis.distances_batch_4 (
944
+ buffered_ids[0 ],
945
+ buffered_ids[1 ],
946
+ buffered_ids[2 ],
947
+ buffered_ids[3 ],
948
+ dis[0 ],
949
+ dis[1 ],
950
+ dis[2 ],
951
+ dis[3 ]);
952
+
953
+ for (size_t id4 = 0 ; id4 < 4 ; id4++) {
954
+ update_with_candidate (buffered_ids[id4], dis[id4]);
955
+ }
956
+
957
+ n_buffered = 0 ;
958
+ }
959
+ }
960
+
961
+ // process leftovers
962
+ for (size_t icnt = 0 ; icnt < n_buffered; icnt++) {
963
+ float dis = qdis (buffered_ids[icnt]);
964
+ update_with_candidate (buffered_ids[icnt], dis);
965
+ }
966
+ }
967
+
968
+ // update stats
969
+ stats.ndis += ndis;
970
+ stats.nhops += 1 ;
971
+
972
+ if (nearest == prev_nearest) {
973
+ return stats;
974
+ }
975
+ }
976
+ }
977
+ } // namespace hnsw_utils
978
+ namespace {
979
+ using MinimaxHeap = HNSW::MinimaxHeap;
980
+ using Node = HNSW::Node;
981
+ using C = HNSW::C;
982
+
987
983
// just used as a lower bound for the minmaxheap, but it is set for heap search
988
984
int extract_k_from_ResultHandler (ResultHandler<C>& res) {
989
985
using RH = HeapBlockResultHandler<C>;
@@ -993,7 +989,7 @@ int extract_k_from_ResultHandler(ResultHandler<C>& res) {
993
989
return 1 ;
994
990
}
995
991
996
- } // anonymous namespace
992
+ } // namespace
997
993
998
994
HNSWStats HNSW::search (
999
995
DistanceComputer& qdis,
@@ -1014,8 +1010,8 @@ HNSWStats HNSW::search(
1014
1010
float d_nearest = qdis (nearest);
1015
1011
1016
1012
for (int level = max_level; level >= 1 ; level--) {
1017
- HNSWStats local_stats =
1018
- greedy_update_nearest ( *this , qdis, level, nearest, d_nearest);
1013
+ HNSWStats local_stats = hnsw_utils::greedy_update_nearest (
1014
+ *this , qdis, level, nearest, d_nearest);
1019
1015
stats.combine (local_stats);
1020
1016
}
1021
1017
@@ -1025,11 +1021,11 @@ HNSWStats HNSW::search(
1025
1021
1026
1022
candidates.push (nearest, d_nearest);
1027
1023
1028
- search_from_candidates (
1024
+ hnsw_utils:: search_from_candidates (
1029
1025
*this , qdis, res, candidates, vt, stats, 0 , 0 , params);
1030
1026
} else {
1031
1027
std::priority_queue<Node> top_candidates =
1032
- search_from_candidate_unbounded (
1028
+ hnsw_utils:: search_from_candidate_unbounded (
1033
1029
*this , Node (d_nearest, nearest), qdis, ef, &vt, stats);
1034
1030
1035
1031
while (top_candidates.size () > k) {
@@ -1081,7 +1077,7 @@ void HNSW::search_level_0(
1081
1077
1082
1078
candidates.push (cj, nearest_d[j]);
1083
1079
1084
- nres = search_from_candidates (
1080
+ nres = hnsw_utils:: search_from_candidates (
1085
1081
hnsw,
1086
1082
qdis,
1087
1083
res,
@@ -1106,7 +1102,7 @@ void HNSW::search_level_0(
1106
1102
candidates.push (cj, nearest_d[j]);
1107
1103
}
1108
1104
1109
- search_from_candidates (
1105
+ hnsw_utils:: search_from_candidates (
1110
1106
hnsw, qdis, res, candidates, vt, search_stats, 0 , 0 , params);
1111
1107
}
1112
1108
}
0 commit comments