Skip to content

Commit f262011

Browse files
algoriddlefacebook-github-bot
authored andcommitted
fix omp parallelism in fast scan range search
Summary: Fix omp n^2 parallelism Reviewed By: mdouze Differential Revision: D53705601 fbshipit-source-id: 3fcc2368c436185119f6e988ee2867dfd7d8eb07
1 parent 8898eab commit f262011

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

faiss/IndexIVFFastScan.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ void IndexIVFFastScan::range_search_dispatch_implem(
643643
{
644644
RangeSearchPartialResult pres(&rres);
645645

646-
#pragma omp parallel for reduction(+ : ndis, nlist_visited)
646+
#pragma omp for reduction(+ : ndis, nlist_visited)
647647
for (int slice = 0; slice < nslice; slice++) {
648648
idx_t i0 = n * slice / nslice;
649649
idx_t i1 = n * (slice + 1) / nslice;

tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ set(FAISS_TEST_SRC
3131
test_code_distance.cpp
3232
test_hnsw.cpp
3333
test_partitioning.cpp
34+
test_fastscan_perf.cpp
3435
)
3536

3637
add_executable(faiss_test ${FAISS_TEST_SRC})

tests/test_fastscan_perf.cpp

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <gtest/gtest.h>
9+
10+
#include <cstddef>
11+
#include <cstdint>
12+
#include <memory>
13+
#include <random>
14+
#include <vector>
15+
16+
#include <omp.h>
17+
18+
#include <faiss/IndexFlat.h>
19+
#include <faiss/IndexIVFPQFastScan.h>
20+
#include <faiss/impl/AuxIndexStructures.h>
21+
22+
TEST(TestFastScan, knnVSrange) {
23+
// small vectors and database
24+
int d = 64;
25+
size_t nb = 1000;
26+
27+
// ivf centroids
28+
size_t nlist = 4;
29+
30+
// more than 2 threads to surface
31+
// problems related to multi-threading
32+
omp_set_num_threads(8);
33+
34+
// random database, also used as queries
35+
std::vector<float> database(nb * d);
36+
std::mt19937 rng;
37+
std::uniform_real_distribution<> distrib;
38+
for (size_t i = 0; i < nb * d; i++) {
39+
database[i] = distrib(rng);
40+
}
41+
42+
// build index
43+
faiss::IndexFlatL2 coarse_quantizer(d);
44+
faiss::IndexIVFPQFastScan index(
45+
&coarse_quantizer, d, nlist, d / 2, 4, faiss::METRIC_L2, 32);
46+
index.pq.cp.niter = 10; // speed up train
47+
index.nprobe = nlist;
48+
index.train(nb, database.data());
49+
index.add(nb, database.data());
50+
51+
std::vector<float> distances(nb);
52+
std::vector<faiss::idx_t> labels(nb);
53+
auto t = std::chrono::high_resolution_clock::now();
54+
index.search(nb, database.data(), 1, distances.data(), labels.data());
55+
auto knn_time = std::chrono::duration_cast<std::chrono::milliseconds>(
56+
std::chrono::high_resolution_clock::now() - t)
57+
.count();
58+
59+
faiss::RangeSearchResult rsr(nb);
60+
t = std::chrono::high_resolution_clock::now();
61+
index.range_search(nb, database.data(), 1.0, &rsr);
62+
auto range_time = std::chrono::duration_cast<std::chrono::milliseconds>(
63+
std::chrono::high_resolution_clock::now() - t)
64+
.count();
65+
66+
// we expect the perf of knn and range search
67+
// to be similar, at least within a factor of 2
68+
ASSERT_LT(range_time, knn_time * 2);
69+
ASSERT_LT(knn_time, range_time * 2);
70+
}

0 commit comments

Comments
 (0)