Skip to content

Commit a3fbf2d

Browse files
mdouzefacebook-github-bot
authored andcommitted
Better NaN handling (facebookresearch#2986)
Summary: Pull Request resolved: facebookresearch#2986 A NaN vector is a vector with at least one NaN (not-a-number) entry. After discussion in the Faiss team we decided that: - training should throw an exception on NaN vectors - added NaN vectors should be ignored (never returned) - searched NaN vectors should return only -1s This diff implements this for a few common index types + adds relevant tests. Reviewed By: algoriddle Differential Revision: D48031390 fbshipit-source-id: 99e7786582e91950e3a53c1d8bcffdd00b6afd24
1 parent a4ddb18 commit a3fbf2d

File tree

5 files changed

+215
-108
lines changed

5 files changed

+215
-108
lines changed

faiss/impl/HNSW.cpp

+14-11
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <faiss/impl/HNSW.h>
1111

12+
#include <cmath>
1213
#include <string>
1314

1415
#include <faiss/impl/AuxIndexStructures.h>
@@ -542,12 +543,11 @@ int search_from_candidates(
542543
for (int i = 0; i < candidates.size(); i++) {
543544
idx_t v1 = candidates.ids[i];
544545
float d = candidates.dis[i];
545-
FAISS_ASSERT(v1 >= 0);
546+
assert(v1 >= 0);
546547
if (!sel || sel->is_member(v1)) {
547-
if (nres < k) {
548-
faiss::maxheap_push(++nres, D, I, d, v1);
549-
} else if (d < D[0]) {
550-
faiss::maxheap_replace_top(nres, D, I, d, v1);
548+
if (d < D[0]) {
549+
faiss::maxheap_replace_top(k, D, I, d, v1);
550+
nres++;
551551
}
552552
}
553553
vt.set(v1);
@@ -612,10 +612,9 @@ int search_from_candidates(
612612

613613
auto add_to_heap = [&](const size_t idx, const float dis) {
614614
if (!sel || sel->is_member(idx)) {
615-
if (nres < k) {
616-
faiss::maxheap_push(++nres, D, I, dis, idx);
617-
} else if (dis < D[0]) {
618-
faiss::maxheap_replace_top(nres, D, I, dis, idx);
615+
if (dis < D[0]) {
616+
faiss::maxheap_replace_top(k, D, I, dis, idx);
617+
nres++;
619618
}
620619
}
621620
candidates.push(idx, dis);
@@ -668,7 +667,7 @@ int search_from_candidates(
668667
stats.n3 += ndis;
669668
}
670669

671-
return nres;
670+
return std::min(nres, k);
672671
}
673672

674673
std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
@@ -816,6 +815,11 @@ HNSWStats HNSW::search(
816815
// greedy search on upper levels
817816
storage_idx_t nearest = entry_point;
818817
float d_nearest = qdis(nearest);
818+
if (!std::isfinite(d_nearest)) {
819+
// means either the query or the entry point are NaN: in
820+
// both cases we can only return -1 as a result
821+
return stats;
822+
}
819823

820824
for (int level = max_level; level >= 1; level--) {
821825
greedy_update_nearest(*this, qdis, level, nearest, d_nearest);
@@ -826,7 +830,6 @@ HNSWStats HNSW::search(
826830
MinimaxHeap candidates(ef);
827831

828832
candidates.push(nearest, d_nearest);
829-
830833
search_from_candidates(
831834
*this, qdis, k, I, D, candidates, vt, stats, 0, 0, params);
832835
} else {

faiss/impl/ResultHandler.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -445,8 +445,8 @@ struct SingleBestResultHandler {
445445
/// begin results for query # i
446446
void begin(const size_t current_idx) {
447447
this->current_idx = current_idx;
448-
min_dis = HUGE_VALF;
449-
min_idx = 0;
448+
min_dis = C::neutral();
449+
min_idx = -1;
450450
}
451451

452452
/// add one result for query i
@@ -472,7 +472,8 @@ struct SingleBestResultHandler {
472472
this->i1 = i1;
473473

474474
for (size_t i = i0; i < i1; i++) {
475-
this->dis_tab[i] = HUGE_VALF;
475+
this->dis_tab[i] = C::neutral();
476+
this->ids_tab[i] = -1;
476477
}
477478
}
478479

faiss/impl/ScalarQuantizer.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,11 @@ void ScalarQuantizer::set_derived_sizes() {
10751075
}
10761076

10771077
void ScalarQuantizer::train(size_t n, const float* x) {
1078+
for (size_t i = 0; i < n * d; i++) {
1079+
FAISS_THROW_IF_NOT_MSG(
1080+
std::isfinite(x[i]), "training data contains NaN or Inf");
1081+
}
1082+
10781083
int bit_per_dim = qtype == QT_4bit_uniform ? 4
10791084
: qtype == QT_4bit ? 4
10801085
: qtype == QT_6bit ? 6

tests/test_error_reporting.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""This script tests a few failure cases of Faiss and whether they are handled
7+
properly."""
8+
9+
import numpy as np
10+
import unittest
11+
import faiss
12+
13+
from common_faiss_tests import get_dataset_2
14+
from faiss.contrib.datasets import SyntheticDataset
15+
16+
17+
class TestValidIndexParams(unittest.TestCase):
18+
19+
def test_IndexIVFPQ(self):
20+
d = 32
21+
nb = 1000
22+
nt = 1500
23+
nq = 200
24+
25+
(xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
26+
27+
coarse_quantizer = faiss.IndexFlatL2(d)
28+
index = faiss.IndexIVFPQ(coarse_quantizer, d, 32, 8, 8)
29+
index.cp.min_points_per_centroid = 5 # quiet warning
30+
index.train(xt)
31+
index.add(xb)
32+
33+
# invalid nprobe
34+
index.nprobe = 0
35+
k = 10
36+
self.assertRaises(RuntimeError, index.search, xq, k)
37+
38+
# invalid k
39+
index.nprobe = 4
40+
k = -10
41+
self.assertRaises(AssertionError, index.search, xq, k)
42+
43+
# valid params
44+
index.nprobe = 4
45+
k = 10
46+
D, nns = index.search(xq, k)
47+
48+
self.assertEqual(D.shape[0], nq)
49+
self.assertEqual(D.shape[1], k)
50+
51+
def test_IndexFlat(self):
52+
d = 32
53+
nb = 1000
54+
nt = 0
55+
nq = 200
56+
57+
(xt, xb, xq) = get_dataset_2(d, nt, nb, nq)
58+
index = faiss.IndexFlat(d, faiss.METRIC_L2)
59+
60+
index.add(xb)
61+
62+
# invalid k
63+
k = -5
64+
self.assertRaises(AssertionError, index.search, xq, k)
65+
66+
# valid k
67+
k = 5
68+
D, I = index.search(xq, k)
69+
70+
self.assertEqual(D.shape[0], nq)
71+
self.assertEqual(D.shape[1], k)
72+
73+
74+
class TestReconsException(unittest.TestCase):
75+
76+
def test_recons_exception(self):
77+
78+
d = 64 # dimension
79+
nb = 1000
80+
rs = np.random.RandomState(1234)
81+
xb = rs.rand(nb, d).astype('float32')
82+
nlist = 10
83+
quantizer = faiss.IndexFlatL2(d) # the other index
84+
index = faiss.IndexIVFFlat(quantizer, d, nlist)
85+
index.train(xb)
86+
index.add(xb)
87+
index.make_direct_map()
88+
89+
index.reconstruct(9)
90+
91+
self.assertRaises(
92+
RuntimeError,
93+
index.reconstruct, 100001
94+
)
95+
96+
def test_reconstuct_after_add(self):
97+
index = faiss.index_factory(10, 'IVF5,SQfp16')
98+
index.train(faiss.randn((100, 10), 123))
99+
index.add(faiss.randn((100, 10), 345))
100+
index.make_direct_map()
101+
index.add(faiss.randn((100, 10), 678))
102+
103+
# should not raise an exception
104+
index.reconstruct(5)
105+
print(index.ntotal)
106+
index.reconstruct(150)
107+
108+
109+
class TestNaN(unittest.TestCase):
110+
""" NaN values handling is transparent: they don't produce results
111+
but should not crash. The tests below cover a few common index types.
112+
"""
113+
114+
def do_test_train(self, factory_string):
115+
""" NaN and Inf should raise an exception at train time """
116+
ds = SyntheticDataset(32, 200, 20, 10)
117+
index = faiss.index_factory(ds.d, factory_string)
118+
# try to train with NaNs
119+
xt = ds.get_train().copy()
120+
xt[:, ::4] = np.nan
121+
self.assertRaises(RuntimeError, index.train, xt)
122+
123+
def test_train_IVFSQ(self):
124+
self.do_test_train("IVF10,SQ8")
125+
126+
def test_train_IVFPQ(self):
127+
self.do_test_train("IVF10,PQ4np")
128+
129+
def test_train_SQ(self):
130+
self.do_test_train("SQ8")
131+
132+
def do_test_add(self, factory_string):
133+
""" stored NaNs should not be returned at search time """
134+
ds = SyntheticDataset(32, 200, 20, 10)
135+
index = faiss.index_factory(ds.d, factory_string)
136+
if not index.is_trained:
137+
index.train(ds.get_train())
138+
xb = ds.get_database()
139+
xb[12, 3] = np.nan
140+
index.add(xb)
141+
D, I = index.search(ds.get_queries(), 20)
142+
self.assertTrue(np.where(I == 12)[0].size == 0)
143+
144+
def test_add_Flat(self):
145+
self.do_test_add("Flat")
146+
147+
def test_add_HNSW(self):
148+
self.do_test_add("HNSW32,Flat")
149+
150+
def xx_test_add_SQ8(self):
151+
# this is expected to fail because:
152+
# in ASAN mode, the float NaN -> int conversion crashes
153+
# in opt mode it works but there is no way to encode the NaN,
154+
# so the value cannot be ignored.
155+
self.do_test_add("SQ8")
156+
157+
def test_add_IVFFlat(self):
158+
self.do_test_add("IVF10,Flat")
159+
160+
def do_test_search(self, factory_string):
161+
""" NaN query vectors should return -1 """
162+
ds = SyntheticDataset(32, 200, 20, 10)
163+
index = faiss.index_factory(ds.d, factory_string)
164+
if not index.is_trained:
165+
index.train(ds.get_train())
166+
index.add(ds.get_database())
167+
xq = ds.get_queries()
168+
xq[7, 3] = np.nan
169+
D, I = index.search(ds.get_queries(), 20)
170+
self.assertTrue(np.all(I[7] == -1))
171+
172+
def test_search_Flat(self):
173+
self.do_test_search("Flat")
174+
175+
def test_search_HNSW(self):
176+
self.do_test_search("HNSW32,Flat")
177+
178+
def test_search_IVFFlat(self):
179+
self.do_test_search("IVF10,Flat")
180+
181+
def test_search_SQ(self):
182+
self.do_test_search("SQ8")

0 commit comments

Comments
 (0)