13
13
14
14
#include < faiss/impl/AuxIndexStructures.h>
15
15
#include < faiss/impl/FaissException.h>
16
+ #include < faiss/impl/IDSelector.h>
16
17
#include < faiss/utils/Heap.h>
17
18
#include < faiss/utils/partitioning.h>
18
19
#include < iostream>
@@ -26,16 +27,21 @@ namespace faiss {
26
27
* - by instanciating a SingleResultHandler that tracks results for a single
27
28
* query
28
29
* - with begin_multiple/add_results/end_multiple calls where a whole block of
29
- * resutls is submitted
30
+ * results is submitted
30
31
* All classes are templated on C which to define wheter the min or the max of
31
- * results is to be kept.
32
+ * results is to be kept, and on sel, so that the codepaths for with / without
33
+ * selector can be separated at compile time.
32
34
*****************************************************************/
33
35
34
- template <class C >
36
+ template <class C , bool use_sel = false >
35
37
struct BlockResultHandler {
36
38
size_t nq; // number of queries for which we search
39
+ const IDSelector* sel;
37
40
38
- explicit BlockResultHandler (size_t nq) : nq(nq) {}
41
+ explicit BlockResultHandler (size_t nq, const IDSelector* sel = nullptr )
42
+ : nq(nq), sel(sel) {
43
+ assert (!use_sel || sel);
44
+ }
39
45
40
46
// currently handled query range
41
47
size_t i0 = 0 , i1 = 0 ;
@@ -53,13 +59,17 @@ struct BlockResultHandler {
53
59
virtual void end_multiple () {}
54
60
55
61
virtual ~BlockResultHandler () {}
62
+
63
+ bool is_in_selection (idx_t i) const {
64
+ return !use_sel || sel->is_member (i);
65
+ }
56
66
};
57
67
58
68
// handler for a single query
59
69
template <class C >
60
70
struct ResultHandler {
61
71
// if not better than threshold, then not necessary to call add_result
62
- typename C::T threshold = 0 ;
72
+ typename C::T threshold = C::neutral() ;
63
73
64
74
// return whether threshold was updated
65
75
virtual bool add_result (typename C::T dis, typename C::TI idx) = 0;
@@ -73,20 +83,26 @@ struct ResultHandler {
73
83
* some temporary data in memory.
74
84
*****************************************************************/
75
85
76
- template <class C >
77
- struct Top1BlockResultHandler : BlockResultHandler<C> {
86
+ template <class C , bool use_sel = false >
87
+ struct Top1BlockResultHandler : BlockResultHandler<C, use_sel > {
78
88
using T = typename C::T;
79
89
using TI = typename C::TI;
80
- using BlockResultHandler<C>::i0;
81
- using BlockResultHandler<C>::i1;
90
+ using BlockResultHandler<C, use_sel >::i0;
91
+ using BlockResultHandler<C, use_sel >::i1;
82
92
83
93
// contains exactly nq elements
84
94
T* dis_tab;
85
95
// contains exactly nq elements
86
96
TI* ids_tab;
87
97
88
- Top1BlockResultHandler (size_t nq, T* dis_tab, TI* ids_tab)
89
- : BlockResultHandler<C>(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
98
+ Top1BlockResultHandler (
99
+ size_t nq,
100
+ T* dis_tab,
101
+ TI* ids_tab,
102
+ const IDSelector* sel = nullptr )
103
+ : BlockResultHandler<C, use_sel>(nq, sel),
104
+ dis_tab (dis_tab),
105
+ ids_tab (ids_tab) {}
90
106
91
107
struct SingleResultHandler : ResultHandler<C> {
92
108
Top1BlockResultHandler& hr;
@@ -165,12 +181,12 @@ struct Top1BlockResultHandler : BlockResultHandler<C> {
165
181
* Heap based result handler
166
182
*****************************************************************/
167
183
168
- template <class C >
169
- struct HeapBlockResultHandler : BlockResultHandler<C> {
184
+ template <class C , bool use_sel = false >
185
+ struct HeapBlockResultHandler : BlockResultHandler<C, use_sel > {
170
186
using T = typename C::T;
171
187
using TI = typename C::TI;
172
- using BlockResultHandler<C>::i0;
173
- using BlockResultHandler<C>::i1;
188
+ using BlockResultHandler<C, use_sel >::i0;
189
+ using BlockResultHandler<C, use_sel >::i1;
174
190
175
191
T* heap_dis_tab;
176
192
TI* heap_ids_tab;
@@ -181,8 +197,9 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
181
197
size_t nq,
182
198
T* heap_dis_tab,
183
199
TI* heap_ids_tab,
184
- size_t k)
185
- : BlockResultHandler<C>(nq),
200
+ size_t k,
201
+ const IDSelector* sel = nullptr )
202
+ : BlockResultHandler<C, use_sel>(nq, sel),
186
203
heap_dis_tab (heap_dis_tab),
187
204
heap_ids_tab (heap_ids_tab),
188
205
k (k) {}
@@ -347,12 +364,12 @@ struct ReservoirTopN : ResultHandler<C> {
347
364
}
348
365
};
349
366
350
- template <class C >
351
- struct ReservoirBlockResultHandler : BlockResultHandler<C> {
367
+ template <class C , bool use_sel = false >
368
+ struct ReservoirBlockResultHandler : BlockResultHandler<C, use_sel > {
352
369
using T = typename C::T;
353
370
using TI = typename C::TI;
354
- using BlockResultHandler<C>::i0;
355
- using BlockResultHandler<C>::i1;
371
+ using BlockResultHandler<C, use_sel >::i0;
372
+ using BlockResultHandler<C, use_sel >::i1;
356
373
357
374
T* heap_dis_tab;
358
375
TI* heap_ids_tab;
@@ -364,8 +381,9 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
364
381
size_t nq,
365
382
T* heap_dis_tab,
366
383
TI* heap_ids_tab,
367
- size_t k)
368
- : BlockResultHandler<C>(nq),
384
+ size_t k,
385
+ const IDSelector* sel = nullptr )
386
+ : BlockResultHandler<C, use_sel>(nq, sel),
369
387
heap_dis_tab (heap_dis_tab),
370
388
heap_ids_tab (heap_ids_tab),
371
389
k (k) {
@@ -460,18 +478,23 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
460
478
* Result handler for range searches
461
479
*****************************************************************/
462
480
463
- template <class C >
464
- struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
481
+ template <class C , bool use_sel = false >
482
+ struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel > {
465
483
using T = typename C::T;
466
484
using TI = typename C::TI;
467
- using BlockResultHandler<C>::i0;
468
- using BlockResultHandler<C>::i1;
485
+ using BlockResultHandler<C, use_sel >::i0;
486
+ using BlockResultHandler<C, use_sel >::i1;
469
487
470
488
RangeSearchResult* res;
471
489
T radius;
472
490
473
- RangeSearchBlockResultHandler (RangeSearchResult* res, float radius)
474
- : BlockResultHandler<C>(res->nq), res(res), radius(radius) {}
491
+ RangeSearchBlockResultHandler (
492
+ RangeSearchResult* res,
493
+ float radius,
494
+ const IDSelector* sel = nullptr )
495
+ : BlockResultHandler<C, use_sel>(res->nq, sel),
496
+ res (res),
497
+ radius (radius) {}
475
498
476
499
/* *****************************************************
477
500
* API for 1 result at a time (each SingleResultHandler is
@@ -582,4 +605,81 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
582
605
}
583
606
};
584
607
608
+ /* ****************************************************************
609
+ * Dispatcher function to choose the right knn result handler depending on k
610
+ *****************************************************************/
611
+
612
+ // declared in distances.cpp
613
+ extern int distance_compute_min_k_reservoir;
614
+
615
+ template <class Consumer , class ... Types>
616
+ typename Consumer::T dispatch_knn_ResultHandler (
617
+ size_t nx,
618
+ float * vals,
619
+ int64_t * ids,
620
+ size_t k,
621
+ MetricType metric,
622
+ const IDSelector* sel,
623
+ Consumer& consumer,
624
+ Types... args) {
625
+ #define DISPATCH_C_SEL (C, use_sel ) \
626
+ if (k == 1 ) { \
627
+ Top1BlockResultHandler<C, use_sel> res (nx, vals, ids, sel); \
628
+ return consumer.template f (res, args...); \
629
+ } else if (k < distance_compute_min_k_reservoir) { \
630
+ HeapBlockResultHandler<C, use_sel> res (nx, vals, ids, k, sel); \
631
+ return consumer.template f (res, args...); \
632
+ } else { \
633
+ ReservoirBlockResultHandler<C, use_sel> res (nx, vals, ids, k, sel); \
634
+ return consumer.template f (res, args...); \
635
+ }
636
+
637
+ if (is_similarity_metric (metric)) {
638
+ using C = CMin<float , int64_t >;
639
+ if (sel) {
640
+ DISPATCH_C_SEL (C, true );
641
+ } else {
642
+ DISPATCH_C_SEL (C, false );
643
+ }
644
+ } else {
645
+ using C = CMax<float , int64_t >;
646
+ if (sel) {
647
+ DISPATCH_C_SEL (C, true );
648
+ } else {
649
+ DISPATCH_C_SEL (C, false );
650
+ }
651
+ }
652
+ #undef DISPATCH_C_SEL
653
+ }
654
+
655
+ template <class Consumer , class ... Types>
656
+ typename Consumer::T dispatch_range_ResultHandler (
657
+ RangeSearchResult* res,
658
+ float radius,
659
+ MetricType metric,
660
+ const IDSelector* sel,
661
+ Consumer& consumer,
662
+ Types... args) {
663
+ #define DISPATCH_C_SEL (C, use_sel ) \
664
+ RangeSearchBlockResultHandler<C, use_sel> resb (res, radius, sel); \
665
+ return consumer.template f (resb, args...);
666
+
667
+ if (is_similarity_metric (metric)) {
668
+ using C = CMin<float , int64_t >;
669
+ if (sel) {
670
+ DISPATCH_C_SEL (C, true );
671
+ } else {
672
+ DISPATCH_C_SEL (C, false );
673
+ }
674
+ } else {
675
+ using C = CMax<float , int64_t >;
676
+ if (sel) {
677
+ DISPATCH_C_SEL (C, true );
678
+ } else {
679
+ DISPATCH_C_SEL (C, false );
680
+ }
681
+ }
682
+ #undef DISPATCH_C_SEL
683
+ }
684
+
585
685
} // namespace faiss
0 commit comments