Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip HNSWPQ sdc init with new io flag #3250

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion faiss/impl/index_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ Index* read_index(IOReader* f, int io_flags) {
read_HNSW(&idxhnsw->hnsw, f);
idxhnsw->storage = read_index(f, io_flags);
idxhnsw->own_fields = true;
if (h == fourcc("IHNp")) {
if (h == fourcc("IHNp") && !(io_flags & IO_FLAG_PQ_SKIP_SDC_TABLE)) {
dynamic_cast<IndexPQ*>(idxhnsw->storage)->pq.compute_sdc_table();
}
idx = idxhnsw;
Expand Down
6 changes: 6 additions & 0 deletions faiss/index_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ const int IO_FLAG_ONDISK_SAME_DIR = 4;
const int IO_FLAG_SKIP_IVF_DATA = 8;
// don't initialize precomputed table after loading
const int IO_FLAG_SKIP_PRECOMPUTE_TABLE = 16;
// don't compute the sdc table for PQ-based indices
// this will prevent distances from being computed
// between elements in the index. For indices like HNSWPQ,
// this will prevent graph building because sdc
// computations are required to construct the graph
const int IO_FLAG_PQ_SKIP_SDC_TABLE = 32;
// try to memmap data (useful to load an ArrayInvertedLists as an
// OnDiskInvertedLists)
const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000;
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ set(FAISS_TEST_SRC
test_hnsw.cpp
test_partitioning.cpp
test_fastscan_perf.cpp
test_io.cpp
)

add_executable(faiss_test ${FAISS_TEST_SRC})
Expand Down
61 changes: 61 additions & 0 deletions tests/test_io.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <gtest/gtest.h>

#include <random>

#include "faiss/Index.h"
#include "faiss/IndexHNSW.h"
#include "faiss/index_factory.h"
#include "faiss/index_io.h"
#include "test_util.h"

pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER;

TEST(IO, TestReadHNSWPQ_whenSDCDisabledFlagPassed_thenDisableSDCTable) {
Tempfilename index_filename(&temp_file_mutex, "/tmp/faiss_TestReadHNSWPQ");
int d = 32, n = 256;
std::default_random_engine rng(123);
std::uniform_real_distribution<float> u(0, 100);
std::vector<float> vectors(n * d);
for (size_t i = 0; i < n * d; i++) {
vectors[i] = u(rng);
}

// Build the index and write it to the temp file
{
std::unique_ptr<faiss::Index> index_writer(
faiss::index_factory(d, "HNSW8,PQ4", faiss::METRIC_L2));
index_writer->train(n, vectors.data());
index_writer->add(n, vectors.data());

faiss::write_index(index_writer.get(), index_filename.c_str());
}

// Load index from disk. Confirm that the sdc table is equal to 0 when
// disable sdc is set
{
std::unique_ptr<faiss::IndexHNSWPQ> index_reader_read_write(
dynamic_cast<faiss::IndexHNSWPQ*>(
faiss::read_index(index_filename.c_str())));
std::unique_ptr<faiss::IndexHNSWPQ> index_reader_sdc_disabled(
dynamic_cast<faiss::IndexHNSWPQ*>(faiss::read_index(
index_filename.c_str(),
faiss::IO_FLAG_PQ_SKIP_SDC_TABLE)));

ASSERT_NE(
dynamic_cast<faiss::IndexPQ*>(index_reader_read_write->storage)
->pq.sdc_table.size(),
0);
ASSERT_EQ(
dynamic_cast<faiss::IndexPQ*>(
index_reader_sdc_disabled->storage)
->pq.sdc_table.size(),
0);
}
}
35 changes: 5 additions & 30 deletions tests/test_merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,22 @@
*/

#include <cstdio>
#include <cstdlib>
#include <random>

#include <unistd.h>

#include <gtest/gtest.h>

#include <faiss/IVFlib.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFFlat.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/IndexPreTransform.h>
#include <faiss/MetaIndexes.h>
#include <faiss/invlists/OnDiskInvertedLists.h>

namespace {

struct Tempfilename {
static pthread_mutex_t mutex;

std::string filename = "/tmp/faiss_tmp_XXXXXX";

Tempfilename() {
pthread_mutex_lock(&mutex);
int fd = mkstemp(&filename[0]);
close(fd);
pthread_mutex_unlock(&mutex);
}

~Tempfilename() {
if (access(filename.c_str(), F_OK)) {
unlink(filename.c_str());
}
}
#include "test_util.h"

const char* c_str() {
return filename.c_str();
}
};
namespace {

pthread_mutex_t Tempfilename::mutex = PTHREAD_MUTEX_INITIALIZER;
pthread_mutex_t temp_file_mutex = PTHREAD_MUTEX_INITIALIZER;

typedef faiss::idx_t idx_t;

Expand Down Expand Up @@ -95,7 +70,7 @@ int compare_merged(
std::vector<float> refD(k * nq);

index_shards->search(nq, cd.queries.data(), k, refD.data(), refI.data());
Tempfilename filename;
Tempfilename filename(&temp_file_mutex, "/tmp/faiss_tmp_XXXXXX");

std::vector<idx_t> newI(k * nq);
std::vector<float> newD(k * nq);
Expand Down Expand Up @@ -212,7 +187,7 @@ TEST(MERGE, merge_flat_vt) {
TEST(MERGE, merge_flat_ondisk) {
faiss::IndexShards index_shards(d, false, false);
index_shards.own_indices = true;
Tempfilename filename;
Tempfilename filename(&temp_file_mutex, "/tmp/faiss_tmp_XXXXXX");

for (int i = 0; i < nindex; i++) {
auto ivf = new faiss::IndexIVFFlat(&cd.quantizer, d, nlist);
Expand Down
39 changes: 39 additions & 0 deletions tests/test_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

#ifndef FAISS_TEST_UTIL_H
#define FAISS_TEST_UTIL_H

#include <faiss/IndexIVFPQ.h>
#include <unistd.h>
#include <cstdlib>

struct Tempfilename {
pthread_mutex_t* mutex;
std::string filename;

Tempfilename(pthread_mutex_t* mutex, std::string filename) {
this->mutex = mutex;
this->filename = filename;
pthread_mutex_lock(mutex);
int fd = mkstemp(&filename[0]);
close(fd);
pthread_mutex_unlock(mutex);
}

~Tempfilename() {
if (access(filename.c_str(), F_OK)) {
unlink(filename.c_str());
}
}

const char* c_str() {
return filename.c_str();
}
};

#endif // FAISS_TEST_UTIL_H