@@ -255,23 +255,23 @@ void hnsw_search(
255
255
FAISS_THROW_IF_NOT_MSG (params, " params type invalid" );
256
256
efSearch = params->efSearch ;
257
257
}
258
- size_t n1 = 0 , n2 = 0 , ndis = 0 ;
258
+ size_t n1 = 0 , n2 = 0 , ndis = 0 , nhops = 0 ;
259
259
260
260
idx_t check_period = InterruptCallback::get_period_hint (
261
261
hnsw.max_level * index ->d * efSearch);
262
262
263
263
for (idx_t i0 = 0 ; i0 < n; i0 += check_period) {
264
264
idx_t i1 = std::min (i0 + check_period, n);
265
265
266
- #pragma omp parallel
266
+ #pragma omp parallel if (i1 - i0 > 1)
267
267
{
268
268
VisitedTable vt (index ->ntotal );
269
269
typename BlockResultHandler::SingleResultHandler res (bres);
270
270
271
271
std::unique_ptr<DistanceComputer> dis (
272
272
storage_distance_computer (index ->storage ));
273
273
274
- #pragma omp for reduction(+ : n1, n2, ndis) schedule(guided)
274
+ #pragma omp for reduction(+ : n1, n2, ndis, nhops ) schedule(guided)
275
275
for (idx_t i = i0; i < i1; i++) {
276
276
res.begin (i);
277
277
dis->set_query (x + i * index ->d );
@@ -280,13 +280,14 @@ void hnsw_search(
280
280
n1 += stats.n1 ;
281
281
n2 += stats.n2 ;
282
282
ndis += stats.ndis ;
283
+ nhops += stats.nhops ;
283
284
res.end ();
284
285
}
285
286
}
286
287
InterruptCallback::check ();
287
288
}
288
289
289
- hnsw_stats.combine ({n1, n2, ndis});
290
+ hnsw_stats.combine ({n1, n2, ndis, nhops });
290
291
}
291
292
292
293
} // anonymous namespace
@@ -612,6 +613,10 @@ void IndexHNSW::permute_entries(const idx_t* perm) {
612
613
hnsw.permute_entries (perm);
613
614
}
614
615
616
+ DistanceComputer* IndexHNSW::get_distance_computer () const {
617
+ return storage->get_distance_computer ();
618
+ }
619
+
615
620
/* *************************************************************
616
621
* IndexHNSWFlat implementation
617
622
**************************************************************/
@@ -635,8 +640,13 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
635
640
636
641
IndexHNSWPQ::IndexHNSWPQ () = default ;
637
642
638
- IndexHNSWPQ::IndexHNSWPQ (int d, int pq_m, int M, int pq_nbits)
639
- : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) {
643
+ IndexHNSWPQ::IndexHNSWPQ (
644
+ int d,
645
+ int pq_m,
646
+ int M,
647
+ int pq_nbits,
648
+ MetricType metric)
649
+ : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits, metric), M) {
640
650
own_fields = true ;
641
651
is_trained = false ;
642
652
}
@@ -762,7 +772,7 @@ void IndexHNSW2Level::search(
762
772
IndexHNSW::search (n, x, k, distances, labels);
763
773
764
774
} else { // "mixed" search
765
- size_t n1 = 0 , n2 = 0 , ndis = 0 ;
775
+ size_t n1 = 0 , n2 = 0 , ndis = 0 , nhops = 0 ;
766
776
767
777
const IndexIVFPQ* index_ivfpq =
768
778
dynamic_cast <const IndexIVFPQ*>(storage);
@@ -791,10 +801,10 @@ void IndexHNSW2Level::search(
791
801
std::unique_ptr<DistanceComputer> dis (
792
802
storage_distance_computer (storage));
793
803
794
- int candidates_size = hnsw. upper_beam ;
804
+ constexpr int candidates_size = 1 ;
795
805
MinimaxHeap candidates (candidates_size);
796
806
797
- #pragma omp for reduction(+ : n1, n2, ndis)
807
+ #pragma omp for reduction(+ : n1, n2, ndis, nhops )
798
808
for (idx_t i = 0 ; i < n; i++) {
799
809
idx_t * idxi = labels + i * k;
800
810
float * simi = distances + i * k;
@@ -816,7 +826,7 @@ void IndexHNSW2Level::search(
816
826
817
827
candidates.clear ();
818
828
819
- for (int j = 0 ; j < hnsw. upper_beam && j < k; j++) {
829
+ for (int j = 0 ; j < k; j++) {
820
830
if (idxi[j] < 0 )
821
831
break ;
822
832
candidates.push (idxi[j], simi[j]);
@@ -840,6 +850,7 @@ void IndexHNSW2Level::search(
840
850
n1 += search_stats.n1 ;
841
851
n2 += search_stats.n2 ;
842
852
ndis += search_stats.ndis ;
853
+ nhops += search_stats.nhops ;
843
854
844
855
vt.advance ();
845
856
vt.advance ();
@@ -848,7 +859,7 @@ void IndexHNSW2Level::search(
848
859
}
849
860
}
850
861
851
- hnsw_stats.combine ({n1, n2, ndis});
862
+ hnsw_stats.combine ({n1, n2, ndis, nhops });
852
863
}
853
864
}
854
865
0 commit comments