Skip to content

Commit 261edde

Browse files
mdouzefacebook-github-bot
authored andcommitted
add dispatcher for VectorDistance and ResultHandlers
Summary: Add dispatcher function to avoid repeating dispatching code for distance computation and result handlers. Reviewed By: asadoughi Differential Revision: D59318865 fbshipit-source-id: 59046ede02f71a0da3b8061289fc70306bf875cb
1 parent 444614b commit 261edde

File tree

5 files changed

+305
-236
lines changed

5 files changed

+305
-236
lines changed

faiss/impl/ResultHandler.h

+130-29
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313

1414
#include <faiss/impl/AuxIndexStructures.h>
1515
#include <faiss/impl/FaissException.h>
16+
#include <faiss/impl/IDSelector.h>
1617
#include <faiss/utils/Heap.h>
1718
#include <faiss/utils/partitioning.h>
19+
#include <algorithm>
1820
#include <iostream>
1921

2022
namespace faiss {
@@ -26,16 +28,21 @@ namespace faiss {
2628
* - by instanciating a SingleResultHandler that tracks results for a single
2729
* query
2830
* - with begin_multiple/add_results/end_multiple calls where a whole block of
29-
* resutls is submitted
31+
* results is submitted
3032
* All classes are templated on C which to define wheter the min or the max of
31-
* results is to be kept.
33+
* results is to be kept, and on sel, so that the codepaths for with / without
34+
* selector can be separated at compile time.
3235
*****************************************************************/
3336

34-
template <class C>
37+
template <class C, bool use_sel = false>
3538
struct BlockResultHandler {
3639
size_t nq; // number of queries for which we search
40+
const IDSelector* sel;
3741

38-
explicit BlockResultHandler(size_t nq) : nq(nq) {}
42+
explicit BlockResultHandler(size_t nq, const IDSelector* sel = nullptr)
43+
: nq(nq), sel(sel) {
44+
assert(!use_sel || sel);
45+
}
3946

4047
// currently handled query range
4148
size_t i0 = 0, i1 = 0;
@@ -53,13 +60,17 @@ struct BlockResultHandler {
5360
virtual void end_multiple() {}
5461

5562
virtual ~BlockResultHandler() {}
63+
64+
bool is_in_selection(idx_t i) const {
65+
return !use_sel || sel->is_member(i);
66+
}
5667
};
5768

5869
// handler for a single query
5970
template <class C>
6071
struct ResultHandler {
6172
// if not better than threshold, then not necessary to call add_result
62-
typename C::T threshold = 0;
73+
typename C::T threshold = C::neutral();
6374

6475
// return whether threshold was updated
6576
virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
@@ -73,20 +84,26 @@ struct ResultHandler {
7384
* some temporary data in memory.
7485
*****************************************************************/
7586

76-
template <class C>
77-
struct Top1BlockResultHandler : BlockResultHandler<C> {
87+
template <class C, bool use_sel = false>
88+
struct Top1BlockResultHandler : BlockResultHandler<C, use_sel> {
7889
using T = typename C::T;
7990
using TI = typename C::TI;
80-
using BlockResultHandler<C>::i0;
81-
using BlockResultHandler<C>::i1;
91+
using BlockResultHandler<C, use_sel>::i0;
92+
using BlockResultHandler<C, use_sel>::i1;
8293

8394
// contains exactly nq elements
8495
T* dis_tab;
8596
// contains exactly nq elements
8697
TI* ids_tab;
8798

88-
Top1BlockResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
89-
: BlockResultHandler<C>(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
99+
Top1BlockResultHandler(
100+
size_t nq,
101+
T* dis_tab,
102+
TI* ids_tab,
103+
const IDSelector* sel = nullptr)
104+
: BlockResultHandler<C, use_sel>(nq, sel),
105+
dis_tab(dis_tab),
106+
ids_tab(ids_tab) {}
90107

91108
struct SingleResultHandler : ResultHandler<C> {
92109
Top1BlockResultHandler& hr;
@@ -165,12 +182,12 @@ struct Top1BlockResultHandler : BlockResultHandler<C> {
165182
* Heap based result handler
166183
*****************************************************************/
167184

168-
template <class C>
169-
struct HeapBlockResultHandler : BlockResultHandler<C> {
185+
template <class C, bool use_sel = false>
186+
struct HeapBlockResultHandler : BlockResultHandler<C, use_sel> {
170187
using T = typename C::T;
171188
using TI = typename C::TI;
172-
using BlockResultHandler<C>::i0;
173-
using BlockResultHandler<C>::i1;
189+
using BlockResultHandler<C, use_sel>::i0;
190+
using BlockResultHandler<C, use_sel>::i1;
174191

175192
T* heap_dis_tab;
176193
TI* heap_ids_tab;
@@ -181,8 +198,9 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
181198
size_t nq,
182199
T* heap_dis_tab,
183200
TI* heap_ids_tab,
184-
size_t k)
185-
: BlockResultHandler<C>(nq),
201+
size_t k,
202+
const IDSelector* sel = nullptr)
203+
: BlockResultHandler<C, use_sel>(nq, sel),
186204
heap_dis_tab(heap_dis_tab),
187205
heap_ids_tab(heap_ids_tab),
188206
k(k) {}
@@ -347,12 +365,12 @@ struct ReservoirTopN : ResultHandler<C> {
347365
}
348366
};
349367

350-
template <class C>
351-
struct ReservoirBlockResultHandler : BlockResultHandler<C> {
368+
template <class C, bool use_sel = false>
369+
struct ReservoirBlockResultHandler : BlockResultHandler<C, use_sel> {
352370
using T = typename C::T;
353371
using TI = typename C::TI;
354-
using BlockResultHandler<C>::i0;
355-
using BlockResultHandler<C>::i1;
372+
using BlockResultHandler<C, use_sel>::i0;
373+
using BlockResultHandler<C, use_sel>::i1;
356374

357375
T* heap_dis_tab;
358376
TI* heap_ids_tab;
@@ -364,8 +382,9 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
364382
size_t nq,
365383
T* heap_dis_tab,
366384
TI* heap_ids_tab,
367-
size_t k)
368-
: BlockResultHandler<C>(nq),
385+
size_t k,
386+
const IDSelector* sel = nullptr)
387+
: BlockResultHandler<C, use_sel>(nq, sel),
369388
heap_dis_tab(heap_dis_tab),
370389
heap_ids_tab(heap_ids_tab),
371390
k(k) {
@@ -460,18 +479,23 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
460479
* Result handler for range searches
461480
*****************************************************************/
462481

463-
template <class C>
464-
struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
482+
template <class C, bool use_sel = false>
483+
struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
465484
using T = typename C::T;
466485
using TI = typename C::TI;
467-
using BlockResultHandler<C>::i0;
468-
using BlockResultHandler<C>::i1;
486+
using BlockResultHandler<C, use_sel>::i0;
487+
using BlockResultHandler<C, use_sel>::i1;
469488

470489
RangeSearchResult* res;
471490
T radius;
472491

473-
RangeSearchBlockResultHandler(RangeSearchResult* res, float radius)
474-
: BlockResultHandler<C>(res->nq), res(res), radius(radius) {}
492+
RangeSearchBlockResultHandler(
493+
RangeSearchResult* res,
494+
float radius,
495+
const IDSelector* sel = nullptr)
496+
: BlockResultHandler<C, use_sel>(res->nq, sel),
497+
res(res),
498+
radius(radius) {}
475499

476500
/******************************************************
477501
* API for 1 result at a time (each SingleResultHandler is
@@ -582,4 +606,81 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
582606
}
583607
};
584608

609+
/*****************************************************************
610+
* Dispatcher function to choose the right knn result handler depending on k
611+
*****************************************************************/
612+
613+
// declared in distances.cpp
614+
FAISS_API extern int distance_compute_min_k_reservoir;
615+
616+
template <class Consumer, class... Types>
617+
typename Consumer::T dispatch_knn_ResultHandler(
618+
size_t nx,
619+
float* vals,
620+
int64_t* ids,
621+
size_t k,
622+
MetricType metric,
623+
const IDSelector* sel,
624+
Consumer& consumer,
625+
Types... args) {
626+
#define DISPATCH_C_SEL(C, use_sel) \
627+
if (k == 1) { \
628+
Top1BlockResultHandler<C, use_sel> res(nx, vals, ids, sel); \
629+
return consumer.template f<>(res, args...); \
630+
} else if (k < distance_compute_min_k_reservoir) { \
631+
HeapBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
632+
return consumer.template f<>(res, args...); \
633+
} else { \
634+
ReservoirBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
635+
return consumer.template f<>(res, args...); \
636+
}
637+
638+
if (is_similarity_metric(metric)) {
639+
using C = CMin<float, int64_t>;
640+
if (sel) {
641+
DISPATCH_C_SEL(C, true);
642+
} else {
643+
DISPATCH_C_SEL(C, false);
644+
}
645+
} else {
646+
using C = CMax<float, int64_t>;
647+
if (sel) {
648+
DISPATCH_C_SEL(C, true);
649+
} else {
650+
DISPATCH_C_SEL(C, false);
651+
}
652+
}
653+
#undef DISPATCH_C_SEL
654+
}
655+
656+
template <class Consumer, class... Types>
657+
typename Consumer::T dispatch_range_ResultHandler(
658+
RangeSearchResult* res,
659+
float radius,
660+
MetricType metric,
661+
const IDSelector* sel,
662+
Consumer& consumer,
663+
Types... args) {
664+
#define DISPATCH_C_SEL(C, use_sel) \
665+
RangeSearchBlockResultHandler<C, use_sel> resb(res, radius, sel); \
666+
return consumer.template f<>(resb, args...);
667+
668+
if (is_similarity_metric(metric)) {
669+
using C = CMin<float, int64_t>;
670+
if (sel) {
671+
DISPATCH_C_SEL(C, true);
672+
} else {
673+
DISPATCH_C_SEL(C, false);
674+
}
675+
} else {
676+
using C = CMax<float, int64_t>;
677+
if (sel) {
678+
DISPATCH_C_SEL(C, true);
679+
} else {
680+
DISPATCH_C_SEL(C, false);
681+
}
682+
}
683+
#undef DISPATCH_C_SEL
684+
}
685+
585686
} // namespace faiss

0 commit comments

Comments
 (0)