@@ -50,16 +50,18 @@ void pairwise_extra_distances_template(
50
50
}
51
51
}
52
52
53
- template <class VD , class C >
53
+ template <class VD >
54
54
void knn_extra_metrics_template (
55
55
VD vd,
56
56
const float * x,
57
57
const float * y,
58
58
size_t nx,
59
59
size_t ny,
60
- HeapArray<C>* res) {
61
- size_t k = res->k ;
60
+ size_t k,
61
+ float * distances,
62
+ int64_t * labels) {
62
63
size_t d = vd.d ;
64
+ using C = typename VD::C;
63
65
size_t check_period = InterruptCallback::get_period_hint (ny * d);
64
66
check_period *= omp_get_max_threads ();
65
67
@@ -71,18 +73,15 @@ void knn_extra_metrics_template(
71
73
const float * x_i = x + i * d;
72
74
const float * y_j = y;
73
75
size_t j;
74
- float * simi = res-> get_val (i) ;
75
- int64_t * idxi = res-> get_ids (i) ;
76
+ float * simi = distances + k * i ;
77
+ int64_t * idxi = labels + k * i ;
76
78
77
79
// maxheap_heapify(k, simi, idxi);
78
80
heap_heapify<C>(k, simi, idxi);
79
81
for (j = 0 ; j < ny; j++) {
80
82
float disij = vd (x_i, y_j);
81
83
82
- // if (disij < simi[0]) {
83
- if ((!vd.is_similarity && (disij < simi[0 ])) ||
84
- (vd.is_similarity && (disij > simi[0 ]))) {
85
- // maxheap_replace_top(k, simi, idxi, disij, j);
84
+ if (C::cmp (simi[0 ], disij)) {
86
85
heap_replace_top<C>(k, simi, idxi, disij, j);
87
86
}
88
87
y_j += d;
@@ -165,13 +164,13 @@ void pairwise_extra_distances(
165
164
HANDLE_VAR (Lp);
166
165
HANDLE_VAR (Jaccard);
167
166
HANDLE_VAR (NaNEuclidean);
167
+ HANDLE_VAR (ABS_INNER_PRODUCT);
168
168
#undef HANDLE_VAR
169
169
default :
170
170
FAISS_THROW_MSG (" metric type not implemented" );
171
171
}
172
172
}
173
173
174
- template <class C >
175
174
void knn_extra_metrics (
176
175
const float * x,
177
176
const float * y,
@@ -180,13 +179,15 @@ void knn_extra_metrics(
180
179
size_t ny,
181
180
MetricType mt,
182
181
float metric_arg,
183
- HeapArray<C>* res) {
182
+ size_t k,
183
+ float * distances,
184
+ int64_t * indexes) {
184
185
switch (mt) {
185
- #define HANDLE_VAR (kw ) \
186
- case METRIC_##kw: { \
187
- VectorDistance<METRIC_##kw> vd = {(size_t )d, metric_arg}; \
188
- knn_extra_metrics_template (vd, x, y, nx, ny, res); \
189
- break ; \
186
+ #define HANDLE_VAR (kw ) \
187
+ case METRIC_##kw: { \
188
+ VectorDistance<METRIC_##kw> vd = {(size_t )d, metric_arg}; \
189
+ knn_extra_metrics_template (vd, x, y, nx, ny, k, distances, indexes); \
190
+ break ; \
190
191
}
191
192
HANDLE_VAR (L2);
192
193
HANDLE_VAR (L1);
@@ -197,32 +198,13 @@ void knn_extra_metrics(
197
198
HANDLE_VAR (Lp);
198
199
HANDLE_VAR (Jaccard);
199
200
HANDLE_VAR (NaNEuclidean);
201
+ HANDLE_VAR (ABS_INNER_PRODUCT);
200
202
#undef HANDLE_VAR
201
203
default :
202
204
FAISS_THROW_MSG (" metric type not implemented" );
203
205
}
204
206
}
205
207
206
- template void knn_extra_metrics<CMax<float , int64_t >>(
207
- const float * x,
208
- const float * y,
209
- size_t d,
210
- size_t nx,
211
- size_t ny,
212
- MetricType mt,
213
- float metric_arg,
214
- HeapArray<CMax<float , int64_t >>* res);
215
-
216
- template void knn_extra_metrics<CMin<float , int64_t >>(
217
- const float * x,
218
- const float * y,
219
- size_t d,
220
- size_t nx,
221
- size_t ny,
222
- MetricType mt,
223
- float metric_arg,
224
- HeapArray<CMin<float , int64_t >>* res);
225
-
226
208
FlatCodesDistanceComputer* get_extra_distance_computer (
227
209
size_t d,
228
210
MetricType mt,
@@ -245,6 +227,7 @@ FlatCodesDistanceComputer* get_extra_distance_computer(
245
227
HANDLE_VAR (Lp);
246
228
HANDLE_VAR (Jaccard);
247
229
HANDLE_VAR (NaNEuclidean);
230
+ HANDLE_VAR (ABS_INNER_PRODUCT);
248
231
#undef HANDLE_VAR
249
232
default :
250
233
FAISS_THROW_MSG (" metric type not implemented" );
0 commit comments