Skip to content

Commit a91a288

Browse files
mdouzefacebook-github-bot
authored andcommitted
use dispatcher function to call HammingComputer (facebookresearch#2918)
Summary: Pull Request resolved: facebookresearch#2918 The HammingComputer class is optimized for several vector sizes. So far it's been the caller's responsiblity to instanciate the relevant optimized version. This diff introduces a `dispatch_HammingComputer` function that can be called with a template class that is instanciated for all existing optimized HammingComputer's. Reviewed By: algoriddle Differential Revision: D46858553 fbshipit-source-id: 32c31689bba7c0b406b309fc8574c95fa24022ba
1 parent a27036a commit a91a288

13 files changed

+365
-686
lines changed

benchs/bench_hamming_computer.cpp

+60
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,66 @@
1818

1919
using namespace faiss;
2020

21+
// These implementations are currently slower than HammingComputerDefault so
22+
// they are not in the main faiss anymore.
23+
struct HammingComputerM8 {
24+
const uint64_t* a;
25+
int n;
26+
27+
HammingComputerM8() {}
28+
29+
HammingComputerM8(const uint8_t* a8, int code_size) {
30+
set(a8, code_size);
31+
}
32+
33+
void set(const uint8_t* a8, int code_size) {
34+
assert(code_size % 8 == 0);
35+
a = (uint64_t*)a8;
36+
n = code_size / 8;
37+
}
38+
39+
int hamming(const uint8_t* b8) const {
40+
const uint64_t* b = (uint64_t*)b8;
41+
int accu = 0;
42+
for (int i = 0; i < n; i++)
43+
accu += popcount64(a[i] ^ b[i]);
44+
return accu;
45+
}
46+
47+
inline int get_code_size() const {
48+
return n * 8;
49+
}
50+
};
51+
52+
struct HammingComputerM4 {
53+
const uint32_t* a;
54+
int n;
55+
56+
HammingComputerM4() {}
57+
58+
HammingComputerM4(const uint8_t* a4, int code_size) {
59+
set(a4, code_size);
60+
}
61+
62+
void set(const uint8_t* a4, int code_size) {
63+
assert(code_size % 4 == 0);
64+
a = (uint32_t*)a4;
65+
n = code_size / 4;
66+
}
67+
68+
int hamming(const uint8_t* b8) const {
69+
const uint32_t* b = (uint32_t*)b8;
70+
int accu = 0;
71+
for (int i = 0; i < n; i++)
72+
accu += popcount64(a[i] ^ b[i]);
73+
return accu;
74+
}
75+
76+
inline int get_code_size() const {
77+
return n * 4;
78+
}
79+
};
80+
2181
template <class T>
2282
void hamming_cpt_test(
2383
int code_size,

faiss/IndexBinaryHNSW.cpp

+10-20
Original file line numberDiff line numberDiff line change
@@ -281,31 +281,21 @@ struct FlatHammingDis : DistanceComputer {
281281
}
282282
};
283283

284+
struct BuildDistanceComputer {
285+
using T = DistanceComputer*;
286+
template <class HammingComputer>
287+
DistanceComputer* f(IndexBinaryFlat* flat_storage) {
288+
return new FlatHammingDis<HammingComputer>(*flat_storage);
289+
}
290+
};
291+
284292
} // namespace
285293

286294
DistanceComputer* IndexBinaryHNSW::get_distance_computer() const {
287295
IndexBinaryFlat* flat_storage = dynamic_cast<IndexBinaryFlat*>(storage);
288-
289296
FAISS_ASSERT(flat_storage != nullptr);
290-
291-
switch (code_size) {
292-
case 4:
293-
return new FlatHammingDis<HammingComputer4>(*flat_storage);
294-
case 8:
295-
return new FlatHammingDis<HammingComputer8>(*flat_storage);
296-
case 16:
297-
return new FlatHammingDis<HammingComputer16>(*flat_storage);
298-
case 20:
299-
return new FlatHammingDis<HammingComputer20>(*flat_storage);
300-
case 32:
301-
return new FlatHammingDis<HammingComputer32>(*flat_storage);
302-
case 64:
303-
return new FlatHammingDis<HammingComputer64>(*flat_storage);
304-
default:
305-
break;
306-
}
307-
308-
return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
297+
BuildDistanceComputer bd;
298+
return dispatch_HammingComputer(code_size, bd, flat_storage);
309299
}
310300

