Skip to content

Commit 07fe2b6

Browse files
mdouzefacebook-github-bot
authored andcommitted
Binary cloning and GPU range search (facebookresearch#2916)
Summary: Pull Request resolved: facebookresearch#2916 Overall better support for binary indexes: - cloning (to CPU and GPU), only for BinaryFlat for now - fix bug in reconstruct_n - range_search_max_results Reviewed By: algoriddle Differential Revision: D46755778 fbshipit-source-id: 777ad90aff5c54a77f9685ed6512247a922c6ef5
1 parent e153cac commit 07fe2b6

22 files changed

+384
-116
lines changed

contrib/exhaustive_search.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,18 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
6060
- None. In that case, at most gpu_k results will be returned
6161
"""
6262
nq, d = xq.shape
63-
k = min(index_gpu.ntotal, gpu_k)
63+
is_binary_index = isinstance(index_gpu, faiss.IndexBinary)
6464
keep_max = faiss.is_similarity_metric(index_gpu.metric_type)
65-
LOG.debug(f"GPU search {nq} queries with {k=:}")
65+
r2 = int(r2) if is_binary_index else float(r2)
66+
k = min(index_gpu.ntotal, gpu_k)
67+
LOG.debug(
68+
f"GPU search {nq} queries with {k=:} {is_binary_index=:} {keep_max=:}")
6669
t0 = time.time()
6770
D, I = index_gpu.search(xq, k)
6871
t1 = time.time() - t0
72+
if is_binary_index:
73+
assert d * 8 < 32768 # let's compact the distance matrix
74+
D = D.astype('int16')
6975
t2 = 0
7076
lim_remain = None
7177
if index_cpu is not None:
@@ -79,14 +85,24 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
7985
if isinstance(index_cpu, np.ndarray):
8086
# then it in fact an array that we have to make flat
8187
xb = index_cpu
82-
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
88+
if is_binary_index:
89+
index_cpu = faiss.IndexBinaryFlat(d * 8)
90+
else:
91+
index_cpu = faiss.IndexFlat(d, index_gpu.metric_type)
8392
index_cpu.add(xb)
8493
lim_remain, D_remain, I_remain = index_cpu.range_search(xq[mask], r2)
94+
if is_binary_index:
95+
D_remain = D_remain.astype('int16')
8596
t2 = time.time() - t0
8697
LOG.debug("combine")
8798
t0 = time.time()
8899

89-
combiner = faiss.CombinerRangeKNN(nq, k, float(r2), keep_max)
100+
CombinerRangeKNN = (
101+
faiss.CombinerRangeKNNint16 if is_binary_index else
102+
faiss.CombinerRangeKNNfloat
103+
)
104+
105+
combiner = CombinerRangeKNN(nq, k, r2, keep_max)
90106
if True:
91107
sp = faiss.swig_ptr
92108
combiner.I = sp(I)
@@ -101,7 +117,7 @@ def range_search_gpu(xq, r2, index_gpu, index_cpu, gpu_k=1024):
101117
L_res = np.empty(nq + 1, dtype='int64')
102118
combiner.compute_sizes(sp(L_res))
103119
nres = L_res[-1]
104-
D_res = np.empty(nres, dtype='float32')
120+
D_res = np.empty(nres, dtype=D.dtype)
105121
I_res = np.empty(nres, dtype='int64')
106122
combiner.write_result(sp(D_res), sp(I_res))
107123
else:
@@ -251,6 +267,7 @@ def range_search_max_results(index, query_iterator, radius,
251267
"""
252268
# TODO: all result manipulations are in python, should move to C++ if perf
253269
# critical
270+
is_binary_index = isinstance(index, faiss.IndexBinary)
254271

255272
if min_results is None:
256273
assert max_results is not None
@@ -268,6 +285,8 @@ def range_search_max_results(index, query_iterator, radius,
268285
co = faiss.GpuMultipleClonerOptions()
269286
co.shard = shard
270287
index_gpu = faiss.index_cpu_to_all_gpus(index, co=co, ngpu=ngpu)
288+
else:
289+
index_gpu = None
271290

272291
t_start = time.time()
273292
t_search = t_post_process = 0
@@ -276,7 +295,8 @@ def range_search_max_results(index, query_iterator, radius,
276295

277296
for xqi in query_iterator:
278297
t0 = time.time()
279-
if ngpu > 0:
298+
LOG.debug(f"searching {len(xqi)} vectors")
299+
if index_gpu:
280300
lims_i, Di, Ii = range_search_gpu(xqi, radius, index_gpu, index)
281301
else:
282302
lims_i, Di, Ii = index.range_search(xqi, radius)
@@ -286,8 +306,7 @@ def range_search_max_results(index, query_iterator, radius,
286306
qtot += len(xqi)
287307

288308
t1 = time.time()
289-
if xqi.dtype != np.float32:
290-
# for binary indexes
309+
if is_binary_index:
291310
# weird Faiss quirk that returns floats for Hamming distances
292311
Di = Di.astype('int16')
293312

@@ -299,7 +318,7 @@ def range_search_max_results(index, query_iterator, radius,
299318
(totres, max_results))
300319
radius, totres = apply_maxres(
301320
res_batches, min_results,
302-
keep_max=faiss.is_similarity_metric(index.metric_type)
321+
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
303322
)
304323
t2 = time.time()
305324
t_search += t1 - t0
@@ -315,7 +334,7 @@ def range_search_max_results(index, query_iterator, radius,
315334
if clip_to_min and totres > min_results:
316335
radius, totres = apply_maxres(
317336
res_batches, min_results,
318-
keep_max=faiss.is_similarity_metric(index.metric_type)
337+
keep_max=index.metric_type == faiss.METRIC_INNER_PRODUCT
319338
)
320339

321340
nres = np.hstack([nres_i for nres_i, dis_i, ids_i in res_batches])

faiss/IndexBinary.cpp

+8-3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515

1616
namespace faiss {
1717

18+
IndexBinary::IndexBinary(idx_t d, MetricType metric)
19+
: d(d), code_size(d / 8), metric_type(metric) {
20+
FAISS_THROW_IF_NOT(d % 8 == 0);
21+
}
22+
1823
IndexBinary::~IndexBinary() {}
1924

2025
void IndexBinary::train(idx_t, const uint8_t*) {
@@ -51,7 +56,7 @@ void IndexBinary::reconstruct(idx_t, uint8_t*) const {
5156

5257
void IndexBinary::reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const {
5358
for (idx_t i = 0; i < ni; i++) {
54-
reconstruct(i0 + i, recons + i * d);
59+
reconstruct(i0 + i, recons + i * code_size);
5560
}
5661
}
5762

@@ -70,10 +75,10 @@ void IndexBinary::search_and_reconstruct(
7075
for (idx_t j = 0; j < k; ++j) {
7176
idx_t ij = i * k + j;
7277
idx_t key = labels[ij];
73-
uint8_t* reconstructed = recons + ij * d;
78+
uint8_t* reconstructed = recons + ij * code_size;
7479
if (key < 0) {
7580
// Fill with NaNs
76-
memset(reconstructed, -1, sizeof(*reconstructed) * d);
81+
memset(reconstructed, -1, code_size);
7782
} else {
7883
reconstruct(key, reconstructed);
7984
}

faiss/IndexBinary.h

+8-19
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// -*- c++ -*-
9-
108
#ifndef FAISS_INDEX_BINARY_H
119
#define FAISS_INDEX_BINARY_H
1210

@@ -16,7 +14,6 @@
1614
#include <typeinfo>
1715

1816
#include <faiss/Index.h>
19-
#include <faiss/impl/FaissAssert.h>
2017

2118
namespace faiss {
2219

@@ -35,27 +32,19 @@ struct IndexBinary {
3532
using component_t = uint8_t;
3633
using distance_t = int32_t;
3734

38-
int d; ///< vector dimension
39-
int code_size; ///< number of bytes per vector ( = d / 8 )
40-
idx_t ntotal; ///< total nb of indexed vectors
41-
bool verbose; ///< verbosity level
35+
int d = 0; ///< vector dimension
36+
int code_size = 0; ///< number of bytes per vector ( = d / 8 )
37+
idx_t ntotal = 0; ///< total nb of indexed vectors
38+
bool verbose = false; ///< verbosity level
4239

4340
/// set if the Index does not require training, or if training is done
4441
/// already
45-
bool is_trained;
42+
bool is_trained = true;
4643

4744
/// type of metric this index uses for search
48-
MetricType metric_type;
49-
50-
explicit IndexBinary(idx_t d = 0, MetricType metric = METRIC_L2)
51-
: d(d),
52-
code_size(d / 8),
53-
ntotal(0),
54-
verbose(false),
55-
is_trained(true),
56-
metric_type(metric) {
57-
FAISS_THROW_IF_NOT(d % 8 == 0);
58-
}
45+
MetricType metric_type = METRIC_L2;
46+
47+
explicit IndexBinary(idx_t d = 0, MetricType metric = METRIC_L2);
5948

6049
virtual ~IndexBinary();
6150

faiss/IndexBinaryFromFloat.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <faiss/IndexBinaryFromFloat.h>
1111

12+
#include <faiss/impl/FaissAssert.h>
1213
#include <faiss/utils/utils.h>
1314
#include <algorithm>
1415
#include <memory>

faiss/IndexIDMap.cpp

+15-2
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,31 @@
2121

2222
namespace faiss {
2323

24+
namespace {
25+
26+
// IndexBinary needs to update the code_size when d is set...
27+
28+
void sync_d(Index* index) {}
29+
30+
void sync_d(IndexBinary* index) {
31+
FAISS_THROW_IF_NOT(index->d % 8 == 0);
32+
index->code_size = index->d / 8;
33+
}
34+
35+
} // anonymous namespace
36+
2437
/*****************************************************
2538
* IndexIDMap implementation
2639
*******************************************************/
2740

2841
template <typename IndexT>
29-
IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index)
30-
: index(index), own_fields(false) {
42+
IndexIDMapTemplate<IndexT>::IndexIDMapTemplate(IndexT* index) : index(index) {
3143
FAISS_THROW_IF_NOT_MSG(index->ntotal == 0, "index must be empty on input");
3244
this->is_trained = index->is_trained;
3345
this->metric_type = index->metric_type;
3446
this->verbose = index->verbose;
3547
this->d = index->d;
48+
sync_d(this);
3649
}
3750

3851
template <typename IndexT>

faiss/IndexIDMap.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ struct IndexIDMapTemplate : IndexT {
2222
using component_t = typename IndexT::component_t;
2323
using distance_t = typename IndexT::distance_t;
2424

25-
IndexT* index; ///! the sub-index
26-
bool own_fields; ///! whether pointers are deleted in destructo
25+
IndexT* index = nullptr; ///! the sub-index
26+
bool own_fields = false; ///! whether pointers are deleted in destructo
2727
std::vector<idx_t> id_map;
2828

2929
explicit IndexIDMapTemplate(IndexT* index);

faiss/IndexReplicas.cpp

+21-24
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,34 @@
1212

1313
namespace faiss {
1414

15+
namespace {
16+
17+
// IndexBinary needs to update the code_size when d is set...
18+
19+
void sync_d(Index* index) {}
20+
21+
void sync_d(IndexBinary* index) {
22+
FAISS_THROW_IF_NOT(index->d % 8 == 0);
23+
index->code_size = index->d / 8;
24+
}
25+
26+
} // anonymous namespace
27+
1528
template <typename IndexT>
1629
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(bool threaded)
1730
: ThreadedIndex<IndexT>(threaded) {}
1831

1932
template <typename IndexT>
2033
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(idx_t d, bool threaded)
21-
: ThreadedIndex<IndexT>(d, threaded) {}
34+
: ThreadedIndex<IndexT>(d, threaded) {
35+
sync_d(this);
36+
}
2237

2338
template <typename IndexT>
2439
IndexReplicasTemplate<IndexT>::IndexReplicasTemplate(int d, bool threaded)
25-
: ThreadedIndex<IndexT>(d, threaded) {}
40+
: ThreadedIndex<IndexT>(d, threaded) {
41+
sync_d(this);
42+
}
2643

2744
template <typename IndexT>
2845
void IndexReplicasTemplate<IndexT>::onAfterAddIndex(IndexT* index) {
@@ -168,6 +185,8 @@ void IndexReplicasTemplate<IndexT>::syncWithSubIndexes() {
168185
}
169186

170187
auto firstIndex = this->at(0);
188+
this->d = firstIndex->d;
189+
sync_d(this);
171190
this->metric_type = firstIndex->metric_type;
172191
this->is_trained = firstIndex->is_trained;
173192
this->ntotal = firstIndex->ntotal;
@@ -181,28 +200,6 @@ void IndexReplicasTemplate<IndexT>::syncWithSubIndexes() {
181200
}
182201
}
183202

184-
// No metric_type for IndexBinary
185-
template <>
186-
void IndexReplicasTemplate<IndexBinary>::syncWithSubIndexes() {
187-
if (!this->count()) {
188-
this->is_trained = false;
189-
this->ntotal = 0;
190-
191-
return;
192-
}
193-
194-
auto firstIndex = this->at(0);
195-
this->is_trained = firstIndex->is_trained;
196-
this->ntotal = firstIndex->ntotal;
197-
198-
for (int i = 1; i < this->count(); ++i) {
199-
auto index = this->at(i);
200-
FAISS_THROW_IF_NOT(this->d == index->d);
201-
FAISS_THROW_IF_NOT(this->is_trained == index->is_trained);
202-
FAISS_THROW_IF_NOT(this->ntotal == index->ntotal);
203-
}
204-
}
205-
206203
// explicit instantiations
207204
template struct IndexReplicasTemplate<Index>;
208205
template struct IndexReplicasTemplate<IndexBinary>;

0 commit comments

Comments
 (0)