Skip to content

Commit c96f0f4

Browse files
mdouzefacebook-github-bot
authored andcommitted
torch.distributed kmeans (#3876)
Summary: Pull Request resolved: #3876 Demo script for distributed kmeans. It provides a `DatasetAssign` object and shows how to run it with torch.distributed. Reviewed By: asadoughi, pankajsingh88 Differential Revision: D63013820
1 parent 866e3fe commit c96f0f4

9 files changed

+215
-43
lines changed

contrib/clustering.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def assign_to(self, centroids, weights=None):
155155
sum_per_centroid = np.zeros((nc, d), dtype='float32')
156156
if weights is None:
157157
np.add.at(sum_per_centroid, I, self.x)
158-
else:
158+
else:
159159
np.add.at(sum_per_centroid, I, weights[:, np.newaxis] * self.x)
160160

161161
return I, D, sum_per_centroid
@@ -183,7 +183,7 @@ def perform_search(self, centroids):
183183

184184
def sparse_assign_to_dense(xq, xb, xq_norms=None, xb_norms=None):
185185
""" assignment function for xq is sparse, xb is dense
186-
uses a matrix multiplication. The squared norms can be provided if
186+
uses a matrix multiplication. The squared norms can be provided if
187187
available.
188188
"""
189189
nq = xq.shape[0]
@@ -271,7 +271,7 @@ def assign_to(self, centroids, weights=None):
271271
if weights is None:
272272
weights = np.ones(n, dtype='float32')
273273
nc = len(centroids)
274-
274+
275275
m = scipy.sparse.csc_matrix(
276276
(weights, I, np.arange(n + 1)),
277277
shape=(nc, n))
@@ -289,7 +289,7 @@ def check_if_torch(x):
289289
if x.__class__ == np.ndarray:
290290
return False
291291
import torch
292-
if isinstance(x, torch.Tensor):
292+
if isinstance(x, torch.Tensor):
293293
return True
294294
raise NotImplementedError(f"Unknown tensor type {type(x)}")
295295

@@ -307,11 +307,11 @@ def reassign_centroids(hassign, centroids, rs=None):
307307
if len(empty_cents) == 0:
308308
return 0
309309

310-
if is_torch:
310+
if is_torch:
311311
import torch
312-
fac = torch.ones_like(centroids[0])
313-
else:
314-
fac = np.ones_like(centroids[0])
312+
fac = torch.ones_like(centroids[0])
313+
else:
314+
fac = np.ones_like(centroids[0])
315315
fac[::2] += 1 / 1024.
316316
fac[1::2] -= 1 / 1024.
317317

@@ -347,9 +347,9 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
347347
return_stats=False):
348348
"""Pure python kmeans implementation. Follows the Faiss C++ version
349349
quite closely, but takes a DatasetAssign instead of a training data
350-
matrix. Also redo is not implemented.
351-
352-
For the torch implementation, the centroids are tensors (possibly on GPU),
350+
matrix. Also redo is not implemented.
351+
352+
For the torch implementation, the centroids are tensors (possibly on GPU),
353353
but the indices remain numpy on CPU.
354354
"""
355355
n, d = data.count(), data.dim()
@@ -382,15 +382,15 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
382382
t_search_tot += time.time() - t0s;
383383

384384
err = D.sum()
385-
if is_torch:
385+
if is_torch:
386386
err = err.item()
387387
obj.append(err)
388388

389389
hassign = np.bincount(assign, minlength=k)
390390

391391
fac = hassign.reshape(-1, 1).astype('float32')
392392
fac[fac == 0] = 1 # quiet warning
393-
if is_torch:
393+
if is_torch:
394394
import torch
395395
fac = torch.from_numpy(fac).to(sums.device)
396396

@@ -402,7 +402,7 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
402402
"obj": err,
403403
"time": (time.time() - t0),
404404
"time_search": t_search_tot,
405-
"imbalance_factor": imbalance_factor (k, assign),
405+
"imbalance_factor": imbalance_factor(k, assign),
406406
"nsplit": nsplit
407407
}
408408

