|
| 1 | +/** |
| 2 | + * Copyright (c) Facebook, Inc. and its affiliates. |
| 3 | + * |
| 4 | + * This source code is licensed under the MIT license found in the |
| 5 | + * LICENSE file in the root directory of this source tree. |
| 6 | + */ |
| 7 | + |
| 8 | +#include <gtest/gtest.h> |
| 9 | + |
| 10 | +#include <cstddef> |
| 11 | +#include <cstdint> |
| 12 | +#include <memory> |
| 13 | +#include <random> |
| 14 | +#include <vector> |
| 15 | + |
| 16 | +#include <omp.h> |
| 17 | + |
| 18 | +#include <faiss/IndexFlat.h> |
| 19 | +#include <faiss/IndexIVFPQFastScan.h> |
| 20 | +#include <faiss/impl/AuxIndexStructures.h> |
| 21 | + |
| 22 | +TEST(TestFastScan, knnVSrange) { |
| 23 | + // small vectors and database |
| 24 | + int d = 64; |
| 25 | + size_t nb = 1000; |
| 26 | + |
| 27 | + // ivf centroids |
| 28 | + size_t nlist = 4; |
| 29 | + |
| 30 | + // more than 2 threads to surface |
| 31 | + // problems related to multi-threading |
| 32 | + omp_set_num_threads(8); |
| 33 | + |
| 34 | + // random database, also used as queries |
| 35 | + std::vector<float> database(nb * d); |
| 36 | + std::mt19937 rng; |
| 37 | + std::uniform_real_distribution<> distrib; |
| 38 | + for (size_t i = 0; i < nb * d; i++) { |
| 39 | + database[i] = distrib(rng); |
| 40 | + } |
| 41 | + |
| 42 | + // build index |
| 43 | + faiss::IndexFlatL2 coarse_quantizer(d); |
| 44 | + faiss::IndexIVFPQFastScan index( |
| 45 | + &coarse_quantizer, d, nlist, d / 2, 4, faiss::METRIC_L2, 32); |
| 46 | + index.pq.cp.niter = 10; // speed up train |
| 47 | + index.nprobe = nlist; |
| 48 | + index.train(nb, database.data()); |
| 49 | + index.add(nb, database.data()); |
| 50 | + |
| 51 | + std::vector<float> distances(nb); |
| 52 | + std::vector<faiss::idx_t> labels(nb); |
| 53 | + auto t = std::chrono::high_resolution_clock::now(); |
| 54 | + index.search(nb, database.data(), 1, distances.data(), labels.data()); |
| 55 | + auto knn_time = std::chrono::duration_cast<std::chrono::milliseconds>( |
| 56 | + std::chrono::high_resolution_clock::now() - t) |
| 57 | + .count(); |
| 58 | + |
| 59 | + faiss::RangeSearchResult rsr(nb); |
| 60 | + t = std::chrono::high_resolution_clock::now(); |
| 61 | + index.range_search(nb, database.data(), 1.0, &rsr); |
| 62 | + auto range_time = std::chrono::duration_cast<std::chrono::milliseconds>( |
| 63 | + std::chrono::high_resolution_clock::now() - t) |
| 64 | + .count(); |
| 65 | + |
| 66 | + // we expect the perf of knn and range search |
| 67 | + // to be similar, at least within a factor of 2 |
| 68 | + ASSERT_LT(range_time, knn_time * 2); |
| 69 | + ASSERT_LT(knn_time, range_time * 2); |
| 70 | +} |
0 commit comments