Skip to content

Commit 3219e3d

Browse files
mdouzefacebook-github-bot
authored andcommitted
Support Selector for IDMap (facebookresearch#2848)
Summary: Pull Request resolved: facebookresearch#2848 Add selector support for IDMap wrapped indices. Caveat: this requires to wrap the IDSelector with another selector. Since the params are const, the const is casted away. This is a problem if the same params are used from multiple execution threads with different selectors. However, this seems rare enough to take the risk. Reviewed By: alexanderguzhva Differential Revision: D45598823 fbshipit-source-id: ec23465c13f1f8273a6a46f9aa869ccae2cdb79c
1 parent 5b17225 commit 3219e3d

File tree

5 files changed

+111
-30
lines changed

5 files changed

+111
-30
lines changed

faiss/IndexIDMap.cpp

+42-21
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include <faiss/impl/AuxIndexStructures.h>
1818
#include <faiss/impl/FaissAssert.h>
19-
#include <faiss/impl/IDSelector.h>
2019
#include <faiss/utils/Heap.h>
2120
#include <faiss/utils/WorkerThread.h>
2221

@@ -71,6 +70,27 @@ void IndexIDMapTemplate<IndexT>::add_with_ids(
7170
this->ntotal = index->ntotal;
7271
}
7372

73+
namespace {
74+
75+
/// RAII object to reset the IDSelector in the params object
76+
struct ScopedSelChange {
77+
SearchParameters* params = nullptr;
78+
IDSelector* old_sel = nullptr;
79+
80+
void set(SearchParameters* params, IDSelector* new_sel) {
81+
this->params = params;
82+
old_sel = params->sel;
83+
params->sel = new_sel;
84+
}
85+
~ScopedSelChange() {
86+
if (params) {
87+
params->sel = old_sel;
88+
}
89+
}
90+
};
91+
92+
} // namespace
93+
7494
template <typename IndexT>
7595
void IndexIDMapTemplate<IndexT>::search(
7696
idx_t n,
@@ -79,9 +99,26 @@ void IndexIDMapTemplate<IndexT>::search(
7999
typename IndexT::distance_t* distances,
80100
idx_t* labels,
81101
const SearchParameters* params) const {
82-
FAISS_THROW_IF_NOT_MSG(
83-
!params, "search params not supported for this index");
84-
index->search(n, x, k, distances, labels);
102+
IDSelectorTranslated this_idtrans(this->id_map, nullptr);
103+
ScopedSelChange sel_change;
104+
105+
if (params && params->sel) {
106+
auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
107+
108+
if (!idtrans) {
109+
/*
110+
FAISS_THROW_IF_NOT_MSG(
111+
idtrans,
112+
"IndexIDMap requires an IDSelectorTranslated on input");
113+
*/
114+
// then make an idtrans and force it into the SearchParameters
115+
// (hence the const_cast)
116+
auto params_non_const = const_cast<SearchParameters*>(params);
117+
this_idtrans.sel = params->sel;
118+
sel_change.set(params_non_const, &this_idtrans);
119+
}
120+
}
121+
index->search(n, x, k, distances, labels, params);
85122
idx_t* li = labels;
86123
#pragma omp parallel for
87124
for (idx_t i = 0; i < n * k; i++) {
@@ -106,26 +143,10 @@ void IndexIDMapTemplate<IndexT>::range_search(
106143
}
107144
}
108145

109-
namespace {
110-
111-
struct IDTranslatedSelector : IDSelector {
112-
const std::vector<int64_t>& id_map;
113-
const IDSelector& sel;
114-
IDTranslatedSelector(
115-
const std::vector<int64_t>& id_map,
116-
const IDSelector& sel)
117-
: id_map(id_map), sel(sel) {}
118-
bool is_member(idx_t id) const override {
119-
return sel.is_member(id_map[id]);
120-
}
121-
};
122-
123-
} // namespace
124-
125146
template <typename IndexT>
126147
size_t IndexIDMapTemplate<IndexT>::remove_ids(const IDSelector& sel) {
127148
// remove in sub-index first
128-
IDTranslatedSelector sel2(id_map, sel);
149+
IDSelectorTranslated sel2(id_map, &sel);
129150
size_t nremove = index->remove_ids(sel2);
130151

131152
int64_t j = 0;

faiss/IndexIDMap.h

+22
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <faiss/Index.h>
1111
#include <faiss/IndexBinary.h>
12+
#include <faiss/impl/IDSelector.h>
1213

1314
#include <unordered_map>
1415
#include <vector>
@@ -102,4 +103,25 @@ struct IndexIDMap2Template : IndexIDMapTemplate<IndexT> {
102103
using IndexIDMap2 = IndexIDMap2Template<Index>;
103104
using IndexBinaryIDMap2 = IndexIDMap2Template<IndexBinary>;
104105

106+
// IDSelector that translates the ids using an IDMap
107+
struct IDSelectorTranslated : IDSelector {
108+
const std::vector<int64_t>& id_map;
109+
const IDSelector* sel;
110+
111+
IDSelectorTranslated(
112+
const std::vector<int64_t>& id_map,
113+
const IDSelector* sel)
114+
: id_map(id_map), sel(sel) {}
115+
116+
IDSelectorTranslated(IndexBinaryIDMap& index_idmap, const IDSelector* sel)
117+
: id_map(index_idmap.id_map), sel(sel) {}
118+
119+
IDSelectorTranslated(IndexIDMap& index_idmap, const IDSelector* sel)
120+
: id_map(index_idmap.id_map), sel(sel) {}
121+
122+
bool is_member(idx_t id) const override {
123+
return sel->is_member(id_map[id]);
124+
}
125+
};
126+
105127
} // namespace faiss

faiss/python/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def replacement_function(*args):
192192
add_ref_in_constructor(IDSelectorAnd, slice(2))
193193
add_ref_in_constructor(IDSelectorOr, slice(2))
194194
add_ref_in_constructor(IDSelectorXOr, slice(2))
195+
add_ref_in_constructor(IDSelectorTranslated, slice(2))
195196

196197
# seems really marginal...
197198
# remove_ref_from_method(IndexReplicas, 'removeIndex', 0)

faiss/python/swigfaiss.swig

+7-5
Original file line numberDiff line numberDiff line change
@@ -494,11 +494,6 @@ void gpu_sync_all_devices()
494494
%template(IndexBinaryReplicas) faiss::IndexReplicasTemplate<faiss::IndexBinary>;
495495

496496
%include <faiss/MetaIndexes.h>
497-
%include <faiss/IndexIDMap.h>
498-
%template(IndexIDMap) faiss::IndexIDMapTemplate<faiss::Index>;
499-
%template(IndexBinaryIDMap) faiss::IndexIDMapTemplate<faiss::IndexBinary>;
500-
%template(IndexIDMap2) faiss::IndexIDMap2Template<faiss::Index>;
501-
%template(IndexBinaryIDMap2) faiss::IndexIDMap2Template<faiss::IndexBinary>;
502497

503498
%include <faiss/IndexRowwiseMinMax.h>
504499

@@ -513,6 +508,13 @@ void gpu_sync_all_devices()
513508
%include <faiss/impl/AuxIndexStructures.h>
514509
%include <faiss/impl/IDSelector.h>
515510

511+
%include <faiss/IndexIDMap.h>
512+
%template(IndexIDMap) faiss::IndexIDMapTemplate<faiss::Index>;
513+
%template(IndexBinaryIDMap) faiss::IndexIDMapTemplate<faiss::IndexBinary>;
514+
%template(IndexIDMap2) faiss::IndexIDMap2Template<faiss::Index>;
515+
%template(IndexBinaryIDMap2) faiss::IndexIDMap2Template<faiss::IndexBinary>;
516+
517+
516518
%include <faiss/utils/approx_topk/mode.h>
517519

518520
#ifdef GPU_WRAPPER

tests/test_search_params.py

+39-4
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,17 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR
101101
sel = faiss.IDSelectorNot(faiss.IDSelectorBatch(inverse_subset))
102102
elif id_selector_type == "or":
103103
sel = faiss.IDSelectorOr(
104-
faiss.IDSelectorBatch(lhs_subset),
104+
faiss.IDSelectorBatch(lhs_subset),
105105
faiss.IDSelectorBatch(rhs_subset)
106106
)
107107
elif id_selector_type == "and":
108108
sel = faiss.IDSelectorAnd(
109-
faiss.IDSelectorBatch(lhs_subset),
109+
faiss.IDSelectorBatch(lhs_subset),
110110
faiss.IDSelectorBatch(rhs_subset)
111111
)
112112
elif id_selector_type == "xor":
113113
sel = faiss.IDSelectorXOr(
114-
faiss.IDSelectorBatch(lhs_subset),
114+
faiss.IDSelectorBatch(lhs_subset),
115115
faiss.IDSelectorBatch(rhs_subset)
116116
)
117117
else:
@@ -181,7 +181,7 @@ def test_Flat_id_bitmap(self):
181181

182182
def test_Flat_id_not(self):
183183
self.do_test_id_selector("Flat", id_selector_type="not")
184-
184+
185185
def test_Flat_id_or(self):
186186
self.do_test_id_selector("Flat", id_selector_type="or")
187187

@@ -220,6 +220,41 @@ def do_test_id_selector_weak(self, index_key):
220220
def test_HSNW(self):
221221
self.do_test_id_selector_weak("HNSW")
222222

223+
def test_idmap(self):
224+
ds = datasets.SyntheticDataset(32, 100, 100, 20)
225+
rs = np.random.RandomState(123)
226+
ids = rs.choice(10000, size=100, replace=False)
227+
mask = ids % 2 == 0
228+
index = faiss.index_factory(ds.d, "IDMap,SQ8")
229+
index.train(ds.get_train())
230+
231+
# ref result
232+
index.add_with_ids(ds.get_database()[mask], ids[mask])
233+
Dref, Iref = index.search(ds.get_queries(), 10)
234+
235+
# with selector
236+
index.reset()
237+
index.add_with_ids(ds.get_database(), ids)
238+
239+
valid_ids = ids[mask]
240+
sel = faiss.IDSelectorTranslated(
241+
index, faiss.IDSelectorBatch(valid_ids))
242+
243+
Dnew, Inew = index.search(
244+
ds.get_queries(), 10,
245+
params=faiss.SearchParameters(sel=sel)
246+
)
247+
np.testing.assert_array_equal(Iref, Inew)
248+
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
249+
250+
# let the IDMap::search add the translation...
251+
Dnew, Inew = index.search(
252+
ds.get_queries(), 10,
253+
params=faiss.SearchParameters(sel=faiss.IDSelectorBatch(valid_ids))
254+
)
255+
np.testing.assert_array_equal(Iref, Inew)
256+
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
257+
223258

224259
class TestSearchParams(unittest.TestCase):
225260

0 commit comments

Comments
 (0)