Skip to content

Commit 89e93e2

Browse files
mdouzefacebook-github-bot
authored andcommitted
more fast-scan reconstruction (facebookresearch#4128)
Summary: Pull Request resolved: facebookresearch#4128 Fix reconstruction code for the fast-scan and IVF fast-scan indices. Reviewed By: asadoughi Differential Revision: D68159014 fbshipit-source-id: fb33416eed994196b34f0f6d3014f4d4859b6039
1 parent 86fa0db commit 89e93e2

10 files changed

+131
-57
lines changed

faiss/IndexFastScan.h

+9
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ struct IndexFastScan : Index {
133133

134134
void merge_from(Index& otherIndex, idx_t add_id = 0) override;
135135
void check_compatible_for_merge(const Index& otherIndex) const override;
136+
137+
/// standalone codes interface (but the codes are flattened)
138+
size_t sa_code_size() const override {
139+
return code_size;
140+
}
141+
142+
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override {
143+
compute_codes(bytes, n, x);
144+
}
136145
};
137146

138147
struct FastScanStats {

faiss/IndexIVF.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ struct IndexIVF : Index, IndexIVFInterface {
436436
size_t sa_code_size() const override;
437437

438438
/** encode a set of vectors
439-
* sa_encode will call encode_vector with include_listno=true
439+
* sa_encode will call encode_vectors with include_listno=true
440440
* @param n nb of vectors to encode
441441
* @param x the vectors to encode
442442
* @param bytes output array for the codes

faiss/IndexIVFAdditiveQuantizerFastScan.cpp

+1-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
#include <faiss/IndexIVFAdditiveQuantizerFastScan.h>
99

10-
#include <cassert>
1110
#include <cinttypes>
1211
#include <cstdio>
1312

@@ -67,7 +66,7 @@ void IndexIVFAdditiveQuantizerFastScan::init(
6766
} else {
6867
M = aq->M;
6968
}
70-
init_fastscan(M, 4, nlist, metric, bbs);
69+
init_fastscan(aq, M, 4, nlist, metric, bbs);
7170

7271
max_train_points = 1024 * ksub * M;
7372
by_residual = true;
@@ -440,13 +439,6 @@ void IndexIVFAdditiveQuantizerFastScan::compute_LUT(
440439
}
441440
}
442441

443-
void IndexIVFAdditiveQuantizerFastScan::sa_decode(
444-
idx_t n,
445-
const uint8_t* bytes,
446-
float* x) const {
447-
aq->decode(bytes, x, n);
448-
}
449-
450442
/********** IndexIVFLocalSearchQuantizerFastScan ************/
451443
IndexIVFLocalSearchQuantizerFastScan::IndexIVFLocalSearchQuantizerFastScan(
452444
Index* quantizer,

faiss/IndexIVFAdditiveQuantizerFastScan.h

-2
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,6 @@ struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan {
9696
const CoarseQuantized& cq,
9797
AlignedTable<float>& dis_tables,
9898
AlignedTable<float>& biases) const override;
99-
100-
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
10199
};
102100

103101
struct IndexIVFLocalSearchQuantizerFastScan

faiss/IndexIVFFastScan.cpp

+29-1
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,24 @@ IndexIVFFastScan::IndexIVFFastScan() {
5555
}
5656

5757
void IndexIVFFastScan::init_fastscan(
58+
Quantizer* fine_quantizer,
5859
size_t M,
5960
size_t nbits_init,
6061
size_t nlist,
6162
MetricType /* metric */,
6263
int bbs_2) {
6364
FAISS_THROW_IF_NOT(bbs_2 % 32 == 0);
6465
FAISS_THROW_IF_NOT(nbits_init == 4);
66+
FAISS_THROW_IF_NOT(fine_quantizer->d == d);
6567

68+
this->fine_quantizer = fine_quantizer;
6669
this->M = M;
6770
this->nbits = nbits_init;
6871
this->bbs = bbs_2;
6972
ksub = (1 << nbits_init);
7073
M2 = roundup(M, 2);
7174
code_size = M2 / 2;
75+
FAISS_THROW_IF_NOT(code_size == fine_quantizer->code_size);
7276

7377
is_trained = false;
7478
replace_invlists(new BlockInvertedLists(nlist, get_CodePacker()), true);
@@ -1373,7 +1377,7 @@ void IndexIVFFastScan::reconstruct_orig_invlists() {
13731377
FAISS_THROW_IF_NOT(orig_invlists->list_size(0) == 0);
13741378

13751379
#pragma omp parallel for if (nlist > 100)
1376-
for (size_t list_no = 0; list_no < nlist; list_no++) {
1380+
for (idx_t list_no = 0; list_no < nlist; list_no++) {
13771381
InvertedLists::ScopedCodes codes(invlists, list_no);
13781382
InvertedLists::ScopedIds ids(invlists, list_no);
13791383
size_t list_size = invlists->list_size(list_no);
@@ -1396,6 +1400,30 @@ void IndexIVFFastScan::reconstruct_orig_invlists() {
13961400
}
13971401
}
13981402

1403+
void IndexIVFFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x)
1404+
const {
1405+
size_t coarse_size = coarse_code_size();
1406+
1407+
#pragma omp parallel if (n > 1)
1408+
{
1409+
std::vector<float> residual(d);
1410+
1411+
#pragma omp for
1412+
for (idx_t i = 0; i < n; i++) {
1413+
const uint8_t* code = codes + i * (code_size + coarse_size);
1414+
int64_t list_no = decode_listno(code);
1415+
float* xi = x + i * d;
1416+
fine_quantizer->decode(code + coarse_size, xi, 1);
1417+
if (by_residual) {
1418+
quantizer->reconstruct(list_no, residual.data());
1419+
for (size_t j = 0; j < d; j++) {
1420+
xi[j] += residual[j];
1421+
}
1422+
}
1423+
}
1424+
}
1425+
}
1426+
13991427
IVFFastScanStats IVFFastScan_stats;
14001428

14011429
} // namespace faiss

faiss/IndexIVFFastScan.h

+17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ namespace faiss {
1616

1717
struct NormTableScaler;
1818
struct SIMDResultHandlerToFloat;
19+
struct Quantizer;
1920

2021
/** Fast scan version of IVFPQ and IVFAQ. Works for 4-bit PQ/AQ for now.
2122
*
@@ -59,6 +60,9 @@ struct IndexIVFFastScan : IndexIVF {
5960
int qbs = 0;
6061
size_t qbs2 = 0;
6162

63+
// quantizer used to pack the codes
64+
Quantizer* fine_quantizer = nullptr;
65+
6266
IndexIVFFastScan(
6367
Index* quantizer,
6468
size_t d,
@@ -68,7 +72,9 @@ struct IndexIVFFastScan : IndexIVF {
6872

6973
IndexIVFFastScan();
7074

75+
/// called by implementations
7176
void init_fastscan(
77+
Quantizer* fine_quantizer,
7278
size_t M,
7379
size_t nbits,
7480
size_t nlist,
@@ -225,6 +231,17 @@ struct IndexIVFFastScan : IndexIVF {
225231

226232
// reconstruct orig invlists (for debugging)
227233
void reconstruct_orig_invlists();
234+
235+
/** Decode a set of vectors.
236+
*
237+
* NOTE: The codes in the IndexFastScan object are non-contiguous.
238+
* But this method requires a contiguous representation.
239+
*
240+
* @param n number of vectors
241+
* @param bytes input encoded vectors, size n * code_size
242+
* @param x output vectors, size n * d
243+
*/
244+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
228245
};
229246

230247
struct IVFFastScanStats {

faiss/IndexIVFPQFastScan.cpp

+4-27
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(
4242
: IndexIVFFastScan(quantizer, d, nlist, 0, metric), pq(d, M, nbits) {
4343
by_residual = false; // set to false by default because it's faster
4444

45-
init_fastscan(M, nbits, nlist, metric, bbs);
45+
init_fastscan(&pq, M, nbits, nlist, metric, bbs);
4646
}
4747

4848
IndexIVFPQFastScan::IndexIVFPQFastScan() {
@@ -61,7 +61,8 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
6161
pq(orig.pq) {
6262
FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
6363

64-
init_fastscan(orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs);
64+
init_fastscan(
65+
&pq, orig.pq.M, orig.pq.nbits, orig.nlist, orig.metric_type, bbs);
6566

6667
by_residual = orig.by_residual;
6768
ntotal = orig.ntotal;
@@ -77,7 +78,7 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
7778
}
7879

7980
#pragma omp parallel for if (nlist > 100)
80-
for (size_t i = 0; i < nlist; i++) {
81+
for (idx_t i = 0; i < nlist; i++) {
8182
size_t nb = orig.invlists->list_size(i);
8283
size_t nb2 = roundup(nb, bbs);
8384
AlignedTable<uint8_t> tmp(nb2 * M2 / 2);
@@ -283,28 +284,4 @@ void IndexIVFPQFastScan::compute_LUT(
283284
}
284285
}
285286

286-
void IndexIVFPQFastScan::sa_decode(idx_t n, const uint8_t* codes, float* x)
287-
const {
288-
size_t coarse_size = coarse_code_size();
289-
290-
#pragma omp parallel if (n > 1)
291-
{
292-
std::vector<float> residual(d);
293-
294-
#pragma omp for
295-
for (idx_t i = 0; i < n; i++) {
296-
const uint8_t* code = codes + i * (code_size + coarse_size);
297-
int64_t list_no = decode_listno(code);
298-
float* xi = x + i * d;
299-
pq.decode(code + coarse_size, xi);
300-
if (by_residual) {
301-
quantizer->reconstruct(list_no, residual.data());
302-
for (size_t j = 0; j < d; j++) {
303-
xi[j] += residual[j];
304-
}
305-
}
306-
}
307-
}
308-
}
309-
310287
} // namespace faiss

faiss/IndexIVFPQFastScan.h

-2
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ struct IndexIVFPQFastScan : IndexIVFFastScan {
8080
const CoarseQuantized& cq,
8181
AlignedTable<float>& dis_tables,
8282
AlignedTable<float>& biases) const override;
83-
84-
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
8583
};
8684

8785
} // namespace faiss

faiss/IndexPQFastScan.h

-9
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,6 @@ struct IndexPQFastScan : IndexFastScan {
4747

4848
void compute_float_LUT(float* lut, idx_t n, const float* x) const override;
4949

50-
/** Decode a set of vectors.
51-
*
52-
* NOTE: The codes in the IndexPQFastScan object are non-contiguous.
53-
* But this method requires a contiguous representation.
54-
*
55-
* @param n number of vectors
56-
* @param bytes input encoded vectors, size n * code_size
57-
* @param x output vectors, size n * d
58-
*/
5950
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
6051
};
6152

tests/test_fast_scan_ivf.py

+70-6
Original file line numberDiff line numberDiff line change
@@ -544,35 +544,99 @@ def test_by_residual_odd_dim(self):
544544

545545

546546
class TestReconstruct(unittest.TestCase):
547+
""" test reconstruct and sa_encode / sa_decode
548+
(also for a few additive quantizer variants) """
547549

548550
def do_test(self, by_residual=False):
549551
d = 32
550552
metric = faiss.METRIC_L2
551553

552-
ds = datasets.SyntheticDataset(d, 2000, 5000, 200)
554+
ds = datasets.SyntheticDataset(d, 250, 200, 10)
553555

554-
index = faiss.IndexIVFPQFastScan(faiss.IndexFlatL2(d), d, 50, d // 2, 4, metric)
556+
index = faiss.IndexIVFPQFastScan(
557+
faiss.IndexFlatL2(d), d, 50, d // 2, 4, metric)
555558
index.by_residual = by_residual
556559
index.make_direct_map(True)
557560
index.train(ds.get_train())
558561
index.add(ds.get_database())
559562

560563
# Test reconstruction
561-
index.reconstruct(123) # single id
562-
index.reconstruct_n(123, 10) # single id
563-
index.reconstruct_batch(np.arange(10))
564+
v123 = index.reconstruct(123) # single id
565+
v120_10 = index.reconstruct_n(120, 10)
566+
np.testing.assert_array_equal(v120_10[3], v123)
567+
v120_10 = index.reconstruct_batch(np.arange(120, 130))
568+
np.testing.assert_array_equal(v120_10[3], v123)
564569

565570
# Test original list reconstruction
566-
index.orig_invlists = faiss.ArrayInvertedLists(index.nlist, index.code_size)
571+
index.orig_invlists = faiss.ArrayInvertedLists(
572+
index.nlist, index.code_size)
567573
index.reconstruct_orig_invlists()
568574
assert index.orig_invlists.compute_ntotal() == index.ntotal
569575

576+
# compare with non fast-scan index
577+
index2 = faiss.IndexIVFPQ(
578+
index.quantizer, d, 50, d // 2, 4, metric)
579+
index2.by_residual = by_residual
580+
index2.pq = index.pq
581+
index2.is_trained = True
582+
index2.replace_invlists(index.orig_invlists, False)
583+
index2.ntotal = index.ntotal
584+
index2.make_direct_map(True)
585+
assert np.all(index.reconstruct(123) == index2.reconstruct(123))
586+
570587
def test_no_residual(self):
571588
self.do_test(by_residual=False)
572589

573590
def test_by_residual(self):
574591
self.do_test(by_residual=True)
575592

593+
def do_test_generic(self, factory_string,
594+
by_residual=False, metric=faiss.METRIC_L2):
595+
d = 32
596+
ds = datasets.SyntheticDataset(d, 250, 200, 10)
597+
index = faiss.index_factory(ds.d, factory_string, metric)
598+
if "IVF" in factory_string:
599+
index.by_residual = by_residual
600+
index.make_direct_map(True)
601+
index.train(ds.get_train())
602+
index.add(ds.get_database())
603+
604+
# Test reconstruction
605+
v123 = index.reconstruct(123) # single id
606+
v120_10 = index.reconstruct_n(120, 10)
607+
np.testing.assert_array_equal(v120_10[3], v123)
608+
v120_10 = index.reconstruct_batch(np.arange(120, 130))
609+
np.testing.assert_array_equal(v120_10[3], v123)
610+
codes = index.sa_encode(ds.get_database()[120:130])
611+
np.testing.assert_array_equal(index.sa_decode(codes), v120_10)
612+
613+
# make sure pointers are correct after serialization
614+
index2 = faiss.deserialize_index(faiss.serialize_index(index))
615+
codes2 = index2.sa_encode(ds.get_database()[120:130])
616+
np.testing.assert_array_equal(codes, codes2)
617+
618+
619+
def test_ivfpq_residual(self):
620+
self.do_test_generic("IVF20,PQ16x4fs", by_residual=True)
621+
622+
def test_ivfpq_no_residual(self):
623+
self.do_test_generic("IVF20,PQ16x4fs", by_residual=False)
624+
625+
def test_pq(self):
626+
self.do_test_generic("PQ16x4fs")
627+
628+
def test_rq(self):
629+
self.do_test_generic("RQ4x4fs", metric=faiss.METRIC_INNER_PRODUCT)
630+
631+
def test_ivfprq(self):
632+
self.do_test_generic("IVF20,PRQ8x2x4fs", by_residual=True, metric=faiss.METRIC_INNER_PRODUCT)
633+
634+
def test_ivfprq_no_residual(self):
635+
self.do_test_generic("IVF20,PRQ8x2x4fs", by_residual=False, metric=faiss.METRIC_INNER_PRODUCT)
636+
637+
def test_prq(self):
638+
self.do_test_generic("PRQ8x2x4fs", metric=faiss.METRIC_INNER_PRODUCT)
639+
576640

577641
class TestIsTrained(unittest.TestCase):
578642

0 commit comments

Comments
 (0)