Skip to content

Commit 71026d3

Browse files
mdouzefacebook-github-bot
authored andcommitted
begin torch_contrib
Summary: The contrib.torch subdirectory is intended to receive modules in python that are useful for similarity search and that apply to CPU or GPU pytorch tensors. The current version includes CPU clustering on torch tensors. To be added: * implementation of PQ Differential Revision: D62759207
1 parent 281c604 commit 71026d3

9 files changed

+282
-10
lines changed

contrib/clustering.py

+37-8
Original file line numberDiff line numberDiff line change
@@ -285,25 +285,40 @@ def imbalance_factor(k, assign):
285285
return faiss.imbalance_factor(len(assign), k, faiss.swig_ptr(assign))
286286

287287

288+
def check_if_torch(x):
289+
if x.__class__ == np.ndarray:
290+
return False
291+
import torch
292+
if isinstance(x, torch.Tensor):
293+
return True
294+
raise NotImplementedError(f"Unknown tensor type {type(x)}")
295+
296+
288297
def reassign_centroids(hassign, centroids, rs=None):
289298
""" reassign centroids when some of them collapse """
290299
if rs is None:
291300
rs = np.random
292301
k, d = centroids.shape
293302
nsplit = 0
303+
is_torch = check_if_torch(centroids)
304+
294305
empty_cents = np.where(hassign == 0)[0]
295306

296-
if empty_cents.size == 0:
307+
if len(empty_cents) == 0:
297308
return 0
298309

299-
fac = np.ones(d)
310+
if is_torch:
311+
import torch
312+
fac = torch.ones_like(centroids[0])
313+
else:
314+
fac = np.ones_like(centroids[0])
300315
fac[::2] += 1 / 1024.
301316
fac[1::2] -= 1 / 1024.
302317

303318
# this is a single pass unless there are more than k/2
304319
# empty centroids
305-
while empty_cents.size > 0:
306-
# choose which centroids to split
320+
while len(empty_cents) > 0:
321+
# choose which centroids to split (numpy)
307322
probas = hassign.astype('float') - 1
308323
probas[probas < 0] = 0
309324
probas /= probas.sum()
@@ -327,13 +342,17 @@ def reassign_centroids(hassign, centroids, rs=None):
327342
return nsplit
328343

329344

345+
330346
def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
331347
return_stats=False):
332348
"""Pure python kmeans implementation. Follows the Faiss C++ version
333349
quite closely, but takes a DatasetAssign instead of a training data
334-
matrix. Also redo is not implemented. """
350+
matrix. Also redo is not implemented.
351+
352+
For the torch implementation, the centroids are tensors (possibly on GPU),
353+
but the indices remain numpy on CPU.
354+
"""
335355
n, d = data.count(), data.dim()
336-
337356
log = print if verbose else print_nop
338357

339358
log(("Clustering %d points in %dD to %d clusters, " +
@@ -345,6 +364,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
345364
# initialization
346365
perm = rs.choice(n, size=k, replace=False)
347366
centroids = data.get_subset(perm)
367+
is_torch = check_if_torch(centroids)
348368

349369
iteration_stats = []
350370

@@ -362,12 +382,17 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
362382
t_search_tot += time.time() - t0s;
363383

364384
err = D.sum()
385+
if is_torch:
386+
err = err.item()
365387
obj.append(err)
366388

367389
hassign = np.bincount(assign, minlength=k)
368390

369391
fac = hassign.reshape(-1, 1).astype('float32')
370-
fac[fac == 0] = 1 # quiet warning
392+
fac[fac == 0] = 1 # quiet warning
393+
if is_torch:
394+
import torch
395+
fac = torch.from_numpy(fac).to(sums.device)
371396

372397
centroids = sums / fac
373398

@@ -391,7 +416,11 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
391416

392417
if checkpoint is not None:
393418
log('storing centroids in', checkpoint)
394-
np.save(checkpoint, centroids)
419+
if is_torch:
420+
import torch
421+
torch.save(centroids, checkpoint)
422+
else:
423+
np.save(checkpoint, centroids)
395424

396425
if return_stats:
397426
return centroids, iteration_stats

contrib/torch/README.md

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# The Torch contrib
2+
3+
This contrib directory contains a few Pytorch routines that
4+
are useful for similarity search. They do not necessarily depend on Faiss.
5+
6+
The code is designed to work with CPU and GPU tensors.

contrib/torch/__init__.py

Whitespace-only changes.

contrib/torch/clustering.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
This contrib module contains Pytorch code for k-means clustering
8+
"""
9+
import faiss
10+
import faiss.contrib.torch_utils
11+
import torch
12+
13+
# the kmeans can produce both torch and numpy centroids
14+
from faiss.contrib.clustering import DatasetAssign, kmeans
15+
16+
17+
class DatasetAssign:
18+
"""Wrapper for a tensor that offers a function to assign the vectors
19+
to centroids. All other implementations offer the same interface"""
20+
21+
def __init__(self, x):
22+
self.x = x
23+
24+
def count(self):
25+
return self.x.shape[0]
26+
27+
def dim(self):
28+
return self.x.shape[1]
29+
30+
def get_subset(self, indices):
31+
return self.x[indices]
32+
33+
def perform_search(self, centroids):
34+
return faiss.knn(self.x, centroids, 1)
35+
36+
def assign_to(self, centroids, weights=None):
37+
D, I = self.perform_search(centroids)
38+
39+
I = I.ravel()
40+
D = D.ravel()
41+
nc, d = centroids.shape
42+
43+
sum_per_centroid = torch.zeros_like(centroids)
44+
if weights is None:
45+
sum_per_centroid.index_add_(0, I, self.x)
46+
else:
47+
sum_per_centroid.index_add_(0, I, self.x * weights[:, None])
48+
49+
# the indices are still in numpy.
50+
return I.cpu().numpy(), D, sum_per_centroid
51+
52+
53+
class DatasetAssignGPU(DatasetAssign):
54+
55+
def __init__(self, res, x):
56+
DatasetAssign.__init__(self, x)
57+
self.res = res
58+
59+
def perform_search(self, centroids):
60+
return faiss.knn_gpu(self.res, self.x, centroids, 1)

