Skip to content

Commit b9ea339

Browse files
mdouzefacebook-github-bot
authored andcommitted
support range search from GPU (facebookresearch#2860)
Summary: Pull Request resolved: facebookresearch#2860 Optimized range search function where the GPU computes by default and falls back on gpu for queries where there are too many results. Parallelize the CPU to GPU cloning, it seems to work. Support range_search_preassigned in Python Fix long-standing issue with SWIG exposed functions that did not release the GIL (in particular the MapLong2Long). Adds a MapInt64ToInt64 that is more efficient than MapLong2Long. Reviewed By: algoriddle Differential Revision: D45672301 fbshipit-source-id: 2e77397c40083818584dbafa5427149359a2abfd
1 parent 54d331e commit b9ea339

19 files changed

+711
-181
lines changed

contrib/evaluation.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def compute_PR_for(q):
226226
# Functions that compare search results with a reference result.
227227
# They are intended for use in tests
228228

229-
def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
229+
def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
230230
""" test that knn search results are identical, raise if not """
231231
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
232232
# here we have to be careful because of draws
@@ -243,14 +243,14 @@ def test_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
243243
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))
244244

245245

246-
def test_ref_range_results(lims_ref, Dref, Iref,
247-
lims_new, Dnew, Inew):
246+
def check_ref_range_results(Lref, Dref, Iref,
247+
Lnew, Dnew, Inew):
248248
""" compare range search results wrt. a reference result,
249249
throw if it fails """
250-
np.testing.assert_array_equal(lims_ref, lims_new)
251-
nq = len(lims_ref) - 1
250+
np.testing.assert_array_equal(Lref, Lnew)
251+
nq = len(Lref) - 1
252252
for i in range(nq):
253-
l0, l1 = lims_ref[i], lims_ref[i + 1]
253+
l0, l1 = Lref[i], Lref[i + 1]
254254
Ii_ref = Iref[l0:l1]
255255
Ii_new = Inew[l0:l1]
256256
Di_ref = Dref[l0:l1]

contrib/exhaustive_search.py

+68-34
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
LOG = logging.getLogger(__name__)
1313

