@@ -249,7 +249,7 @@ def test_sa_encode_decode(self):
249
249
return
250
250
251
251
class TestTorchUtilsKnnGpu (unittest .TestCase ):
252
- def test_knn_gpu (self ):
252
+ def test_knn_gpu (self , use_raft = False ):
253
253
torch .manual_seed (10 )
254
254
d = 32
255
255
nb = 1024
@@ -286,7 +286,7 @@ def test_knn_gpu(self):
286
286
else :
287
287
xb_c = xb_np
288
288
289
- D , I = faiss .knn_gpu (res , xq_c , xb_c , k )
289
+ D , I = faiss .knn_gpu (res , xq_c , xb_c , k , use_raft = use_raft )
290
290
291
291
self .assertTrue (torch .equal (torch .from_numpy (I ), gt_I ))
292
292
self .assertLess ((torch .from_numpy (D ) - gt_D ).abs ().max (), 1e-4 )
@@ -312,15 +312,15 @@ def test_knn_gpu(self):
312
312
xb_c = to_column_major_torch (xb )
313
313
assert not xb_c .is_contiguous ()
314
314
315
- D , I = faiss .knn_gpu (res , xq_c , xb_c , k )
315
+ D , I = faiss .knn_gpu (res , xq_c , xb_c , k , use_raft = use_raft )
316
316
317
317
self .assertTrue (torch .equal (I .cpu (), gt_I ))
318
318
self .assertLess ((D .cpu () - gt_D ).abs ().max (), 1e-4 )
319
319
320
320
# test on subset
321
321
try :
322
322
# This internally uses the current pytorch stream
323
- D , I = faiss .knn_gpu (res , xq_c [6 :8 ], xb_c , k )
323
+ D , I = faiss .knn_gpu (res , xq_c [6 :8 ], xb_c , k , use_raft = use_raft )
324
324
except TypeError :
325
325
if not xq_row_major :
326
326
# then it is expected
@@ -331,7 +331,13 @@ def test_knn_gpu(self):
331
331
self .assertTrue (torch .equal (I .cpu (), gt_I [6 :8 ]))
332
332
self .assertLess ((D .cpu () - gt_D [6 :8 ]).abs ().max (), 1e-4 )
333
333
334
- def test_knn_gpu_datatypes (self ):
334
+ @unittest .skipUnless (
335
+ "RAFT" in faiss .get_compile_options (),
336
+ "only if RAFT is compiled in" )
337
+ def test_knn_gpu_raft (self ):
338
+ self .test_knn_gpu (use_raft = True )
339
+
340
+ def test_knn_gpu_datatypes (self , use_raft = False ):
335
341
torch .manual_seed (10 )
336
342
d = 10
337
343
nb = 1024
@@ -354,7 +360,7 @@ def test_knn_gpu_datatypes(self):
354
360
D = torch .zeros (nq , k , device = xb_c .device , dtype = torch .float32 )
355
361
I = torch .zeros (nq , k , device = xb_c .device , dtype = torch .int32 )
356
362
357
- faiss .knn_gpu (res , xq_c , xb_c , k , D , I )
363
+ faiss .knn_gpu (res , xq_c , xb_c , k , D , I , use_raft = use_raft )
358
364
359
365
self .assertTrue (torch .equal (I .long ().cpu (), gt_I ))
360
366
self .assertLess ((D .float ().cpu () - gt_D ).abs ().max (), 1.5e-3 )
@@ -366,7 +372,7 @@ def test_knn_gpu_datatypes(self):
366
372
xb_c = xb .half ().numpy ()
367
373
xq_c = xq .half ().numpy ()
368
374
369
- faiss .knn_gpu (res , xq_c , xb_c , k , D , I )
375
+ faiss .knn_gpu (res , xq_c , xb_c , k , D , I , use_raft = use_raft )
370
376
371
377
self .assertTrue (torch .equal (torch .from_numpy (I ).long (), gt_I ))
372
378
self .assertLess ((torch .from_numpy (D ) - gt_D ).abs ().max (), 1.5e-3 )
0 commit comments