Skip to content

Commit 6797877

Browse files
authored
Merge pull request #3 from lry22221111/main
[feat]Add binary flat and binary ivf flat sel
2 parents e49ff16 + 863a867 commit 6797877

File tree

6 files changed

+118
-57
lines changed

6 files changed

+118
-57
lines changed

faiss/IndexBinaryFlat.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ void IndexBinaryFlat::search(
3838
int32_t* distances,
3939
idx_t* labels,
4040
const SearchParameters* params) const {
41-
FAISS_THROW_IF_NOT_MSG(
42-
!params, "search params not supported for this index");
41+
IDSelector* sel = params ? params->sel : nullptr;
42+
4343
FAISS_THROW_IF_NOT(k > 0);
4444

4545
const idx_t block_size = query_batch_size;
@@ -61,7 +61,8 @@ void IndexBinaryFlat::search(
6161
ntotal,
6262
code_size,
6363
/* ordered = */ true,
64-
approx_topk_mode);
64+
approx_topk_mode,
65+
sel);
6566
} else {
6667
hammings_knn_mc(
6768
x + s * code_size,
@@ -108,9 +109,9 @@ void IndexBinaryFlat::range_search(
108109
int radius,
109110
RangeSearchResult* result,
110111
const SearchParameters* params) const {
111-
FAISS_THROW_IF_NOT_MSG(
112-
!params, "search params not supported for this index");
113-
hamming_range_search(x, xb.data(), n, ntotal, radius, code_size, result);
112+
IDSelector* sel = params ? params->sel : nullptr;
113+
114+
hamming_range_search(x, xb.data(), n, ntotal, radius, code_size, result,sel);
114115
}
115116

116117
} // namespace faiss

faiss/IndexBinaryIVF.cpp

+62-18
Original file line numberDiff line numberDiff line change
@@ -113,25 +113,38 @@ void IndexBinaryIVF::search(
113113
idx_t k,
114114
int32_t* distances,
115115
idx_t* labels,
116-
const SearchParameters* params) const {
117-
FAISS_THROW_IF_NOT_MSG(
118-
!params, "search params not supported for this index");
116+
const SearchParameters* params_in) const {
117+
const IVFSearchParameters* params = nullptr;
118+
if (params_in) {
119+
params = dynamic_cast<const IVFSearchParameters*>(params_in);
120+
FAISS_THROW_IF_NOT_MSG(
121+
params, "IndexBinaryIVF params have incorrect type");
122+
}
123+
const size_t nprobe_2 =
124+
std::min(nlist, params ? params->nprobe : this->nprobe);
119125
FAISS_THROW_IF_NOT(k > 0);
120-
FAISS_THROW_IF_NOT(nprobe > 0);
126+
FAISS_THROW_IF_NOT(nprobe_2 > 0);
121127

122-
const size_t nprobe_2 = std::min(nlist, this->nprobe);
123128
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe_2]);
124129
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]);
125130

126131
double t0 = getmillisecs();
127-
quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get());
132+
quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get(), nullptr);
128133
indexIVF_stats.quantization_time += getmillisecs() - t0;
129134

130135
t0 = getmillisecs();
131136
invlists->prefetch_lists(idx.get(), n * nprobe_2);
132137

133138
search_preassigned(
134-
n, x, k, idx.get(), coarse_dis.get(), distances, labels, false);
139+
n,
140+
x,
141+
k,
142+
idx.get(),
143+
coarse_dis.get(),
144+
distances,
145+
labels,
146+
false,
147+
params);
135148
indexIVF_stats.search_time += getmillisecs() - t0;
136149
}
137150