14-
1514
def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):
1615
"""Computes the exact KNN search results for a dataset that possibly
1716
does not fit in RAM but for which we have an iterator that
@@ -51,47 +50,82 @@ def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):
5150

5251

5352

54-
def range_search_gpu(xq, r2, index_gpu, index_cpu):
53+
def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
5554
"""GPU does not support range search, so we emulate it with
5655
knn search + fallback to CPU index.
5756
58-
The index_cpu can either be a CPU index or a numpy table that will
59-
be used to construct a Flat index if needed.
57+
The index_cpu can either be:
58+
- a CPU index that supports range search
59+
- a numpy table, that will be used to construct a Flat index if needed.
60+
- None. In that case, at most gpu_k results will be returned
6061
"""
6162
nq, d = xq.shape
62-
LOG.debug("GPU search %d queries" % nq)
63-
k = min(index_gpu.ntotal, 1024)
63+
k = min(index_gpu.ntotal, gpu_k)
64+
keep_max = faiss.is_similarity_metric(index_gpu.metric_type)
65+
LOG.debug(f"GPU search {nq} queries with {k=:}")
66+
t0 = time.time()
6467
D, I = index_gpu.search(xq, k)
65-
if index_gpu.metric_type == faiss.METRIC_L2:
66-
mask = D[:, k - 1] < r2
67-
else:
68-
mask = D[:, k - 1] > r2
69-
if mask.sum() > 0:
70-
LOG.debug("CPU search remain %d" % mask.sum())
71-
if isinstance(index_cpu, np.ndarray):
72-
# then it in fact an array that we have to make flat
73-
xb = index_cpu
74-
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
75-
index_cpu.add(xb)
76-
lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
68+
t1 = time.time() - t0
69+
t2 = 0
70+
lim_remain = None
71+
if index_cpu is not None:
72+
if not keep_max:
73+
mask = D[:, k - 1] < r2
74+
else:
75+
mask = D[:, k - 1] > r2
76+
if mask.sum() > 0:
77+
LOG.debug("CPU search remain %d" % mask.sum())
78+
t0 = time.time()
79+
if isinstance(index_cpu, np.ndarray):
80+
# then it in fact an array that we have to make flat
81+
xb = index_cpu
82+
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
83+
index_cpu.add(xb)
84+
lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
85+
t2 = time.time() - t0
7786
LOG.debug("combine")
78-
D_res, I_res = [], []
79-
nr = 0
80-
for i in range(nq):
81-
if not mask[i]:
82-
if index_gpu.metric_type == faiss.METRIC_L2:
83-
nv = (D[i, :] < r2).sum()
87+
t0 = time.time()
88+
89+
combiner = faiss.CombinerRangeKNN(nq, k, float(r2), keep_max)
90+
if True:
91+
sp = faiss.swig_ptr
92+
combiner.I = sp(I)
93+
combiner.D = sp(D)
94+
# combiner.set_knn_result(sp(I), sp(D))
95+
if lim_remain is not None:
96+
combiner.mask = sp(mask)
97+
combiner.D_remain = sp(D_remain)
98+
combiner.lim_remain = sp(lim_remain.view("int64"))
99+
combiner.I_remain = sp(I_remain)
100+
# combiner.set_range_result(sp(mask), sp(lim_remain.view("int64")), sp(D_remain), sp(I_remain))
101+
L_res = np.empty(nq + 1, dtype='int64')
102+
combiner.compute_sizes(sp(L_res))
103+
nres = L_res[-1]
104+
D_res = np.empty(nres, dtype='float32')
105+
I_res = np.empty(nres, dtype='int64')
106+
combiner.write_result(sp(D_res), sp(I_res))
107+
else:
108+
D_res, I_res = [], []
109+
nr = 0
110+
for i in range(nq):
111+
if not mask[i]:
112+
if index_gpu.metric_type == faiss.METRIC_L2:
113+
nv = (D[i, :] < r2).sum()
114+
else:
115+
nv = (D[i, :] > r2).sum()
116+
D_res.append(D[i, :nv])
117+
I_res.append(I[i, :nv])
84118
else:
85-
nv = (D[i, :] > r2).sum()
86-
D_res.append(D[i, :nv])
87-
I_res.append(I[i, :nv])
88-
else:
89-
l0, l1 = lim_remain[nr], lim_remain[nr + 1]
90-
D_res.append(D_remain[l0:l1])
91-
I_res.append(I_remain[l0:l1])
92-
nr += 1
93-
lims = np.cumsum([0] + [len(di) for di in D_res])
94-
return lims, np.hstack(D_res), np.hstack(I_res)
119+
l0, l1 = lim_remain[nr], lim_remain[nr + 1]
120+
D_res.append(D_remain[l0:l1])
121+
I_res.append(I_remain[l0:l1])
122+
nr += 1
123+
L_res = np.cumsum([0] + [len(di) for di in D_res])
124+
D_res = np.hstack(D_res)
125+
I_res = np.hstack(I_res)
126+
t3 = time.time() - t0
127+
LOG.debug(f"times {t1:.3f}s {t2:.3f}s {t3:.3f}s")
128+
return L_res, D_res, I_res
95129

96130

97131
def range_ground_truth(xq, db_iterator, threshold, metric_type=faiss.METRIC_L2,

contrib/ivf_tools.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def range_search_preassigned(index_ivf, x, radius, list_nos, coarse_dis=None):
7777
res = faiss.RangeSearchResult(n)
7878
sp = faiss.swig_ptr
7979

80-
index_ivf.range_search_preassigned(
80+
index_ivf.range_search_preassigned_c(
8181
n, sp(x), radius,
8282
sp(list_nos), sp(coarse_dis),
8383
res

faiss/gpu/GpuCloner.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ Index* ToGpuClonerMultiple::clone_Index_to_shards(const Index* index) {
309309

310310
std::vector<faiss::Index*> shards(n);
311311

312+
#pragma omp parallel for
312313
for (idx_t i = 0; i < n; i++) {
313314
// make a shallow copy
314315
if (reserveVecs) {

faiss/gpu/test/test_contrib_gpu.py

+48-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from faiss.contrib import datasets, evaluation, big_batch_search
1414
from faiss.contrib.exhaustive_search import knn_ground_truth, \
15-
range_ground_truth
15+
range_ground_truth, range_search_gpu
1616

1717

1818
class TestComputeGT(unittest.TestCase):
@@ -51,7 +51,7 @@ def do_test_range(self, metric):
5151
xq, ds.database_iterator(bs=100), threshold,
5252
metric_type=metric)
5353

54-
evaluation.test_ref_range_results(
54+
evaluation.check_ref_range_results(
5555
ref_lims, ref_D, ref_I,
5656
new_lims, new_D, new_I
5757
)
@@ -131,3 +131,49 @@ def knn_function(xq, xb, k, metric=faiss.METRIC_L2, thread_id=None):
131131

132132
def test_Flat(self):
133133
self.do_test("IVF64,Flat")
134+
135+
136+
class TestRangeSearchGpu(unittest.TestCase):
137+
138+
def do_test(self, factory_string):
139+
ds = datasets.SyntheticDataset(32, 2000, 4000, 1000)
140+
k = 10
141+
index_gpu = faiss.index_cpu_to_all_gpus(
142+
faiss.index_factory(ds.d, factory_string)
143+
)
144+
index_gpu.train(ds.get_train())
145+
index_gpu.add(ds.get_database())
146+
# just to find a reasonable threshold
147+
D, _ = index_gpu.search(ds.get_queries(), k)
148+
threshold = np.median(D[:, 5])
149+
150+
# ref run
151+
index_cpu = faiss.index_gpu_to_cpu(index_gpu)
152+
Lref, Dref, Iref = index_cpu.range_search(ds.get_queries(), threshold)
153+
nres_per_query = Lref[1:] - Lref[:-1]
154+
# make sure some entries were computed by CPU and some by GPU
155+
assert np.any(nres_per_query > 4) and not np.all(nres_per_query > 4)
156+
157+
# mixed GPU / CPU run
158+
Lnew, Dnew, Inew = range_search_gpu(
159+
ds.get_queries(), threshold, index_gpu, index_cpu, gpu_k=4)
160+
evaluation.check_ref_range_results(
161+
Lref, Dref, Iref,
162+
Lnew, Dnew, Inew
163+
)
164+
165+
# also test the version without CPU search
166+
Lnew2, Dnew2, Inew2 = range_search_gpu(
167+
ds.get_queries(), threshold, index_gpu, None, gpu_k=4)
168+
for q in range(ds.nq):
169+
ref = Iref[Lref[q]:Lref[q+1]]
170+
new = Inew2[Lnew2[q]:Lnew2[q+1]]
171+
if nres_per_query[q] <= 4:
172+
self.assertEqual(set(ref), set(new))
173+
else:
174+
ref = set(ref)
175+
for v in new:
176+
self.assertIn(v, ref)
177+
178+
def test_ivf(self):
179+
self.do_test("IVF64,Flat")

faiss/python/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from faiss.extra_wrappers import kmin, kmax, pairwise_distances, rand, randint, \
2323
lrand, randn, rand_smooth_vectors, eval_intersection, normalize_L2, \
2424
ResultHeap, knn, Kmeans, checksum, matrix_bucket_sort_inplace, bucket_sort, \
25-
merge_knn_results
25+
merge_knn_results, MapInt64ToInt64
2626

2727

2828
__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR,

faiss/python/class_wrappers.py

+96-1
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ def replacement_range_search(self, x, thresh, *, params=None):
544544
n, d = x.shape
545545
assert d == self.d
546546
x = np.ascontiguousarray(x, dtype='float32')
547+
thresh = float(thresh)
547548

548549
res = RangeSearchResult(n)
549550
self.range_search_c(n, swig_ptr(x), thresh, res, params)
@@ -618,6 +619,64 @@ def replacement_search_preassigned(self, x, k, Iq, Dq, *, params=None, D=None, I
618619
)
619620
return D, I
620621

622+
def replacement_range_search_preassigned(self, x, thresh, Iq, Dq, *, params=None):
623+
"""Search vectors that are within a distance of the query vectors.
624+
625+
Parameters
626+
----------
627+
x : array_like
628+
Query vectors, shape (n, d) where d is appropriate for the index.
629+
`dtype` must be float32.
630+
thresh : float
631+
Threshold to select neighbors. All elements within this radius are returned,
632+
except for maximum inner product indexes, where the elements above the
633+
threshold are returned
634+
Iq : array_like, optional
635+
Nearest centroids, size (n, nprobe)
636+
Dq : array_like, optional
637+
Distance array to the centroids, size (n, nprobe)
638+
params : SearchParameters
639+
Search parameters of the current search (overrides the class-level params)
640+
641+
642+
Returns
643+
-------
644+
lims: array_like
645+
Starting index of the results for each query vector, size n+1.
646+
D : array_like
647+
Distances of the nearest neighbors, shape `lims[n]`. The distances for
648+
query i are in `D[lims[i]:lims[i+1]]`.
649+
I : array_like
650+
Labels of nearest neighbors, shape `lims[n]`. The labels for query i
651+
are in `I[lims[i]:lims[i+1]]`.
652+
653+
"""
654+
n, d = x.shape
655+
assert d == self.d
656+
x = np.ascontiguousarray(x, dtype='float32')
657+
658+
Iq = np.ascontiguousarray(Iq, dtype='int64')
659+
assert params is None, "params not supported"
660+
assert Iq.shape == (n, self.nprobe)
661+
662+
if Dq is not None:
663+
Dq = np.ascontiguousarray(Dq, dtype='float32')
664+
assert Dq.shape == Iq.shape
665+
666+
thresh = float(thresh)
667+
res = RangeSearchResult(n)
668+
self.range_search_preassigned_c(
669+
n, swig_ptr(x), thresh,
670+
swig_ptr(Iq), swig_ptr(Dq),
671+
res
672+
)
673+
# get pointers and copy them
674+
lims = rev_swig_ptr(res.lims, n + 1).copy()
675+
nd = int(lims[-1])
676+
D = rev_swig_ptr(res.distances, nd).copy()
677+
I = rev_swig_ptr(res.labels, nd).copy()
678+
return lims, D, I
679+
621680
def replacement_sa_encode(self, x, codes=None):
622681
n, d = x.shape
623682
assert d == self.d
@@ -675,8 +734,12 @@ def replacement_permute_entries(self, perm):
675734
ignore_missing=True)
676735
replace_method(the_class, 'search_and_reconstruct',
677736
replacement_search_and_reconstruct, ignore_missing=True)
737+
738+
# these ones are IVF-specific
678739
replace_method(the_class, 'search_preassigned',
679740
replacement_search_preassigned, ignore_missing=True)
741+
replace_method(the_class, 'range_search_preassigned',
742+
replacement_range_search_preassigned, ignore_missing=True)
680743
replace_method(the_class, 'sa_encode', replacement_sa_encode)
681744
replace_method(the_class, 'sa_decode', replacement_sa_decode)
682745
replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes,
@@ -776,6 +839,36 @@ def replacement_range_search(self, x, thresh):
776839
I = rev_swig_ptr(res.labels, nd).copy()
777840
return lims, D, I
778841

842+
def replacement_range_search_preassigned(self, x, thresh, Iq, Dq, *, params=None):
843+
n, d = x.shape
844+
x = _check_dtype_uint8(x)
845+
assert d * 8 == self.d
846+
847+
Iq = np.ascontiguousarray(Iq, dtype='int64')
848+
assert params is None, "params not supported"
849+
assert Iq.shape == (n, self.nprobe)
850+
851+
if Dq is not None:
852+
Dq = np.ascontiguousarray(Dq, dtype='int32')
853+
assert Dq.shape == Iq.shape
854+
855+
thresh = int(thresh)
856+
res = RangeSearchResult(n)
857+
self.range_search_preassigned_c(
858+
n, swig_ptr(x), thresh,
859+
swig_ptr(Iq), swig_ptr(Dq),
860+
res
861+
)
862+
# get pointers and copy them
863+
lims = rev_swig_ptr(res.lims, n + 1).copy()
864+
nd = int(lims[-1])
865+
D = rev_swig_ptr(res.distances, nd).copy()
866+
I = rev_swig_ptr(res.labels, nd).copy()
867+
return lims, D, I
868+
869+
870+
871+
779872
def replacement_remove_ids(self, x):
780873
if isinstance(x, IDSelector):
781874
sel = x
@@ -794,6 +887,8 @@ def replacement_remove_ids(self, x):
794887
replace_method(the_class, 'remove_ids', replacement_remove_ids)
795888
replace_method(the_class, 'search_preassigned',
796889
replacement_search_preassigned, ignore_missing=True)
890+
replace_method(the_class, 'range_search_preassigned',
891+
replacement_range_search_preassigned, ignore_missing=True)
797892

798893

799894
def handle_VectorTransform(the_class):
@@ -937,7 +1032,7 @@ def handle_MapLong2Long(the_class):
9371032

9381033
def replacement_map_add(self, keys, vals):
9391034
n, = keys.shape
940-
assert (n,) == keys.shape
1035+
assert (n,) == vals.shape
9411036
self.add_c(n, swig_ptr(keys), swig_ptr(vals))
9421037

9431038
def replacement_map_search_multiple(self, keys):

0 commit comments

Comments
 (0)