From 281c604e08d5f624546afb76fbed89dbdddcfc32 Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Fri, 20 Sep 2024 03:07:32 -0700 Subject: [PATCH 1/3] rewrite python kmeans without scipy Summary: The previous version required scipy to do the accumulation, which is replaced here with a nifty piece of numpy accumulation. This removes the need for scipy for non-sparse data. Differential Revision: D62884307 --- contrib/clustering.py | 16 ++++++++-------- tests/test_contrib.py | 20 ++++++++++++++++++++ tests/test_contrib_with_scipy.py | 20 -------------------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/contrib/clustering.py b/contrib/clustering.py index e84a7e63f6..79b6b05a5f 100644 --- a/contrib/clustering.py +++ b/contrib/clustering.py @@ -151,14 +151,12 @@ def assign_to(self, centroids, weights=None): I = I.ravel() D = D.ravel() - n = len(self.x) + nc, d = centroids.shape + sum_per_centroid = np.zeros((nc, d), dtype='float32') if weights is None: - weights = np.ones(n, dtype='float32') - nc = len(centroids) - m = scipy.sparse.csc_matrix( - (weights, I, np.arange(n + 1)), - shape=(nc, n)) - sum_per_centroid = m * self.x + np.add.at(sum_per_centroid, I, self.x) + else: + np.add.at(sum_per_centroid, I, weights[:, np.newaxis] * self.x) return I, D, sum_per_centroid @@ -185,7 +183,8 @@ def perform_search(self, centroids): def sparse_assign_to_dense(xq, xb, xq_norms=None, xb_norms=None): """ assignment function for xq is sparse, xb is dense - uses a matrix multiplication. The squared norms can be provided if available. + uses a matrix multiplication. The squared norms can be provided if + available. """ nq = xq.shape[0] nb = xb.shape[0] @@ -272,6 +271,7 @@ def assign_to(self, centroids, weights=None): if weights is None: weights = np.ones(n, dtype='float32') nc = len(centroids) + m = scipy.sparse.csc_matrix( (weights, I, np.arange(n + 1)), shape=(nc, n)) diff --git a/tests/test_contrib.py b/tests/test_contrib.py index 05a2c4ac8b..fa5d85ab51 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -517,6 +517,26 @@ def test_binary(self): class TestClustering(unittest.TestCase): + def test_python_kmeans(self): + """ Test the python implementation of kmeans """ + ds = datasets.SyntheticDataset(32, 10000, 0, 0) + x = ds.get_train() + + # bad distribution to stress-test split code + xt = x[:10000].copy() + xt[:5000] = x[0] + + km_ref = faiss.Kmeans(ds.d, 100, niter=10) + km_ref.train(xt) + err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() + + data = clustering.DatasetAssign(xt) + centroids = clustering.kmeans(100, data, 10) + err2 = faiss.knn(xt, centroids, 1)[0].sum() + + # err=33498.332 err2=33380.477 + self.assertLess(err2, err * 1.1) + def test_2level(self): " verify that 2-level clustering is not too sub-optimal " ds = datasets.SyntheticDataset(32, 10000, 0, 0) diff --git a/tests/test_contrib_with_scipy.py b/tests/test_contrib_with_scipy.py index 4f89e2fc1b..618a550b73 100644 --- a/tests/test_contrib_with_scipy.py +++ b/tests/test_contrib_with_scipy.py @@ -17,26 +17,6 @@ class TestClustering(unittest.TestCase): - def test_python_kmeans(self): - """ Test the python implementation of kmeans """ - ds = datasets.SyntheticDataset(32, 10000, 0, 0) - x = ds.get_train() - - # bad distribution to stress-test split code - xt = x[:10000].copy() - xt[:5000] = x[0] - - km_ref = faiss.Kmeans(ds.d, 100, niter=10) - km_ref.train(xt) - err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() - - data = clustering.DatasetAssign(xt) - centroids = clustering.kmeans(100, data, 10) - err2 = faiss.knn(xt, centroids, 1)[0].sum() - - # 33517.645 and 33031.098 - self.assertLess(err2, err * 1.1) - def test_sparse_routines(self): """ the sparse assignment routine """ ds = datasets.SyntheticDataset(1000, 2000, 0, 200) From 866e3fe16400a146cf530d9d76908d35a3925f4d Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Fri, 20 Sep 2024 06:00:35 -0700 Subject: [PATCH 2/3] begin torch_contrib Summary: The contrib.torch subdirectory is intended to receive modules in python that are useful for similarity search and that apply to CPU or GPU pytorch tensors. The current version includes CPU clustering on torch tensors. To be added: * implementation of PQ Differential Revision: D62759207 --- contrib/clustering.py | 45 ++++++++++++++--- contrib/torch/README.md | 6 +++ contrib/torch/__init__.py | 0 contrib/torch/clustering.py | 60 +++++++++++++++++++++++ contrib/torch/quantization.py | 53 ++++++++++++++++++++ contrib/torch_utils.py | 62 ++++++++++++++++++++++++ faiss/gpu/test/torch_test_contrib_gpu.py | 33 +++++++++++++ tests/test_contrib.py | 3 +- tests/torch_test_contrib.py | 30 ++++++++++++ 9 files changed, 282 insertions(+), 10 deletions(-) create mode 100644 contrib/torch/README.md create mode 100644 contrib/torch/__init__.py create mode 100644 contrib/torch/clustering.py create mode 100644 contrib/torch/quantization.py diff --git a/contrib/clustering.py b/contrib/clustering.py index 79b6b05a5f..c1e8775c9b 100644 --- a/contrib/clustering.py +++ b/contrib/clustering.py @@ -285,25 +285,40 @@ def imbalance_factor(k, assign): return faiss.imbalance_factor(len(assign), k, faiss.swig_ptr(assign)) +def check_if_torch(x): + if x.__class__ == np.ndarray: + return False + import torch + if isinstance(x, torch.Tensor): + return True + raise NotImplementedError(f"Unknown tensor type {type(x)}") + + def reassign_centroids(hassign, centroids, rs=None): """ reassign centroids when some of them collapse """ if rs is None: rs = np.random k, d = centroids.shape nsplit = 0 + is_torch = check_if_torch(centroids) + empty_cents = np.where(hassign == 0)[0] - if empty_cents.size == 0: + if len(empty_cents) == 0: return 0 - fac = np.ones(d) + if is_torch: + import torch + fac = torch.ones_like(centroids[0]) + else: + fac = np.ones_like(centroids[0]) fac[::2] += 1 / 1024. fac[1::2] -= 1 / 1024. # this is a single pass unless there are more than k/2 # empty centroids - while empty_cents.size > 0: - # choose which centroids to split + while len(empty_cents) > 0: + # choose which centroids to split (numpy) probas = hassign.astype('float') - 1 probas[probas < 0] = 0 probas /= probas.sum() @@ -327,13 +342,17 @@ def reassign_centroids(hassign, centroids, rs=None): return nsplit + def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, return_stats=False): """Pure python kmeans implementation. Follows the Faiss C++ version quite closely, but takes a DatasetAssign instead of a training data - matrix. Also redo is not implemented. """ + matrix. Also redo is not implemented. + + For the torch implementation, the centroids are tensors (possibly on GPU), + but the indices remain numpy on CPU. + """ n, d = data.count(), data.dim() - log = print if verbose else print_nop log(("Clustering %d points in %dD to %d clusters, " + @@ -345,6 +364,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, # initialization perm = rs.choice(n, size=k, replace=False) centroids = data.get_subset(perm) + is_torch = check_if_torch(centroids) iteration_stats = [] @@ -362,12 +382,17 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, t_search_tot += time.time() - t0s; err = D.sum() + if is_torch: + err = err.item() obj.append(err) hassign = np.bincount(assign, minlength=k) fac = hassign.reshape(-1, 1).astype('float32') - fac[fac == 0] = 1 # quiet warning + fac[fac == 0] = 1 # quiet warning + if is_torch: + import torch + fac = torch.from_numpy(fac).to(sums.device) centroids = sums / fac @@ -391,7 +416,11 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, if checkpoint is not None: log('storing centroids in', checkpoint) - np.save(checkpoint, centroids) + if is_torch: + import torch + torch.save(centroids, checkpoint) + else: + np.save(checkpoint, centroids) if return_stats: return centroids, iteration_stats diff --git a/contrib/torch/README.md b/contrib/torch/README.md new file mode 100644 index 0000000000..470d062250 --- /dev/null +++ b/contrib/torch/README.md @@ -0,0 +1,6 @@ +# The Torch contrib + +This contrib directory contains a few Pytorch routines that +are useful for similarity search. They do not necessarily depend on Faiss. + +The code is designed to work with CPU and GPU tensors. diff --git a/contrib/torch/__init__.py b/contrib/torch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/contrib/torch/clustering.py b/contrib/torch/clustering.py new file mode 100644 index 0000000000..bdaa0a1f9a --- /dev/null +++ b/contrib/torch/clustering.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This contrib module contains Pytorch code for k-means clustering +""" +import faiss +import faiss.contrib.torch_utils +import torch + +# the kmeans can produce both torch and numpy centroids +from faiss.contrib.clustering import DatasetAssign, kmeans + + +class DatasetAssign: + """Wrapper for a tensor that offers a function to assign the vectors + to centroids. All other implementations offer the same interface""" + + def __init__(self, x): + self.x = x + + def count(self): + return self.x.shape[0] + + def dim(self): + return self.x.shape[1] + + def get_subset(self, indices): + return self.x[indices] + + def perform_search(self, centroids): + return faiss.knn(self.x, centroids, 1) + + def assign_to(self, centroids, weights=None): + D, I = self.perform_search(centroids) + + I = I.ravel() + D = D.ravel() + nc, d = centroids.shape + + sum_per_centroid = torch.zeros_like(centroids) + if weights is None: + sum_per_centroid.index_add_(0, I, self.x) + else: + sum_per_centroid.index_add_(0, I, self.x * weights[:, None]) + + # the indices are still in numpy. + return I.cpu().numpy(), D, sum_per_centroid + + +class DatasetAssignGPU(DatasetAssign): + + def __init__(self, res, x): + DatasetAssign.__init__(self, x) + self.res = res + + def perform_search(self, centroids): + return faiss.knn_gpu(self.res, self.x, centroids, 1) diff --git a/contrib/torch/quantization.py b/contrib/torch/quantization.py new file mode 100644 index 0000000000..550c17dbb7 --- /dev/null +++ b/contrib/torch/quantization.py @@ -0,0 +1,53 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +This contrib module contains Pytorch code for quantization. +""" + +import numpy as np +import torch +import faiss + +from faiss.contrib import torch_utils + + +class Quantizer: + + def __init__(self, d, code_size): + self.d = d + self.code_size = code_size + + def train(self, x): + pass + + def encode(self, x): + pass + + def decode(self, x): + pass + + +class VectorQuantizer(Quantizer): + + def __init__(self, d, k): + code_size = int(torch.ceil(torch.log2(k) / 8)) + Quantizer.__init__(d, code_size) + self.k = k + + def train(self, x): + pass + + +class ProductQuantizer(Quantizer): + + def __init__(self, d, M, nbits): + code_size = int(torch.ceil(M * nbits / 8)) + Quantizer.__init__(d, code_size) + self.M = M + self.nbits = nbits + + def train(self, x): + pass diff --git a/contrib/torch_utils.py b/contrib/torch_utils.py index 18f136e914..21e6439726 100644 --- a/contrib/torch_utils.py +++ b/contrib/torch_utils.py @@ -28,6 +28,10 @@ import sys import numpy as np +################################################################## +# Equivalent of swig_ptr for Torch tensors +################################################################## + def swig_ptr_from_UInt8Tensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -35,6 +39,7 @@ def swig_ptr_from_UInt8Tensor(x): return faiss.cast_integer_to_uint8_ptr( x.untyped_storage().data_ptr() + x.storage_offset()) + def swig_ptr_from_HalfTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -43,6 +48,7 @@ def swig_ptr_from_HalfTensor(x): return faiss.cast_integer_to_void_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 2) + def swig_ptr_from_FloatTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -50,6 +56,7 @@ def swig_ptr_from_FloatTensor(x): return faiss.cast_integer_to_float_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 4) + def swig_ptr_from_IntTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -57,6 +64,7 @@ def swig_ptr_from_IntTensor(x): return faiss.cast_integer_to_int_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 4) + def swig_ptr_from_IndicesTensor(x): """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """ assert x.is_contiguous() @@ -64,6 +72,10 @@ def swig_ptr_from_IndicesTensor(x): return faiss.cast_integer_to_idx_t_ptr( x.untyped_storage().data_ptr() + x.storage_offset() * 8) +################################################################## +# utilities +################################################################## + @contextlib.contextmanager def using_stream(res, pytorch_stream=None): """ Creates a scoping object to make Faiss GPU use the same stream @@ -107,6 +119,10 @@ def torch_replace_method(the_class, name, replacement, setattr(the_class, name + '_numpy', orig_method) setattr(the_class, name, replacement) +################################################################## +# Setup wrappers +################################################################## + def handle_torch_Index(the_class): def torch_replacement_add(self, x): if type(x) is np.ndarray: @@ -493,6 +509,52 @@ def torch_replacement_sa_decode(self, codes, x=None): handle_torch_Index(the_class) +# allows torch tensor usage with knn +def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0): + if type(xb) is np.ndarray: + # Forward to faiss __init__.py base method + return faiss.knn_numpy(xq, xb, k, metric=metric, metric_arg=metric_arg) + + nb, d = xb.size() + assert xb.is_contiguous() + assert xb.dtype == torch.float32 + assert not xb.is_cuda, "use knn_gpu for GPU tensors" + + nq, d2 = xq.size() + assert d2 == d + assert xq.is_contiguous() + assert xq.dtype == torch.float32 + assert not xq.is_cuda, "use knn_gpu for GPU tensors" + + D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) + I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) + I_ptr = swig_ptr_from_IndicesTensor(I) + D_ptr = swig_ptr_from_FloatTensor(D) + xb_ptr = swig_ptr_from_FloatTensor(xb) + xq_ptr = swig_ptr_from_FloatTensor(xq) + + if metric == faiss.METRIC_L2: + faiss.knn_L2sqr( + xq_ptr, xb_ptr, + d, nq, nb, k, D_ptr, I_ptr + ) + elif metric == faiss.METRIC_INNER_PRODUCT: + faiss.knn_inner_product( + xq_ptr, xb_ptr, + d, nq, nb, k, D_ptr, I_ptr + ) + else: + faiss.knn_extra_metrics( + xq_ptr, xb_ptr, + d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr + ) + + return D, I + + +torch_replace_method(faiss_module, 'knn', torch_replacement_knn, True, True) + + # allows torch tensor usage with bfKnn def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_raft=False): if type(xb) is np.ndarray: diff --git a/faiss/gpu/test/torch_test_contrib_gpu.py b/faiss/gpu/test/torch_test_contrib_gpu.py index f1a92c33b3..6c58b37b25 100644 --- a/faiss/gpu/test/torch_test_contrib_gpu.py +++ b/faiss/gpu/test/torch_test_contrib_gpu.py @@ -9,6 +9,10 @@ import faiss import faiss.contrib.torch_utils +from faiss.contrib import datasets +from faiss.contrib.torch import clustering + + def to_column_major_torch(x): if hasattr(torch, 'contiguous_format'): return x.t().clone(memory_format=torch.contiguous_format).t() @@ -377,6 +381,7 @@ def test_knn_gpu_datatypes(self, use_raft=False): self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I)) self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3) + class TestTorchUtilsPairwiseDistanceGpu(unittest.TestCase): def test_pairwise_distance_gpu(self): torch.manual_seed(10) @@ -470,3 +475,31 @@ def test_pairwise_distance_gpu(self): D, _ = torch.sort(D, dim=1) self.assertLess((D.cpu() - gt_D[4:8]).abs().max(), 1e-4) + + +class TestClustering(unittest.TestCase): + + def test_python_kmeans(self): + """ Test the python implementation of kmeans """ + ds = datasets.SyntheticDataset(32, 10000, 0, 0) + x = ds.get_train() + + # bad distribution to stress-test split code + xt = x[:10000].copy() + xt[:5000] = x[0] + + # CPU baseline + km_ref = faiss.Kmeans(ds.d, 100, niter=10) + km_ref.train(xt) + err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() + + xt_torch = torch.from_numpy(xt).to("cuda:0") + res = faiss.StandardGpuResources() + data = clustering.DatasetAssignGPU(res, xt_torch) + centroids = clustering.kmeans(100, data, 10) + centroids = centroids.cpu().numpy() + err2 = faiss.knn(xt, centroids, 1)[0].sum() + + # 33498.332 33380.477 + print(err, err2) + self.assertLess(err2, err * 1.1) diff --git a/tests/test_contrib.py b/tests/test_contrib.py index fa5d85ab51..a2eb7046bd 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -26,8 +26,7 @@ range_search_max_results, exponential_query_iterator from contextlib import contextmanager -@unittest.skipIf(platform.python_version_tuple()[0] < '3', - 'Submodule import broken in python 2.') + class TestComputeGT(unittest.TestCase): def do_test_compute_GT(self, metric=faiss.METRIC_L2): diff --git a/tests/torch_test_contrib.py b/tests/torch_test_contrib.py index e26a79c6bb..41311d6c78 100644 --- a/tests/torch_test_contrib.py +++ b/tests/torch_test_contrib.py @@ -9,6 +9,10 @@ import faiss # usort: skip import faiss.contrib.torch_utils # usort: skip +from faiss.contrib import datasets +from faiss.contrib.torch import clustering + + class TestTorchUtilsCPU(unittest.TestCase): # tests add, search @@ -344,3 +348,29 @@ def test_non_contiguous(self): # disabled since we now accept non-contiguous arrays # with self.assertRaises(ValueError): # index.add(xb.numpy()) + + +class TestClustering(unittest.TestCase): + + def test_python_kmeans(self): + """ Test the python implementation of kmeans """ + ds = datasets.SyntheticDataset(32, 10000, 0, 0) + x = ds.get_train() + + # bad distribution to stress-test split code + xt = x[:10000].copy() + xt[:5000] = x[0] + + km_ref = faiss.Kmeans(ds.d, 100, niter=10) + km_ref.train(xt) + err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() + + xt_torch = torch.from_numpy(xt) + data = clustering.DatasetAssign(xt_torch) + centroids = clustering.kmeans(100, data, 10) + centroids = centroids.numpy() + err2 = faiss.knn(xt, centroids, 1)[0].sum() + + # 33498.332 33380.477 + # print(err, err2) 1/0 + self.assertLess(err2, err * 1.1) From 64d29f3da8b689de5b8c091919c90629541ca9e3 Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Fri, 20 Sep 2024 07:05:06 -0700 Subject: [PATCH 3/3] torch.distributed kmeans (#3876) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3876 Demo script for distributed kmeans. It provides a `DatasetAssign` object and shows how to run it with torch.distributed. Reviewed By: asadoughi, pankajsingh88 Differential Revision: D63013820 --- contrib/clustering.py | 32 ++--- contrib/torch/clustering.py | 5 +- contrib/torch/quantization.py | 30 ++-- contrib/torch_utils.py | 6 +- demos/demo_distributed_kmeans_torch.py | 173 +++++++++++++++++++++++ faiss/gpu/test/torch_test_contrib_gpu.py | 2 +- faiss/python/CMakeLists.txt | 4 +- faiss/python/setup.py | 5 +- tests/torch_test_contrib.py | 4 +- 9 files changed, 217 insertions(+), 44 deletions(-) create mode 100644 demos/demo_distributed_kmeans_torch.py diff --git a/contrib/clustering.py b/contrib/clustering.py index c1e8775c9b..19c2656dc1 100644 --- a/contrib/clustering.py +++ b/contrib/clustering.py @@ -155,7 +155,7 @@ def assign_to(self, centroids, weights=None): sum_per_centroid = np.zeros((nc, d), dtype='float32') if weights is None: np.add.at(sum_per_centroid, I, self.x) - else: + else: np.add.at(sum_per_centroid, I, weights[:, np.newaxis] * self.x) return I, D, sum_per_centroid @@ -183,7 +183,7 @@ def perform_search(self, centroids): def sparse_assign_to_dense(xq, xb, xq_norms=None, xb_norms=None): """ assignment function for xq is sparse, xb is dense - uses a matrix multiplication. The squared norms can be provided if + uses a matrix multiplication. The squared norms can be provided if available. """ nq = xq.shape[0] @@ -271,7 +271,7 @@ def assign_to(self, centroids, weights=None): if weights is None: weights = np.ones(n, dtype='float32') nc = len(centroids) - + m = scipy.sparse.csc_matrix( (weights, I, np.arange(n + 1)), shape=(nc, n)) @@ -289,7 +289,7 @@ def check_if_torch(x): if x.__class__ == np.ndarray: return False import torch - if isinstance(x, torch.Tensor): + if isinstance(x, torch.Tensor): return True raise NotImplementedError(f"Unknown tensor type {type(x)}") @@ -307,11 +307,11 @@ def reassign_centroids(hassign, centroids, rs=None): if len(empty_cents) == 0: return 0 - if is_torch: + if is_torch: import torch - fac = torch.ones_like(centroids[0]) - else: - fac = np.ones_like(centroids[0]) + fac = torch.ones_like(centroids[0]) + else: + fac = np.ones_like(centroids[0]) fac[::2] += 1 / 1024. fac[1::2] -= 1 / 1024. @@ -347,9 +347,9 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, return_stats=False): """Pure python kmeans implementation. Follows the Faiss C++ version quite closely, but takes a DatasetAssign instead of a training data - matrix. Also redo is not implemented. - - For the torch implementation, the centroids are tensors (possibly on GPU), + matrix. Also redo is not implemented. + + For the torch implementation, the centroids are tensors (possibly on GPU), but the indices remain numpy on CPU. """ n, d = data.count(), data.dim() @@ -382,7 +382,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, t_search_tot += time.time() - t0s; err = D.sum() - if is_torch: + if is_torch: err = err.item() obj.append(err) @@ -390,7 +390,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, fac = hassign.reshape(-1, 1).astype('float32') fac[fac == 0] = 1 # quiet warning - if is_torch: + if is_torch: import torch fac = torch.from_numpy(fac).to(sums.device) @@ -402,7 +402,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, "obj": err, "time": (time.time() - t0), "time_search": t_search_tot, - "imbalance_factor": imbalance_factor (k, assign), + "imbalance_factor": imbalance_factor(k, assign), "nsplit": nsplit } @@ -416,10 +416,10 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True, if checkpoint is not None: log('storing centroids in', checkpoint) - if is_torch: + if is_torch: import torch torch.save(centroids, checkpoint) - else: + else: np.save(checkpoint, centroids) if return_stats: diff --git a/contrib/torch/clustering.py b/contrib/torch/clustering.py index bdaa0a1f9a..9e1ce94b91 100644 --- a/contrib/torch/clustering.py +++ b/contrib/torch/clustering.py @@ -11,8 +11,7 @@ import torch # the kmeans can produce both torch and numpy centroids -from faiss.contrib.clustering import DatasetAssign, kmeans - +from faiss.contrib.clustering import kmeans class DatasetAssign: """Wrapper for a tensor that offers a function to assign the vectors @@ -52,7 +51,7 @@ def assign_to(self, centroids, weights=None): class DatasetAssignGPU(DatasetAssign): - def __init__(self, res, x): + def __init__(self, res, x): DatasetAssign.__init__(self, x) self.res = res diff --git a/contrib/torch/quantization.py b/contrib/torch/quantization.py index 550c17dbb7..8d6b17fa8f 100644 --- a/contrib/torch/quantization.py +++ b/contrib/torch/quantization.py @@ -8,46 +8,46 @@ """ import numpy as np -import torch +import torch import faiss from faiss.contrib import torch_utils -class Quantizer: +class Quantizer: - def __init__(self, d, code_size): - self.d = d + def __init__(self, d, code_size): + self.d = d self.code_size = code_size - def train(self, x): + def train(self, x): pass - - def encode(self, x): + + def encode(self, x): pass - - def decode(self, x): + + def decode(self, x): pass -class VectorQuantizer(Quantizer): +class VectorQuantizer(Quantizer): - def __init__(self, d, k): + def __init__(self, d, k): code_size = int(torch.ceil(torch.log2(k) / 8)) Quantizer.__init__(d, code_size) self.k = k - def train(self, x): + def train(self, x): pass -class ProductQuantizer(Quantizer): +class ProductQuantizer(Quantizer): - def __init__(self, d, M, nbits): + def __init__(self, d, M, nbits): code_size = int(torch.ceil(M * nbits / 8)) Quantizer.__init__(d, code_size) self.M = M self.nbits = nbits - def train(self, x): + def train(self, x): pass diff --git a/contrib/torch_utils.py b/contrib/torch_utils.py index 21e6439726..9b4855ea3a 100644 --- a/contrib/torch_utils.py +++ b/contrib/torch_utils.py @@ -73,7 +73,7 @@ def swig_ptr_from_IndicesTensor(x): x.untyped_storage().data_ptr() + x.storage_offset() * 8) ################################################################## -# utilities +# utilities ################################################################## @contextlib.contextmanager @@ -519,7 +519,7 @@ def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0): assert xb.is_contiguous() assert xb.dtype == torch.float32 assert not xb.is_cuda, "use knn_gpu for GPU tensors" - + nq, d2 = xq.size() assert d2 == d assert xq.is_contiguous() @@ -543,7 +543,7 @@ def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0): xq_ptr, xb_ptr, d, nq, nb, k, D_ptr, I_ptr ) - else: + else: faiss.knn_extra_metrics( xq_ptr, xb_ptr, d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr diff --git a/demos/demo_distributed_kmeans_torch.py b/demos/demo_distributed_kmeans_torch.py new file mode 100644 index 0000000000..279efa2080 --- /dev/null +++ b/demos/demo_distributed_kmeans_torch.py @@ -0,0 +1,173 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +import torch +import torch.distributed + +import faiss + +import faiss.contrib.torch_utils +from faiss.contrib.torch import clustering +from faiss.contrib import datasets + + +class DatasetAssignDistributedGPU(clustering.DatasetAssign): + """ + There is one instance per worker, each worker has a dataset shard. + The non-master workers do not run through the k-means function, so some + code has run it to keep the workers in sync. + """ + + def __init__(self, res, x, rank, nproc): + clustering.DatasetAssign.__init__(self, x) + self.res = res + self.rank = rank + self.nproc = nproc + self.device = x.device + + n = len(x) + sizes = torch.zeros(nproc, device=self.device, dtype=torch.int64) + sizes[rank] = n + torch.distributed.all_gather( + [sizes[i:i + 1] for i in range(nproc)], sizes[rank:rank + 1]) + self.sizes = sizes.cpu().numpy() + + # begin & end of each shard + self.cs = np.zeros(nproc + 1, dtype='int64') + self.cs[1:] = np.cumsum(self.sizes) + + def count(self): + return int(self.sizes.sum()) + + def int_to_slaves(self, i): + " broadcast an int to all workers " + rank = self.rank + tab = torch.zeros(1, device=self.device, dtype=torch.int64) + if rank == 0: + tab[0] = i + else: + assert i is None + torch.distributed.broadcast(tab, 0) + return tab.item() + + def get_subset(self, indices): + rank = self.rank + assert rank == 0 or indices is None + + len_indices = self.int_to_slaves(len(indices) if rank == 0 else None) + + if rank == 0: + indices = torch.from_numpy(indices).to(self.device) + else: + indices = torch.zeros( + len_indices, dtype=torch.int64, device=self.device) + torch.distributed.broadcast(indices, 0) + + # select subset of indices + + i0, i1 = self.cs[rank], self.cs[rank + 1] + + mask = torch.logical_and(indices < i1, indices >= i0) + output = torch.zeros( + len_indices, self.x.shape[1], + dtype=self.x.dtype, device=self.device) + output[mask] = self.x[indices[mask] - i0] + torch.distributed.reduce(output, 0) # sum + if rank == 0: + return output + else: + return None + + def perform_search(self, centroids): + assert False, "shoudl not be called" + + def assign_to(self, centroids, weights=None): + assert weights is None + + rank, nproc = self.rank, self.nproc + assert rank == 0 or centroids is None + nc = self.int_to_slaves(len(centroids) if rank == 0 else None) + + if rank != 0: + centroids = torch.zeros( + nc, self.x.shape[1], dtype=self.x.dtype, device=self.device) + torch.distributed.broadcast(centroids, 0) + + # perform search + D, I = faiss.knn_gpu( + self.res, self.x, centroids, 1, device=self.device.index) + + I = I.ravel() + D = D.ravel() + + sum_per_centroid = torch.zeros_like(centroids) + if weights is None: + sum_per_centroid.index_add_(0, I, self.x) + else: + sum_per_centroid.index_add_(0, I, self.x * weights[:, None]) + + torch.distributed.reduce(sum_per_centroid, 0) + + if rank == 0: + # gather deos not support tensors of different sizes + # should be implemented with point-to-point communication + assert np.all(self.sizes == self.sizes[0]) + device = self.device + all_I = torch.zeros(self.count(), dtype=I.dtype, device=device) + all_D = torch.zeros(self.count(), dtype=D.dtype, device=device) + torch.distributed.gather( + I, [all_I[self.cs[r]:self.cs[r + 1]] for r in range(nproc)], + dst=0, + ) + torch.distributed.gather( + D, [all_D[self.cs[r]:self.cs[r + 1]] for r in range(nproc)], + dst=0, + ) + return all_I.cpu().numpy(), all_D, sum_per_centroid + else: + torch.distributed.gather(I, None, dst=0) + torch.distributed.gather(D, None, dst=0) + return None + + +if __name__ == "__main__": + + torch.distributed.init_process_group( + backend="nccl", + ) + rank = torch.distributed.get_rank() + nproc = torch.distributed.get_world_size() + + # current version does only support shards of the same size + ds = datasets.SyntheticDataset(32, 10000, 0, 0, seed=1234 + rank) + x = ds.get_train() + + device = torch.device(f"cuda:{rank}") + + torch.cuda.set_device(device) + x = torch.from_numpy(x).to(device) + res = faiss.StandardGpuResources() + + da = DatasetAssignDistributedGPU(res, x, rank, nproc) + + k = 1000 + niter = 25 + + if rank == 0: + print(f"sizes = {da.sizes}") + centroids, iteration_stats = clustering.kmeans( + k, da, niter=niter, return_stats=True) + print("clusters:", centroids.cpu().numpy()) + else: + # make sure the iterations are aligned with master + da.get_subset(None) + + for _ in range(niter): + da.assign_to(None) + + torch.distributed.barrier() + print("Done") diff --git a/faiss/gpu/test/torch_test_contrib_gpu.py b/faiss/gpu/test/torch_test_contrib_gpu.py index 6c58b37b25..1f6f27ecca 100644 --- a/faiss/gpu/test/torch_test_contrib_gpu.py +++ b/faiss/gpu/test/torch_test_contrib_gpu.py @@ -501,5 +501,5 @@ def test_python_kmeans(self): err2 = faiss.knn(xt, centroids, 1)[0].sum() # 33498.332 33380.477 - print(err, err2) + print(err, err2) self.assertLess(err2, err * 1.1) diff --git a/faiss/python/CMakeLists.txt b/faiss/python/CMakeLists.txt index c7b22d19c8..aea99af795 100644 --- a/faiss/python/CMakeLists.txt +++ b/faiss/python/CMakeLists.txt @@ -261,5 +261,5 @@ configure_file(gpu_wrappers.py gpu_wrappers.py COPYONLY) configure_file(extra_wrappers.py extra_wrappers.py COPYONLY) configure_file(array_conversions.py array_conversions.py COPYONLY) -file(GLOB files "${PROJECT_SOURCE_DIR}/../../contrib/*.py") -file(COPY ${files} DESTINATION contrib/) +# file(GLOB files "${PROJECT_SOURCE_DIR}/../../contrib/*.py") +file(COPY ${PROJECT_SOURCE_DIR}/../../contrib DESTINATION .) diff --git a/faiss/python/setup.py b/faiss/python/setup.py index ea623ee1b2..6590a84e3f 100644 --- a/faiss/python/setup.py +++ b/faiss/python/setup.py @@ -13,6 +13,7 @@ shutil.rmtree("faiss", ignore_errors=True) os.mkdir("faiss") shutil.copytree("contrib", "faiss/contrib") +shutil.copytree("contrib/torch", "faiss/contrib/torch") shutil.copyfile("__init__.py", "faiss/__init__.py") shutil.copyfile("loader.py", "faiss/loader.py") shutil.copyfile("class_wrappers.py", "faiss/class_wrappers.py") @@ -79,12 +80,12 @@ long_description=long_description, url='https://github.com/facebookresearch/faiss', author='Matthijs Douze, Jeff Johnson, Herve Jegou, Lucas Hosseini', - author_email='matthijs@fb.com', + author_email='matthijs@meta.com', license='MIT', keywords='search nearest neighbors', install_requires=['numpy', 'packaging'], - packages=['faiss', 'faiss.contrib'], + packages=['faiss', 'faiss.contrib', 'faiss.contrib.torch'], package_data={ 'faiss': ['*.so', '*.pyd'], }, diff --git a/tests/torch_test_contrib.py b/tests/torch_test_contrib.py index 41311d6c78..d3dd8c0ae8 100644 --- a/tests/torch_test_contrib.py +++ b/tests/torch_test_contrib.py @@ -9,7 +9,7 @@ import faiss # usort: skip import faiss.contrib.torch_utils # usort: skip -from faiss.contrib import datasets +from faiss.contrib import datasets from faiss.contrib.torch import clustering @@ -365,7 +365,7 @@ def test_python_kmeans(self): km_ref.train(xt) err = faiss.knn(xt, km_ref.centroids, 1)[0].sum() - xt_torch = torch.from_numpy(xt) + xt_torch = torch.from_numpy(xt) data = clustering.DatasetAssign(xt_torch) centroids = clustering.kmeans(100, data, 10) centroids = centroids.numpy()