@@ -842,6 +842,71 @@ TEST(TestGpuIndexIVFFlat, LongIVFList) {
842
842
#endif
843
843
}
844
844
845
+ TEST (TestGpuIndexIVFFlat, Reconstruct_n) {
846
+ Options opt;
847
+
848
+ std::vector<float > trainVecs = faiss::gpu::randVecs (opt.numTrain , opt.dim );
849
+ std::vector<float > addVecs = faiss::gpu::randVecs (opt.numAdd , opt.dim );
850
+
851
+ faiss::IndexFlatL2 cpuQuantizer (opt.dim );
852
+ faiss::IndexIVFFlat cpuIndex (
853
+ &cpuQuantizer, opt.dim , opt.numCentroids , faiss::METRIC_L2);
854
+ cpuIndex.nprobe = opt.nprobe ;
855
+ cpuIndex.train (opt.numTrain , trainVecs.data ());
856
+ cpuIndex.add (opt.numAdd , addVecs.data ());
857
+
858
+ faiss::gpu::StandardGpuResources res;
859
+ res.noTempMemory ();
860
+
861
+ faiss::gpu::GpuIndexIVFFlatConfig config;
862
+ config.device = opt.device ;
863
+ config.indicesOptions = faiss::gpu::INDICES_64_BIT;
864
+ config.use_raft = false ;
865
+
866
+ faiss::gpu::GpuIndexIVFFlat gpuIndex (
867
+ &res, opt.dim , opt.numCentroids , faiss::METRIC_L2, config);
868
+ gpuIndex.nprobe = opt.nprobe ;
869
+
870
+ gpuIndex.train (opt.numTrain , trainVecs.data ());
871
+ gpuIndex.add (opt.numAdd , addVecs.data ());
872
+
873
+ std::vector<float > gpuVals (opt.numAdd * opt.dim );
874
+
875
+ gpuIndex.reconstruct_n (0 , gpuIndex.ntotal , gpuVals.data ());
876
+
877
+ std::vector<float > cpuVals (opt.numAdd * opt.dim );
878
+
879
+ cpuIndex.reconstruct_n (0 , cpuIndex.ntotal , cpuVals.data ());
880
+
881
+ EXPECT_EQ (gpuVals, cpuVals);
882
+
883
+ config.indicesOptions = faiss::gpu::INDICES_32_BIT;
884
+
885
+ faiss::gpu::GpuIndexIVFFlat gpuIndex1 (
886
+ &res, opt.dim , opt.numCentroids , faiss::METRIC_L2, config);
887
+ gpuIndex1.nprobe = opt.nprobe ;
888
+
889
+ gpuIndex1.train (opt.numTrain , trainVecs.data ());
890
+ gpuIndex1.add (opt.numAdd , addVecs.data ());
891
+
892
+ gpuIndex1.reconstruct_n (0 , gpuIndex1.ntotal , gpuVals.data ());
893
+
894
+ EXPECT_EQ (gpuVals, cpuVals);
895
+
896
+ config.indicesOptions = faiss::gpu::INDICES_CPU;
897
+
898
+ faiss::gpu::GpuIndexIVFFlat gpuIndex2 (
899
+ &res, opt.dim , opt.numCentroids , faiss::METRIC_L2, config);
900
+ gpuIndex2.nprobe = opt.nprobe ;
901
+
902
+ gpuIndex2.train (opt.numTrain , trainVecs.data ());
903
+ gpuIndex2.add (opt.numAdd , addVecs.data ());
904
+
905
+ gpuIndex2.reconstruct_n (0 , gpuIndex2.ntotal , gpuVals.data ());
906
+
907
+ EXPECT_EQ (gpuVals, cpuVals);
908
+ }
909
+
845
910
int main (int argc, char ** argv) {
846
911
testing::InitGoogleTest (&argc, argv);
847
912
0 commit comments