Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more unit tests for index_read and index_write #4068

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@
import sys
import pickle
from multiprocessing.pool import ThreadPool
from common_faiss_tests import get_dataset_2


d = 32
nt = 2000
nb = 1000
nq = 200

class TestIOVariants(unittest.TestCase):

def test_io_error(self):
Expand Down Expand Up @@ -338,6 +344,113 @@ def test_read_vector_transform(self):
os.unlink(fname)


class Test_IO_PQ(unittest.TestCase):
"""
test read and write PQ.
"""
def test_io_pq(self):
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
index = faiss.IndexPQ(d, 4, 4)
index.train(xt)

fd, fname = tempfile.mkstemp()
os.close(fd)

try:
faiss.write_ProductQuantizer(index.pq, fname)

read_pq = faiss.read_ProductQuantizer(fname)

self.assertEqual(index.pq.M, read_pq.M)
self.assertEqual(index.pq.nbits, read_pq.nbits)
self.assertEqual(index.pq.dsub, read_pq.dsub)
self.assertEqual(index.pq.ksub, read_pq.ksub)
np.testing.assert_array_equal(
faiss.vector_to_array(index.pq.centroids),
faiss.vector_to_array(read_pq.centroids)
)

finally:
if os.path.exists(fname):
os.unlink(fname)


class Test_IO_IndexLSH(unittest.TestCase):
"""
test read and write IndexLSH.
"""
def test_io_lsh(self):
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
index_lsh = faiss.IndexLSH(d, 32, True, True)
index_lsh.train(xt)
index_lsh.add(xb)
D, I = index_lsh.search(xq, 10)

fd, fname = tempfile.mkstemp()
os.close(fd)

try:
faiss.write_index(index_lsh, fname)

reader = faiss.BufferedIOReader(
faiss.FileIOReader(fname), 1234)
read_index_lsh = faiss.read_index(reader)
# Delete reader to prevent [WinError 32] The process cannot
# access the file because it is being used by another process
del reader

self.assertEqual(index_lsh.d, read_index_lsh.d)
np.testing.assert_array_equal(
faiss.vector_to_array(index_lsh.codes),
faiss.vector_to_array(read_index_lsh.codes)
)
D_read, I_read = read_index_lsh.search(xq, 10)

np.testing.assert_array_equal(D, D_read)
np.testing.assert_array_equal(I, I_read)

finally:
if os.path.exists(fname):
os.unlink(fname)


class Test_IO_IndexIVFSpectralHash(unittest.TestCase):
"""
test read and write IndexIVFSpectralHash.
"""
def test_io_ivf_spectral_hash(self):
nlist = 1000
xt, xb, xq = get_dataset_2(d, nt, nb, nq)
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFSpectralHash(quantizer, d, nlist, 8, 1.0)
index.train(xt)
index.add(xb)
D, I = index.search(xq, 10)

fd, fname = tempfile.mkstemp()
os.close(fd)

try:
faiss.write_index(index, fname)

reader = faiss.BufferedIOReader(
faiss.FileIOReader(fname), 1234)
read_index = faiss.read_index(reader)
del reader

self.assertEqual(index.d, read_index.d)
self.assertEqual(index.nbit, read_index.nbit)
self.assertEqual(index.period, read_index.period)
self.assertEqual(index.threshold_type, read_index.threshold_type)

D_read, I_read = read_index.search(xq, 10)
np.testing.assert_array_equal(D, D_read)
np.testing.assert_array_equal(I, I_read)

finally:
if os.path.exists(fname):
os.unlink(fname)

class TestIVFPQRead(unittest.TestCase):
def test_reader(self):
d, n = 32, 1000
Expand Down
Loading