Skip to content

Commit 905963f

Browse files
Di-Isfacebook-github-bot
authored andcommitted
Add ngpu default argument to knn_ground_truth (facebookresearch#4123)
Summary: This pull request introduces a new default argument, `ngpu=-1`, to the `knn_ground_truth` function in the `faiss.contrib`. ## Purpose of Change ### Bug Fix In the current implementation, running tests under the tests directory (CPU tests) in an environment with faiss-gpu installed would inadvertently use the GPU and cause unintended behavior. This pull request prevents the GPU from being used during CPU-only tests by explicitly controlling GPU allocation via the ngpu parameter. ### API Consistency Other functions that call `faiss.get_num_gpus` in `faiss.contrib`, such as `range_search_max_results` and `range_ground_truth`, already include the `ngpu` argument. Adding this parameter to `knn_ground_truth` will ensure consistency across the API, reduce potential confusion, and improve ease of use. Pull Request resolved: facebookresearch#4123 Reviewed By: asadoughi Differential Revision: D68199506 Pulled By: junjieqi fbshipit-source-id: cb50e206d8a1a982c21b0ccb42825ea45873f3ef
1 parent 4c315a9 commit 905963f

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

contrib/exhaustive_search.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
LOG = logging.getLogger(__name__)
1313

14-
def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):
14+
def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2, shard=False, ngpu=-1):
1515
"""Computes the exact KNN search results for a dataset that possibly
1616
does not fit in RAM but for which we have an iterator that
1717
returns it block by block.
@@ -23,9 +23,14 @@ def knn_ground_truth(xq, db_iterator, k, metric_type=faiss.METRIC_L2):
2323
rh = faiss.ResultHeap(nq, k, keep_max=keep_max)
2424

2525
index = faiss.IndexFlat(d, metric_type)
26-
if faiss.get_num_gpus():
27-
LOG.info('running on %d GPUs' % faiss.get_num_gpus())
28-
index = faiss.index_cpu_to_all_gpus(index)
26+
if ngpu == -1:
27+
ngpu = faiss.get_num_gpus()
28+
29+
if ngpu:
30+
LOG.info('running on %d GPUs' % ngpu)
31+
co = faiss.GpuMultipleClonerOptions()
32+
co.shard = shard
33+
index = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
2934

3035
# compute ground-truth by blocks, and add to heaps
3136
i0 = 0

tests/test_contrib.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def matrix_iterator(xb, bs):
5050
yield xb[i0:i0 + bs]
5151

5252
Dnew, Inew = knn_ground_truth(
53-
xq, matrix_iterator(xb, 1000), 10, metric)
53+
xq, matrix_iterator(xb, 1000), 10, metric, ngpu=0)
5454

5555
np.testing.assert_array_equal(Iref, Inew)
5656
# decimal = 4 required when run on GPU

0 commit comments

Comments
 (0)