Skip to content

Commit 9a66532

Browse files
Add search parameters for IndexRefine::search() and IndexRefineFlat::search() (facebookresearch#3122)
Summary: Add search params for `faiss::IndexRefine` and `faiss::IndexRefineFlat` Pull Request resolved: facebookresearch#3122 Test Plan: buck test //faiss/tests/:test_refine Reviewed By: pemazare Differential Revision: D50968413 Pulled By: mdouze fbshipit-source-id: 9f020d7e9c9d96b9acba54d9d7fff13bcf703b9e
1 parent df7280b commit 9a66532

File tree

3 files changed

+102
-10
lines changed

3 files changed

+102
-10
lines changed

faiss/IndexRefine.cpp

+40-10
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,26 @@ void IndexRefine::search(
9696
idx_t k,
9797
float* distances,
9898
idx_t* labels,
99-
const SearchParameters* params) const {
100-
FAISS_THROW_IF_NOT_MSG(
101-
!params, "search params not supported for this index");
99+
const SearchParameters* params_in) const {
100+
const IndexRefineSearchParameters* params = nullptr;
101+
if (params_in) {
102+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
103+
FAISS_THROW_IF_NOT_MSG(
104+
params, "IndexRefine params have incorrect type");
105+
}
106+
107+
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
108+
: idx_t(k * k_factor);
109+
SearchParameters* base_index_params =
110+
(params != nullptr) ? params->base_index_params : nullptr;
111+
112+
FAISS_THROW_IF_NOT(k_base >= k);
113+
114+
FAISS_THROW_IF_NOT(base_index);
115+
FAISS_THROW_IF_NOT(refine_index);
116+
102117
FAISS_THROW_IF_NOT(k > 0);
103118
FAISS_THROW_IF_NOT(is_trained);
104-
idx_t k_base = idx_t(k * k_factor);
105119
idx_t* base_labels = labels;
106120
float* base_distances = distances;
107121
ScopeDeleter<idx_t> del1;
@@ -114,7 +128,8 @@ void IndexRefine::search(
114128
del2.set(base_distances);
115129
}
116130

117-
base_index->search(n, x, k_base, base_distances, base_labels);
131+
base_index->search(
132+
n, x, k_base, base_distances, base_labels, base_index_params);
118133

119134
for (int i = 0; i < n * k_base; i++)
120135
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -225,12 +240,26 @@ void IndexRefineFlat::search(
225240
idx_t k,
226241
float* distances,
227242
idx_t* labels,
228-
const SearchParameters* params) const {
229-
FAISS_THROW_IF_NOT_MSG(
230-
!params, "search params not supported for this index");
243+
const SearchParameters* params_in) const {
244+
const IndexRefineSearchParameters* params = nullptr;
245+
if (params_in) {
246+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
247+
FAISS_THROW_IF_NOT_MSG(
248+
params, "IndexRefineFlat params have incorrect type");
249+
}
250+
251+
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
252+
: idx_t(k * k_factor);
253+
SearchParameters* base_index_params =
254+
(params != nullptr) ? params->base_index_params : nullptr;
255+
256+
FAISS_THROW_IF_NOT(k_base >= k);
257+
258+
FAISS_THROW_IF_NOT(base_index);
259+
FAISS_THROW_IF_NOT(refine_index);
260+
231261
FAISS_THROW_IF_NOT(k > 0);
232262
FAISS_THROW_IF_NOT(is_trained);
233-
idx_t k_base = idx_t(k * k_factor);
234263
idx_t* base_labels = labels;
235264
float* base_distances = distances;
236265
ScopeDeleter<idx_t> del1;
@@ -243,7 +272,8 @@ void IndexRefineFlat::search(
243272
del2.set(base_distances);
244273
}
245274

246-
base_index->search(n, x, k_base, base_distances, base_labels);
275+
base_index->search(
276+
n, x, k_base, base_distances, base_labels, base_index_params);
247277

248278
for (int i = 0; i < n * k_base; i++)
249279
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);

faiss/IndexRefine.h

+7
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111

1212
namespace faiss {
1313

14+
struct IndexRefineSearchParameters : SearchParameters {
15+
float k_factor = 1;
16+
SearchParameters* base_index_params = nullptr; // non-owning
17+
18+
virtual ~IndexRefineSearchParameters() = default;
19+
};
20+
1421
/** Index that queries in a base_index (a fast one) and refines the
1522
* results with an exact search, hopefully improving the results.
1623
*/

tests/test_refine.py

+55
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,58 @@ def test_distance_computer_AQ_LUT(self):
6464

6565
def test_distance_computer_AQ_LUT_IP(self):
6666
self.do_test("RQ3x4_Nqint8", faiss.METRIC_INNER_PRODUCT)
67+
68+
69+
class TestIndexRefineSearchParams(unittest.TestCase):
70+
71+
def do_test(self, factory_string):
72+
ds = datasets.SyntheticDataset(32, 256, 100, 40)
73+
74+
index = faiss.index_factory(32, factory_string)
75+
index.train(ds.get_train())
76+
index.add(ds.get_database())
77+
index.nprobe = 4
78+
xq = ds.get_queries()
79+
80+
# do a search with k_factor = 1
81+
D1, I1 = index.search(xq, 10)
82+
inter1 = faiss.eval_intersection(I1, ds.get_groundtruth(10))
83+
84+
# do a search with k_factor = 1.5
85+
params = faiss.IndexRefineSearchParameters(k_factor=1.1)
86+
D2, I2 = index.search(xq, 10, params=params)
87+
inter2 = faiss.eval_intersection(I2, ds.get_groundtruth(10))
88+
89+
# do a search with k_factor = 2
90+
params = faiss.IndexRefineSearchParameters(k_factor=2)
91+
D3, I3 = index.search(xq, 10, params=params)
92+
inter3 = faiss.eval_intersection(I3, ds.get_groundtruth(10))
93+
94+
# make sure that the recall rate increases with k_factor
95+
self.assertGreater(inter2, inter1)
96+
self.assertGreater(inter3, inter2)
97+
98+
# make sure that the baseline k_factor is unchanged
99+
self.assertEqual(index.k_factor, 1)
100+
101+
# try passing params for the baseline index, change nprobe
102+
base_params = faiss.IVFSearchParameters(nprobe=10)
103+
params = faiss.IndexRefineSearchParameters(k_factor=1, base_index_params=base_params)
104+
D4, I4 = index.search(xq, 10, params=params)
105+
inter4 = faiss.eval_intersection(I4, ds.get_groundtruth(10))
106+
107+
base_params = faiss.IVFSearchParameters(nprobe=2)
108+
params = faiss.IndexRefineSearchParameters(k_factor=1, base_index_params=base_params)
109+
D5, I5 = index.search(xq, 10, params=params)
110+
inter5 = faiss.eval_intersection(I5, ds.get_groundtruth(10))
111+
112+
# make sure that the recall rate changes
113+
self.assertNotEqual(inter4, inter5)
114+
115+
def test_rflat(self):
116+
# flat is handled by the IndexRefineFlat class
117+
self.do_test("IVF8,PQ2x4np,RFlat")
118+
119+
def test_refine_sq8(self):
120+
# this case uses the IndexRefine class
121+
self.do_test("IVF8,PQ2x4np,Refine(SQ8)")

0 commit comments

Comments
 (0)