contrib/torch/quantization.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
This contrib module contains Pytorch code for quantization.
8+
"""
9+
10+
import numpy as np
11+
import torch
12+
import faiss
13+
14+
from faiss.contrib import torch_utils
15+
16+
17+
class Quantizer:
18+
19+
def __init__(self, d, code_size):
20+
self.d = d
21+
self.code_size = code_size
22+
23+
def train(self, x):
24+
pass
25+
26+
def encode(self, x):
27+
pass
28+
29+
def decode(self, x):
30+
pass
31+
32+
33+
class VectorQuantizer(Quantizer):
34+
35+
def __init__(self, d, k):
36+
code_size = int(torch.ceil(torch.log2(k) / 8))
37+
Quantizer.__init__(d, code_size)
38+
self.k = k
39+
40+
def train(self, x):
41+
pass
42+
43+
44+
class ProductQuantizer(Quantizer):
45+
46+
def __init__(self, d, M, nbits):
47+
code_size = int(torch.ceil(M * nbits / 8))
48+
Quantizer.__init__(d, code_size)
49+
self.M = M
50+
self.nbits = nbits
51+
52+
def train(self, x):
53+
pass

contrib/torch_utils.py

+62
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@
2828
import sys
2929
import numpy as np
3030

31+
##################################################################
32+
# Equivalent of swig_ptr for Torch tensors
33+
##################################################################
34+
3135
def swig_ptr_from_UInt8Tensor(x):
3236
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
3337
assert x.is_contiguous()
3438
assert x.dtype == torch.uint8
3539
return faiss.cast_integer_to_uint8_ptr(
3640
x.untyped_storage().data_ptr() + x.storage_offset())
3741

42+
3843
def swig_ptr_from_HalfTensor(x):
3944
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
4045
assert x.is_contiguous()
@@ -43,27 +48,34 @@ def swig_ptr_from_HalfTensor(x):
4348
return faiss.cast_integer_to_void_ptr(
4449
x.untyped_storage().data_ptr() + x.storage_offset() * 2)
4550

51+
4652
def swig_ptr_from_FloatTensor(x):
4753
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
4854
assert x.is_contiguous()
4955
assert x.dtype == torch.float32
5056
return faiss.cast_integer_to_float_ptr(
5157
x.untyped_storage().data_ptr() + x.storage_offset() * 4)
5258

59+
5360
def swig_ptr_from_IntTensor(x):
5461
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
5562
assert x.is_contiguous()
5663
assert x.dtype == torch.int32, 'dtype=%s' % x.dtype
5764
return faiss.cast_integer_to_int_ptr(
5865
x.untyped_storage().data_ptr() + x.storage_offset() * 4)
5966

67+
6068
def swig_ptr_from_IndicesTensor(x):
6169
""" gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
6270
assert x.is_contiguous()
6371
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
6472
return faiss.cast_integer_to_idx_t_ptr(
6573
x.untyped_storage().data_ptr() + x.storage_offset() * 8)
6674

