Skip to content

Commit c67d210

Browse files
mdouzeaalekhpatel07
authored andcommitted
add skip_storage flag to HNSW (facebookresearch#3487)
Summary: Pull Request resolved: facebookresearch#3487 Sometimes it is not useful to serialize the storage index along with a HNSW index. This diff adds a flag that supports skipping the storage of the index. Searchign and adding to the index is not possible until a storage index is added back in. Reviewed By: junjieqi Differential Revision: D57911060 fbshipit-source-id: 5a4ceee4a8f53f6f746df59af3942b813a99c14f
1 parent 673ba81 commit c67d210

File tree

6 files changed

+68
-23
lines changed

6 files changed

+68
-23
lines changed

faiss/IndexHNSW.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// -*- c++ -*-
9-
108
#include <faiss/IndexHNSW.h>
119

1210
#include <omp.h>
@@ -251,7 +249,8 @@ void hnsw_search(
251249
const SearchParameters* params_in) {
252250
FAISS_THROW_IF_NOT_MSG(
253251
index->storage,
254-
"Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
252+
"No storage index, please use IndexHNSWFlat (or variants) "
253+
"instead of IndexHNSW directly");
255254
const SearchParametersHNSW* params = nullptr;
256255
const HNSW& hnsw = index->hnsw;
257256

faiss/impl/index_read.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// -*- c++ -*-
9-
108
#include <faiss/index_io.h>
119

1210
#include <faiss/impl/io_macros.h>
@@ -531,7 +529,11 @@ Index* read_index(IOReader* f, int io_flags) {
531529
Index* idx = nullptr;
532530
uint32_t h;
533531
READ1(h);
534-
if (h == fourcc("IxFI") || h == fourcc("IxF2") || h == fourcc("IxFl")) {
532+
if (h == fourcc("null")) {
533+
// denotes a missing index, useful for some cases
534+
return nullptr;
535+
} else if (
536+
h == fourcc("IxFI") || h == fourcc("IxF2") || h == fourcc("IxFl")) {
535537
IndexFlat* idxf;
536538
if (h == fourcc("IxFI")) {
537539
idxf = new IndexFlatIP();
@@ -961,7 +963,7 @@ Index* read_index(IOReader* f, int io_flags) {
961963
read_index_header(idxhnsw, f);
962964
read_HNSW(&idxhnsw->hnsw, f);
963965
idxhnsw->storage = read_index(f, io_flags);
964-
idxhnsw->own_fields = true;
966+
idxhnsw->own_fields = idxhnsw->storage != nullptr;
965967
if (h == fourcc("IHNp") && !(io_flags & IO_FLAG_PQ_SKIP_SDC_TABLE)) {
966968
dynamic_cast<IndexPQ*>(idxhnsw->storage)->pq.compute_sdc_table();
967969
}

faiss/impl/index_write.cpp

+16-9
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// -*- c++ -*-
9-
108
#include <faiss/index_io.h>
119

1210
#include <faiss/impl/io.h>
@@ -390,8 +388,12 @@ static void write_ivf_header(const IndexIVF* ivf, IOWriter* f) {
390388
write_direct_map(&ivf->direct_map, f);
391389
}
392390

393-
void write_index(const Index* idx, IOWriter* f) {
394-
if (const IndexFlat* idxf = dynamic_cast<const IndexFlat*>(idx)) {
391+
void write_index(const Index* idx, IOWriter* f, int io_flags) {
392+
if (idx == nullptr) {
393+
// eg. for a storage component of HNSW that is set to nullptr
394+
uint32_t h = fourcc("null");
395+
WRITE1(h);
396+
} else if (const IndexFlat* idxf = dynamic_cast<const IndexFlat*>(idx)) {
395397
uint32_t h =
396398
fourcc(idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI"
397399
: idxf->metric_type == METRIC_L2 ? "IxF2"
@@ -765,7 +767,12 @@ void write_index(const Index* idx, IOWriter* f) {
765767
WRITE1(h);
766768
write_index_header(idxhnsw, f);
767769
write_HNSW(&idxhnsw->hnsw, f);
768-
write_index(idxhnsw->storage, f);
770+
if (io_flags & IO_FLAG_SKIP_STORAGE) {
771+
uint32_t n4 = fourcc("null");
772+
WRITE1(n4);
773+
} else {
774+
write_index(idxhnsw->storage, f);
775+
}
769776
} else if (const IndexNSG* idxnsg = dynamic_cast<const IndexNSG*>(idx)) {
770777
uint32_t h = dynamic_cast<const IndexNSGFlat*>(idx) ? fourcc("INSf")
771778
: dynamic_cast<const IndexNSGPQ*>(idx) ? fourcc("INSp")
@@ -841,14 +848,14 @@ void write_index(const Index* idx, IOWriter* f) {
841848
}
842849
}
843850

844-
void write_index(const Index* idx, FILE* f) {
851+
void write_index(const Index* idx, FILE* f, int io_flags) {
845852
FileIOWriter writer(f);
846-
write_index(idx, &writer);
853+
write_index(idx, &writer, io_flags);
847854
}
848855

849-
void write_index(const Index* idx, const char* fname) {
856+
void write_index(const Index* idx, const char* fname, int io_flags) {
850857
FileIOWriter writer(fname);
851-
write_index(idx, &writer);
858+
write_index(idx, &writer, io_flags);
852859
}
853860

854861
void write_VectorTransform(const VectorTransform* vt, const char* fname) {

faiss/index_io.h

+6-5
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// -*- c++ -*-
9-
108
// I/O code for indexes
119

1210
#ifndef FAISS_INDEX_IO_H
@@ -35,9 +33,12 @@ struct IOReader;
3533
struct IOWriter;
3634
struct InvertedLists;
3735

38-
void write_index(const Index* idx, const char* fname);
39-
void write_index(const Index* idx, FILE* f);
40-
void write_index(const Index* idx, IOWriter* writer);
36+
/// skip the storage for graph-based indexes
37+
const int IO_FLAG_SKIP_STORAGE = 1;
38+
39+
void write_index(const Index* idx, const char* fname, int io_flags = 0);
40+
void write_index(const Index* idx, FILE* f, int io_flags = 0);
41+
void write_index(const Index* idx, IOWriter* writer, int io_flags = 0);
4142

4243
void write_index_binary(const IndexBinary* idx, const char* fname);
4344
void write_index_binary(const IndexBinary* idx, FILE* f);

faiss/python/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,10 @@ def range_search_with_parameters(index, x, radius, params=None, output_stats=Fal
292292
###########################################
293293

294294

295-
def serialize_index(index):
295+
def serialize_index(index, io_flags=0):
296296
""" convert an index to a numpy uint8 array """
297297
writer = VectorIOWriter()
298-
write_index(index, writer)
298+
write_index(index, writer, io_flags)
299299
return vector_to_array(writer.data)
300300

301301

tests/test_graph_based.py

+36
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,42 @@ def test_ndis_stats(self):
133133
Dhnsw, Ihnsw = index.search(self.xq, 1)
134134
self.assertGreater(stats.ndis, len(self.xq) * index.hnsw.efSearch)
135135

136+
def test_io_no_storage(self):
137+
d = self.xq.shape[1]
138+
index = faiss.IndexHNSWFlat(d, 16)
139+
index.add(self.xb)
140+
141+
Dref, Iref = index.search(self.xq, 5)
142+
143+
# test writing without storage
144+
index2 = faiss.deserialize_index(
145+
faiss.serialize_index(index, faiss.IO_FLAG_SKIP_STORAGE)
146+
)
147+
self.assertEquals(index2.storage, None)
148+
self.assertRaises(
149+
RuntimeError,
150+
index2.search, self.xb, 1)
151+
152+
# make sure we can store an index with empty storage
153+
index4 = faiss.deserialize_index(
154+
faiss.serialize_index(index2))
155+
156+
# add storage afterwards
157+
index.storage = faiss.clone_index(index.storage)
158+
index.own_fields = True
159+
160+
Dnew, Inew = index.search(self.xq, 5)
161+
np.testing.assert_array_equal(Dnew, Dref)
162+
np.testing.assert_array_equal(Inew, Iref)
163+
164+
if False:
165+
# test reading without storage
166+
# not implemented because it is hard to skip over an index
167+
index3 = faiss.deserialize_index(
168+
faiss.serialize_index(index), faiss.IO_FLAG_SKIP_STORAGE
169+
)
170+
self.assertEquals(index3.storage, None)
171+
136172

137173
class TestNSG(unittest.TestCase):
138174

0 commit comments

Comments
 (0)