@@ -335,11 +348,16 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
335348
const idx_t* __restrict ids,
336349
int32_t* __restrict simi,
337350
idx_t* __restrict idxi,
338-
size_t k) const override {
351+
size_t k,
352+
const faiss::IDSelector* sel) const override {
339353
using C = CMax<int32_t, idx_t>;
340354

341355
size_t nup = 0;
342356
for (size_t j = 0; j < n; j++) {
357+
if (sel &&
358+
!sel->is_member(store_pairs ? lo_build(list_no, j) : ids[j])) {
359+
continue;
360+
}
343361
uint32_t dis = hc.hamming(codes);
344362
if (dis < simi[0]) {
345363
idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
@@ -356,8 +374,13 @@ struct IVFBinaryScannerL2 : BinaryInvertedListScanner {
356374
const uint8_t* __restrict codes,
357375
const idx_t* __restrict ids,
358376
int radius,
359-
RangeQueryResult& result) const override {
377+
RangeQueryResult& result,
378+
const faiss::IDSelector* sel) const override {
360379
for (size_t j = 0; j < n; j++) {
380+
if (sel &&
381+
!sel->is_member(store_pairs ? lo_build(list_no, j) : ids[j])) {
382+
continue;
383+
}
361384
uint32_t dis = hc.hamming(codes);
362385
if (dis < radius) {
363386
int64_t id = store_pairs ? lo_build(list_no, j) : ids[j];
@@ -441,7 +464,13 @@ void search_knn_hamming_heap(
441464
}
442465

443466
nheap += scanner->scan_codes(
444-
list_size, scodes.get(), ids, simi, idxi, k);
467+
list_size,
468+
scodes.get(),
469+
ids,
470+
simi,
471+
idxi,
472+
k,
473+
params ? params->sel : nullptr);
445474

446475
nscan += list_size;
447476
if (max_codes && nscan >= max_codes)
@@ -807,21 +836,30 @@ void IndexBinaryIVF::range_search(
807836
const uint8_t* __restrict x,
808837
int radius,
809838
RangeSearchResult* __restrict res,
810-
const SearchParameters* params) const {
811-
FAISS_THROW_IF_NOT_MSG(
812-
!params, "search params not supported for this index");
813-
const size_t nprobe_2 = std::min(nlist, this->nprobe);
839+
const SearchParameters* params_in) const {
840+
const IVFSearchParameters* params = nullptr;
841+
if (params_in) {
842+
params = dynamic_cast<const IVFSearchParameters*>(params_in);
843+
FAISS_THROW_IF_NOT_MSG(
844+
params, "IndexBinaryIVF params have incorrect type");
845+
}
846+
const size_t nprobe_2 =
847+
std::min(nlist, params ? params->nprobe : this->nprobe);
848+
849+
FAISS_THROW_IF_NOT(nprobe_2 > 0);
850+
814851
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe_2]);
815852
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]);
816853

817854
double t0 = getmillisecs();
818-
quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get());
855+
quantizer->search(n, x, nprobe_2, coarse_dis.get(), idx.get(),nullptr);
819856
indexIVF_stats.quantization_time += getmillisecs() - t0;
820857

821858
t0 = getmillisecs();
822859
invlists->prefetch_lists(idx.get(), n * nprobe_2);
823860

824-
range_search_preassigned(n, x, radius, idx.get(), coarse_dis.get(), res);
861+
range_search_preassigned(
862+
n, x, radius, idx.get(), coarse_dis.get(), res, params);
825863

826864
indexIVF_stats.search_time += getmillisecs() - t0;
827865
}
@@ -832,7 +870,8 @@ void IndexBinaryIVF::range_search_preassigned(
832870
int radius,
833871
const idx_t* __restrict assign,
834872
const int32_t* __restrict centroid_dis,
835-
RangeSearchResult* __restrict res) const {
873+
RangeSearchResult* __restrict res,
874+
const IVFSearchParameters* params) const {
836875
const size_t nprobe_2 = std::min(nlist, this->nprobe);
837876
bool store_pairs = false;
838877
size_t nlistv = 0, ndis = 0;
@@ -870,7 +909,12 @@ void IndexBinaryIVF::range_search_preassigned(
870909
nlistv++;
871910
ndis += list_size;
872911
scanner->scan_codes_range(
873-
list_size, scodes.get(), ids.get(), radius, qres);
912+
list_size,
913+
scodes.get(),
914+
ids.get(),
915+
radius,
916+
qres,
917+
params ? params->sel : nullptr);
874918
};
875919

