@@ -345,13 +345,16 @@ def test_knn_gpu(self, use_cuvs=False):
345
345
def test_knn_gpu_cuvs (self ):
346
346
self .test_knn_gpu (use_cuvs = True )
347
347
348
- def test_knn_gpu_datatypes (self , use_cuvs = False ):
348
+ def test_knn_gpu_datatypes (self , use_cuvs = False , use_bf16 = False ):
349
349
torch .manual_seed (10 )
350
350
d = 10
351
351
nb = 1024
352
- nq = 5
352
+ nq = 50
353
353
k = 10
354
354
res = faiss .StandardGpuResources ()
355
+ if use_bf16 and not res .supportsBFloat16CurrentDevice ():
356
+ print ("WARNING bfloat16 not supported -- test not executed" )
357
+ return
355
358
356
359
# make GT on torch cpu and test using IndexFlatL2
357
360
xb = torch .rand (nb , d , dtype = torch .float32 )
@@ -361,29 +364,46 @@ def test_knn_gpu_datatypes(self, use_cuvs=False):
361
364
index .add (xb )
362
365
gt_D , gt_I = index .search (xq , k )
363
366
364
- xb_c = xb .cuda ().half ()
365
- xq_c = xq .cuda ().half ()
367
+ # convert to float16
368
+ if use_bf16 :
369
+ xb_c = xb .cuda ().bfloat16 ()
370
+ xq_c = xq .cuda ().bfloat16 ()
371
+ else :
372
+ xb_c = xb .cuda ().half ()
373
+ xq_c = xq .cuda ().half ()
366
374
367
375
# use i32 output indices
368
376
D = torch .zeros (nq , k , device = xb_c .device , dtype = torch .float32 )
369
377
I = torch .zeros (nq , k , device = xb_c .device , dtype = torch .int32 )
370
378
371
379
faiss .knn_gpu (res , xq_c , xb_c , k , D , I , use_cuvs = use_cuvs )
372
380
373
- self .assertTrue (torch .equal (I .long ().cpu (), gt_I ))
374
- self .assertLess ((D .float ().cpu () - gt_D ).abs ().max (), 1.5e-3 )
381
+ ndiff = (I .cpu () != gt_I ).sum ().item ()
382
+ MSE = ((D .float ().cpu () - gt_D ) ** 2 ).sum ().item ()
383
+ if use_bf16 :
384
+ # 57 -- bf16 is not as accurate as fp16
385
+ self .assertLess (ndiff , 80 )
386
+ # 0.00515
387
+ self .assertLess (MSE , 8e-3 )
388
+ else :
389
+ # 5
390
+ self .assertLess (ndiff , 10 )
391
+ # 8.565e-5
392
+ self .assertLess (MSE , 1e-4 )
375
393
376
394
# Test using numpy
377
- D = np .zeros ((nq , k ), dtype = np .float32 )
378
- I = np .zeros ((nq , k ), dtype = np .int32 )
395
+ if not use_bf16 : # bf16 not supported by numpy
396
+ # use i32 output indices
397
+ D = np .zeros ((nq , k ), dtype = np .float32 )
398
+ I = np .zeros ((nq , k ), dtype = np .int32 )
379
399
380
- xb_c = xb .half ().numpy ()
381
- xq_c = xq .half ().numpy ()
400
+ xb_c = xb .half ().numpy ()
401
+ xq_c = xq .half ().numpy ()
382
402
383
403
faiss .knn_gpu (res , xq_c , xb_c , k , D , I , use_cuvs = use_cuvs )
384
404
385
- self . assertTrue ( torch . equal ( torch . from_numpy ( I ). long (), gt_I ))
386
- self .assertLess (( torch . from_numpy ( D ) - gt_D ). abs (). max (), 1.5e-3 )
405
+ def test_knn_gpu_bf16 ( self ):
406
+ self .test_knn_gpu_datatypes ( use_bf16 = True )
387
407
388
408
389
409
class TestTorchUtilsPairwiseDistanceGpu (unittest .TestCase ):
0 commit comments