Skip to content

Commit ff15eef

Browse files
mdouzefacebook-github-bot
authored andcommitted
support bfloat16 in python (facebookresearch#4037)
Summary: Add support in python Differential Revision: D66074156
1 parent eaab46c commit ff15eef

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

faiss/gpu/test/torch_test_contrib_gpu.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -345,13 +345,16 @@ def test_knn_gpu(self, use_cuvs=False):
345345
def test_knn_gpu_cuvs(self):
346346
self.test_knn_gpu(use_cuvs=True)
347347

348-
def test_knn_gpu_datatypes(self, use_cuvs=False):
348+
def test_knn_gpu_datatypes(self, use_cuvs=False, use_bf16=False):
349349
torch.manual_seed(10)
350350
d = 10
351351
nb = 1024
352-
nq = 5
352+
nq = 50
353353
k = 10
354354
res = faiss.StandardGpuResources()
355+
if use_bf16 and not res.supportsBFloat16CurrentDevice():
356+
print("WARNING bfloat16 not supported -- test not executed")
357+
return
355358

356359
# make GT on torch cpu and test using IndexFlatL2
357360
xb = torch.rand(nb, d, dtype=torch.float32)
@@ -361,29 +364,46 @@ def test_knn_gpu_datatypes(self, use_cuvs=False):
361364
index.add(xb)
362365
gt_D, gt_I = index.search(xq, k)
363366

364-
xb_c = xb.cuda().half()
365-
xq_c = xq.cuda().half()
367+
# convert to float16
368+
if use_bf16:
369+
xb_c = xb.cuda().bfloat16()
370+
xq_c = xq.cuda().bfloat16()
371+
else:
372+
xb_c = xb.cuda().half()
373+
xq_c = xq.cuda().half()
366374

367375
# use i32 output indices
368376
D = torch.zeros(nq, k, device=xb_c.device, dtype=torch.float32)
369377
I = torch.zeros(nq, k, device=xb_c.device, dtype=torch.int32)
370378

371379
faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_cuvs=use_cuvs)
372380

373-
self.assertTrue(torch.equal(I.long().cpu(), gt_I))
374-
self.assertLess((D.float().cpu() - gt_D).abs().max(), 1.5e-3)
381+
ndiff = (I.cpu() != gt_I).sum().item()
382+
MSE = ((D.float().cpu() - gt_D) ** 2).sum().item()
383+
if use_bf16:
384+
# 57 -- bf16 is not as accurate as fp16
385+
self.assertLess(ndiff, 80)
386+
# 0.00515
387+
self.assertLess(MSE, 8e-3)
388+
else:
389+
# 5
390+
self.assertLess(ndiff, 10)
391+
# 8.565e-5
392+
self.assertLess(MSE, 1e-4)
375393

376394
# Test using numpy
377-
D = np.zeros((nq, k), dtype=np.float32)
378-
I = np.zeros((nq, k), dtype=np.int32)
395+
if not use_bf16: # bf16 not supported by numpy
396+
# use i32 output indices
397+
D = np.zeros((nq, k), dtype=np.float32)
398+
I = np.zeros((nq, k), dtype=np.int32)
379399

380-
xb_c = xb.half().numpy()
381-
xq_c = xq.half().numpy()
400+
xb_c = xb.half().numpy()
401+
xq_c = xq.half().numpy()
382402

383403
faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_cuvs=use_cuvs)
384404

385-
self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I))
386-
self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3)
405+
def test_knn_gpu_bf16(self):
406+
self.test_knn_gpu_datatypes(use_bf16=True)
387407

388408

389409
class TestTorchUtilsPairwiseDistanceGpu(unittest.TestCase):

0 commit comments

Comments
 (0)