876920
#pragma omp for

faiss/IndexBinaryIVF.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ struct IndexBinaryIVF : IndexBinary {
148148
int radius,
149149
const idx_t* assign,
150150
const int32_t* centroid_dis,
151-
RangeSearchResult* result) const;
151+
RangeSearchResult* result,
152+
const IVFSearchParameters* params=nullptr) const;
152153

153154
void reconstruct(idx_t key, uint8_t* recons) const override;
154155

@@ -243,14 +244,16 @@ struct BinaryInvertedListScanner {
243244
const idx_t* ids,
244245
int32_t* distances,
245246
idx_t* labels,
246-
size_t k) const = 0;
247+
size_t k,
248+
const faiss::IDSelector* sel=nullptr) const = 0;
247249

248250
virtual void scan_codes_range(
249251
size_t n,
250252
const uint8_t* codes,
251253
const idx_t* ids,
252254
int radius,
253-
RangeQueryResult& result) const = 0;
255+
RangeQueryResult& result,
256+
const faiss::IDSelector* sel=nullptr) const = 0;
254257

255258
virtual ~BinaryInvertedListScanner() {}
256259
};

faiss/utils/approx_topk_hamming/approx_topk_hamming.h

+20-21
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <limits>
1212
#include <utility>
1313

14+
#include <faiss/impl/IDSelector.h>
1415
#include <faiss/utils/Heap.h>
1516
#include <faiss/utils/simdlib.h>
1617

@@ -46,9 +47,11 @@ struct HeapWithBucketsForHamming32<
4647
// output distances
4748
int* const __restrict bh_val,
4849
// output indices, each being within [0, n) range
49-
int64_t* const __restrict bh_ids) {
50+
int64_t* const __restrict bh_ids,
51+
size_t globe_size,
52+
const faiss::IDSelector* sel = nullptr) {
5053
// forward a call to bs_addn with 1 beam
51-
bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids);
54+
bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids, globe_size, sel);
5255
}
5356

