Skip to content

Commit b1f9223

Browse files
mdouzeabhinavdangeti
authored andcommitted
Add ABS_INNER_PRODUCT metric (facebookresearch#3524)
Summary: Pull Request resolved: facebookresearch#3524 Searches with the metric abs(dot(query, database)) This makes it possible to search vectors that are closest to a hyperplane * adds support for alternative metrics in faiss.knn in python * checks that it works with HNSW * simplifies the extra distances interface by removing the template on Reviewed By: asadoughi Differential Revision: D58695971 fbshipit-source-id: 2a0ff49c7f7ac2c005d85f141cc5de148081c9c4
1 parent 32a6a91 commit b1f9223

8 files changed

+82
-50
lines changed

faiss/IndexFlat.cpp

+11-7
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,19 @@ void IndexFlat::search(
4141
} else if (metric_type == METRIC_L2) {
4242
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
4343
knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
44-
} else if (is_similarity_metric(metric_type)) {
45-
float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
46-
knn_extra_metrics(
47-
x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
4844
} else {
49-
FAISS_THROW_IF_NOT(!sel);
50-
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
45+
FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
5146
knn_extra_metrics(
52-
x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
47+
x,
48+
get_xb(),
49+
d,
50+
n,
51+
ntotal,
52+
metric_type,
53+
metric_arg,
54+
k,
55+
distances,
56+
labels);
5357
}
5458
}
5559

faiss/MetricType.h

+6-2
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,13 @@ enum MetricType {
3131
METRIC_Canberra = 20,
3232
METRIC_BrayCurtis,
3333
METRIC_JensenShannon,
34-
METRIC_Jaccard, ///< defined as: sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i))
35-
///< where a_i, b_i > 0
34+
35+
/// sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i)) where a_i, b_i > 0
36+
METRIC_Jaccard,
37+
/// Squared Eucliden distance, ignoring NaNs
3638
METRIC_NaNEuclidean,
39+
/// abs(x | y): the distance to a hyperplane
40+
METRIC_ABS_INNER_PRODUCT,
3741
};
3842

3943
/// all vector indices are this type

faiss/python/extra_wrappers.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def lookup(self, keys):
330330
# KNN function
331331
######################################################
332332

333-
def knn(xq, xb, k, metric=METRIC_L2):
333+
def knn(xq, xb, k, metric=METRIC_L2, metric_arg=0.0):
334334
"""
335335
Compute the k nearest neighbors of a vector without constructing an index
336336
@@ -374,10 +374,16 @@ def knn(xq, xb, k, metric=METRIC_L2):
374374
swig_ptr(xq), swig_ptr(xb),
375375
d, nq, nb, k, swig_ptr(D), swig_ptr(I)
376376
)
377-
else:
378-
raise NotImplementedError("only L2 and INNER_PRODUCT are supported")
377+
else:
378+
knn_extra_metrics(
379+
swig_ptr(xq), swig_ptr(xb),
380+
d, nq, nb, metric, metric_arg, k,
381+
swig_ptr(D), swig_ptr(I)
382+
)
383+
379384
return D, I
380385

386+
381387
def knn_hamming(xq, xb, k, variant="hc"):
382388
"""
383389
Compute the k nearest neighbors of a set of vectors without constructing an index.

faiss/utils/extra_distances-inl.h

+12
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,16 @@ inline float VectorDistance<METRIC_NaNEuclidean>::operator()(
150150
}
151151
return float(d) / float(present) * accu;
152152
}
153+
154+
template <>
155+
inline float VectorDistance<METRIC_ABS_INNER_PRODUCT>::operator()(
156+
const float* x,
157+
const float* y) const {
158+
float accu = 0;
159+
for (size_t i = 0; i < d; i++) {
160+
accu += fabs(x[i] * y[i]);
161+
}
162+
return accu;
163+
}
164+
153165
} // namespace faiss

faiss/utils/extra_distances.cpp

+19-36
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,18 @@ void pairwise_extra_distances_template(
5050
}
5151
}
5252

