Skip to content

Commit 959cd49

Browse files
algoriddletarang-jain
authored andcommittedJun 25, 2024
add use_raft to knn_gpu (torch) (facebookresearch#3509)
Summary: Add support for `use_raft` in the torch version of `knn_gpu`. The numpy version already has this support, see https://github.com/facebookresearch/faiss/blob/main/faiss/python/gpu_wrappers.py#L59 Pull Request resolved: facebookresearch#3509 Reviewed By: mlomeli1, junjieqi Differential Revision: D58489851 Pulled By: algoriddle fbshipit-source-id: cfad722fefd4809b135b765d0d43587cfd782d0e
1 parent 2ce4b28 commit 959cd49

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed
 

‎contrib/torch_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,9 @@ def torch_replacement_sa_decode(self, codes, x=None):
492492
if issubclass(the_class, faiss.Index):
493493
handle_torch_Index(the_class)
494494

495+
495496
# allows torch tensor usage with bfKnn
496-
def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1):
497+
def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_raft=False):
497498
if type(xb) is np.ndarray:
498499
# Forward to faiss __init__.py base method
499500
return faiss.knn_gpu_numpy(res, xq, xb, k, D, I, metric, device)
@@ -574,6 +575,7 @@ def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRI
574575
args.outIndices = I_ptr
575576
args.outIndicesType = I_type
576577
args.device = device
578+
args.use_raft = use_raft
577579

578580
with using_stream(res):
579581
faiss.bfKnn(res, args)

‎faiss/gpu/test/torch_test_contrib_gpu.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def test_sa_encode_decode(self):
249249
return
250250

251251
class TestTorchUtilsKnnGpu(unittest.TestCase):
252-
def test_knn_gpu(self):
252+
def test_knn_gpu(self, use_raft=False):
253253
torch.manual_seed(10)
254254
d = 32
255255
nb = 1024
@@ -286,7 +286,7 @@ def test_knn_gpu(self):
286286
else:
287287
xb_c = xb_np
288288

289-
D, I = faiss.knn_gpu(res, xq_c, xb_c, k)
289+
D, I = faiss.knn_gpu(res, xq_c, xb_c, k, use_raft=use_raft)
290290

291291
self.assertTrue(torch.equal(torch.from_numpy(I), gt_I))
292292
self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1e-4)
@@ -312,15 +312,15 @@ def test_knn_gpu(self):
312312
xb_c = to_column_major_torch(xb)
313313
assert not xb_c.is_contiguous()
314314

315-
D, I = faiss.knn_gpu(res, xq_c, xb_c, k)
315+
D, I = faiss.knn_gpu(res, xq_c, xb_c, k, use_raft=use_raft)
316316

317317
self.assertTrue(torch.equal(I.cpu(), gt_I))
318318
self.assertLess((D.cpu() - gt_D).abs().max(), 1e-4)
319319

320320
# test on subset
321321
try:
322322
# This internally uses the current pytorch stream
323-
D, I = faiss.knn_gpu(res, xq_c[6:8], xb_c, k)
323+
D, I = faiss.knn_gpu(res, xq_c[6:8], xb_c, k, use_raft=use_raft)
324324
except TypeError:
325325
if not xq_row_major:
326326
# then it is expected
@@ -331,7 +331,13 @@ def test_knn_gpu(self):
331331
self.assertTrue(torch.equal(I.cpu(), gt_I[6:8]))
332332
self.assertLess((D.cpu() - gt_D[6:8]).abs().max(), 1e-4)
333333

334-
def test_knn_gpu_datatypes(self):
334+
@unittest.skipUnless(
335+
"RAFT" in faiss.get_compile_options(),
336+
"only if RAFT is compiled in")
337+
def test_knn_gpu_raft(self):
338+
self.test_knn_gpu(use_raft=True)
339+
340+
def test_knn_gpu_datatypes(self, use_raft=False):
335341
torch.manual_seed(10)
336342
d = 10
337343
nb = 1024
@@ -354,7 +360,7 @@ def test_knn_gpu_datatypes(self):
354360
D = torch.zeros(nq, k, device=xb_c.device, dtype=torch.float32)
355361
I = torch.zeros(nq, k, device=xb_c.device, dtype=torch.int32)
356362

357-
faiss.knn_gpu(res, xq_c, xb_c, k, D, I)
363+
faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_raft=use_raft)
358364

359365
self.assertTrue(torch.equal(I.long().cpu(), gt_I))
360366
self.assertLess((D.float().cpu() - gt_D).abs().max(), 1.5e-3)
@@ -366,7 +372,7 @@ def test_knn_gpu_datatypes(self):
366372
xb_c = xb.half().numpy()
367373
xq_c = xq.half().numpy()
368374

369-
faiss.knn_gpu(res, xq_c, xb_c, k, D, I)
375+
faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_raft=use_raft)
370376

371377
self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I))
372378
self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3)

0 commit comments

Comments
 (0)
Please sign in to comment.