Skip to content

Commit 2817344

Browse files
algoriddlefacebook-github-bot
authored andcommitted
fix ACCESS VIOLATION error when searching using IDSelectorArray
Summary: Fixes #3156 Metamate says: "This diff fixes an ACCESS VIOLATION error that occurs when searching using IDSelectorArray. The code changes include adding a new parameter to the knn_inner_products_by_idx and knn_L2sqr_by_idx functions in the distances.cpp file, as well as modifying the test_search_params.py file to test the bounds of the IDSelectorArray." Reviewed By: mdouze Differential Revision: D53185461 fbshipit-source-id: c7ec4783f77455684c078bba3aace160078f6c27
1 parent 67c6a19 commit 2817344

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

faiss/utils/distances.cpp

+12-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <algorithm>
1111
#include <cassert>
1212
#include <cmath>
13+
#include <cstddef>
1314
#include <cstdio>
1415
#include <cstring>
1516

@@ -670,7 +671,7 @@ void knn_inner_product(
670671
}
671672
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
672673
knn_inner_products_by_idx(
673-
x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
674+
x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
674675
return;
675676
}
676677

@@ -726,7 +727,7 @@ void knn_L2sqr(
726727
sel = nullptr;
727728
}
728729
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
729-
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
730+
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
730731
return;
731732
}
732733
if (k == 1) {
@@ -904,6 +905,7 @@ void knn_inner_products_by_idx(
904905
size_t d,
905906
size_t nx,
906907
size_t ny,
908+
size_t nsubset,
907909
size_t k,
908910
float* res_vals,
909911
int64_t* res_ids,
@@ -921,9 +923,10 @@ void knn_inner_products_by_idx(
921923
int64_t* __restrict idxi = res_ids + i * k;
922924
minheap_heapify(k, simi, idxi);
923925

924-
for (j = 0; j < ny; j++) {
925-
if (idsi[j] < 0)
926+
for (j = 0; j < nsubset; j++) {
927+
if (idsi[j] < 0 || idsi[j] >= ny) {
926928
break;
929+
}
927930
float ip = fvec_inner_product(x_, y + d * idsi[j], d);
928931

929932
if (ip > simi[0]) {
@@ -941,6 +944,7 @@ void knn_L2sqr_by_idx(
941944
size_t d,
942945
size_t nx,
943946
size_t ny,
947+
size_t nsubset,
944948
size_t k,
945949
float* res_vals,
946950
int64_t* res_ids,
@@ -955,7 +959,10 @@ void knn_L2sqr_by_idx(
955959
float* __restrict simi = res_vals + i * k;
956960
int64_t* __restrict idxi = res_ids + i * k;
957961
maxheap_heapify(k, simi, idxi);
958-
for (size_t j = 0; j < ny; j++) {
962+
for (size_t j = 0; j < nsubset; j++) {
963+
if (idsi[j] < 0 || idsi[j] >= ny) {
964+
break;
965+
}
959966
float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
960967

961968
if (disij < simi[0]) {

faiss/utils/distances.h

+2
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ void knn_inner_products_by_idx(
376376
const int64_t* subset,
377377
size_t d,
378378
size_t nx,
379+
size_t ny,
379380
size_t nsubset,
380381
size_t k,
381382
float* vals,
@@ -398,6 +399,7 @@ void knn_L2sqr_by_idx(
398399
const int64_t* subset,
399400
size_t d,
400401
size_t nx,
402+
size_t ny,
401403
size_t nsubset,
402404
size_t k,
403405
float* vals,

tests/test_search_params.py

+18
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,24 @@ def test_idmap(self):
257257
np.testing.assert_array_equal(Iref, Inew)
258258
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
259259

260+
def test_bounds(self):
261+
# https://github.com/facebookresearch/faiss/issues/3156
262+
d = 64 # dimension
263+
nb = 100000 # database size
264+
xb = np.random.random((nb, d))
265+
index_ip = faiss.IndexFlatIP(d)
266+
index_ip.add(xb)
267+
index_l2 = faiss.IndexFlatIP(d)
268+
index_l2.add(xb)
269+
270+
out_of_bounds_id = nb + 15 # + 14 or lower will work fine
271+
id_selector = faiss.IDSelectorArray([out_of_bounds_id])
272+
search_params = faiss.SearchParameters(sel=id_selector)
273+
274+
# ignores out of bound, does not crash
275+
distances, indices = index_ip.search(xb[:2], k=3, params=search_params)
276+
distances, indices = index_l2.search(xb[:2], k=3, params=search_params)
277+
260278

261279
class TestSearchParams(unittest.TestCase):
262280

0 commit comments

Comments
 (0)