75+
##################################################################
76+
# utilities
77+
##################################################################
78+
6779
@contextlib.contextmanager
6880
def using_stream(res, pytorch_stream=None):
6981
""" Creates a scoping object to make Faiss GPU use the same stream
@@ -107,6 +119,10 @@ def torch_replace_method(the_class, name, replacement,
107119
setattr(the_class, name + '_numpy', orig_method)
108120
setattr(the_class, name, replacement)
109121

122+
##################################################################
123+
# Setup wrappers
124+
##################################################################
125+
110126
def handle_torch_Index(the_class):
111127
def torch_replacement_add(self, x):
112128
if type(x) is np.ndarray:
@@ -493,6 +509,52 @@ def torch_replacement_sa_decode(self, codes, x=None):
493509
handle_torch_Index(the_class)
494510

495511

512+
# allows torch tensor usage with knn
513+
def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0):
514+
if type(xb) is np.ndarray:
515+
# Forward to faiss __init__.py base method
516+
return faiss.knn_numpy(xq, xb, k, metric=metric, metric_arg=metric_arg)
517+
518+
nb, d = xb.size()
519+
assert xb.is_contiguous()
520+
assert xb.dtype == torch.float32
521+
assert not xb.is_cuda, "use knn_gpu for GPU tensors"
522+
523+
nq, d2 = xq.size()
524+
assert d2 == d
525+
assert xq.is_contiguous()
526+
assert xq.dtype == torch.float32
527+
assert not xq.is_cuda, "use knn_gpu for GPU tensors"
528+
529+
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
530+
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
531+
I_ptr = swig_ptr_from_IndicesTensor(I)
532+
D_ptr = swig_ptr_from_FloatTensor(D)
533+
xb_ptr = swig_ptr_from_FloatTensor(xb)
534+
xq_ptr = swig_ptr_from_FloatTensor(xq)
535+
536+
if metric == faiss.METRIC_L2:
537+
faiss.knn_L2sqr(
538+
xq_ptr, xb_ptr,
539+
d, nq, nb, k, D_ptr, I_ptr
540+
)
541+
elif metric == faiss.METRIC_INNER_PRODUCT:
542+
faiss.knn_inner_product(
543+
xq_ptr, xb_ptr,
544+
d, nq, nb, k, D_ptr, I_ptr
545+
)
546+
else:
547+
faiss.knn_extra_metrics(
548+
xq_ptr, xb_ptr,
549+
d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr
550+
)
551+
552+
return D, I
553+
554+
555+
torch_replace_method(faiss_module, 'knn', torch_replacement_knn, True, True)
556+
557+
496558
# allows torch tensor usage with bfKnn
497559
def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_raft=False):
498560
if type(xb) is np.ndarray:

faiss/gpu/test/torch_test_contrib_gpu.py

+33
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
import faiss
1010
import faiss.contrib.torch_utils
1111

12+
from faiss.contrib import datasets
13+
from faiss.contrib.torch import clustering
14+
15+
1216
def to_column_major_torch(x):
1317
if hasattr(torch, 'contiguous_format'):
1418
return x.t().clone(memory_format=torch.contiguous_format).t()
@@ -377,6 +381,7 @@ def test_knn_gpu_datatypes(self, use_raft=False):
377381
self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I))
378382
self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3)
379383

384+
380385
class TestTorchUtilsPairwiseDistanceGpu(unittest.TestCase):
381386
def test_pairwise_distance_gpu(self):
382387
torch.manual_seed(10)
@@ -470,3 +475,31 @@ def test_pairwise_distance_gpu(self):
470475
D, _ = torch.sort(D, dim=1)
471476

472477
self.assertLess((D.cpu() - gt_D[4:8]).abs().max(), 1e-4)
478+
479+
480+
class TestClustering(unittest.TestCase):
481+
482+
def test_python_kmeans(self):
483+
""" Test the python implementation of kmeans """
484+
ds = datasets.SyntheticDataset(32, 10000, 0, 0)
485+
x = ds.get_train()
486+
487+
# bad distribution to stress-test split code
488+
xt = x[:10000].copy()
489+
xt[:5000] = x[0]
490+
491+
# CPU baseline
492+
km_ref = faiss.Kmeans(ds.d, 100, niter=10)
493+
km_ref.train(xt)
494+
err = faiss.knn(xt, km_ref.centroids, 1)[0].sum()
495+
496+
xt_torch = torch.from_numpy(xt).to("cuda:0")
497+
res = faiss.StandardGpuResources()
498+
data = clustering.DatasetAssignGPU(res, xt_torch)
499+
centroids = clustering.kmeans(100, data, 10)
500+
centroids = centroids.cpu().numpy()
501+
err2 = faiss.knn(xt, centroids, 1)[0].sum()
502+
503+
# 33498.332 33380.477
504+
print(err, err2)
505+
self.assertLess(err2, err * 1.1)

0 commit comments

Comments
 (0)