Skip to content

Commit 7178bf8

Browse files
mdouzefacebook-github-bot
authored andcommitted
Fix radius search with HSNW and IP (facebookresearch#3698)
Summary: Pull Request resolved: facebookresearch#3698 Addressed issue facebookresearch#3684 I forgot to negate the threshold of the radius search. This diff adds a test and fixes the issue. Reviewed By: mengdilin Differential Revision: D60373054 fbshipit-source-id: 70f3daa8292177a4038846a94aff6221f88077e8
1 parent 34bbe5e commit 7178bf8

File tree

2 files changed

+29
-21
lines changed

2 files changed

+29
-21
lines changed

faiss/IndexHNSW.cpp

+1-21
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,6 @@
3535
#include <faiss/utils/random.h>
3636
#include <faiss/utils/sorting.h>
3737

38-
extern "C" {
39-
40-
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
41-
42-
int sgemm_(
43-
const char* transa,
44-
const char* transb,
45-
FINTEGER* m,
46-
FINTEGER* n,
47-
FINTEGER* k,
48-
const float* alpha,
49-
const float* a,
50-
FINTEGER* lda,
51-
const float* b,
52-
FINTEGER* ldb,
53-
float* beta,
54-
float* c,
55-
FINTEGER* ldc);
56-
}
57-
5838
namespace faiss {
5939

6040
using MinimaxHeap = HNSW::MinimaxHeap;
@@ -340,7 +320,7 @@ void IndexHNSW::range_search(
340320
RangeSearchResult* result,
341321
const SearchParameters* params) const {
342322
using RH = RangeSearchBlockResultHandler<HNSW::C>;
343-
RH bres(result, radius);
323+
RH bres(result, is_similarity_metric(metric_type) ? -radius : radius);
344324

345325
hnsw_search(this, n, x, bres, params);
346326

tests/test_graph_based.py

+28
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,34 @@ def test_abs_inner_product(self):
184184
# 4769 vs. 500*10
185185
self.assertGreater(inter, Iref.size * 0.9)
186186

187+
188+
class Issue3684(unittest.TestCase):
189+
190+
def test_issue3684(self):
191+
np.random.seed(1234) # For reproducibility
192+
d = 256 # Example dimension
193+
nb = 10 # Number of database vectors
194+
nq = 2 # Number of query vectors
195+
xb = np.random.random((nb, d)).astype('float32')
196+
xq = np.random.random((nq, d)).astype('float32')
197+
198+
faiss.normalize_L2(xb) # Normalize both query and database vectors
199+
faiss.normalize_L2(xq)
200+
201+
hnsw_index_ip = faiss.IndexHNSWFlat(256, 16, faiss.METRIC_INNER_PRODUCT)
202+
hnsw_index_ip.hnsw.efConstruction = 512
203+
hnsw_index_ip.hnsw.efSearch = 512
204+
hnsw_index_ip.add(xb)
205+
206+
# test knn
207+
D, I = hnsw_index_ip.search(xq, 10)
208+
self.assertTrue(np.all(D[:, :-1] >= D[:, 1:]))
209+
210+
# test range search
211+
radius = 0.74 # Cosine similarity threshold
212+
lims, D, I = hnsw_index_ip.range_search(xq, radius)
213+
self.assertTrue(np.all(D >= radius))
214+
187215

188216
class TestNSG(unittest.TestCase):
189217

0 commit comments

Comments
 (0)