311301
} // namespace faiss

faiss/IndexBinaryHash.cpp

+25-49
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,14 @@ void search_single_query_template(
176176
} while (fe.next());
177177
}
178178

179+
struct Run_search_single_query {
180+
using T = void;
181+
template <class HammingComputer, class... Types>
182+
T f(Types... args) {
183+
search_single_query_template<HammingComputer>(args...);
184+
}
185+
};
186+
179187
template <class SearchResults>
180188
void search_single_query(
181189
const IndexBinaryHash& index,
@@ -184,29 +192,9 @@ void search_single_query(
184192
size_t& n0,
185193
size_t& nlist,
186194
size_t& ndis) {
187-
#define HC(name) \
188-
search_single_query_template<name>(index, q, res, n0, nlist, ndis);
189-
switch (index.code_size) {
190-
case 4:
191-
HC(HammingComputer4);
192-
break;
193-
case 8:
194-
HC(HammingComputer8);
195-
break;
196-
case 16:
197-
HC(HammingComputer16);
198-
break;
199-
case 20:
200-
HC(HammingComputer20);
201-
break;
202-
case 32:
203-
HC(HammingComputer32);
204-
break;
205-
default:
206-
HC(HammingComputerDefault);
207-
break;
208-
}
209-
#undef HC
195+
Run_search_single_query r;
196+
dispatch_HammingComputer(
197+
index.code_size, r, index, q, res, n0, nlist, ndis);
210198
}
211199

212200
} // anonymous namespace
@@ -349,22 +337,30 @@ namespace {
349337

350338
template <class HammingComputer, class SearchResults>
351339
static void verify_shortlist(
352-
const IndexBinaryFlat& index,
340+
const IndexBinaryFlat* index,
353341
const uint8_t* q,
354342
const std::unordered_set<idx_t>& shortlist,
355343
SearchResults& res) {
356-
size_t code_size = index.code_size;
344+
size_t code_size = index->code_size;
357345
size_t nlist = 0, ndis = 0, n0 = 0;
358346

359347
HammingComputer hc(q, code_size);
360-
const uint8_t* codes = index.xb.data();
348+
const uint8_t* codes = index->xb.data();
361349

362350
for (auto i : shortlist) {
363351
int dis = hc.hamming(codes + i * code_size);
364352
res.add(dis, i);
365353
}
366354
}
367355

356+
struct Run_verify_shortlist {
357+
using T = void;
358+
template <class HammingComputer, class... Types>
359+
void f(Types... args) {
360+
verify_shortlist<HammingComputer>(args...);
361+
}
362+
};
363+
368364
template <class SearchResults>
369365
void search_1_query_multihash(
370366
const IndexBinaryMultiHash& index,
@@ -405,29 +401,9 @@ void search_1_query_multihash(
405401
ndis += shortlist.size();
406402

407403
// verify shortlist
408-
409-
#define HC(name) verify_shortlist<name>(*index.storage, xi, shortlist, res)
410-
switch (index.code_size) {
411-
case 4:
412-
HC(HammingComputer4);
413-
break;
414-
case 8:
415-
HC(HammingComputer8);
416-
break;
417-
case 16:
418-
HC(HammingComputer16);
419-
break;
420-
case 20:
421-
HC(HammingComputer20);
422-
break;
423-
case 32:
424-
HC(HammingComputer32);
425-
break;
426-
default:
427-
HC(HammingComputerDefault);
428-
break;
429-
}
430-
#undef HC
404+
Run_verify_shortlist r;
405+
dispatch_HammingComputer(
406+
index.code_size, r, index.storage, xi, shortlist, res);
431407
}
432408

433409
} // anonymous namespace

0 commit comments

Comments
 (0)