Skip to content

Commit 32f0e8c

Browse files
mdouzefacebook-github-bot
authored andcommittedJan 11, 2024
Generalize ResultHanlder, support range search for HNSW and Fast Scan (#3190)
Summary: Pull Request resolved: #3190 This diff adds more result handlers in order to expose them externally. This enables range search for HSNW and Fast Scan, and nprobe parameter support for FastScan. Reviewed By: pemazare Differential Revision: D52547384 fbshipit-source-id: 271da5ffea6411df3d8e50641abade18bd7b774b
1 parent 0013c70 commit 32f0e8c

38 files changed

+1995
-2015
lines changed
 

‎CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ We try to indicate most contributions here with the contributor names who are no
99
the Facebook Faiss team. Feel free to add entries here if you submit a PR.
1010

1111
## [Unreleased]
12+
- Support for range search in HNSW and Fast scan IVF.
1213
## [1.7.4] - 2023-04-12
1314
### Added
1415
- Added big batch IVF search for conducting efficient search with big batches of queries

‎benchs/link_and_code/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ The code runs on top of Faiss. The HNSW index can be extended with a
3939
`ReconstructFromNeighbors` C++ object that refines the distances. The
4040
training is implemented in Python.
4141

42+
Update: 2023-12-28: the current Faiss dropped support for reconstruction with
43+
this method.
4244

4345
Reproducing Table 2 in the paper
4446
--------------------------------

‎contrib/evaluation.py

+1
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, rtol=1e-5):
261261
mask = DrefC == dis
262262
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
263263

