Skip to content

Commit 27dee72

Browse files
Some small improvements.
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
1 parent a56ee81 commit 27dee72

11 files changed

+550
-202
lines changed

faiss/IndexHNSW.cpp

+22-11
Original file line numberDiff line numberDiff line change
@@ -255,23 +255,23 @@ void hnsw_search(
255255
FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
256256
efSearch = params->efSearch;
257257
}
258-
size_t n1 = 0, n2 = 0, ndis = 0;
258+
size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
259259

260260
idx_t check_period = InterruptCallback::get_period_hint(
261261
hnsw.max_level * index->d * efSearch);
262262

263263
for (idx_t i0 = 0; i0 < n; i0 += check_period) {
264264
idx_t i1 = std::min(i0 + check_period, n);
265265

266-
#pragma omp parallel
266+
#pragma omp parallel if (i1 - i0 > 1)
267267
{
268268
VisitedTable vt(index->ntotal);
269269
typename BlockResultHandler::SingleResultHandler res(bres);
270270

271271
std::unique_ptr<DistanceComputer> dis(
272272
storage_distance_computer(index->storage));
273273

274-
#pragma omp for reduction(+ : n1, n2, ndis) schedule(guided)
274+
#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
275275
for (idx_t i = i0; i < i1; i++) {
276276
res.begin(i);
277277
dis->set_query(x + i * index->d);
@@ -280,13 +280,14 @@ void hnsw_search(
280280
n1 += stats.n1;
281281
n2 += stats.n2;
282282
ndis += stats.ndis;
283+
nhops += stats.nhops;
283284
res.end();
284285
}
285286
}
286287
InterruptCallback::check();
287288
}
288289

289-
hnsw_stats.combine({n1, n2, ndis});
290+
hnsw_stats.combine({n1, n2, ndis, nhops});
290291
}
291292

292293
} // anonymous namespace
@@ -612,6 +613,10 @@ void IndexHNSW::permute_entries(const idx_t* perm) {
612613
hnsw.permute_entries(perm);
613614
}
614615

616+
DistanceComputer* IndexHNSW::get_distance_computer() const {
617+
return storage->get_distance_computer();
618+
}
619+
615620
/**************************************************************
616621
* IndexHNSWFlat implementation
617622
**************************************************************/
@@ -635,8 +640,13 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
635640

636641
IndexHNSWPQ::IndexHNSWPQ() = default;
637642

638-
IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits)
639-
: IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) {
643+
IndexHNSWPQ::IndexHNSWPQ(
644+
int d,
645+
int pq_m,
646+
int M,
647+
int pq_nbits,
648+
MetricType metric)
649+
: IndexHNSW(new IndexPQ(d, pq_m, pq_nbits, metric), M) {
640650
own_fields = true;
641651
is_trained = false;
642652
}
@@ -762,7 +772,7 @@ void IndexHNSW2Level::search(
762772
IndexHNSW::search(n, x, k, distances, labels);
763773

764774
} else { // "mixed" search
765-
size_t n1 = 0, n2 = 0, ndis = 0;
775+
size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
766776

767777
const IndexIVFPQ* index_ivfpq =
768778
dynamic_cast<const IndexIVFPQ*>(storage);
@@ -791,10 +801,10 @@ void IndexHNSW2Level::search(
791801
std::unique_ptr<DistanceComputer> dis(
792802
storage_distance_computer(storage));
793803

794-
int candidates_size = hnsw.upper_beam;
804+
constexpr int candidates_size = 1;
795805
MinimaxHeap candidates(candidates_size);
796806

797-
#pragma omp for reduction(+ : n1, n2, ndis)
807+
#pragma omp for reduction(+ : n1, n2, ndis, nhops)
798808
for (idx_t i = 0; i < n; i++) {
799809
idx_t* idxi = labels + i * k;
800810
float* simi = distances + i * k;
@@ -816,7 +826,7 @@ void IndexHNSW2Level::search(
816826

817827
candidates.clear();
818828

819-
for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
829+
for (int j = 0; j < k; j++) {
820830
if (idxi[j] < 0)
821831
break;
822832
candidates.push(idxi[j], simi[j]);
@@ -840,6 +850,7 @@ void IndexHNSW2Level::search(
840850
n1 += search_stats.n1;
841851
n2 += search_stats.n2;
842852
ndis += search_stats.ndis;
853+
nhops += search_stats.nhops;
843854

844855
vt.advance();
845856
vt.advance();
@@ -848,7 +859,7 @@ void IndexHNSW2Level::search(
848859
}
849860
}
850861

851-
hnsw_stats.combine({n1, n2, ndis});
862+
hnsw_stats.combine({n1, n2, ndis, nhops});
852863
}
853864
}
854865

faiss/IndexHNSW.h

+9-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct IndexHNSW;
2727
struct IndexHNSW : Index {
2828
typedef HNSW::storage_idx_t storage_idx_t;
2929

30-
// the link strcuture
30+
// the link structure
3131
HNSW hnsw;
3232

3333
// the sequential storage
@@ -111,6 +111,8 @@ struct IndexHNSW : Index {
111111
void link_singletons();
112112

113113
void permute_entries(const idx_t* perm);
114+
115+
DistanceComputer* get_distance_computer() const override;
114116
};
115117

116118
/** Flat index topped with with a HNSW structure to access elements
@@ -127,7 +129,12 @@ struct IndexHNSWFlat : IndexHNSW {
127129
*/
128130
struct IndexHNSWPQ : IndexHNSW {
129131
IndexHNSWPQ();
130-
IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8);
132+
IndexHNSWPQ(
133+
int d,
134+
int pq_m,
135+
int M,
136+
int pq_nbits = 8,
137+
MetricType metric = METRIC_L2);
131138
void train(idx_t n, const float* x) override;
132139
};
133140

faiss/IndexRefine.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ template <class C>
6868
static void reorder_2_heaps(
6969
idx_t n,
7070
idx_t k,
71-
idx_t* labels,
72-
float* distances,
71+
idx_t* __restrict labels,
72+
float* __restrict distances,
7373
idx_t k_base,
74-
const idx_t* base_labels,
75-
const float* base_distances) {
76-
#pragma omp parallel for
74+
const idx_t* __restrict base_labels,
75+
const float* __restrict base_distances) {
76+
#pragma omp parallel for if (n > 1)
7777
for (idx_t i = 0; i < n; i++) {
7878
idx_t* idxo = labels + i * k;
7979
float* diso = distances + i * k;

0 commit comments

Comments
 (0)