@@ -416,10 +416,10 @@ def kmeans(k, data, niter=25, seed=1234, checkpoint=None, verbose=True,
416416

417417
if checkpoint is not None:
418418
log('storing centroids in', checkpoint)
419-
if is_torch:
419+
if is_torch:
420420
import torch
421421
torch.save(centroids, checkpoint)
422-
else:
422+
else:
423423
np.save(checkpoint, centroids)
424424

425425
if return_stats:

contrib/torch/clustering.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
import torch
1212

1313
# the kmeans can produce both torch and numpy centroids
14-
from faiss.contrib.clustering import DatasetAssign, kmeans
15-
14+
from faiss.contrib.clustering import kmeans
1615

1716
class DatasetAssign:
1817
"""Wrapper for a tensor that offers a function to assign the vectors
@@ -52,7 +51,7 @@ def assign_to(self, centroids, weights=None):
5251

5352
class DatasetAssignGPU(DatasetAssign):
5453

55-
def __init__(self, res, x):
54+
def __init__(self, res, x):
5655
DatasetAssign.__init__(self, x)
5756
self.res = res
5857

contrib/torch/quantization.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -8,46 +8,46 @@
88
"""
99

1010
import numpy as np
11-
import torch
11+
import torch
1212
import faiss
1313

1414
from faiss.contrib import torch_utils
1515

1616

17-
class Quantizer:
17+
class Quantizer:
1818

19-
def __init__(self, d, code_size):
20-
self.d = d
19+
def __init__(self, d, code_size):
20+
self.d = d
2121
self.code_size = code_size
2222

23-
def train(self, x):
23+
def train(self, x):
2424
pass
25-
26-
def encode(self, x):
25+
26+
def encode(self, x):
2727
pass
28-
29-
def decode(self, x):
28+
29+
def decode(self, x):
3030
pass
3131

3232

33-
class VectorQuantizer(Quantizer):
33+
class VectorQuantizer(Quantizer):
3434

35-
def __init__(self, d, k):
35+
def __init__(self, d, k):
3636
code_size = int(torch.ceil(torch.log2(k) / 8))
3737
Quantizer.__init__(d, code_size)
3838
self.k = k
3939

40-
def train(self, x):
40+
def train(self, x):
4141
pass
4242

4343

44-
class ProductQuantizer(Quantizer):
44+
class ProductQuantizer(Quantizer):
4545

46-
def __init__(self, d, M, nbits):
46+
def __init__(self, d, M, nbits):
4747
code_size = int(torch.ceil(M * nbits / 8))
4848
Quantizer.__init__(d, code_size)
4949
self.M = M
5050
self.nbits = nbits
5151

52-
def train(self, x):
52+
def train(self, x):
5353
pass

contrib/torch_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def swig_ptr_from_IndicesTensor(x):
7373
x.untyped_storage().data_ptr() + x.storage_offset() * 8)
7474

7575
##################################################################
76-
# utilities
76+
# utilities
7777
##################################################################
7878

