@@ -96,12 +96,26 @@ void IndexRefine::search(
96
96
idx_t k,
97
97
float * distances,
98
98
idx_t * labels,
99
- const SearchParameters* params) const {
100
- FAISS_THROW_IF_NOT_MSG (
101
- !params, " search params not supported for this index" );
99
+ const SearchParameters* params_in) const {
100
+ const IndexRefineSearchParameters* params = nullptr ;
101
+ if (params_in) {
102
+ params = dynamic_cast <const IndexRefineSearchParameters*>(params_in);
103
+ FAISS_THROW_IF_NOT_MSG (
104
+ params, " IndexRefine params have incorrect type" );
105
+ }
106
+
107
+ idx_t k_base = (params != nullptr ) ? idx_t (k * params->k_factor )
108
+ : idx_t (k * k_factor);
109
+ SearchParameters* base_index_params =
110
+ (params != nullptr ) ? params->base_index_params : nullptr ;
111
+
112
+ FAISS_THROW_IF_NOT (k_base >= k);
113
+
114
+ FAISS_THROW_IF_NOT (base_index);
115
+ FAISS_THROW_IF_NOT (refine_index);
116
+
102
117
FAISS_THROW_IF_NOT (k > 0 );
103
118
FAISS_THROW_IF_NOT (is_trained);
104
- idx_t k_base = idx_t (k * k_factor);
105
119
idx_t * base_labels = labels;
106
120
float * base_distances = distances;
107
121
ScopeDeleter<idx_t > del1;
@@ -114,7 +128,8 @@ void IndexRefine::search(
114
128
del2.set (base_distances);
115
129
}
116
130
117
- base_index->search (n, x, k_base, base_distances, base_labels);
131
+ base_index->search (
132
+ n, x, k_base, base_distances, base_labels, base_index_params);
118
133
119
134
for (int i = 0 ; i < n * k_base; i++)
120
135
assert (base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -225,12 +240,26 @@ void IndexRefineFlat::search(
225
240
idx_t k,
226
241
float * distances,
227
242
idx_t * labels,
228
- const SearchParameters* params) const {
229
- FAISS_THROW_IF_NOT_MSG (
230
- !params, " search params not supported for this index" );
243
+ const SearchParameters* params_in) const {
244
+ const IndexRefineSearchParameters* params = nullptr ;
245
+ if (params_in) {
246
+ params = dynamic_cast <const IndexRefineSearchParameters*>(params_in);
247
+ FAISS_THROW_IF_NOT_MSG (
248
+ params, " IndexRefineFlat params have incorrect type" );
249
+ }
250
+
251
+ idx_t k_base = (params != nullptr ) ? idx_t (k * params->k_factor )
252
+ : idx_t (k * k_factor);
253
+ SearchParameters* base_index_params =
254
+ (params != nullptr ) ? params->base_index_params : nullptr ;
255
+
256
+ FAISS_THROW_IF_NOT (k_base >= k);
257
+
258
+ FAISS_THROW_IF_NOT (base_index);
259
+ FAISS_THROW_IF_NOT (refine_index);
260
+
231
261
FAISS_THROW_IF_NOT (k > 0 );
232
262
FAISS_THROW_IF_NOT (is_trained);
233
- idx_t k_base = idx_t (k * k_factor);
234
263
idx_t * base_labels = labels;
235
264
float * base_distances = distances;
236
265
ScopeDeleter<idx_t > del1;
@@ -243,7 +272,8 @@ void IndexRefineFlat::search(
243
272
del2.set (base_distances);
244
273
}
245
274
246
- base_index->search (n, x, k_base, base_distances, base_labels);
275
+ base_index->search (
276
+ n, x, k_base, base_distances, base_labels, base_index_params);
247
277
248
278
for (int i = 0 ; i < n * k_base; i++)
249
279
assert (base_labels[i] >= -1 && base_labels[i] < ntotal);
0 commit comments