Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax IVFFlatDedup test #3077

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Relax IVFFlatDedup test (#3077)
mdouze authored and facebook-github-bot committed Sep 28, 2023
commit 7da034bad625083c63b6237a3b80699c2df7c6f1
36 changes: 27 additions & 9 deletions contrib/evaluation.py
Original file line number Diff line number Diff line change
@@ -226,23 +226,41 @@ def compute_PR_for(q):
# Functions that compare search results with a reference result.
# They are intended for use in tests

def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew):
""" test that knn search results are identical, raise if not """
np.testing.assert_array_almost_equal(Dref, Dnew, decimal=5)
def _cluster_tables_with_tolerance(tab1, tab2, thr):
""" for two tables, cluster them by merging values closer than thr.
Returns the cluster ids for each table element """
tab = np.hstack([tab1, tab2])
tab.sort()
n = len(tab)
diffs = np.ones(n)
diffs[1:] = tab[1:] - tab[:-1]
unique_vals = tab[diffs > thr]
idx1 = np.searchsorted(unique_vals, tab1, side='right') - 1
idx2 = np.searchsorted(unique_vals, tab2, side='right') - 1
return idx1, idx2


def check_ref_knn_with_draws(Dref, Iref, Dnew, Inew, rtol=1e-5):
""" test that knn search results are identical, with possible ties.
Raise if not. """
np.testing.assert_allclose(Dref, Dnew, rtol=rtol)
# here we have to be careful because of draws
testcase = unittest.TestCase() # because it makes nice error messages
for i in range(len(Iref)):
if np.all(Iref[i] == Inew[i]): # easy case
continue
# we can deduce nothing about the latest line
skip_dis = Dref[i, -1]
for dis in np.unique(Dref):
if dis == skip_dis:

# otherwise collect elements per distance
r = rtol * Dref[i].max()

DrefC, DnewC = _cluster_tables_with_tolerance(Dref[i], Dnew[i], r)

for dis in np.unique(DrefC):
if dis == DrefC[-1]:
continue
mask = Dref[i, :] == dis
mask = DrefC == dis
testcase.assertEqual(set(Iref[i, mask]), set(Inew[i, mask]))


def check_ref_range_results(Lref, Dref, Iref,
Lnew, Dnew, Inew):
""" compare range search results wrt. a reference result,
23 changes: 4 additions & 19 deletions tests/test_index_composite.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
from common_faiss_tests import get_dataset_2
from faiss.contrib.datasets import SyntheticDataset
from faiss.contrib.inspect_tools import make_LinearTransform_matrix

from faiss.contrib.evaluation import check_ref_knn_with_draws

class TestRemoveFastScan(unittest.TestCase):
def do_test(self, ntotal, removed):
@@ -430,12 +430,6 @@ def test_mmappedIO_pretrans(self):

class TestIVFFlatDedup(unittest.TestCase):

def normalize_res(self, D, I):
dmax = D[-1]
res = [(d, i) for d, i in zip(D, I) if d < dmax]
res.sort()
return res

def test_dedup(self):
d = 10
nb = 1000
@@ -471,10 +465,7 @@ def test_dedup(self):
Dref, Iref = index_ref.search(xq, 20)
Dnew, Inew = index_new.search(xq, 20)

for i in range(nq):
ref = self.normalize_res(Dref[i], Iref[i])
new = self.normalize_res(Dnew[i], Inew[i])
assert ref == new
check_ref_knn_with_draws(Dref, Iref, Dnew, Inew)

# test I/O
fd, tmpfile = tempfile.mkstemp()
@@ -487,10 +478,7 @@ def test_dedup(self):
os.unlink(tmpfile)
Dst, Ist = index_st.search(xq, 20)

for i in range(nq):
new = self.normalize_res(Dnew[i], Inew[i])
st = self.normalize_res(Dst[i], Ist[i])
assert st == new
check_ref_knn_with_draws(Dnew, Inew, Dst, Ist)

# test remove
toremove = np.hstack((np.arange(3, 1000, 5), np.arange(850, 950)))
@@ -501,10 +489,7 @@ def test_dedup(self):
Dref, Iref = index_ref.search(xq, 20)
Dnew, Inew = index_new.search(xq, 20)

for i in range(nq):
ref = self.normalize_res(Dref[i], Iref[i])
new = self.normalize_res(Dnew[i], Inew[i])
assert ref == new
check_ref_knn_with_draws(Dref, Iref, Dnew, Inew)


class TestSerialize(unittest.TestCase):