7979
@contextlib.contextmanager
@@ -519,7 +519,7 @@ def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0):
519519
assert xb.is_contiguous()
520520
assert xb.dtype == torch.float32
521521
assert not xb.is_cuda, "use knn_gpu for GPU tensors"
522-
522+
523523
nq, d2 = xq.size()
524524
assert d2 == d
525525
assert xq.is_contiguous()
@@ -543,7 +543,7 @@ def torch_replacement_knn(xq, xb, k, metric=faiss.METRIC_L2, metric_arg=0):
543543
xq_ptr, xb_ptr,
544544
d, nq, nb, k, D_ptr, I_ptr
545545
)
546-
else:
546+
else:
547547
faiss.knn_extra_metrics(
548548
xq_ptr, xb_ptr,
549549
d, nq, nb, metric, metric_arg, k, D_ptr, I_ptr
+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import numpy as np
7+
8+
import torch
9+
import torch.distributed
10+
11+
import faiss
12+
13+
import faiss.contrib.torch_utils
14+
from faiss.contrib.torch import clustering
15+
from faiss.contrib import datasets
16+
17+
18+
class DatasetAssignDistributedGPU(clustering.DatasetAssign):
19+
"""
20+
There is one instance per worker, each worker has a dataset shard.
21+
The non-master workers do not run through the k-means function, so some
22+
code has run it to keep the workers in sync.
23+
"""
24+
25+
def __init__(self, res, x, rank, nproc):
26+
clustering.DatasetAssign.__init__(self, x)
27+
self.res = res
28+
self.rank = rank
29+
self.nproc = nproc
30+
self.device = x.device
31+
32+
n = len(x)
33+
sizes = torch.zeros(nproc, device=self.device, dtype=torch.int64)
34+
sizes[rank] = n
35+
torch.distributed.all_gather(
36+
[sizes[i:i + 1] for i in range(nproc)], sizes[rank:rank + 1])
37+
self.sizes = sizes.cpu().numpy()
38+
39+
# begin & end of each shard
40+
self.cs = np.zeros(nproc + 1, dtype='int64')
41+
self.cs[1:] = np.cumsum(self.sizes)
42+
43+
def count(self):
44+
return int(self.sizes.sum())
45+
46+
def int_to_slaves(self, i):
47+
" broadcast an int to all workers "
48+
rank = self.rank
49+
tab = torch.zeros(1, device=self.device, dtype=torch.int64)
50+
if rank == 0:
51+
tab[0] = i
52+
else:
53+
assert i is None
54+
torch.distributed.broadcast(tab, 0)
55+
return tab.item()
56+
57+
def get_subset(self, indices):
58+
rank = self.rank
59+
assert rank == 0 or indices is None
60+
61+
len_indices = self.int_to_slaves(len(indices) if rank == 0 else None)
62+
63+
if rank == 0:
64+
indices = torch.from_numpy(indices).to(self.device)
65+
else:
66+
indices = torch.zeros(
67+
len_indices, dtype=torch.int64, device=self.device)
68+
torch.distributed.broadcast(indices, 0)
69+
70+
# select subset of indices
71+
72+
i0, i1 = self.cs[rank], self.cs[rank + 1]
73+
74+
mask = torch.logical_and(indices < i1, indices >= i0)
75+
output = torch.zeros(
76+
len_indices, self.x.shape[1],
77+
dtype=self.x.dtype, device=self.device)
78+
output[mask] = self.x[indices[mask] - i0]
79+
torch.distributed.reduce(output, 0) # sum
80+
if rank == 0:
81+
return output
82+
else:
83+
return None
84+
85+
def perform_search(self, centroids):
86+
assert False, "shoudl not be called"
87+
88+
def assign_to(self, centroids, weights=None):
89+
assert weights is None
90+
91+
rank, nproc = self.rank, self.nproc
92+
assert rank == 0 or centroids is None
93+
nc = self.int_to_slaves(len(centroids) if rank == 0 else None)
94+
95+
if rank != 0:
96+
centroids = torch.zeros(
97+
nc, self.x.shape[1], dtype=self.x.dtype, device=self.device)
98+
torch.distributed.broadcast(centroids, 0)
99+
100+
# perform search
101+
D, I = faiss.knn_gpu(
102+
self.res, self.x, centroids, 1, device=self.device.index)
103+
104+
I = I.ravel()
105+
D = D.ravel()
106+
107+
sum_per_centroid = torch.zeros_like(centroids)
108+
if weights is None:
109+
sum_per_centroid.index_add_(0, I, self.x)
110+
else:
111+
sum_per_centroid.index_add_(0, I, self.x * weights[:, None])
112+
113+
torch.distributed.reduce(sum_per_centroid, 0)
114+
115+
if rank == 0:
116+
# gather deos not support tensors of different sizes
117+
# should be implemented with point-to-point communication
118+
assert np.all(self.sizes == self.sizes[0])
119+
device = self.device
120+
all_I = torch.zeros(self.count(), dtype=I.dtype, device=device)
121+
all_D = torch.zeros(self.count(), dtype=D.dtype, device=device)
122+
torch.distributed.gather(
123+
I, [all_I[self.cs[r]:self.cs[r + 1]] for r in range(nproc)],
124+
dst=0,
125+
)
126+
torch.distributed.gather(
127+
D, [all_D[self.cs[r]:self.cs[r + 1]] for r in range(nproc)],
128+
dst=0,
129+
)
130+
return all_I.cpu().numpy(), all_D, sum_per_centroid
131+
else:
132+
torch.distributed.gather(I, None, dst=0)
133+
torch.distributed.gather(D, None, dst=0)
134+
return None
135+
136+
137+
if __name__ == "__main__":
138+
139+
torch.distributed.init_process_group(
140+
backend="nccl",
141+
)
142+
rank = torch.distributed.get_rank()
143+
nproc = torch.distributed.get_world_size()
144+
145+
# current version does only support shards of the same size
146+
ds = datasets.SyntheticDataset(32, 10000, 0, 0, seed=1234 + rank)
147+
x = ds.get_train()
148+
149+
device = torch.device(f"cuda:{rank}")
150+
151+
torch.cuda.set_device(device)
152+
x = torch.from_numpy(x).to(device)
153+
res = faiss.StandardGpuResources()
154+
155+
da = DatasetAssignDistributedGPU(res, x, rank, nproc)
156+
157+
k = 1000
158+
niter = 25
159+
160+
if rank == 0:
161+
print(f"sizes = {da.sizes}")
162+
centroids, iteration_stats = clustering.kmeans(
163+
k, da, niter=niter, return_stats=True)
164+
print("clusters:", centroids.cpu().numpy())
165+
else:
166+
# make sure the iterations are aligned with master
167+
da.get_subset(None)
168+
169+
for _ in range(niter):
170+
da.assign_to(None)
171+
172+
torch.distributed.barrier()
173+
print("Done")

faiss/gpu/test/torch_test_contrib_gpu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -501,5 +501,5 @@ def test_python_kmeans(self):
501501
err2 = faiss.knn(xt, centroids, 1)[0].sum()
502502

503503
# 33498.332 33380.477
504-
print(err, err2)
504+
print(err, err2)
505505
self.assertLess(err2, err * 1.1)

0 commit comments

Comments
 (0)