Skip to content

Commit af70c5b

Browse files
gtwang01facebook-github-bot
authored andcommitted
3893 - Fix index factory order of idmap and refinement (#3928)
Summary: Pull Request resolved: #3928 Fix issue in T203425107 Reviewed By: asadoughi Differential Revision: D64068971 fbshipit-source-id: 56db439793539570a102773ff2c7158d48feb7a9
1 parent c5aed7c commit af70c5b

File tree

2 files changed

+37
-23
lines changed

2 files changed

+37
-23
lines changed

faiss/index_factory.cpp

+18-18
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,24 @@ std::unique_ptr<Index> index_factory_sub(
679679
// for the current match
680680
std::smatch sm;
681681

682+
// IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
683+
// support both
684+
if (re_match(description, "(.+),IDMap2", sm) ||
685+
re_match(description, "IDMap2,(.+)", sm)) {
686+
IndexIDMap2* idmap2 = new IndexIDMap2(
687+
index_factory_sub(d, sm[1].str(), metric).release());
688+
idmap2->own_fields = true;
689+
return std::unique_ptr<Index>(idmap2);
690+
}
691+
692+
if (re_match(description, "(.+),IDMap", sm) ||
693+
re_match(description, "IDMap,(.+)", sm)) {
694+
IndexIDMap* idmap = new IndexIDMap(
695+
index_factory_sub(d, sm[1].str(), metric).release());
696+
idmap->own_fields = true;
697+
return std::unique_ptr<Index>(idmap);
698+
}
699+
682700
// handle refines
683701
if (re_match(description, "(.+),RFlat", sm) ||
684702
re_match(description, "(.+),Refine\\((.+)\\)", sm)) {
@@ -755,24 +773,6 @@ std::unique_ptr<Index> index_factory_sub(
755773
d);
756774
}
757775

758-
// IndexIDMap -- it turns out is was used both as a prefix and a suffix, so
759-
// support both
760-
if (re_match(description, "(.+),IDMap2", sm) ||
761-
re_match(description, "IDMap2,(.+)", sm)) {
762-
IndexIDMap2* idmap2 = new IndexIDMap2(
763-
index_factory_sub(d, sm[1].str(), metric).release());
764-
idmap2->own_fields = true;
765-
return std::unique_ptr<Index>(idmap2);
766-
}
767-
768-
if (re_match(description, "(.+),IDMap", sm) ||
769-
re_match(description, "IDMap,(.+)", sm)) {
770-
IndexIDMap* idmap = new IndexIDMap(
771-
index_factory_sub(d, sm[1].str(), metric).release());
772-
idmap->own_fields = true;
773-
return std::unique_ptr<Index>(idmap);
774-
}
775-
776776
{ // handle basic index types
777777
Index* index = parse_other_indexes(description, d, metric);
778778
if (index) {

tests/test_factory.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from faiss.contrib import factory_tools
1313
from faiss.contrib import datasets
1414

15+
1516
class TestFactory(unittest.TestCase):
1617

1718
def test_factory_1(self):
@@ -40,7 +41,6 @@ def test_factory_2(self):
4041
index = faiss.index_factory(12, "SQ8")
4142
assert index.code_size == 12
4243

43-
4444
def test_factory_3(self):
4545

4646
index = faiss.index_factory(12, "IVF10,PQ4")
@@ -73,7 +73,8 @@ def test_factory_HNSW(self):
7373
def test_factory_HNSW_newstyle(self):
7474
index = faiss.index_factory(12, "HNSW32,Flat")
7575
assert index.storage.sa_code_size() == 12 * 4
76-
index = faiss.index_factory(12, "HNSW32,SQ8", faiss.METRIC_INNER_PRODUCT)
76+
index = faiss.index_factory(12, "HNSW32,SQ8",
77+
faiss.METRIC_INNER_PRODUCT)
7778
assert index.storage.sa_code_size() == 12
7879
assert index.metric_type == faiss.METRIC_INNER_PRODUCT
7980
index = faiss.index_factory(12, "HNSW,PQ4")
@@ -131,7 +132,8 @@ def test_factory_fast_scan(self):
131132
self.assertEqual(index.pq.nbits, 4)
132133
index = faiss.index_factory(56, "PQ28x4fs_64")
133134
self.assertEqual(index.bbs, 64)
134-
index = faiss.index_factory(56, "IVF50,PQ28x4fs_64", faiss.METRIC_INNER_PRODUCT)
135+
index = faiss.index_factory(56, "IVF50,PQ28x4fs_64",
136+
faiss.METRIC_INNER_PRODUCT)
135137
self.assertEqual(index.bbs, 64)
136138
self.assertEqual(index.nlist, 50)
137139
self.assertTrue(index.cp.spherical)
@@ -158,7 +160,6 @@ def test_parenthesis_refine(self):
158160
self.assertEqual(rf.pq.M, 25)
159161
self.assertEqual(rf.pq.nbits, 12)
160162

161-
162163
def test_parenthesis_refine_2(self):
163164
# Refine applies on the whole index including pre-transforms
164165
index = faiss.index_factory(50, "PCA32,IVF32,Flat,Refine(PQ25x12)")
@@ -264,6 +265,19 @@ def test_idmap2_prefix(self):
264265
index = faiss.downcast_index(index)
265266
self.assertEqual(index.__class__, faiss.IndexIDMap2)
266267

268+
def test_idmap_refine(self):
269+
index = faiss.index_factory(8, "IDMap,PQ4x4fs,RFlat")
270+
self.assertEqual(index.__class__, faiss.IndexIDMap)
271+
refine_index = faiss.downcast_index(index.index)
272+
self.assertEqual(refine_index.__class__, faiss.IndexRefineFlat)
273+
base_index = faiss.downcast_index(refine_index.base_index)
274+
self.assertEqual(base_index.__class__, faiss.IndexPQFastScan)
275+
276+
# Index now works with add_with_ids, but not with add
277+
index.train(np.zeros((16, 8)))
278+
index.add_with_ids(np.zeros((16, 8)), np.arange(16))
279+
self.assertRaises(RuntimeError, index.add, np.zeros((16, 8)))
280+
267281
def test_ivf_hnsw(self):
268282
index = faiss.index_factory(123, "IVF100_HNSW,Flat")
269283
quantizer = faiss.downcast_index(index.quantizer)
@@ -337,4 +351,4 @@ def test_replace_vt(self):
337351
index = faiss.IndexIVFSpectralHash(faiss.IndexFlat(10), 10, 20, 10, 1)
338352
index.replace_vt(faiss.ITQTransform(10, 10))
339353
gc.collect()
340-
index.vt.d_out # this should not crash
354+
index.vt.d_out # this should not crash

0 commit comments

Comments
 (0)