264+
264265
def check_ref_range_results(Lref, Dref, Iref,
265266
Lnew, Dnew, Inew):
266267
""" compare range search results wrt. a reference result,

‎faiss/IndexAdditiveQuantizer.cpp

+17-10
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,19 @@ struct AQDistanceComputerLUT : FlatCodesDistanceComputer {
114114
* scanning implementation for search
115115
************************************************************/
116116

117-
template <class VectorDistance, class ResultHandler>
117+
template <class VectorDistance, class BlockResultHandler>
118118
void search_with_decompress(
119119
const IndexAdditiveQuantizer& ir,
120120
const float* xq,
121121
VectorDistance& vd,
122-
ResultHandler& res) {
122+
BlockResultHandler& res) {
123123
const uint8_t* codes = ir.codes.data();
124124
size_t ntotal = ir.ntotal;
125125
size_t code_size = ir.code_size;
126126
const AdditiveQuantizer* aq = ir.aq;
127127

128-
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
128+
using SingleResultHandler =
129+
typename BlockResultHandler::SingleResultHandler;
129130

130131
#pragma omp parallel for if (res.nq > 100)
131132
for (int64_t q = 0; q < res.nq; q++) {
@@ -142,19 +143,23 @@ void search_with_decompress(
142143
}
143144
}
144145

145-
template <bool is_IP, AdditiveQuantizer::Search_type_t st, class ResultHandler>
146+
template <
147+
bool is_IP,
148+
AdditiveQuantizer::Search_type_t st,
149+
class BlockResultHandler>
146150
void search_with_LUT(
147151
const IndexAdditiveQuantizer& ir,
148152
const float* xq,
149-
ResultHandler& res) {
153+
BlockResultHandler& res) {
150154
const AdditiveQuantizer& aq = *ir.aq;
151155
const uint8_t* codes = ir.codes.data();
152156
size_t ntotal = ir.ntotal;
153157
size_t code_size = aq.code_size;
154158
size_t nq = res.nq;
155159
size_t d = ir.d;
156160

157-
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
161+
using SingleResultHandler =
162+
typename BlockResultHandler::SingleResultHandler;
158163
std::unique_ptr<float[]> LUT(new float[nq * aq.total_codebook_size]);
159164

160165
aq.compute_LUT(nq, xq, LUT.get());
@@ -241,21 +246,23 @@ void IndexAdditiveQuantizer::search(
241246
if (metric_type == METRIC_L2) {
242247
using VD = VectorDistance<METRIC_L2>;
243248
VD vd = {size_t(d), metric_arg};
244-
HeapResultHandler<VD::C> rh(n, distances, labels, k);
249+
HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
245250
search_with_decompress(*this, x, vd, rh);
246251
} else if (metric_type == METRIC_INNER_PRODUCT) {
247252
using VD = VectorDistance<METRIC_INNER_PRODUCT>;
248253
VD vd = {size_t(d), metric_arg};
249-
HeapResultHandler<VD::C> rh(n, distances, labels, k);
254+
HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
250255
search_with_decompress(*this, x, vd, rh);
251256
}
252257
} else {
253258
if (metric_type == METRIC_INNER_PRODUCT) {
254-
HeapResultHandler<CMin<float, idx_t>> rh(n, distances, labels, k);
259+
HeapBlockResultHandler<CMin<float, idx_t>> rh(
260+
n, distances, labels, k);
255261
search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
256262
*this, x, rh);
257263
} else {
258-
HeapResultHandler<CMax<float, idx_t>> rh(n, distances, labels, k);
264+
HeapBlockResultHandler<CMax<float, idx_t>> rh(
265+
n, distances, labels, k);
259266
switch (aq->search_type) {
260267
#define DISPATCH(st) \
261268
case AdditiveQuantizer::st: \

‎faiss/IndexAdditiveQuantizerFastScan.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ void IndexAdditiveQuantizerFastScan::search(
203203

204204
NormTableScaler scaler(norm_scale);
205205
if (metric_type == METRIC_L2) {
206-
search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
206+
search_dispatch_implem<true>(n, x, k, distances, labels, &scaler);
207207
} else {
208-
search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
208+
search_dispatch_implem<false>(n, x, k, distances, labels, &scaler);
209209
}
210210
}
211211

‎faiss/IndexBinaryHNSW.cpp

+13-10
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// -*- c++ -*-
9-
108
#include <faiss/IndexBinaryHNSW.h>
119

1210
#include <omp.h>
@@ -28,6 +26,7 @@
2826
#include <faiss/impl/AuxIndexStructures.h>
2927
#include <faiss/impl/DistanceComputer.h>
3028
#include <faiss/impl/FaissAssert.h>
29+
#include <faiss/impl/ResultHandler.h>
3130
#include <faiss/utils/Heap.h>
3231
#include <faiss/utils/hamming.h>
3332
#include <faiss/utils/random.h>
@@ -201,27 +200,31 @@ void IndexBinaryHNSW::search(
201200
!params, "search params not supported for this index");
202201
FAISS_THROW_IF_NOT(k > 0);
203202

203+
// we use the buffer for distances as float but convert them back
204+
// to int in the end
205+
float* distances_f = (float*)distances;
206+
207+
using RH = HeapBlockResultHandler<HNSW::C>;
208+
RH bres(n, distances_f, labels, k);
209+
204210
#pragma omp parallel
205211
{
206212
VisitedTable vt(ntotal);
207213
std::unique_ptr<DistanceComputer> dis(get_distance_computer());
214+
RH::SingleResultHandler res(bres);
208215

209216
#pragma omp for
210217
for (idx_t i = 0; i < n; i++) {
211-
idx_t* idxi = labels + i * k;
212-
float* simi = (float*)(distances + i * k);
213-
218+
res.begin(i);
214219
dis->set_query((float*)(x + i * code_size));
215-
216-
maxheap_heapify(k, simi, idxi);
217-
hnsw.search(*dis, k, idxi, simi, vt);
218-
maxheap_reorder(k, simi, idxi);
220+
hnsw.search(*dis, res, vt);
221+
res.end();
219222
}
220223
}
221224

222225
#pragma omp parallel for
223226
for (int i = 0; i < n * k; ++i) {
224-
distances[i] = std::round(((float*)distances)[i]);
227+
distances[i] = std::round(distances_f[i]);
225228
}
226229
}
227230

0 commit comments

Comments
 (0)
Please sign in to comment.