diff --git a/contrib/clustering.py b/contrib/clustering.py index e84a7e63f6..19c2656dc1 100644 --- a/contrib/clustering.py +++ b/contrib/clustering.py @@ -151,14 +151,12 @@ def assign_to(self, centroids, weights=None): I = I.ravel() D = D.ravel() - n = len(self.x) + nc, d = centroids.shape + sum_per_centroid = np.zeros((nc, d), dtype='float32') if weights is None: - weights = np.ones(n, dtype='float32') - nc = len(centroids) - m = scipy.sparse.csc_matrix( - (weights, I, np.arange(n + 1)), - shape=(nc, n)) - sum_per_centroid = m * self.x + np.add.at(sum_per_centroid, I, self.x) + else: + np.add.at(sum_per_centroid, I, weights[:, np.newaxis] * self.x) return I, D, sum_per_centroid @@ -185,7 +183,8 @@ def perform_search(self, centroids): def sparse_assign_to_dense(xq, xb, xq_norms=None, xb_norms=None): """ assignment function for xq is sparse, xb is dense - uses a matrix multiplication. The squared norms can be provided if available. + uses a matrix multiplication. The squared norms can be provided if + available. """ nq = xq.shape[0] nb = xb.shape[0] @@ -272,6 +271,7 @@ def assign_to(self, centroids, weights=None): if weights is None: weights = np.ones(n, dtype='float32') nc = len(centroids) + m = scipy.sparse.csc_matrix( (weights, I, np.arange(n + 1)), shape=(nc, n)) @@ -285,25 +285,40 @@ def imbalance_factor(k, assign): return faiss.imbalance_factor(len(assign), k, faiss.swig_ptr(assign)) +def check_if_torch(x): + if x.__class__ == np.ndarray: + return False + import torch + if isinstance(x, torch.Tensor): + return True + raise NotImplementedError(f"Unknown tensor type {type(x)}") + + def reassign_centroids(hassign, centroids, rs=None): """ reassign centroids when some of them collapse """ if rs is None: rs = np.random k, d = centroids.shape nsplit = 0 + is_torch = check_if_torch(centroids) + empty_cents = np.where(hassign == 0)[0] - if empty_cents.size == 0: + if len(empty_cents) == 0: return 0 - fac = np.ones(d) + if is_torch: + import torch + fac = torch.ones_like(centroids[0]) + else: + fac = np.ones_like(centroids[0]) fac[::2] += 1 / 1024. fac[1::2] -= 1 / 1024. # this is a single pass unless there are more than k/2 # empty centroids - while empty_cents.size > 0: - # choose which centroids to split + while len(empty_cents) > 0: + # choose which centroids to split (numpy) probas = hassign.astype('float') - 1 probas[probas < 0] = 0 probas /= probas.sum() @@ -327,13 +342,17 @@ def reassign_centroids(hassign, centroids, rs=None): return nsplit + def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, return_stats=False): """Pure python kmeans implementation. Follows the Faiss C++ version quite closely, but takes a DatasetAssign instead of a training data - matrix. Also redo is not implemented. """ - n, d = data.count(), data.dim() + matrix. Also redo is not implemented. + For the torch implementation, the centroids are tensors (possibly on GPU), + but the indices remain numpy on CPU. + """ + n, d = data.count(), data.dim() log = print if verbose else print_nop log(("Clustering %d points in %dD to %d clusters, " + @@ -345,6 +364,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, # initialization perm = rs.choice(n, size=k, replace=False) centroids = data.get_subset(perm) + is_torch = check_if_torch(centroids) iteration_stats = [] @@ -362,12 +382,17 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, t_search_tot += time.time() - t0s; err = D.sum() + if is_torch: + err = err.item() obj.append(err) hassign = np.bincount(assign, minlength=k) fac = hassign.reshape(-1, 1).astype('float32') - fac[fac == 0] = 1 # quiet warning + fac[fac == 0] = 1 # quiet warning + if is_torch: + import torch + fac = torch.from_numpy(fac).to(sums.device) centroids = sums / fac @@ -377,7 +402,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, "obj": err, "time": (time.time() - t0), "time_search": t_search_tot, - "imbalance_factor": imbalance_factor (k, assign), + "imbalance_factor": imbalance_factor(k, assign), "nsplit": nsplit } @@ -391,7 +416,11 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, if checkpoint is not None: log('storing centroids in', checkpoint) - np.save(checkpoint, centroids) + if is_torch: + import torch + torch.save(centroids, checkpoint) + else: + np.save(checkpoint, centroids) if return_stats: return centroids, iteration_stats diff --git a/contrib/torch/README.md b/contrib/torch/README.md new file mode 100644 index 0000000000..470d062250 --- /dev/null +++ b/contrib/torch/README.md @@ -0,0 +1,6 @@ +# The Torch contrib + +This contrib directory contains a few Pytorch routines that +are useful for similarity search. They do not necessarily depend on Faiss. + +The code is designed to work with CPU and GPU tensors. diff --git a/contrib/torch/__init__.py b/contrib/torch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contrib/torch/clustering.py b/contrib/torch/clustering.py new file mode 100644 index 0000000000..9e1ce94b91 --- /dev/null +++ b/contrib/torch/clustering.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This contrib module contains Pytorch code for k-means clustering +""" +import faiss +import faiss.contrib.torch_utils +import torch + +# the kmeans can produce both torch and numpy centroids +from faiss.contrib.clustering import kmeans + +class DatasetAssign: + """Wrapper for a tensor that offers a function to assign the vectors + to centroids. All other implementations offer the same interface""" + + def __init__(self, x): + self.x = x + + def count(self): + return self.x.shape[0] + + def dim(self): + return self.x.shape[1] + + def get_subset(self, indices): + return self.x[indices] + + def perform_search(self, centroids): + return faiss.knn(self.x, centroids, 1) + + def assign_to(self, centroids, weights=None): + D, I = self.perform_search(centroids) + + I = I.ravel() + D = D.ravel() + nc, d = centroids.shape + + sum_per_centroid = torch.zeros_like(centroids) + if weights is None: + sum_per_centroid.index_add_(0, I, self.x) + else: + sum_per_centroid.index_add_(0, I, self.x * weights[:, None]) + + # the indices are still in numpy. + return I.cpu().numpy(), D, sum_per_centroid + + +class DatasetAssignGPU(DatasetAssign): + + def __init__(self, res, x): + DatasetAssign.__init__(self, x) + self.res = res + + def perform_search(self, centroids): + return faiss.knn_gpu(self.res, self.x, centroids, 1) diff --git a/contrib/torch/quantization.py b/contrib/torch/quantization.py new file mode 100644 index 0000000000..8d6b17fa8f --- /dev/null +++ b/contrib/torch/quantization.py @@ -0,0 +1,53 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This contrib module contains Pytorch code for quantization. +""" + +import numpy as np +import torch +import faiss + +from faiss.contrib import torch_utils + + +class Quantizer: + + def __init__(self, d, code_size): + self.d = d + self.code_size = code_size + + def train(self, x): + pass + + def encode(self, x): + pass + + def decode(self, x): + pass + + +class VectorQuantizer(Quantizer): + + def __init__(self, d, k): + code_size = int(torch.ceil(torch.log2(k) / 8)) + Quantizer.__init__(d, code_size) + self.k = k + + def train(self, x): + pass + + +class ProductQuantizer(Quantizer): + + def __init__(self, d, M, nbits): + code_size = int(torch.ceil(M * nbits / 8)) + Quantizer.__init__(d, code_size) + self.M = M + self.nbits = nbits + + def train(self, x): + pass diff --git a/contrib/torch_utils.py b/contrib/torch_utils.py index 18f136e914..9b4855ea3a 100644 --- a/contrib/torch_utils.py +++ b/contrib/torch_utils.py @@ -28,6 +28,10 @@ import sys import numpy as np +################################################################## +# Equivalent of swig_ptr for Torch tensors +################################################################## + def swig_ptr_from_UInt8Tensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -35,6 +39,7 @@ def swig_ptr_from_UInt8Tensor(x): return faiss.cast_integer_to_uint8_ptr( x.untyped_storage().data_ptr() + x.storage_offset()) + def swig_ptr_from_HalfTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -43,6 +48,7 @@ def swig_ptr_from_HalfTensor(x): return faiss.cast_integer_to_void_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 2) + def swig_ptr_from_FloatTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -50,6 +56,7 @@ def swig_ptr_from_FloatTensor(x): return faiss.cast_integer_to_float_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 4) + def swig_ptr_from_IntTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -57,6 +64,7 @@ def swig_ptr_from_IntTensor(x): return faiss.cast_integer_to_int_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 4) + def swig_ptr_from_IndicesTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -64,6 +72,10 @@ def swig_ptr_from_IndicesTensor(x): return faiss.cast_integer_to_idx_t_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 8) +################################################################## +# utilities +################################################################## + @contextlib.contextmanager def using_stream(res, pytorch_stream=None): """ Creates a scoping object to make Faiss GPU use the same stream @@ -107,6 +119,10 @@ def torch_replace_method(the_class, name, replacement, setattr(the_class, name + '_numpy', orig_method) setattr(the_class, name, replacement) +################################################################## +# Setup wrappers +################################################################## + def handle_torch_Index(the_class): def torch_replacement_add(self, x): if type(x) is np.ndarray: @@ -493,6 +509,52 @@ def torch_replacement_sa_decode(self, codes, x=None): handle_torch_Index(the_class) +# allows torch tensor usage with knn +def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0): + if type(xb) is np.ndarray: + # Forward to faiss __init__.py base method + return faiss.knn_numpy(xq, xb, k, metric=metric, metric_arg=metric_arg) + + nb, d = xb.size() + assert xb.is_contiguous() + assert xb.dtype == torch.float32 + assert not xb.is_cuda, "use knn_gpu for GPU tensors" + + nq, d2 = xq.size() + assert d2 == d + assert xq.is_contiguous() + assert xq.dtype == torch.float32 + assert not xq.is_cuda, "use knn_gpu for GPU tensors" + + D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) + I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) + I_ptr = swig_ptr_from_IndicesTensor(I) + D_ptr = swig_ptr_from_FloatTensor(D) + xb_ptr = swig_ptr_from_FloatTensor(xb) + xq_ptr = swig_ptr_from_FloatTensor(xq) + + if metric == faiss.METRIC_L2: + faiss.knn_L2sqr( + xq_ptr, xb_ptr, + d, nq, nb, k, D_ptr, I_ptr + ) + elif metric == faiss.METRIC_INNER_PRODUCT: + faiss.knn_inner_product( + xq_ptr, xb_ptr, + d, nq, nb, k, D_ptr, I_ptr + ) + else: + faiss.knn_extra_metrics( + xq_ptr, xb_ptr, + d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr + ) + + return D, I + + +torch_replace_method(faiss_module, 'knn', torch_replacement_knn, True, True) + + # allows torch tensor usage with bfKnn def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_raft=False): if type(xb) is np.ndarray: diff --git a/demos/demo_distributed_kmeans_torch.py b/demos/demo_distributed_kmeans_torch.py new file mode 100644 index 0000000000..279efa2080 --- /dev/null +++ b/demos/demo_distributed_kmeans_torch.py @@ -0,0 +1,173 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +import torch +import torch.distributed + +import faiss + +import faiss.contrib.torch_utils +from faiss.contrib.torch import clustering +from faiss.contrib import datasets + + +class DatasetAssignDistributedGPU(clustering.DatasetAssign): + """ + There is one instance per worker, each worker has a dataset shard. + The non-master workers do not run through the k-means function, so some + code has run it to keep the workers in sync. + """ + + def __init__(self, res, x, rank, nproc): + clustering.DatasetAssign.__init__(self, x) + self.res = res + self.rank = rank + self.nproc = nproc + self.device = x.device + + n = len(x) + sizes = torch.zeros(nproc, device=self.device, dtype=torch.int64) + sizes[rank] = n + torch.distributed.all_gather( + [sizes[i:i + 1] for i in range(nproc)], sizes[rank:rank + 1]) + self.sizes = sizes.cpu().numpy() + + # begin & end of each shard + self.cs = np.zeros(nproc + 1, dtype='int64') + self.cs[1:] = np.cumsum(self.sizes) + + def count(self): + return int(self.sizes.sum()) + + def int_to_slaves(self, i): + " broadcast an int to all workers " + rank = self.rank + tab = torch.zeros(1, device=self.device, dtype=torch.int64) + if rank == 0: + tab[0] = i + else: + assert i is None + torch.distributed.broadcast(tab, 0) + return tab.item() + + def get_subset(self, indices): + rank = self.rank + assert rank == 0 or indices is None + + len_indices = self.int_to_slaves(len(indices) if rank == 0 else None) + + if rank == 0: + indices = torch.from_numpy(indices).to(self.device) + else: + indices = torch.zeros( + len_indices, dtype=torch.int64, device=self.device) + torch.distributed.broadcast(indices, 0) + + # select subset of indices + + i0, i1 = self.cs[rank], self.cs[rank + 1] + + mask = torch.logical_and(indices < i1, indices >= i0) + output = torch.zeros( + len_indices, self.x.shape[1], + dtype=self.x.dtype, device=self.device) + output[mask] = self.x[indices[mask] - i0] + torch.distributed.reduce(output, 0) # sum + if rank == 0: + return output + else: + return None + + def perform_search(self, centroids): + assert False, "shoudl not be called" + + def assign_to(self, centroids, weights=None): + assert weights is None + + rank, nproc = self.rank, self.nproc + assert rank == 0 or centroids is None + nc = self.int_to_slaves(len(centroids) if rank == 0 else None) + + if rank != 0: + centroids = torch.zeros( + nc, self.x.shape[1], dtype=self.x.dtype, device=self.device) + torch.distributed.broadcast(centroids, 0) + + # perform search + D, I = faiss.knn_gpu( + self.res, self.x, centroids, 1, device=self.device.index) + + I = I.ravel() + D = D.ravel() + + sum_per_centroid = torch.zeros_like(centroids) + if weights is None: + sum_per_centroid.index_add_(0, I, self.x) + else: + sum_per_centroid.index_add_(0, I, self.x * weights[:, None]) + + torch.distributed.reduce(sum_per_centroid, 0) + + if rank == 0: + # gather deos not support tensors of different sizes + # should be implemented with point-to-point communication + assert np.all(self.sizes == self.sizes[0]) + device = self.device + all_I = torch.zeros(self.count(), dtype=I.dtype, device=device) + all_D = torch.zeros(self.count(), dtype=D.dtype, device=device) + torch.distributed.gather( + I, [all_I[self.cs[r]:self.cs[r + 1]] for r in range(nproc)], + dst=0, + ) + torch.distributed.gather( + D, [all_D[self.cs[r]:self.cs[r + 1]] for r in range(nproc)], + dst=0, + ) + return all_I.cpu().numpy(), all_D, sum_per_centroid + else: + torch.distributed.gather(I, None, dst=0) + torch.distributed.gather(D, None, dst=0) + return None + + +if __name__ == "__main__": + + torch.distributed.init_process_group( + backend="nccl", + ) + rank = torch.distributed.get_rank() + nproc = torch.distributed.get_world_size() + + # current version does only support shards of the same size + ds = datasets.SyntheticDataset(32, 10000, 0, 0, seed=1234 + rank) + x = ds.get_train() + + device = torch.device(f"cuda:{rank}") + + torch.cuda.set_device(device) + x = torch.from_numpy(x).to(device) + res = faiss.StandardGpuResources() + + da = DatasetAssignDistributedGPU(res, x, rank, nproc) + + k = 1000 + niter = 25 + + if rank == 0: + print(f"sizes = {da.sizes}") + centroids, iteration_stats = clustering.kmeans( + k, da, niter=niter, return_stats=True) + print("clusters:", centroids.cpu().numpy()) + else: + # make sure the iterations are aligned with master + da.get_subset(None) + + for _ in range(niter): + da.assign_to(None) + + torch.distributed.barrier() + print("Done") diff --git a/faiss/gpu/test/torch_test_contrib_gpu.py b/faiss/gpu/test/torch_test_contrib_gpu.py index f1a92c33b3..1f6f27ecca 100644 --- a/faiss/gpu/test/torch_test_contrib_gpu.py +++ b/faiss/gpu/test/torch_test_contrib_gpu.py @@ -9,6 +9,10 @@ import faiss import faiss.contrib.torch_utils +from faiss.contrib import datasets +from faiss.contrib.torch import clustering + + def to_column_major_torch(x): if hasattr(torch, 'contiguous_format'): return x.t().clone(memory_format=torch.contiguous_format).t() @@ -377,6 +381,7 @@ def test_knn_gpu_datatypes(self, use_raft=False): self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I)) self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3) + class TestTorchUtilsPairwiseDistanceGpu(unittest.TestCase): def test_pairwise_distance_gpu(self): torch.manual_seed(10) @@ -470,3 +475,31 @@ def test_pairwise_distance_gpu(self): D, _ = torch.sort(D, dim=1) self.assertLess((D.cpu() - gt_D[4:8]).abs().max(), 1e-4) + + +class TestClustering(unittest.TestCase): + + def test_python_kmeans(self): + """ Test the python implementation of kmeans """ + ds = datasets.SyntheticDataset(32, 10000, 0, 0) + x = ds.get_train() + + # bad distribution to stress-test split code + xt = x[:10000].copy() + xt[:5000] = x[0] + + # CPU baseline + km_ref = faiss.Kmeans(ds.d, 100, niter=10) + km_ref.train(xt) + err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() + + xt_torch = torch.from_numpy(xt).to("cuda:0") + res = faiss.StandardGpuResources() + data = clustering.DatasetAssignGPU(res, xt_torch) + centroids = clustering.kmeans(100, data, 10) + centroids = centroids.cpu().numpy() + err2 = faiss.knn(xt, centroids, 1)[0].sum() + + # 33498.332 33380.477 + print(err, err2) + self.assertLess(err2, err * 1.1) diff --git a/faiss/python/CMakeLists.txt b/faiss/python/CMakeLists.txt index c7b22d19c8..aea99af795 100644 --- a/faiss/python/CMakeLists.txt +++ b/faiss/python/CMakeLists.txt @@ -261,5 +261,5 @@ configure_file(gpu_wrappers.py gpu_wrappers.py COPYONLY) configure_file(extra_wrappers.py extra_wrappers.py COPYONLY) configure_file(array_conversions.py array_conversions.py COPYONLY) -file(GLOB files "${PROJECT_SOURCE_DIR}/../../contrib/*.py") -file(COPY ${files} DESTINATION contrib/) +# file(GLOB files "${PROJECT_SOURCE_DIR}/../../contrib/*.py") +file(COPY ${PROJECT_SOURCE_DIR}/../../contrib DESTINATION .) diff --git a/faiss/python/setup.py b/faiss/python/setup.py index ea623ee1b2..6590a84e3f 100644 --- a/faiss/python/setup.py +++ b/faiss/python/setup.py @@ -13,6 +13,7 @@ shutil.rmtree("faiss", ignore_errors=True) os.mkdir("faiss") shutil.copytree("contrib", "faiss/contrib") +shutil.copytree("contrib/torch", "faiss/contrib/torch") shutil.copyfile("__init__.py", "faiss/__init__.py") shutil.copyfile("loader.py", "faiss/loader.py") shutil.copyfile("class_wrappers.py", "faiss/class_wrappers.py") @@ -79,12 +80,12 @@ long_description=long_description, url='https://github.com/facebookresearch/faiss', author='Matthijs Douze, Jeff Johnson, Herve Jegou, Lucas Hosseini', - author_email='matthijs@fb.com', + author_email='matthijs@meta.com', license='MIT', keywords='search nearest neighbors', install_requires=['numpy', 'packaging'], - packages=['faiss', 'faiss.contrib'], + packages=['faiss', 'faiss.contrib', 'faiss.contrib.torch'], package_data={ 'faiss': ['*.so', '*.pyd'], }, diff --git a/tests/test_contrib.py b/tests/test_contrib.py index 05a2c4ac8b..a2eb7046bd 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -26,8 +26,7 @@ range_search_max_results, exponential_query_iterator from contextlib import contextmanager -@unittest.skipIf(platform.python_version_tuple()[0] < '3', - 'Submodule import broken in python 2.') + class TestComputeGT(unittest.TestCase): def do_test_compute_GT(self, metric=faiss.METRIC_L2): @@ -517,6 +516,26 @@ def test_binary(self): class TestClustering(unittest.TestCase): + def test_python_kmeans(self): + """ Test the python implementation of kmeans """ + ds = datasets.SyntheticDataset(32, 10000, 0, 0) + x = ds.get_train() + + # bad distribution to stress-test split code + xt = x[:10000].copy() + xt[:5000] = x[0] + + km_ref = faiss.Kmeans(ds.d, 100, niter=10) + km_ref.train(xt) + err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() + + data = clustering.DatasetAssign(xt) + centroids = clustering.kmeans(100, data, 10) + err2 = faiss.knn(xt, centroids, 1)[0].sum() + + # err=33498.332 err2=33380.477 + self.assertLess(err2, err * 1.1) + def test_2level(self): " verify that 2-level clustering is not too sub-optimal " ds = datasets.SyntheticDataset(32, 10000, 0, 0) diff --git a/tests/test_contrib_with_scipy.py b/tests/test_contrib_with_scipy.py index 4f89e2fc1b..618a550b73 100644 --- a/tests/test_contrib_with_scipy.py +++ b/tests/test_contrib_with_scipy.py @@ -17,26 +17,6 @@ class TestClustering(unittest.TestCase): - def test_python_kmeans(self): - """ Test the python implementation of kmeans """ - ds = datasets.SyntheticDataset(32, 10000, 0, 0) - x = ds.get_train() - - # bad distribution to stress-test split code - xt = x[:10000].copy() - xt[:5000] = x[0] - - km_ref = faiss.Kmeans(ds.d, 100, niter=10) - km_ref.train(xt) - err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() - - data = clustering.DatasetAssign(xt) - centroids = clustering.kmeans(100, data, 10) - err2 = faiss.knn(xt, centroids, 1)[0].sum() - - # 33517.645 and 33031.098 - self.assertLess(err2, err * 1.1) - def test_sparse_routines(self): """ the sparse assignment routine """ ds = datasets.SyntheticDataset(1000, 2000, 0, 200) diff --git a/tests/torch_test_contrib.py b/tests/torch_test_contrib.py index e26a79c6bb..d3dd8c0ae8 100644 --- a/tests/torch_test_contrib.py +++ b/tests/torch_test_contrib.py @@ -9,6 +9,10 @@ import faiss # usort: skip import faiss.contrib.torch_utils # usort: skip +from faiss.contrib import datasets +from faiss.contrib.torch import clustering + + class TestTorchUtilsCPU(unittest.TestCase): # tests add, search @@ -344,3 +348,29 @@ def test_non_contiguous(self): # disabled since we now accept non-contiguous arrays # with self.assertRaises(ValueError): # index.add(xb.numpy()) + + +class TestClustering(unittest.TestCase): + + def test_python_kmeans(self): + """ Test the python implementation of kmeans """ + ds = datasets.SyntheticDataset(32, 10000, 0, 0) + x = ds.get_train() + + # bad distribution to stress-test split code + xt = x[:10000].copy() + xt[:5000] = x[0] + + km_ref = faiss.Kmeans(ds.d, 100, niter=10) + km_ref.train(xt) + err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() + + xt_torch = torch.from_numpy(xt) + data = clustering.DatasetAssign(xt_torch) + centroids = clustering.kmeans(100, data, 10) + centroids = centroids.numpy() + err2 = faiss.knn(xt, centroids, 1)[0].sum() + + # 33498.332 33380.477 + # print(err, err2) 1/0 + self.assertLess(err2, err * 1.1)