12
12
from faiss .contrib import factory_tools
13
13
from faiss .contrib import datasets
14
14
15
+
15
16
class TestFactory (unittest .TestCase ):
16
17
17
18
def test_factory_1 (self ):
@@ -40,7 +41,6 @@ def test_factory_2(self):
40
41
index = faiss .index_factory (12 , "SQ8" )
41
42
assert index .code_size == 12
42
43
43
-
44
44
def test_factory_3 (self ):
45
45
46
46
index = faiss .index_factory (12 , "IVF10,PQ4" )
@@ -73,7 +73,8 @@ def test_factory_HNSW(self):
73
73
def test_factory_HNSW_newstyle (self ):
74
74
index = faiss .index_factory (12 , "HNSW32,Flat" )
75
75
assert index .storage .sa_code_size () == 12 * 4
76
- index = faiss .index_factory (12 , "HNSW32,SQ8" , faiss .METRIC_INNER_PRODUCT )
76
+ index = faiss .index_factory (12 , "HNSW32,SQ8" ,
77
+ faiss .METRIC_INNER_PRODUCT )
77
78
assert index .storage .sa_code_size () == 12
78
79
assert index .metric_type == faiss .METRIC_INNER_PRODUCT
79
80
index = faiss .index_factory (12 , "HNSW,PQ4" )
@@ -131,7 +132,8 @@ def test_factory_fast_scan(self):
131
132
self .assertEqual (index .pq .nbits , 4 )
132
133
index = faiss .index_factory (56 , "PQ28x4fs_64" )
133
134
self .assertEqual (index .bbs , 64 )
134
- index = faiss .index_factory (56 , "IVF50,PQ28x4fs_64" , faiss .METRIC_INNER_PRODUCT )
135
+ index = faiss .index_factory (56 , "IVF50,PQ28x4fs_64" ,
136
+ faiss .METRIC_INNER_PRODUCT )
135
137
self .assertEqual (index .bbs , 64 )
136
138
self .assertEqual (index .nlist , 50 )
137
139
self .assertTrue (index .cp .spherical )
@@ -158,7 +160,6 @@ def test_parenthesis_refine(self):
158
160
self .assertEqual (rf .pq .M , 25 )
159
161
self .assertEqual (rf .pq .nbits , 12 )
160
162
161
-
162
163
def test_parenthesis_refine_2 (self ):
163
164
# Refine applies on the whole index including pre-transforms
164
165
index = faiss .index_factory (50 , "PCA32,IVF32,Flat,Refine(PQ25x12)" )
@@ -264,6 +265,19 @@ def test_idmap2_prefix(self):
264
265
index = faiss .downcast_index (index )
265
266
self .assertEqual (index .__class__ , faiss .IndexIDMap2 )
266
267
268
+ def test_idmap_refine (self ):
269
+ index = faiss .index_factory (8 , "IDMap,PQ4x4fs,RFlat" )
270
+ self .assertEqual (index .__class__ , faiss .IndexIDMap )
271
+ refine_index = faiss .downcast_index (index .index )
272
+ self .assertEqual (refine_index .__class__ , faiss .IndexRefineFlat )
273
+ base_index = faiss .downcast_index (refine_index .base_index )
274
+ self .assertEqual (base_index .__class__ , faiss .IndexPQFastScan )
275
+
276
+ # Index now works with add_with_ids, but not with add
277
+ index .train (np .zeros ((16 , 8 )))
278
+ index .add_with_ids (np .zeros ((16 , 8 )), np .arange (16 ))
279
+ self .assertRaises (RuntimeError , index .add , np .zeros ((16 , 8 )))
280
+
267
281
def test_ivf_hnsw (self ):
268
282
index = faiss .index_factory (123 , "IVF100_HNSW,Flat" )
269
283
quantizer = faiss .downcast_index (index .quantizer )
@@ -337,4 +351,4 @@ def test_replace_vt(self):
337
351
index = faiss .IndexIVFSpectralHash (faiss .IndexFlat (10 ), 10 , 20 , 10 , 1 )
338
352
index .replace_vt (faiss .ITQTransform (10 , 10 ))
339
353
gc .collect ()
340
- index .vt .d_out # this should not crash
354
+ index .vt .d_out # this should not crash
0 commit comments