@@ -544,35 +544,99 @@ def test_by_residual_odd_dim(self):
544
544
545
545
546
546
class TestReconstruct (unittest .TestCase ):
547
+ """ test reconstruct and sa_encode / sa_decode
548
+ (also for a few additive quantizer variants) """
547
549
548
550
def do_test (self , by_residual = False ):
549
551
d = 32
550
552
metric = faiss .METRIC_L2
551
553
552
- ds = datasets .SyntheticDataset (d , 2000 , 5000 , 200 )
554
+ ds = datasets .SyntheticDataset (d , 250 , 200 , 10 )
553
555
554
- index = faiss .IndexIVFPQFastScan (faiss .IndexFlatL2 (d ), d , 50 , d // 2 , 4 , metric )
556
+ index = faiss .IndexIVFPQFastScan (
557
+ faiss .IndexFlatL2 (d ), d , 50 , d // 2 , 4 , metric )
555
558
index .by_residual = by_residual
556
559
index .make_direct_map (True )
557
560
index .train (ds .get_train ())
558
561
index .add (ds .get_database ())
559
562
560
563
# Test reconstruction
561
- index .reconstruct (123 ) # single id
562
- index .reconstruct_n (123 , 10 ) # single id
563
- index .reconstruct_batch (np .arange (10 ))
564
+ v123 = index .reconstruct (123 ) # single id
565
+ v120_10 = index .reconstruct_n (120 , 10 )
566
+ np .testing .assert_array_equal (v120_10 [3 ], v123 )
567
+ v120_10 = index .reconstruct_batch (np .arange (120 , 130 ))
568
+ np .testing .assert_array_equal (v120_10 [3 ], v123 )
564
569
565
570
# Test original list reconstruction
566
- index .orig_invlists = faiss .ArrayInvertedLists (index .nlist , index .code_size )
571
+ index .orig_invlists = faiss .ArrayInvertedLists (
572
+ index .nlist , index .code_size )
567
573
index .reconstruct_orig_invlists ()
568
574
assert index .orig_invlists .compute_ntotal () == index .ntotal
569
575
576
+ # compare with non fast-scan index
577
+ index2 = faiss .IndexIVFPQ (
578
+ index .quantizer , d , 50 , d // 2 , 4 , metric )
579
+ index2 .by_residual = by_residual
580
+ index2 .pq = index .pq
581
+ index2 .is_trained = True
582
+ index2 .replace_invlists (index .orig_invlists , False )
583
+ index2 .ntotal = index .ntotal
584
+ index2 .make_direct_map (True )
585
+ assert np .all (index .reconstruct (123 ) == index2 .reconstruct (123 ))
586
+
570
587
def test_no_residual (self ):
571
588
self .do_test (by_residual = False )
572
589
573
590
def test_by_residual (self ):
574
591
self .do_test (by_residual = True )
575
592
593
+ def do_test_generic (self , factory_string ,
594
+ by_residual = False , metric = faiss .METRIC_L2 ):
595
+ d = 32
596
+ ds = datasets .SyntheticDataset (d , 250 , 200 , 10 )
597
+ index = faiss .index_factory (ds .d , factory_string , metric )
598
+ if "IVF" in factory_string :
599
+ index .by_residual = by_residual
600
+ index .make_direct_map (True )
601
+ index .train (ds .get_train ())
602
+ index .add (ds .get_database ())
603
+
604
+ # Test reconstruction
605
+ v123 = index .reconstruct (123 ) # single id
606
+ v120_10 = index .reconstruct_n (120 , 10 )
607
+ np .testing .assert_array_equal (v120_10 [3 ], v123 )
608
+ v120_10 = index .reconstruct_batch (np .arange (120 , 130 ))
609
+ np .testing .assert_array_equal (v120_10 [3 ], v123 )
610
+ codes = index .sa_encode (ds .get_database ()[120 :130 ])
611
+ np .testing .assert_array_equal (index .sa_decode (codes ), v120_10 )
612
+
613
+ # make sure pointers are correct after serialization
614
+ index2 = faiss .deserialize_index (faiss .serialize_index (index ))
615
+ codes2 = index2 .sa_encode (ds .get_database ()[120 :130 ])
616
+ np .testing .assert_array_equal (codes , codes2 )
617
+
618
+
619
+ def test_ivfpq_residual (self ):
620
+ self .do_test_generic ("IVF20,PQ16x4fs" , by_residual = True )
621
+
622
+ def test_ivfpq_no_residual (self ):
623
+ self .do_test_generic ("IVF20,PQ16x4fs" , by_residual = False )
624
+
625
+ def test_pq (self ):
626
+ self .do_test_generic ("PQ16x4fs" )
627
+
628
+ def test_rq (self ):
629
+ self .do_test_generic ("RQ4x4fs" , metric = faiss .METRIC_INNER_PRODUCT )
630
+
631
+ def test_ivfprq (self ):
632
+ self .do_test_generic ("IVF20,PRQ8x2x4fs" , by_residual = True , metric = faiss .METRIC_INNER_PRODUCT )
633
+
634
+ def test_ivfprq_no_residual (self ):
635
+ self .do_test_generic ("IVF20,PRQ8x2x4fs" , by_residual = False , metric = faiss .METRIC_INNER_PRODUCT )
636
+
637
+ def test_prq (self ):
638
+ self .do_test_generic ("PRQ8x2x4fs" , metric = faiss .METRIC_INNER_PRODUCT )
639
+
576
640
577
641
class TestIsTrained (unittest .TestCase ):
578
642
0 commit comments