53-
template <class VD, class C>
53+
template <class VD>
5454
void knn_extra_metrics_template(
5555
VD vd,
5656
const float* x,
5757
const float* y,
5858
size_t nx,
5959
size_t ny,
60-
HeapArray<C>* res) {
61-
size_t k = res->k;
60+
size_t k,
61+
float* distances,
62+
int64_t* labels) {
6263
size_t d = vd.d;
64+
using C = typename VD::C;
6365
size_t check_period = InterruptCallback::get_period_hint(ny * d);
6466
check_period *= omp_get_max_threads();
6567

@@ -71,18 +73,15 @@ void knn_extra_metrics_template(
7173
const float* x_i = x + i * d;
7274
const float* y_j = y;
7375
size_t j;
74-
float* simi = res->get_val(i);
75-
int64_t* idxi = res->get_ids(i);
76+
float* simi = distances + k * i;
77+
int64_t* idxi = labels + k * i;
7678

7779
// maxheap_heapify(k, simi, idxi);
7880
heap_heapify<C>(k, simi, idxi);
7981
for (j = 0; j < ny; j++) {
8082
float disij = vd(x_i, y_j);
8183

82-
// if (disij < simi[0]) {
83-
if ((!vd.is_similarity && (disij < simi[0])) ||
84-
(vd.is_similarity && (disij > simi[0]))) {
85-
// maxheap_replace_top(k, simi, idxi, disij, j);
84+
if (C::cmp(simi[0], disij)) {
8685
heap_replace_top<C>(k, simi, idxi, disij, j);
8786
}
8887
y_j += d;
@@ -165,13 +164,13 @@ void pairwise_extra_distances(
165164
HANDLE_VAR(Lp);
166165
HANDLE_VAR(Jaccard);
167166
HANDLE_VAR(NaNEuclidean);
167+
HANDLE_VAR(ABS_INNER_PRODUCT);
168168
#undef HANDLE_VAR
169169
default:
170170
FAISS_THROW_MSG("metric type not implemented");
171171
}
172172
}
173173

174-
template <class C>
175174
void knn_extra_metrics(
176175
const float* x,
177176
const float* y,
@@ -180,13 +179,15 @@ void knn_extra_metrics(
180179
size_t ny,
181180
MetricType mt,
182181
float metric_arg,
183-
HeapArray<C>* res) {
182+
size_t k,
183+
float* distances,
184+
int64_t* indexes) {
184185
switch (mt) {
185-
#define HANDLE_VAR(kw) \
186-
case METRIC_##kw: { \
187-
VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
188-
knn_extra_metrics_template(vd, x, y, nx, ny, res); \
189-
break; \
186+
#define HANDLE_VAR(kw) \
187+
case METRIC_##kw: { \
188+
VectorDistance<METRIC_##kw> vd = {(size_t)d, metric_arg}; \
189+
knn_extra_metrics_template(vd, x, y, nx, ny, k, distances, indexes); \
190+
break; \
190191
}
191192
HANDLE_VAR(L2);
192193
HANDLE_VAR(L1);
@@ -197,32 +198,13 @@ void knn_extra_metrics(
197198
HANDLE_VAR(Lp);
198199
HANDLE_VAR(Jaccard);
199200
HANDLE_VAR(NaNEuclidean);
201+
HANDLE_VAR(ABS_INNER_PRODUCT);
200202
#undef HANDLE_VAR
201203
default:
202204
FAISS_THROW_MSG("metric type not implemented");
203205
}
204206
}
205207

206-
template void knn_extra_metrics<CMax<float, int64_t>>(
207-
const float* x,
208-
const float* y,
209-
size_t d,
210-
size_t nx,
211-
size_t ny,
212-
MetricType mt,
213-
float metric_arg,
214-
HeapArray<CMax<float, int64_t>>* res);
215-
216-
template void knn_extra_metrics<CMin<float, int64_t>>(
217-
const float* x,
218-
const float* y,
219-
size_t d,
220-
size_t nx,
221-
size_t ny,
222-
MetricType mt,
223-
float metric_arg,
224-
HeapArray<CMin<float, int64_t>>* res);
225-
226208
FlatCodesDistanceComputer* get_extra_distance_computer(
227209
size_t d,
228210
MetricType mt,
@@ -245,6 +227,7 @@ FlatCodesDistanceComputer* get_extra_distance_computer(
245227
HANDLE_VAR(Lp);
246228
HANDLE_VAR(Jaccard);
247229
HANDLE_VAR(NaNEuclidean);
230+
HANDLE_VAR(ABS_INNER_PRODUCT);
248231
#undef HANDLE_VAR
249232
default:
250233
FAISS_THROW_MSG("metric type not implemented");

faiss/utils/extra_distances.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ void pairwise_extra_distances(
3333
int64_t ldb = -1,
3434
int64_t ldd = -1);
3535

36-
template <class C>
3736
void knn_extra_metrics(
3837
const float* x,
3938
const float* y,
@@ -42,7 +41,9 @@ void knn_extra_metrics(
4241
size_t ny,
4342
MetricType mt,
4443
float metric_arg,
45-
HeapArray<C>* res);
44+
size_t k,
45+
float* distances,
46+
int64_t* indexes);
4647

4748
/** get a DistanceComputer that refers to this type of distance and
4849
* indexes a flat array of size nb */

tests/test_extra_distances.py

+7
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ def test_nan_euclidean(self):
114114
new_dis = faiss.pairwise_distances(x, q, faiss.METRIC_NaNEuclidean)
115115
self.assertTrue(np.isnan(new_dis[0]))
116116

117+
def test_abs_inner_product(self):
118+
xq, yb = self.make_example()
119+
dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_ABS_INNER_PRODUCT)
120+
121+
gt_dis = np.abs(xq @ yb.T)
122+
np.testing.assert_allclose(dis, gt_dis, atol=1e-5)
123+
117124

118125
class TestKNN(unittest.TestCase):
119126
""" test that the knn search gives the same as distance matrix + argmin """

tests/test_graph_based.py

+15
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,21 @@ def test_io_no_storage(self):
169169
)
170170
self.assertEquals(index3.storage, None)
171171

172+
def test_abs_inner_product(self):
173+
"""Test HNSW with abs inner product (not a real distance, so dubious that triangular inequality works)"""
174+
d = self.xq.shape[1]
175+
xb = self.xb - self.xb.mean(axis=0) # need to be centered to give interesting directions
176+
xq = self.xq - self.xq.mean(axis=0)
177+
Dref, Iref = faiss.knn(xq, xb, 10, faiss.METRIC_ABS_INNER_PRODUCT)
178+
179+
index = faiss.IndexHNSWFlat(d, 32, faiss.METRIC_ABS_INNER_PRODUCT)
180+
index.add(xb)
181+
Dnew, Inew = index.search(xq, 10)
182+
183+
inter = faiss.eval_intersection(Iref, Inew)
184+
# 4769 vs. 500*10
185+
self.assertGreater(inter, Iref.size * 0.9)
186+
172187

173188
class TestNSG(unittest.TestCase):
174189

0 commit comments

Comments
 (0)