Skip to content

Commit 86fa0db

Browse files
alisafayafacebook-github-bot
authored andcommitted
Fix IndexIVFFastScan reconstruct_from_offset method (facebookresearch#4095)
Summary: Resolves issue facebookresearch#4089 - IndexIVFPQFastScan crashes with certain nlist values The `reconstruct_from_offset` method in `IndexIVFFastScan` was incorrectly reconstructing vectors, causing crashes when the `nlist` parameter was not byte-aligned (e.g. 100 instead of 256). The root cause was that the `list_no` (Voronoi cell number) was not being properly encoded into the `code` vector before passing it to the `sa_decode` function. This resulted in invalid `list_no` values being read in `sa_decode`, triggering the assertion failure `'list_no >= 0 && list_no < nlist'` when `nlist` in some cases. This PR fixes the issue with the following changes to `reconstruct_from_offset`: 1. Encode the `list_no` into the beginning of the `code` vector using the existing `encode_listno` method 2. Start the `BitstringWriter` after the coarse code portion of `code` (shifted by `coarse_code_size()` bytes) 3. Remove the residual centroid addition logic, as it is already handled in `sa_decode` After these changes: - Crashes no longer occur for any `nlist` value - Reconstruction is now correct, matching the output of `IndexIVFPQ` Fixes facebookresearch#4089 Please review and let me know if any changes are needed. Thanks! Pull Request resolved: facebookresearch#4095 Reviewed By: asadoughi Differential Revision: D67937160 Pulled By: mdouze fbshipit-source-id: 4705106ba49c01c788b3c75c39c2260615f45764
1 parent b9fe1dc commit 86fa0db

File tree

3 files changed

+40
-12
lines changed

3 files changed

+40
-12
lines changed

faiss/IndexIVFFastScan.cpp

+8-12
Original file line numberDiff line numberDiff line change
@@ -1353,34 +1353,30 @@ void IndexIVFFastScan::reconstruct_from_offset(
13531353
int64_t offset,
13541354
float* recons) const {
13551355
// unpack codes
1356+
size_t coarse_size = coarse_code_size();
1357+
std::vector<uint8_t> code(coarse_size + code_size, 0);
1358+
encode_listno(list_no, code.data());
13561359
InvertedLists::ScopedCodes list_codes(invlists, list_no);
1357-
std::vector<uint8_t> code(code_size, 0);
1358-
BitstringWriter bsw(code.data(), code_size);
1360+
BitstringWriter bsw(code.data() + coarse_size, code_size);
1361+
13591362
for (size_t m = 0; m < M; m++) {
13601363
uint8_t c =
13611364
pq4_get_packed_element(list_codes.get(), bbs, M2, offset, m);
13621365
bsw.write(c, nbits);
13631366
}
1364-
sa_decode(1, code.data(), recons);
13651367

1366-
// add centroid to it
1367-
if (by_residual) {
1368-
std::vector<float> centroid(d);
1369-
quantizer->reconstruct(list_no, centroid.data());
1370-
for (int i = 0; i < d; ++i) {
1371-
recons[i] += centroid[i];
1372-
}
1373-
}
1368+
sa_decode(1, code.data(), recons);
13741369
}
13751370

13761371
void IndexIVFFastScan::reconstruct_orig_invlists() {
13771372
FAISS_THROW_IF_NOT(orig_invlists != nullptr);
13781373
FAISS_THROW_IF_NOT(orig_invlists->list_size(0) == 0);
13791374

1375+
#pragma omp parallel for if (nlist > 100)
13801376
for (size_t list_no = 0; list_no < nlist; list_no++) {
13811377
InvertedLists::ScopedCodes codes(invlists, list_no);
13821378
InvertedLists::ScopedIds ids(invlists, list_no);
1383-
size_t list_size = orig_invlists->list_size(list_no);
1379+
size_t list_size = invlists->list_size(list_no);
13841380
std::vector<uint8_t> code(code_size, 0);
13851381

13861382
for (size_t offset = 0; offset < list_size; offset++) {

faiss/IndexIVFPQFastScan.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
7676
precomputed_table.nbytes());
7777
}
7878

79+
#pragma omp parallel for if (nlist > 100)
7980
for (size_t i = 0; i < nlist; i++) {
8081
size_t nb = orig.invlists->list_size(i);
8182
size_t nb2 = roundup(nb, bbs);

tests/test_fast_scan_ivf.py

+31
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,37 @@ def test_by_residual_odd_dim(self):
543543
self.do_test(by_residual=True, d=30)
544544

545545

546+
class TestReconstruct(unittest.TestCase):
547+
548+
def do_test(self, by_residual=False):
549+
d = 32
550+
metric = faiss.METRIC_L2
551+
552+
ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
553+
554+
index = faiss.IndexIVFPQFastScan(faiss.IndexFlatL2(d), d, 50, d // 2, 4, metric)
555+
index.by_residual = by_residual
556+
index.make_direct_map(True)
557+
index.train(ds.get_train())
558+
index.add(ds.get_database())
559+
560+
# Test reconstruction
561+
index.reconstruct(123) # single id
562+
index.reconstruct_n(123, 10) # single id
563+
index.reconstruct_batch(np.arange(10))
564+
565+
# Test original list reconstruction
566+
index.orig_invlists = faiss.ArrayInvertedLists(index.nlist, index.code_size)
567+
index.reconstruct_orig_invlists()
568+
assert index.orig_invlists.compute_ntotal() == index.ntotal
569+
570+
def test_no_residual(self):
571+
self.do_test(by_residual=False)
572+
573+
def test_by_residual(self):
574+
self.do_test(by_residual=True)
575+
576+
546577
class TestIsTrained(unittest.TestCase):
547578

548579
def test_issue_2019(self):

0 commit comments

Comments
 (0)