Skip to content

Commit a27036a

Browse files
mdouzefacebook-github-bot
authored andcommitted
add small benchmark for hamming computers
Summary: to measure impact of hamming computer diff Reviewed By: algoriddle Differential Revision: D46913890 fbshipit-source-id: 7b9850205885b9b7c5f394f17a79ba222e7b1e2e
1 parent 391601d commit a27036a

File tree

3 files changed

+82
-3
lines changed

3 files changed

+82
-3
lines changed

benchs/bench_hamming_knn.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import time
7+
import numpy as np
8+
import faiss
9+
10+
if __name__ == "__main__":
11+
faiss.omp_set_num_threads(1)
12+
13+
for d in 4, 8, 16, 13:
14+
nq = 10000
15+
nb = 30000
16+
print('Bits per vector = 8 *', d)
17+
xq = faiss.randint((nq, d // 4), seed=1234, vmax=256**4).view('uint8')
18+
xb = faiss.randint((nb, d // 4), seed=1234, vmax=256**4).view('uint8')
19+
for variant in "hc", "mc":
20+
print(f"{variant=:}", end="\t")
21+
for k in 1, 4, 16, 64, 256:
22+
times = []
23+
for _run in range(5):
24+
t0 = time.time()
25+
D, I = faiss.knn_hamming(xq, xb, k, variant=variant)
26+
t1 = time.time()
27+
times.append(t1 - t0)
28+
print(f'| {k=:} t={np.mean(times):.3f} s ± {np.std(times):.3f} ', flush=True, end="")
29+
print()

faiss/python/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from faiss.extra_wrappers import kmin, kmax, pairwise_distances, rand, randint, \
2323
lrand, randn, rand_smooth_vectors, eval_intersection, normalize_L2, \
2424
ResultHeap, knn, Kmeans, checksum, matrix_bucket_sort_inplace, bucket_sort, \
25-
merge_knn_results, MapInt64ToInt64
25+
merge_knn_results, MapInt64ToInt64, knn_hamming
2626

2727

2828
__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR,

faiss/python/extra_wrappers.py

+52-2
Original file line numberDiff line numberDiff line change
@@ -335,10 +335,10 @@ def knn(xq, xb, k, metric=METRIC_L2):
335335
Parameters
336336
----------
337337
xq : array_like
338-
Query vectors, shape (nq, d) where d is appropriate for the index.
338+
Query vectors, shape (nq, d) where the dimension d is that same as xb
339339
`dtype` must be float32.
340340
xb : array_like
341-
Database vectors, shape (nb, d) where d is appropriate for the index.
341+
Database vectors, shape (nb, d) where dimension d is the same as xq
342342
`dtype` must be float32.
343343
k : int
344344
Number of nearest neighbors.
@@ -375,6 +375,56 @@ def knn(xq, xb, k, metric=METRIC_L2):
375375
raise NotImplementedError("only L2 and INNER_PRODUCT are supported")
376376
return D, I
377377

378+
def knn_hamming(xq, xb, k, variant="hc"):
379+
"""
380+
Compute the k nearest neighbors of a set of vectors without constructing an index.
381+
382+
Parameters
383+
----------
384+
xq : array_like
385+
Query vectors, shape (nq, d) where d is the number of bits / 8
386+
`dtype` must be uint8.
387+
xb : array_like
388+
Database vectors, shape (nb, d) where d is the number of bits / 8
389+
`dtype` must be uint8.
390+
k : int
391+
Number of nearest neighbors.
392+
variant : string
393+
Function variant to use, either "mc" (counter) or "hc" (heap)
394+
395+
Returns
396+
-------
397+
D : array_like
398+
Distances of the nearest neighbors, shape (nq, k)
399+
I : array_like
400+
Labels of the nearest neighbors, shape (nq, k)
401+
"""
402+
# other variant is "mc"
403+
nq, d = xq.shape
404+
nb, d2 = xb.shape
405+
assert d == d2
406+
D = np.empty((nq, k), dtype='int32')
407+
I = np.empty((nq, k), dtype='int64')
408+
409+
if variant == "hc":
410+
heap = faiss.int_maxheap_array_t()
411+
heap.k = k
412+
heap.nh = nq
413+
heap.ids = faiss.swig_ptr(I)
414+
heap.val = faiss.swig_ptr(D)
415+
faiss.hammings_knn_hc(
416+
heap, faiss.swig_ptr(xq), faiss.swig_ptr(xb), nb,
417+
d, 1
418+
)
419+
elif variant == "mc":
420+
faiss.hammings_knn_mc(
421+
faiss.swig_ptr(xq), faiss.swig_ptr(xb), nq, nb, k, d,
422+
faiss.swig_ptr(D), faiss.swig_ptr(I)
423+
)
424+
else:
425+
raise NotImplementedError
426+
return D, I
427+
378428

379429
###########################################
380430
# Kmeans object

0 commit comments

Comments
 (0)