@@ -107,17 +107,15 @@ def randn(n, seed=12345):
107
107
def checksum (a ):
108
108
""" compute a checksum for quick-and-dirty comparisons of arrays """
109
109
a = a .view ('uint8' )
110
- n = a . size
111
- n4 = n & ~ 3
112
- cs = ivec_checksum ( int ( n4 / 4 ), swig_ptr ( a [: n4 ]. view ( 'int32' )))
113
- for i in range ( n4 , n ):
114
- cs += x [ i ] * 33657
110
+ if a . ndim == 1 :
111
+ return bvec_checksum ( s . size , swig_ptr ( a ))
112
+ n , d = a . shape
113
+ cs = np . zeros ( n , dtype = 'uint64' )
114
+ bvecs_checksum ( n , d , swig_ptr ( a ), swig_ptr ( cs ))
115
115
return cs
116
116
117
-
118
117
rand_smooth_vectors_c = rand_smooth_vectors
119
118
120
-
121
119
def rand_smooth_vectors (n , d , seed = 1234 ):
122
120
res = np .empty ((n , d ), dtype = 'float32' )
123
121
rand_smooth_vectors_c (n , d , swig_ptr (res ), seed )
@@ -422,7 +420,7 @@ def __init__(self, d, k, **kwargs):
422
420
including niter=25, verbose=False, spherical = False
423
421
"""
424
422
self .d = d
425
- self .k = k
423
+ self .reset ( k )
426
424
self .gpu = False
427
425
if "progressive_dim_steps" in kwargs :
428
426
self .cp = ProgressiveDimClusteringParameters ()
@@ -437,7 +435,32 @@ def __init__(self, d, k, **kwargs):
437
435
# if this raises an exception, it means that it is a non-existent field
438
436
getattr (self .cp , k )
439
437
setattr (self .cp , k , v )
438
+ self .set_index ()
439
+
440
+ def set_index (self ):
441
+ d = self .d
442
+ if self .cp .__class__ == ClusteringParameters :
443
+ if self .cp .spherical :
444
+ self .index = IndexFlatIP (d )
445
+ else :
446
+ self .index = IndexFlatL2 (d )
447
+ if self .gpu :
448
+ self .index = faiss .index_cpu_to_all_gpus (self .index , ngpu = self .gpu )
449
+ else :
450
+ if self .gpu :
451
+ fac = GpuProgressiveDimIndexFactory (ngpu = self .gpu )
452
+ else :
453
+ fac = ProgressiveDimIndexFactory ()
454
+ self .fac = fac
455
+
456
+ def reset (self , k = None ):
457
+ """ prepare k-means object to perform a new clustering, possibly
458
+ with another number of centroids """
459
+ if k is not None :
460
+ self .k = int (k )
440
461
self .centroids = None
462
+ self .obj = None
463
+ self .iteration_stats = None
441
464
442
465
def train (self , x , weights = None , init_centroids = None ):
443
466
""" Perform k-means clustering.
@@ -476,24 +499,14 @@ def train(self, x, weights=None, init_centroids=None):
476
499
nc , d2 = init_centroids .shape
477
500
assert d2 == d
478
501
faiss .copy_array_to_vector (init_centroids .ravel (), clus .centroids )
479
- if self .cp .spherical :
480
- self .index = IndexFlatIP (d )
481
- else :
482
- self .index = IndexFlatL2 (d )
483
- if self .gpu :
484
- self .index = faiss .index_cpu_to_all_gpus (self .index , ngpu = self .gpu )
485
502
clus .train (x , self .index , weights )
486
503
else :
487
504
# not supported for progressive dim
488
505
assert weights is None
489
506
assert init_centroids is None
490
507
assert not self .cp .spherical
491
508
clus = ProgressiveDimClustering (d , self .k , self .cp )
492
- if self .gpu :
493
- fac = GpuProgressiveDimIndexFactory (ngpu = self .gpu )
494
- else :
495
- fac = ProgressiveDimIndexFactory ()
496
- clus .train (n , swig_ptr (x ), fac )
509
+ clus .train (n , swig_ptr (x ), self .fac )
497
510
498
511
centroids = faiss .vector_float_to_array (clus .centroids )
499
512
0 commit comments