13
13
14
14
#include < faiss/impl/AuxIndexStructures.h>
15
15
#include < faiss/impl/FaissException.h>
16
+ #include < faiss/impl/IDSelector.h>
16
17
#include < faiss/utils/Heap.h>
17
18
#include < faiss/utils/partitioning.h>
19
+ #include < algorithm>
18
20
#include < iostream>
19
21
20
22
namespace faiss {
@@ -26,16 +28,21 @@ namespace faiss {
26
28
* - by instanciating a SingleResultHandler that tracks results for a single
27
29
* query
28
30
* - with begin_multiple/add_results/end_multiple calls where a whole block of
29
- * resutls is submitted
31
+ * results is submitted
30
32
* All classes are templated on C which to define wheter the min or the max of
31
- * results is to be kept.
33
+ * results is to be kept, and on sel, so that the codepaths for with / without
34
+ * selector can be separated at compile time.
32
35
*****************************************************************/
33
36
34
- template <class C >
37
+ template <class C , bool use_sel = false >
35
38
struct BlockResultHandler {
36
39
size_t nq; // number of queries for which we search
40
+ const IDSelector* sel;
37
41
38
- explicit BlockResultHandler (size_t nq) : nq(nq) {}
42
+ explicit BlockResultHandler (size_t nq, const IDSelector* sel = nullptr )
43
+ : nq(nq), sel(sel) {
44
+ assert (!use_sel || sel);
45
+ }
39
46
40
47
// currently handled query range
41
48
size_t i0 = 0 , i1 = 0 ;
@@ -53,13 +60,17 @@ struct BlockResultHandler {
53
60
virtual void end_multiple () {}
54
61
55
62
virtual ~BlockResultHandler () {}
63
+
64
+ bool is_in_selection (idx_t i) const {
65
+ return !use_sel || sel->is_member (i);
66
+ }
56
67
};
57
68
58
69
// handler for a single query
59
70
template <class C >
60
71
struct ResultHandler {
61
72
// if not better than threshold, then not necessary to call add_result
62
- typename C::T threshold = 0 ;
73
+ typename C::T threshold = C::neutral() ;
63
74
64
75
// return whether threshold was updated
65
76
virtual bool add_result (typename C::T dis, typename C::TI idx) = 0;
@@ -73,20 +84,26 @@ struct ResultHandler {
73
84
* some temporary data in memory.
74
85
*****************************************************************/
75
86
76
- template <class C >
77
- struct Top1BlockResultHandler : BlockResultHandler<C> {
87
+ template <class C , bool use_sel = false >
88
+ struct Top1BlockResultHandler : BlockResultHandler<C, use_sel > {
78
89
using T = typename C::T;
79
90
using TI = typename C::TI;
80
- using BlockResultHandler<C>::i0;
81
- using BlockResultHandler<C>::i1;
91
+ using BlockResultHandler<C, use_sel >::i0;
92
+ using BlockResultHandler<C, use_sel >::i1;
82
93
83
94
// contains exactly nq elements
84
95
T* dis_tab;
85
96
// contains exactly nq elements
86
97
TI* ids_tab;
87
98
88
- Top1BlockResultHandler (size_t nq, T* dis_tab, TI* ids_tab)
89
- : BlockResultHandler<C>(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
99
+ Top1BlockResultHandler (
100
+ size_t nq,
101
+ T* dis_tab,
102
+ TI* ids_tab,
103
+ const IDSelector* sel = nullptr )
104
+ : BlockResultHandler<C, use_sel>(nq, sel),
105
+ dis_tab (dis_tab),
106
+ ids_tab (ids_tab) {}
90
107
91
108
struct SingleResultHandler : ResultHandler<C> {
92
109
Top1BlockResultHandler& hr;
@@ -165,12 +182,12 @@ struct Top1BlockResultHandler : BlockResultHandler<C> {
165
182
* Heap based result handler
166
183
*****************************************************************/
167
184
168
- template <class C >
169
- struct HeapBlockResultHandler : BlockResultHandler<C> {
185
+ template <class C , bool use_sel = false >
186
+ struct HeapBlockResultHandler : BlockResultHandler<C, use_sel > {
170
187
using T = typename C::T;
171
188
using TI = typename C::TI;
172
- using BlockResultHandler<C>::i0;
173
- using BlockResultHandler<C>::i1;
189
+ using BlockResultHandler<C, use_sel >::i0;
190
+ using BlockResultHandler<C, use_sel >::i1;
174
191
175
192
T* heap_dis_tab;
176
193
TI* heap_ids_tab;
@@ -181,8 +198,9 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
181
198
size_t nq,
182
199
T* heap_dis_tab,
183
200
TI* heap_ids_tab,
184
- size_t k)
185
- : BlockResultHandler<C>(nq),
201
+ size_t k,
202
+ const IDSelector* sel = nullptr )
203
+ : BlockResultHandler<C, use_sel>(nq, sel),
186
204
heap_dis_tab (heap_dis_tab),
187
205
heap_ids_tab (heap_ids_tab),
188
206
k (k) {}
@@ -347,12 +365,12 @@ struct ReservoirTopN : ResultHandler<C> {
347
365
}
348
366
};
349
367
350
- template <class C >
351
- struct ReservoirBlockResultHandler : BlockResultHandler<C> {
368
+ template <class C , bool use_sel = false >
369
+ struct ReservoirBlockResultHandler : BlockResultHandler<C, use_sel > {
352
370
using T = typename C::T;
353
371
using TI = typename C::TI;
354
- using BlockResultHandler<C>::i0;
355
- using BlockResultHandler<C>::i1;
372
+ using BlockResultHandler<C, use_sel >::i0;
373
+ using BlockResultHandler<C, use_sel >::i1;
356
374
357
375
T* heap_dis_tab;
358
376
TI* heap_ids_tab;
@@ -364,8 +382,9 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
364
382
size_t nq,
365
383
T* heap_dis_tab,
366
384
TI* heap_ids_tab,
367
- size_t k)
368
- : BlockResultHandler<C>(nq),
385
+ size_t k,
386
+ const IDSelector* sel = nullptr )
387
+ : BlockResultHandler<C, use_sel>(nq, sel),
369
388
heap_dis_tab (heap_dis_tab),
370
389
heap_ids_tab (heap_ids_tab),
371
390
k (k) {
@@ -460,18 +479,23 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
460
479
* Result handler for range searches
461
480
*****************************************************************/
462
481
463
- template <class C >
464
- struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
482
+ template <class C , bool use_sel = false >
483
+ struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel > {
465
484
using T = typename C::T;
466
485
using TI = typename C::TI;
467
- using BlockResultHandler<C>::i0;
468
- using BlockResultHandler<C>::i1;
486
+ using BlockResultHandler<C, use_sel >::i0;
487
+ using BlockResultHandler<C, use_sel >::i1;
469
488
470
489
RangeSearchResult* res;
471
490
T radius;
472
491
473
- RangeSearchBlockResultHandler (RangeSearchResult* res, float radius)
474
- : BlockResultHandler<C>(res->nq), res(res), radius(radius) {}
492
+ RangeSearchBlockResultHandler (
493
+ RangeSearchResult* res,
494
+ float radius,
495
+ const IDSelector* sel = nullptr )
496
+ : BlockResultHandler<C, use_sel>(res->nq, sel),
497
+ res (res),
498
+ radius (radius) {}
475
499
476
500
/* *****************************************************
477
501
* API for 1 result at a time (each SingleResultHandler is
@@ -582,4 +606,81 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
582
606
}
583
607
};
584
608
609
+ /* ****************************************************************
610
+ * Dispatcher function to choose the right knn result handler depending on k
611
+ *****************************************************************/
612
+
613
+ // declared in distances.cpp
614
+ FAISS_API extern int distance_compute_min_k_reservoir;
615
+
616
+ template <class Consumer , class ... Types>
617
+ typename Consumer::T dispatch_knn_ResultHandler (
618
+ size_t nx,
619
+ float * vals,
620
+ int64_t * ids,
621
+ size_t k,
622
+ MetricType metric,
623
+ const IDSelector* sel,
624
+ Consumer& consumer,
625
+ Types... args) {
626
+ #define DISPATCH_C_SEL (C, use_sel ) \
627
+ if (k == 1 ) { \
628
+ Top1BlockResultHandler<C, use_sel> res (nx, vals, ids, sel); \
629
+ return consumer.template f <>(res, args...); \
630
+ } else if (k < distance_compute_min_k_reservoir) { \
631
+ HeapBlockResultHandler<C, use_sel> res (nx, vals, ids, k, sel); \
632
+ return consumer.template f <>(res, args...); \
633
+ } else { \
634
+ ReservoirBlockResultHandler<C, use_sel> res (nx, vals, ids, k, sel); \
635
+ return consumer.template f <>(res, args...); \
636
+ }
637
+
638
+ if (is_similarity_metric (metric)) {
639
+ using C = CMin<float , int64_t >;
640
+ if (sel) {
641
+ DISPATCH_C_SEL (C, true );
642
+ } else {
643
+ DISPATCH_C_SEL (C, false );
644
+ }
645
+ } else {
646
+ using C = CMax<float , int64_t >;
647
+ if (sel) {
648
+ DISPATCH_C_SEL (C, true );
649
+ } else {
650
+ DISPATCH_C_SEL (C, false );
651
+ }
652
+ }
653
+ #undef DISPATCH_C_SEL
654
+ }
655
+
656
+ template <class Consumer , class ... Types>
657
+ typename Consumer::T dispatch_range_ResultHandler (
658
+ RangeSearchResult* res,
659
+ float radius,
660
+ MetricType metric,
661
+ const IDSelector* sel,
662
+ Consumer& consumer,
663
+ Types... args) {
664
+ #define DISPATCH_C_SEL (C, use_sel ) \
665
+ RangeSearchBlockResultHandler<C, use_sel> resb (res, radius, sel); \
666
+ return consumer.template f <>(resb, args...);
667
+
668
+ if (is_similarity_metric (metric)) {
669
+ using C = CMin<float , int64_t >;
670
+ if (sel) {
671
+ DISPATCH_C_SEL (C, true );
672
+ } else {
673
+ DISPATCH_C_SEL (C, false );
674
+ }
675
+ } else {
676
+ using C = CMax<float , int64_t >;
677
+ if (sel) {
678
+ DISPATCH_C_SEL (C, true );
679
+ } else {
680
+ DISPATCH_C_SEL (C, false );
681
+ }
682
+ }
683
+ #undef DISPATCH_C_SEL
684
+ }
685
+
585
686
} // namespace faiss
0 commit comments