Skip to content

Commit 92594d1

Browse files
mdouzefacebook-github-bot
authored andcommitted
add dispatcher for VectorDistance and ResultHandlers
Summary: Add dispatcher function to avoid repeating dispatching code for distance computation and result handlers. Differential Revision: D59318865
1 parent 3fe0b93 commit 92594d1

File tree

5 files changed

+304
-236
lines changed

5 files changed

+304
-236
lines changed

faiss/impl/ResultHandler.h

+129-29
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include <faiss/impl/AuxIndexStructures.h>
1515
#include <faiss/impl/FaissException.h>
16+
#include <faiss/impl/IDSelector.h>
1617
#include <faiss/utils/Heap.h>
1718
#include <faiss/utils/partitioning.h>
1819
#include <iostream>
@@ -26,16 +27,21 @@ namespace faiss {
2627
* - by instanciating a SingleResultHandler that tracks results for a single
2728
* query
2829
* - with begin_multiple/add_results/end_multiple calls where a whole block of
29-
* resutls is submitted
30+
* results is submitted
3031
* All classes are templated on C which to define wheter the min or the max of
31-
* results is to be kept.
32+
* results is to be kept, and on sel, so that the codepaths for with / without
33+
* selector can be separated at compile time.
3234
*****************************************************************/
3335

34-
template <class C>
36+
template <class C, bool use_sel = false>
3537
struct BlockResultHandler {
3638
size_t nq; // number of queries for which we search
39+
const IDSelector* sel;
3740

38-
explicit BlockResultHandler(size_t nq) : nq(nq) {}
41+
explicit BlockResultHandler(size_t nq, const IDSelector* sel = nullptr)
42+
: nq(nq), sel(sel) {
43+
assert(!use_sel || sel);
44+
}
3945

4046
// currently handled query range
4147
size_t i0 = 0, i1 = 0;
@@ -53,13 +59,17 @@ struct BlockResultHandler {
5359
virtual void end_multiple() {}
5460

5561
virtual ~BlockResultHandler() {}
62+
63+
bool is_in_selection(idx_t i) const {
64+
return !use_sel || sel->is_member(i);
65+
}
5666
};
5767

5868
// handler for a single query
5969
template <class C>
6070
struct ResultHandler {
6171
// if not better than threshold, then not necessary to call add_result
62-
typename C::T threshold = 0;
72+
typename C::T threshold = C::neutral();
6373

6474
// return whether threshold was updated
6575
virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
@@ -73,20 +83,26 @@ struct ResultHandler {
7383
* some temporary data in memory.
7484
*****************************************************************/
7585

76-
template <class C>
77-
struct Top1BlockResultHandler : BlockResultHandler<C> {
86+
template <class C, bool use_sel = false>
87+
struct Top1BlockResultHandler : BlockResultHandler<C, use_sel> {
7888
using T = typename C::T;
7989
using TI = typename C::TI;
80-
using BlockResultHandler<C>::i0;
81-
using BlockResultHandler<C>::i1;
90+
using BlockResultHandler<C, use_sel>::i0;
91+
using BlockResultHandler<C, use_sel>::i1;
8292

8393
// contains exactly nq elements
8494
T* dis_tab;
8595
// contains exactly nq elements
8696
TI* ids_tab;
8797

88-
Top1BlockResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
89-
: BlockResultHandler<C>(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
98+
Top1BlockResultHandler(
99+
size_t nq,
100+
T* dis_tab,
101+
TI* ids_tab,
102+
const IDSelector* sel = nullptr)
103+
: BlockResultHandler<C, use_sel>(nq, sel),
104+
dis_tab(dis_tab),
105+
ids_tab(ids_tab) {}
90106

91107
struct SingleResultHandler : ResultHandler<C> {
92108
Top1BlockResultHandler& hr;
@@ -165,12 +181,12 @@ struct Top1BlockResultHandler : BlockResultHandler<C> {
165181
* Heap based result handler
166182
*****************************************************************/
167183

168-
template <class C>
169-
struct HeapBlockResultHandler : BlockResultHandler<C> {
184+
template <class C, bool use_sel = false>
185+
struct HeapBlockResultHandler : BlockResultHandler<C, use_sel> {
170186
using T = typename C::T;
171187
using TI = typename C::TI;
172-
using BlockResultHandler<C>::i0;
173-
using BlockResultHandler<C>::i1;
188+
using BlockResultHandler<C, use_sel>::i0;
189+
using BlockResultHandler<C, use_sel>::i1;
174190

175191
T* heap_dis_tab;
176192
TI* heap_ids_tab;
@@ -181,8 +197,9 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
181197
size_t nq,
182198
T* heap_dis_tab,
183199
TI* heap_ids_tab,
184-
size_t k)
185-
: BlockResultHandler<C>(nq),
200+
size_t k,
201+
const IDSelector* sel = nullptr)
202+
: BlockResultHandler<C, use_sel>(nq, sel),
186203
heap_dis_tab(heap_dis_tab),
187204
heap_ids_tab(heap_ids_tab),
188205
k(k) {}
@@ -347,12 +364,12 @@ struct ReservoirTopN : ResultHandler<C> {
347364
}
348365
};
349366

350-
template <class C>
351-
struct ReservoirBlockResultHandler : BlockResultHandler<C> {
367+
template <class C, bool use_sel = false>
368+
struct ReservoirBlockResultHandler : BlockResultHandler<C, use_sel> {
352369
using T = typename C::T;
353370
using TI = typename C::TI;
354-
using BlockResultHandler<C>::i0;
355-
using BlockResultHandler<C>::i1;
371+
using BlockResultHandler<C, use_sel>::i0;
372+
using BlockResultHandler<C, use_sel>::i1;
356373

357374
T* heap_dis_tab;
358375
TI* heap_ids_tab;
@@ -364,8 +381,9 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
364381
size_t nq,
365382
T* heap_dis_tab,
366383
TI* heap_ids_tab,
367-
size_t k)
368-
: BlockResultHandler<C>(nq),
384+
size_t k,
385+
const IDSelector* sel = nullptr)
386+
: BlockResultHandler<C, use_sel>(nq, sel),
369387
heap_dis_tab(heap_dis_tab),
370388
heap_ids_tab(heap_ids_tab),
371389
k(k) {
@@ -460,18 +478,23 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
460478
* Result handler for range searches
461479
*****************************************************************/
462480

463-
template <class C>
464-
struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
481+
template <class C, bool use_sel = false>
482+
struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
465483
using T = typename C::T;
466484
using TI = typename C::TI;
467-
using BlockResultHandler<C>::i0;
468-
using BlockResultHandler<C>::i1;
485+
using BlockResultHandler<C, use_sel>::i0;
486+
using BlockResultHandler<C, use_sel>::i1;
469487

470488
RangeSearchResult* res;
471489
T radius;
472490

473-
RangeSearchBlockResultHandler(RangeSearchResult* res, float radius)
474-
: BlockResultHandler<C>(res->nq), res(res), radius(radius) {}
491+
RangeSearchBlockResultHandler(
492+
RangeSearchResult* res,
493+
float radius,
494+
const IDSelector* sel = nullptr)
495+
: BlockResultHandler<C, use_sel>(res->nq, sel),
496+
res(res),
497+
radius(radius) {}
475498

476499
/******************************************************
477500
* API for 1 result at a time (each SingleResultHandler is
@@ -582,4 +605,81 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
582605
}
583606
};
584607

608+
/*****************************************************************
609+
* Dispatcher function to choose the right knn result handler depending on k
610+
*****************************************************************/
611+
612+
// declared in distances.cpp
613+
extern int distance_compute_min_k_reservoir;
614+
615+
template <class Consumer, class... Types>
616+
typename Consumer::T dispatch_knn_ResultHandler(
617+
size_t nx,
618+
float* vals,
619+
int64_t* ids,
620+
size_t k,
621+
MetricType metric,
622+
const IDSelector* sel,
623+
Consumer& consumer,
624+
Types... args) {
625+
#define DISPATCH_C_SEL(C, use_sel) \
626+
if (k == 1) { \
627+
Top1BlockResultHandler<C, use_sel> res(nx, vals, ids, sel); \
628+
return consumer.template f(res, args...); \
629+
} else if (k < distance_compute_min_k_reservoir) { \
630+
HeapBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
631+
return consumer.template f(res, args...); \
632+
} else { \
633+
ReservoirBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
634+
return consumer.template f(res, args...); \
635+
}
636+
637+
if (is_similarity_metric(metric)) {
638+
using C = CMin<float, int64_t>;
639+
if (sel) {
640+
DISPATCH_C_SEL(C, true);
641+
} else {
642+
DISPATCH_C_SEL(C, false);
643+
}
644+
} else {
645+
using C = CMax<float, int64_t>;
646+
if (sel) {
647+
DISPATCH_C_SEL(C, true);
648+
} else {
649+
DISPATCH_C_SEL(C, false);
650+
}
651+
}
652+
#undef DISPATCH_C_SEL
653+
}
654+
655+
template <class Consumer, class... Types>
656+
typename Consumer::T dispatch_range_ResultHandler(
657+
RangeSearchResult* res,
658+
float radius,
659+
MetricType metric,
660+
const IDSelector* sel,
661+
Consumer& consumer,
662+
Types... args) {
663+
#define DISPATCH_C_SEL(C, use_sel) \
664+
RangeSearchBlockResultHandler<C, use_sel> resb(res, radius, sel); \
665+
return consumer.template f(resb, args...);
666+
667+
if (is_similarity_metric(metric)) {
668+
using C = CMin<float, int64_t>;
669+
if (sel) {
670+
DISPATCH_C_SEL(C, true);
671+
} else {
672+
DISPATCH_C_SEL(C, false);
673+
}
674+
} else {
675+
using C = CMax<float, int64_t>;
676+
if (sel) {
677+
DISPATCH_C_SEL(C, true);
678+
} else {
679+
DISPATCH_C_SEL(C, false);
680+
}
681+
}
682+
#undef DISPATCH_C_SEL
683+
}
684+
585685
} // namespace faiss

0 commit comments

Comments
 (0)