@@ -60,12 +60,18 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
60
60
- None. In that case, at most gpu_k results will be returned
61
61
"""
62
62
nq , d = xq .shape
63
- k = min (index_gpu . ntotal , gpu_k )
63
+ is_binary_index = isinstance (index_gpu , faiss . IndexBinary )
64
64
keep_max = faiss .is_similarity_metric (index_gpu .metric_type )
65
- LOG .debug (f"GPU search { nq } queries with { k = :} " )
65
+ r2 = int (r2 ) if is_binary_index else float (r2 )
66
+ k = min (index_gpu .ntotal , gpu_k )
67
+ LOG .debug (
68
+ f"GPU search { nq } queries with { k = :} { is_binary_index = :} { keep_max = :} " )
66
69
t0 = time .time ()
67
70
D , I = index_gpu .search (xq , k )
68
71
t1 = time .time () - t0
72
+ if is_binary_index :
73
+ assert d * 8 < 32768 # let's compact the distance matrix
74
+ D = D .astype ('int16' )
69
75
t2 = 0
70
76
lim_remain = None
71
77
if index_cpu is not None :
@@ -79,14 +85,24 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
79
85
if isinstance (index_cpu , np .ndarray ):
80
86
# then it in fact an array that we have to make flat
81
87
xb = index_cpu
82
- index_cpu = faiss .IndexFlat (d , index_gpu .metric_type )
88
+ if is_binary_index :
89
+ index_cpu = faiss .IndexBinaryFlat (d * 8 )
90
+ else :
91
+ index_cpu = faiss .IndexFlat (d , index_gpu .metric_type )
83
92
index_cpu .add (xb )
84
93
lim_remain , D_remain , I_remain = index_cpu .range_search (xq [mask ], r2 )
94
+ if is_binary_index :
95
+ D_remain = D_remain .astype ('int16' )
85
96
t2 = time .time () - t0
86
97
LOG .debug ("combine" )
87
98
t0 = time .time ()
88
99
89
- combiner = faiss .CombinerRangeKNN (nq , k , float (r2 ), keep_max )
100
+ CombinerRangeKNN = (
101
+ faiss .CombinerRangeKNNint16 if is_binary_index else
102
+ faiss .CombinerRangeKNNfloat
103
+ )
104
+
105
+ combiner = CombinerRangeKNN (nq , k , r2 , keep_max )
90
106
if True :
91
107
sp = faiss .swig_ptr
92
108
combiner .I = sp (I )
@@ -101,7 +117,7 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
101
117
L_res = np .empty (nq + 1 , dtype = 'int64' )
102
118
combiner .compute_sizes (sp (L_res ))
103
119
nres = L_res [- 1 ]
104
- D_res = np .empty (nres , dtype = 'float32' )
120
+ D_res = np .empty (nres , dtype = D . dtype )
105
121
I_res = np .empty (nres , dtype = 'int64' )
106
122
combiner .write_result (sp (D_res ), sp (I_res ))
107
123
else :
@@ -251,6 +267,7 @@ def range_search_max_results(index, query_iterator, radius,
251
267
"""
252
268
# TODO: all result manipulations are in python, should move to C++ if perf
253
269
# critical
270
+ is_binary_index = isinstance (index , faiss .IndexBinary )
254
271
255
272
if min_results is None :
256
273
assert max_results is not None
@@ -268,6 +285,8 @@ def range_search_max_results(index, query_iterator, radius,
268
285
co = faiss .GpuMultipleClonerOptions ()
269
286
co .shard = shard
270
287
index_gpu = faiss .index_cpu_to_all_gpus (index , co = co , ngpu = ngpu )
288
+ else :
289
+ index_gpu = None
271
290
272
291
t_start = time .time ()
273
292
t_search = t_post_process = 0
@@ -276,7 +295,8 @@ def range_search_max_results(index, query_iterator, radius,
276
295
277
296
for xqi in query_iterator :
278
297
t0 = time .time ()
279
- if ngpu > 0 :
298
+ LOG .debug (f"searching { len (xqi )} vectors" )
299
+ if index_gpu :
280
300
lims_i , Di , Ii = range_search_gpu (xqi , radius , index_gpu , index )
281
301
else :
282
302
lims_i , Di , Ii = index .range_search (xqi , radius )
@@ -286,8 +306,7 @@ def range_search_max_results(index, query_iterator, radius,
286
306
qtot += len (xqi )
287
307
288
308
t1 = time .time ()
289
- if xqi .dtype != np .float32 :
290
- # for binary indexes
309
+ if is_binary_index :
291
310
# weird Faiss quirk that returns floats for Hamming distances
292
311
Di = Di .astype ('int16' )
293
312
@@ -299,7 +318,7 @@ def range_search_max_results(index, query_iterator, radius,
299
318
(totres , max_results ))
300
319
radius , totres = apply_maxres (
301
320
res_batches , min_results ,
302
- keep_max = faiss . is_similarity_metric ( index .metric_type )
321
+ keep_max = index .metric_type == faiss . METRIC_INNER_PRODUCT
303
322
)
304
323
t2 = time .time ()
305
324
t_search += t1 - t0
@@ -315,7 +334,7 @@ def range_search_max_results(index, query_iterator, radius,
315
334
if clip_to_min and totres > min_results :
316
335
radius , totres = apply_maxres (
317
336
res_batches , min_results ,
318
- keep_max = faiss . is_similarity_metric ( index .metric_type )
337
+ keep_max = index .metric_type == faiss . METRIC_INNER_PRODUCT
319
338
)
320
339
321
340
nres = np .hstack ([nres_i for nres_i , dis_i , ids_i in res_batches ])
0 commit comments