Skip to content

Commit 1876925

Browse files
Amir Sadoughifacebook-github-bot
Amir Sadoughi
authored andcommitted
Implement METRIC.NaNEuclidean (facebookresearch#3414)
Summary: Pull Request resolved: facebookresearch#3414 facebookresearch#3355 A couple open questions: - Given L2 was squared, I figured I would leave this one as squared as well? - Also, wasn't sure if we wanted to return nan when present == 0 or -1? Reviewed By: mdouze Differential Revision: D57017608 fbshipit-source-id: ba14458b92c8b055f3bf2a871565175935c8333a
1 parent 72571c7 commit 1876925

File tree

4 files changed

+44
-0
lines changed

4 files changed

+44
-0
lines changed

faiss/MetricType.h

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ enum MetricType {
3333
METRIC_JensenShannon,
3434
METRIC_Jaccard, ///< defined as: sum_i(min(a_i, b_i)) / sum_i(max(a_i, b_i))
3535
///< where a_i, b_i > 0
36+
METRIC_NaNEuclidean,
3637
};
3738

3839
/// all vector indices are this type

faiss/utils/extra_distances-inl.h

+20
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <faiss/MetricType.h>
1212
#include <faiss/utils/distances.h>
13+
#include <cmath>
1314
#include <type_traits>
1415

1516
namespace faiss {
@@ -130,4 +131,23 @@ inline float VectorDistance<METRIC_Jaccard>::operator()(
130131
return accu_num / accu_den;
131132
}
132133

134+
template <>
135+
inline float VectorDistance<METRIC_NaNEuclidean>::operator()(
136+
const float* x,
137+
const float* y) const {
138+
// https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.nan_euclidean_distances.html
139+
float accu = 0;
140+
size_t present = 0;
141+
for (size_t i = 0; i < d; i++) {
142+
if (!std::isnan(x[i]) && !std::isnan(y[i])) {
143+
float diff = x[i] - y[i];
144+
accu += diff * diff;
145+
present++;
146+
}
147+
}
148+
if (present == 0) {
149+
return NAN;
150+
}
151+
return float(d) / float(present) * accu;
152+
}
133153
} // namespace faiss

faiss/utils/extra_distances.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ void pairwise_extra_distances(
164164
HANDLE_VAR(JensenShannon);
165165
HANDLE_VAR(Lp);
166166
HANDLE_VAR(Jaccard);
167+
HANDLE_VAR(NaNEuclidean);
167168
#undef HANDLE_VAR
168169
default:
169170
FAISS_THROW_MSG("metric type not implemented");
@@ -195,6 +196,7 @@ void knn_extra_metrics(
195196
HANDLE_VAR(JensenShannon);
196197
HANDLE_VAR(Lp);
197198
HANDLE_VAR(Jaccard);
199+
HANDLE_VAR(NaNEuclidean);
198200
#undef HANDLE_VAR
199201
default:
200202
FAISS_THROW_MSG("metric type not implemented");
@@ -242,6 +244,7 @@ FlatCodesDistanceComputer* get_extra_distance_computer(
242244
HANDLE_VAR(JensenShannon);
243245
HANDLE_VAR(Lp);
244246
HANDLE_VAR(Jaccard);
247+
HANDLE_VAR(NaNEuclidean);
245248
#undef HANDLE_VAR
246249
default:
247250
FAISS_THROW_MSG("metric type not implemented");

tests/test_extra_distances.py

+20
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,26 @@ def test_jaccard(self):
9494
new_dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_Jaccard)
9595
self.assertTrue(np.allclose(ref_dis, new_dis))
9696

97+
def test_nan_euclidean(self):
98+
xq, yb = self.make_example()
99+
ref_dis = np.array([
100+
[scipy.spatial.distance.sqeuclidean(x, y) for y in yb]
101+
for x in xq
102+
])
103+
new_dis = faiss.pairwise_distances(xq, yb, faiss.METRIC_NaNEuclidean)
104+
self.assertTrue(np.allclose(ref_dis, new_dis))
105+
106+
x = [[3, np.nan, np.nan, 6]]
107+
q = [[1, np.nan, np.nan, 5]]
108+
dis = [(4 / 2 * ((3 - 1)**2 + (6 - 5)**2))]
109+
new_dis = faiss.pairwise_distances(x, q, faiss.METRIC_NaNEuclidean)
110+
self.assertTrue(np.allclose(new_dis, dis))
111+
112+
x = [[np.nan] * 4]
113+
q = [[np.nan] * 4]
114+
new_dis = faiss.pairwise_distances(x, q, faiss.METRIC_NaNEuclidean)
115+
self.assertTrue(np.isnan(new_dis[0]))
116+
97117

98118
class TestKNN(unittest.TestCase):
99119
""" test that the knn search gives the same as distance matrix + argmin """

0 commit comments

Comments
 (0)