5457
static void bs_addn(
@@ -66,7 +69,9 @@ struct HeapWithBucketsForHamming32<
6669
int* const __restrict bh_val,
6770
// output indices, each being within [0, n_per_beam * beam_size)
6871
// range
69-
int64_t* const __restrict bh_ids) {
72+
int64_t* const __restrict bh_ids,
73+
size_t globe_size,
74+
const faiss::IDSelector* sel = nullptr) {
7075
//
7176
using C = CMax<int, int64_t>;
7277

@@ -96,6 +101,12 @@ struct HeapWithBucketsForHamming32<
96101
for (uint32_t j = 0; j < NBUCKETS_8; j++) {
97102
uint32_t hamming_distances[8];
98103
for (size_t j8 = 0; j8 < 8; j8++) {
104+
const auto index = globe_size + j8 + j * 8 + ip +
105+
n_per_beam * beam_index;
106+
if (sel && !sel->is_member(index)) {
107+
hamming_distances[j8] = std::numeric_limits<uint32_t>::max();
108+
continue;
109+
}
99110
hamming_distances[j8] = hc.hamming(
100111
binary_vectors +
101112
(j8 + j * 8 + ip + n_per_beam * beam_index) *
@@ -168,11 +179,15 @@ struct HeapWithBucketsForHamming32<
168179
// process leftovers
169180
for (uint32_t ip = nb; ip < n_per_beam; ip++) {
170181
const auto index = ip + n_per_beam * beam_index;
182+
if (sel && !sel->is_member(index + globe_size)) {
183+
continue;
184+
}
171185
const auto value =
172186
hc.hamming(binary_vectors + (index)*code_size);
173187

174188
if (C::cmp(bh_val[0], value)) {
175-
heap_replace_top<C>(k, bh_val, bh_ids, value, index);
189+
heap_replace_top<C>(
190+
k, bh_val, bh_ids, value, index + globe_size);
176191
}
177192
}
178193
}
@@ -249,23 +264,7 @@ struct HeapWithBucketsForHamming16<
249264
for (uint32_t p = 0; p < N; p++) {
250265
min_distances_i[j][p] =
251266
simd16uint16(std::numeric_limits<int16_t>::max());
252-
min_indices_i[j][p] = simd16uint16(
253-
0,
254-
1,
255-
2,
256-
3,
257-
4,
258-
5,
259-
6,
260-
7,
261-
8,
262-
9,
263-
10,
264-
11,
265-
12,
266-
13,
267-
14,
268-
15);
267+
min_indices_i[j][p] = simd16uint16(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15);
269268
}
270269
}
271270

faiss/utils/hamming.cpp

+18-7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include <faiss/impl/AuxIndexStructures.h>
3333
#include <faiss/impl/FaissAssert.h>
34+
#include <faiss/impl/IDSelector.h>
3435
#include <faiss/utils/Heap.h>
3536
#include <faiss/utils/approx_topk_hamming/approx_topk_hamming.h>
3637
#include <faiss/utils/utils.h>
@@ -172,7 +173,8 @@ void hammings_knn_hc(
172173
size_t n2,
173174
bool order = true,
174175
bool init_heap = true,
175-
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK) {
176+
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
177+
const faiss::IDSelector* sel = nullptr) {
176178
size_t k = ha->k;
177179
if (init_heap)
178180
ha->heapify();
@@ -205,7 +207,7 @@ void hammings_knn_hc(
205207
NB, \
206208
BD, \
207209
HammingComputer>:: \
208-
addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_); \
210+
addn(j1 - j0, hc, bs2_, k, bh_val_, bh_ids_, j0, sel); \
209211
break;
210212

211213
switch (approx_topk_mode) {
@@ -215,6 +217,9 @@ void hammings_knn_hc(
215217
HANDLE_APPROX(32, 2)
216218
default: {
217219
for (size_t j = j0; j < j1; j++, bs2_ += bytes_per_code) {
220+
if (sel && !sel->is_member(j)) {
221+
continue;
222+
}
218223
dis = hc.hamming(bs2_);
219224
if (dis < bh_val_[0]) {
220225
faiss::maxheap_replace_top<hamdis_t>(
@@ -292,7 +297,8 @@ void hamming_range_search(
292297
size_t nb,
293298
int radius,
294299
size_t code_size,
295-
RangeSearchResult* res) {
300+
RangeSearchResult* res,
301+
const faiss::IDSelector* sel = nullptr) {
296302
#pragma omp parallel
297303
{
298304
RangeSearchPartialResult pres(res);
@@ -304,6 +310,9 @@ void hamming_range_search(
304310
RangeQueryResult& qres = pres.new_result(i);
305311

306312
for (size_t j = 0; j < nb; j++) {
313+
if (sel && !sel->is_member(j)) {
314+
continue;
315+
}
307316
int dis = hc.hamming(yi);
308317
if (dis < radius) {
309318
qres.add(dis, j);
@@ -490,10 +499,11 @@ void hammings_knn_hc(
490499
size_t nb,
491500
size_t ncodes,
492501
int order,
493-
ApproxTopK_mode_t approx_topk_mode) {
502+
ApproxTopK_mode_t approx_topk_mode,
503+
const faiss::IDSelector* sel) {
494504
Run_hammings_knn_hc r;
495505
dispatch_HammingComputer(
496-
ncodes, r, ncodes, ha, a, b, nb, order, true, approx_topk_mode);
506+
ncodes, r, ncodes, ha, a, b, nb, order, true, approx_topk_mode,sel);
497507
}
498508

499509
void hammings_knn_mc(
@@ -517,10 +527,11 @@ void hamming_range_search(
517527
size_t nb,
518528
int radius,
519529
size_t code_size,
520-
RangeSearchResult* result) {
530+
RangeSearchResult* result,
531+
const faiss::IDSelector* sel) {
521532
Run_hamming_range_search r;
522533
dispatch_HammingComputer(
523-
code_size, r, a, b, na, nb, radius, code_size, result);
534+
code_size, r, a, b, na, nb, radius, code_size, result,sel);
524535
}
525536

526537
/* Count number of matches given a max threshold */

0 commit comments

Comments
 (0)