Skip to content

Commit a0b9590

Browse files
committed
Skip PQ sdc init with new io flag
Add new IO flag, IO_FLAG_PQ_SKIP_SDC_TABLE, so that when reading HNSWPQ from disk, it will skip building the sdc table. sdc table is only used during graph construction, so if this flag is set, the HNSWPQ index will not be updateable. In addition, adds cpp test case verifying functionality and build test util header file to share creation of temporary files amongst tests. Signed-off-by: John Mazanec <jmazane@amazon.com>
1 parent a187394 commit a0b9590

File tree

6 files changed

+113
-31
lines changed

6 files changed

+113
-31
lines changed

faiss/impl/index_read.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ Index* read_index(IOReader* f, int io_flags) {
962962
read_HNSW(&idxhnsw->hnsw, f);
963963
idxhnsw->storage = read_index(f, io_flags);
964964
idxhnsw->own_fields = true;
965-
if (h == fourcc("IHNp")) {
965+
if (h == fourcc("IHNp") && !(io_flags & IO_FLAG_PQ_SKIP_SDC_TABLE)) {
966966
dynamic_cast<IndexPQ*>(idxhnsw->storage)->pq.compute_sdc_table();
967967
}
968968
idx = idxhnsw;

faiss/index_io.h

+6
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ const int IO_FLAG_ONDISK_SAME_DIR = 4;
5252
const int IO_FLAG_SKIP_IVF_DATA = 8;
5353
// don't initialize precomputed table after loading
5454
const int IO_FLAG_SKIP_PRECOMPUTE_TABLE = 16;
55+
// don't compute the sdc table for PQ-based indices
56+
// this will prevent distances from being computed
57+
// between elements in the index. For indices like HNSWPQ,
58+
// this will prevent graph building because sdc
59+
// computations are required to construct the graph
60+
const int IO_FLAG_PQ_SKIP_SDC_TABLE = 32;
5561
// try to memmap data (useful to load an ArrayInvertedLists as an
5662
// OnDiskInvertedLists)
5763
const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000;

tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ set(FAISS_TEST_SRC
3232
test_hnsw.cpp
3333
test_partitioning.cpp
3434
test_fastscan_perf.cpp
35+
test_io.cpp
3536
)
3637

3738
add_executable(faiss_test ${FAISS_TEST_SRC})

tests/test_io.cpp

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <gtest/gtest.h>
9+
10+
#include <random>
11+
12+
#include "faiss/Index.h"
13+
#include "faiss/IndexHNSW.h"
14+
#include "faiss/index_factory.h"
15+
#include "faiss/index_io.h"
16+
#include "test_util.h"
17+
18+
pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER;
19+
20+
TEST(IO, TestReadHNSWPQ_whenSDCDisabledFlagPassed_thenDisableSDCTable) {
21+
Tempfilename index_filename(&temp_file_mutex, "/tmp/faiss_TestReadHNSWPQ");
22+
int d = 32, n = 256;
23+
std::default_random_engine rng(123);
24+
std::uniform_real_distribution<float> u(0, 100);
25+
std::vector<float> vectors(n * d);
26+
for (size_t i = 0; i < n * d; i++) {
27+
vectors[i] = u(rng);
28+
}
29+
30+
// Build the index and write it to the temp file
31+
{
32+
std::unique_ptr<faiss::Index> index_writer(
33+
faiss::index_factory(d, "HNSW8,PQ4", faiss::METRIC_L2));
34+
index_writer->train(n, vectors.data());
35+
index_writer->add(n, vectors.data());
36+
37+
faiss::write_index(index_writer.get(), index_filename.c_str());
38+
}
39+
40+
// Load index from disk. Confirm that the sdc table is equal to 0 when
41+
// disable sdc is set
42+
{
43+
std::unique_ptr<faiss::IndexHNSWPQ> index_reader_read_write(
44+
dynamic_cast<faiss::IndexHNSWPQ*>(
45+
faiss::read_index(index_filename.c_str())));
46+
std::unique_ptr<faiss::IndexHNSWPQ> index_reader_sdc_disabled(
47+
dynamic_cast<faiss::IndexHNSWPQ*>(faiss::read_index(
48+
index_filename.c_str(),
49+
faiss::IO_FLAG_PQ_SKIP_SDC_TABLE)));
50+
51+
ASSERT_NE(
52+
dynamic_cast<faiss::IndexPQ*>(index_reader_read_write->storage)
53+
->pq.sdc_table.size(),
54+
0);
55+
ASSERT_EQ(
56+
dynamic_cast<faiss::IndexPQ*>(
57+
index_reader_sdc_disabled->storage)
58+
->pq.sdc_table.size(),
59+
0);
60+
}
61+
}

tests/test_merge.cpp

+5-30
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,22 @@
66
*/
77

88
#include <cstdio>
9-
#include <cstdlib>
109
#include <random>
1110

12-
#include <unistd.h>
13-
1411
#include <gtest/gtest.h>
1512

1613
#include <faiss/IVFlib.h>
1714
#include <faiss/IndexFlat.h>
1815
#include <faiss/IndexIVFFlat.h>
19-
#include <faiss/IndexIVFPQ.h>
2016
#include <faiss/IndexPreTransform.h>
2117
#include <faiss/MetaIndexes.h>
2218
#include <faiss/invlists/OnDiskInvertedLists.h>
2319

24-
namespace {
25-
26-
struct Tempfilename {
27-
static pthread_mutex_t mutex;
28-
29-
std::string filename = "/tmp/faiss_tmp_XXXXXX";
30-
31-
Tempfilename() {
32-
pthread_mutex_lock(&mutex);
33-
int fd = mkstemp(&filename[0]);
34-
close(fd);
35-
pthread_mutex_unlock(&mutex);
36-
}
37-
38-
~Tempfilename() {
39-
if (access(filename.c_str(), F_OK)) {
40-
unlink(filename.c_str());
41-
}
42-
}
20+
#include "test_util.h"
4321

44-
const char* c_str() {
45-
return filename.c_str();
46-
}
47-
};
22+
namespace {
4823

49-
pthread_mutex_t Tempfilename::mutex = PTHREAD_MUTEX_INITIALIZER;
24+
pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER;
5025

5126
typedef faiss::idx_t idx_t;
5227

@@ -95,7 +70,7 @@ int compare_merged(
9570
std::vector<float> refD(k * nq);
9671

9772
index_shards->search(nq, cd.queries.data(), k, refD.data(), refI.data());
98-
Tempfilename filename;
73+
Tempfilename filename(&temp_file_mutex, "/tmp/faiss_tmp_XXXXXX");
9974

10075
std::vector<idx_t> newI(k * nq);
10176
std::vector<float> newD(k * nq);
@@ -212,7 +187,7 @@ TEST(MERGE, merge_flat_vt) {
212187
TEST(MERGE, merge_flat_ondisk) {
213188
faiss::IndexShards index_shards(d, false, false);
214189
index_shards.own_indices = true;
215-
Tempfilename filename;
190+
Tempfilename filename(&temp_file_mutex, "/tmp/faiss_tmp_XXXXXX");
216191

217192
for (int i = 0; i < nindex; i++) {
218193
auto ivf = new faiss::IndexIVFFlat(&cd.quantizer, d, nlist);

tests/test_util.h

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#ifndef FAISS_TEST_UTIL_H
9+
#define FAISS_TEST_UTIL_H
10+
11+
#include <faiss/IndexIVFPQ.h>
12+
#include <unistd.h>
13+
#include <cstdlib>
14+
15+
struct Tempfilename {
16+
pthread_mutex_t* mutex;
17+
std::string filename;
18+
19+
Tempfilename(pthread_mutex_t* mutex, std::string filename) {
20+
this->mutex = mutex;
21+
this->filename = filename;
22+
pthread_mutex_lock(mutex);
23+
int fd = mkstemp(&filename[0]);
24+
close(fd);
25+
pthread_mutex_unlock(mutex);
26+
}
27+
28+
~Tempfilename() {
29+
if (access(filename.c_str(), F_OK)) {
30+
unlink(filename.c_str());
31+
}
32+
}
33+
34+
const char* c_str() {
35+
return filename.c_str();
36+
}
37+
};
38+
39+
#endif // FAISS_TEST_UTIL_H

0 commit comments

Comments
 (0)