Skip to content

Commit 51b6083

Browse files
algoriddlefacebook-github-bot
authored andcommitted
faiss on rocksdb demo (facebookresearch#3216)
Summary: Pull Request resolved: facebookresearch#3216 Reviewed By: mdouze Differential Revision: D53051090 Pulled By: algoriddle fbshipit-source-id: 13a027db36207af9be11a2f181116994b2aff2cb
1 parent c4b91a5 commit 51b6083

File tree

5 files changed

+278
-0
lines changed

5 files changed

+278
-0
lines changed

demos/rocksdb_ivf/CMakeLists.txt

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
cmake_minimum_required(VERSION 3.17 FATAL_ERROR)
2+
project (ROCKSDB_IVF)
3+
set(CMAKE_BUILD_TYPE Debug)
4+
find_package(faiss REQUIRED)
5+
find_package(RocksDB REQUIRED)
6+
7+
add_executable(demo_rocksdb_ivf demo_rocksdb_ivf.cpp RocksDBInvertedLists.cpp)
8+
target_link_libraries(demo_rocksdb_ivf faiss RocksDB::rocksdb)

demos/rocksdb_ivf/README.md

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Storing Faiss inverted lists in RocksDB
2+
3+
Demo of storing the inverted lists of any IVF index in RocksDB or any similar key-value store which supports the prefix scan operation.
4+
5+
# How to build
6+
7+
We use conda to create the build environment for simplicity. Only tested on Linux x86.
8+
9+
```
10+
conda create -n rocksdb_ivf
11+
conda activate rocksdb_ivf
12+
conda install pytorch::faiss-cpu conda-forge::rocksdb cmake make gxx_linux-64 sysroot_linux-64
13+
cd ~/faiss/demos/rocksdb_ivf
14+
cmake -B build .
15+
make -C build -j$(nproc)
16+
```
17+
18+
# Run the example
19+
20+
```
21+
cd ~/faiss/demos/rocksdb_ivf/build
22+
./rocksdb_ivf test_db
23+
```
+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#include "RocksDBInvertedLists.h"
4+
5+
#include <faiss/impl/FaissAssert.h>
6+
7+
using namespace faiss;
8+
9+
namespace faiss_rocksdb {
10+
11+
RocksDBInvertedListsIterator::RocksDBInvertedListsIterator(
12+
rocksdb::DB* db,
13+
size_t list_no,
14+
size_t code_size)
15+
: InvertedListsIterator(),
16+
it(db->NewIterator(rocksdb::ReadOptions())),
17+
list_no(list_no),
18+
code_size(code_size),
19+
codes(code_size) {
20+
it->Seek(rocksdb::Slice(
21+
reinterpret_cast<const char*>(&list_no), sizeof(size_t)));
22+
}
23+
24+
bool RocksDBInvertedListsIterator::is_available() const {
25+
return it->Valid() &&
26+
it->key().starts_with(rocksdb::Slice(
27+
reinterpret_cast<const char*>(&list_no), sizeof(size_t)));
28+
}
29+
30+
void RocksDBInvertedListsIterator::next() {
31+
it->Next();
32+
}
33+
34+
std::pair<idx_t, const uint8_t*> RocksDBInvertedListsIterator::
35+
get_id_and_codes() {
36+
idx_t id =
37+
*reinterpret_cast<const idx_t*>(&it->key().data()[sizeof(size_t)]);
38+
assert(code_size == it->value().size());
39+
return {id, reinterpret_cast<const uint8_t*>(it->value().data())};
40+
}
41+
42+
RocksDBInvertedLists::RocksDBInvertedLists(
43+
const char* db_directory,
44+
size_t nlist,
45+
size_t code_size)
46+
: InvertedLists(nlist, code_size) {
47+
use_iterator = true;
48+
49+
rocksdb::Options options;
50+
options.create_if_missing = true;
51+
rocksdb::DB* db;
52+
rocksdb::Status status = rocksdb::DB::Open(options, db_directory, &db);
53+
db_ = std::unique_ptr<rocksdb::DB>(db);
54+
assert(status.ok());
55+
}
56+
57+
size_t RocksDBInvertedLists::list_size(size_t /*list_no*/) const {
58+
FAISS_THROW_MSG("list_size is not supported");
59+
}
60+
61+
const uint8_t* RocksDBInvertedLists::get_codes(size_t /*list_no*/) const {
62+
FAISS_THROW_MSG("get_codes is not supported");
63+
}
64+
65+
const idx_t* RocksDBInvertedLists::get_ids(size_t /*list_no*/) const {
66+
FAISS_THROW_MSG("get_ids is not supported");
67+
}
68+
69+
size_t RocksDBInvertedLists::add_entries(
70+
size_t list_no,
71+
size_t n_entry,
72+
const idx_t* ids,
73+
const uint8_t* code) {
74+
rocksdb::WriteOptions wo;
75+
std::vector<char> key(sizeof(size_t) + sizeof(idx_t));
76+
memcpy(key.data(), &list_no, sizeof(size_t));
77+
for (size_t i = 0; i < n_entry; i++) {
78+
memcpy(key.data() + sizeof(size_t), ids + i, sizeof(idx_t));
79+
rocksdb::Status status = db_->Put(
80+
wo,
81+
rocksdb::Slice(key.data(), key.size()),
82+
rocksdb::Slice(
83+
reinterpret_cast<const char*>(code + i * code_size),
84+
code_size));
85+
assert(status.ok());
86+
}
87+
return 0; // ignored
88+
}
89+
90+
void RocksDBInvertedLists::update_entries(
91+
size_t /*list_no*/,
92+
size_t /*offset*/,
93+
size_t /*n_entry*/,
94+
const idx_t* /*ids*/,
95+
const uint8_t* /*code*/) {
96+
FAISS_THROW_MSG("update_entries is not supported");
97+
}
98+
99+
void RocksDBInvertedLists::resize(size_t /*list_no*/, size_t /*new_size*/) {
100+
FAISS_THROW_MSG("resize is not supported");
101+
}
102+
103+
InvertedListsIterator* RocksDBInvertedLists::get_iterator(
104+
size_t list_no) const {
105+
return new RocksDBInvertedListsIterator(db_.get(), list_no, code_size);
106+
}
107+
108+
} // namespace faiss_rocksdb
+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#pragma once
4+
5+
#include <faiss/invlists/InvertedLists.h>
6+
7+
#include <rocksdb/db.h>
8+
9+
namespace faiss_rocksdb {
10+
11+
struct RocksDBInvertedListsIterator : faiss::InvertedListsIterator {
12+
RocksDBInvertedListsIterator(
13+
rocksdb::DB* db,
14+
size_t list_no,
15+
size_t code_size);
16+
virtual bool is_available() const override;
17+
virtual void next() override;
18+
virtual std::pair<faiss::idx_t, const uint8_t*> get_id_and_codes() override;
19+
20+
private:
21+
std::unique_ptr<rocksdb::Iterator> it;
22+
size_t list_no;
23+
size_t code_size;
24+
std::vector<uint8_t> codes; // buffer for returning codes in next()
25+
};
26+
27+
struct RocksDBInvertedLists : faiss::InvertedLists {
28+
RocksDBInvertedLists(
29+
const char* db_directory,
30+
size_t nlist,
31+
size_t code_size);
32+
33+
size_t list_size(size_t list_no) const override;
34+
const uint8_t* get_codes(size_t list_no) const override;
35+
const faiss::idx_t* get_ids(size_t list_no) const override;
36+
37+
size_t add_entries(
38+
size_t list_no,
39+
size_t n_entry,
40+
const faiss::idx_t* ids,
41+
const uint8_t* code) override;
42+
43+
void update_entries(
44+
size_t list_no,
45+
size_t offset,
46+
size_t n_entry,
47+
const faiss::idx_t* ids,
48+
const uint8_t* code) override;
49+
50+
void resize(size_t list_no, size_t new_size) override;
51+
52+
faiss::InvertedListsIterator* get_iterator(size_t list_no) const override;
53+
54+
private:
55+
std::unique_ptr<rocksdb::DB> db_;
56+
};
57+
58+
} // namespace faiss_rocksdb
+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
#include <exception>
4+
#include <iostream>
5+
#include <memory>
6+
7+
#include "RocksDBInvertedLists.h"
8+
9+
#include <faiss/IndexFlat.h>
10+
#include <faiss/IndexIVFFlat.h>
11+
#include <faiss/impl/AuxIndexStructures.h>
12+
#include <faiss/impl/FaissException.h>
13+
#include <faiss/utils/random.h>
14+
15+
using namespace faiss;
16+
17+
int main(int argc, char* argv[]) {
18+
try {
19+
if (argc != 2) {
20+
std::cerr << "missing db directory argument" << std::endl;
21+
return -1;
22+
}
23+
size_t d = 128;
24+
size_t nlist = 100;
25+
IndexFlatL2 quantizer(d);
26+
IndexIVFFlat index(&quantizer, d, nlist);
27+
faiss_rocksdb::RocksDBInvertedLists ril(
28+
argv[1], nlist, index.code_size);
29+
index.replace_invlists(&ril, false);
30+
31+
idx_t nb = 10000;
32+
std::vector<float> xb(d * nb);
33+
float_rand(xb.data(), d * nb, 12345);
34+
std::vector<idx_t> xids(nb);
35+
std::iota(xids.begin(), xids.end(), 0);
36+
37+
index.train(nb, xb.data());
38+
index.add_with_ids(nb, xb.data(), xids.data());
39+
40+
idx_t nq = 20; // nb;
41+
index.nprobe = 2;
42+
43+
std::cout << "search" << std::endl;
44+
idx_t k = 5;
45+
std::vector<float> distances(nq * k);
46+
std::vector<idx_t> labels(nq * k, -1);
47+
index.search(
48+
nq, xb.data(), k, distances.data(), labels.data(), nullptr);
49+
50+
for (idx_t iq = 0; iq < nq; iq++) {
51+
std::cout << iq << ": ";
52+
for (auto j = 0; j < k; j++) {
53+
std::cout << labels[iq * k + j] << " " << distances[iq * k + j]
54+
<< " | ";
55+
}
56+
std::cout << std::endl;
57+
}
58+
59+
std::cout << std::endl << "range search" << std::endl;
60+
float range = 15.0f;
61+
RangeSearchResult result(nq);
62+
index.range_search(nq, xb.data(), range, &result);
63+
64+
for (idx_t iq = 0; iq < nq; iq++) {
65+
std::cout << iq << ": ";
66+
for (auto j = result.lims[iq]; j < result.lims[iq + 1]; j++) {
67+
std::cout << result.labels[j] << " " << result.distances[j]
68+
<< " | ";
69+
}
70+
std::cout << std::endl;
71+
}
72+
73+
} catch (FaissException& e) {
74+
std::cerr << e.what() << '\n';
75+
} catch (std::exception& e) {
76+
std::cerr << e.what() << '\n';
77+
} catch (...) {
78+
std::cerr << "Unrecognized exception!\n";
79+
}
80+
return 0;
81+
}

0 commit comments

Comments
 (0)