@@ -113,25 +113,38 @@ void IndexBinaryIVF::search(
113
113
idx_t k,
114
114
int32_t * distances,
115
115
idx_t * labels,
116
- const SearchParameters* params) const {
117
- FAISS_THROW_IF_NOT_MSG (
118
- !params, " search params not supported for this index" );
116
+ const SearchParameters* params_in) const {
117
+ const IVFSearchParameters* params = nullptr ;
118
+ if (params_in) {
119
+ params = dynamic_cast <const IVFSearchParameters*>(params_in);
120
+ FAISS_THROW_IF_NOT_MSG (
121
+ params, " IndexBinaryIVF params have incorrect type" );
122
+ }
123
+ const size_t nprobe_2 =
124
+ std::min (nlist, params ? params->nprobe : this ->nprobe );
119
125
FAISS_THROW_IF_NOT (k > 0 );
120
- FAISS_THROW_IF_NOT (nprobe > 0 );
126
+ FAISS_THROW_IF_NOT (nprobe_2 > 0 );
121
127
122
- const size_t nprobe_2 = std::min (nlist, this ->nprobe );
123
128
std::unique_ptr<idx_t []> idx (new idx_t [n * nprobe_2]);
124
129
std::unique_ptr<int32_t []> coarse_dis (new int32_t [n * nprobe_2]);
125
130
126
131
double t0 = getmillisecs ();
127
- quantizer->search (n, x, nprobe_2, coarse_dis.get (), idx.get ());
132
+ quantizer->search (n, x, nprobe_2, coarse_dis.get (), idx.get (), nullptr );
128
133
indexIVF_stats.quantization_time += getmillisecs () - t0;
129
134
130
135
t0 = getmillisecs ();
131
136
invlists->prefetch_lists (idx.get (), n * nprobe_2);
132
137
133
138
search_preassigned (
134
- n, x, k, idx.get (), coarse_dis.get (), distances, labels, false );
139
+ n,
140
+ x,
141
+ k,
142
+ idx.get (),
143
+ coarse_dis.get (),
144
+ distances,
145
+ labels,
146
+ false ,
147
+ params);
135
148
indexIVF_stats.search_time += getmillisecs () - t0;
136
149
}
137
150
@@ -335,11 +348,16 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
335
348
const idx_t * __restrict ids,
336
349
int32_t * __restrict simi,
337
350
idx_t * __restrict idxi,
338
- size_t k) const override {
351
+ size_t k,
352
+ const faiss::IDSelector* sel) const override {
339
353
using C = CMax<int32_t , idx_t >;
340
354
341
355
size_t nup = 0 ;
342
356
for (size_t j = 0 ; j < n; j++) {
357
+ if (sel &&
358
+ !sel->is_member (store_pairs ? lo_build (list_no, j) : ids[j])) {
359
+ continue ;
360
+ }
343
361
uint32_t dis = hc.hamming (codes);
344
362
if (dis < simi[0 ]) {
345
363
idx_t id = store_pairs ? lo_build (list_no, j) : ids[j];
@@ -356,8 +374,13 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
356
374
const uint8_t * __restrict codes,
357
375
const idx_t * __restrict ids,
358
376
int radius,
359
- RangeQueryResult& result) const override {
377
+ RangeQueryResult& result,
378
+ const faiss::IDSelector* sel) const override {
360
379
for (size_t j = 0 ; j < n; j++) {
380
+ if (sel &&
381
+ !sel->is_member (store_pairs ? lo_build (list_no, j) : ids[j])) {
382
+ continue ;
383
+ }
361
384
uint32_t dis = hc.hamming (codes);
362
385
if (dis < radius) {
363
386
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
@@ -441,7 +464,13 @@ void search_knn_hamming_heap(
441
464
}
442
465
443
466
nheap += scanner->scan_codes (
444
- list_size, scodes.get (), ids, simi, idxi, k);
467
+ list_size,
468
+ scodes.get (),
469
+ ids,
470
+ simi,
471
+ idxi,
472
+ k,
473
+ params ? params->sel : nullptr );
445
474
446
475
nscan += list_size;
447
476
if (max_codes && nscan >= max_codes)
@@ -807,21 +836,30 @@ void IndexBinaryIVF::range_search(
807
836
const uint8_t * __restrict x,
808
837
int radius,
809
838
RangeSearchResult* __restrict res,
810
- const SearchParameters* params) const {
811
- FAISS_THROW_IF_NOT_MSG (
812
- !params, " search params not supported for this index" );
813
- const size_t nprobe_2 = std::min (nlist, this ->nprobe );
839
+ const SearchParameters* params_in) const {
840
+ const IVFSearchParameters* params = nullptr ;
841
+ if (params_in) {
842
+ params = dynamic_cast <const IVFSearchParameters*>(params_in);
843
+ FAISS_THROW_IF_NOT_MSG (
844
+ params, " IndexBinaryIVF params have incorrect type" );
845
+ }
846
+ const size_t nprobe_2 =
847
+ std::min (nlist, params ? params->nprobe : this ->nprobe );
848
+
849
+ FAISS_THROW_IF_NOT (nprobe_2 > 0 );
850
+
814
851
std::unique_ptr<idx_t []> idx (new idx_t [n * nprobe_2]);
815
852
std::unique_ptr<int32_t []> coarse_dis (new int32_t [n * nprobe_2]);
816
853
817
854
double t0 = getmillisecs ();
818
- quantizer->search (n, x, nprobe_2, coarse_dis.get (), idx.get ());
855
+ quantizer->search (n, x, nprobe_2, coarse_dis.get (), idx.get (), nullptr );
819
856
indexIVF_stats.quantization_time += getmillisecs () - t0;
820
857
821
858
t0 = getmillisecs ();
822
859
invlists->prefetch_lists (idx.get (), n * nprobe_2);
823
860
824
- range_search_preassigned (n, x, radius, idx.get (), coarse_dis.get (), res);
861
+ range_search_preassigned (
862
+ n, x, radius, idx.get (), coarse_dis.get (), res, params);
825
863
826
864
indexIVF_stats.search_time += getmillisecs () - t0;
827
865
}
@@ -832,7 +870,8 @@ void IndexBinaryIVF::range_search_preassigned(
832
870
int radius,
833
871
const idx_t * __restrict assign,
834
872
const int32_t * __restrict centroid_dis,
835
- RangeSearchResult* __restrict res) const {
873
+ RangeSearchResult* __restrict res,
874
+ const IVFSearchParameters* params) const {
836
875
const size_t nprobe_2 = std::min (nlist, this ->nprobe );
837
876
bool store_pairs = false ;
838
877
size_t nlistv = 0 , ndis = 0 ;
@@ -870,7 +909,12 @@ void IndexBinaryIVF::range_search_preassigned(
870
909
nlistv++;
871
910
ndis += list_size;
872
911
scanner->scan_codes_range (
873
- list_size, scodes.get (), ids.get (), radius, qres);
912
+ list_size,
913
+ scodes.get (),
914
+ ids.get (),
915
+ radius,
916
+ qres,
917
+ params ? params->sel : nullptr );
874
918
};
875
919
876
920
#pragma omp for
0 commit comments