Skip to content

Commit a99dbcd

Browse files
mdouzefacebook-github-bot
authored andcommitted
implement ST_norm_from_LUT for the ResidualQuantizer (facebookresearch#3917)
Summary: Pull Request resolved: facebookresearch#3917 The norm computation ST_norm_from_LUT was not implemented in Faiss. See issue facebookresearch#3882 This diff adds an implementation for it. It is probably not very quick. A few precomputed tables for AdditiveQuantizer were moved form ResidualQuantizer. Reviewed By: asadoughi Differential Revision: D63975689 fbshipit-source-id: 1bbe497a66bb3891ae727a1cd2b719479f80a836
1 parent 07a345c commit a99dbcd

8 files changed

+130
-58
lines changed

faiss/IndexAdditiveQuantizer.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,7 @@ void IndexAdditiveQuantizer::search(
273273
DISPATCH(ST_norm_qint8)
274274
DISPATCH(ST_norm_qint4)
275275
DISPATCH(ST_norm_cqint4)
276+
DISPATCH(ST_norm_from_LUT)
276277
case AdditiveQuantizer::ST_norm_cqint8:
277278
case AdditiveQuantizer::ST_norm_lsq2x4:
278279
case AdditiveQuantizer::ST_norm_rq2x4:

faiss/IndexIVFAdditiveQuantizer.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ InvertedListScanner* IndexIVFAdditiveQuantizer::get_InvertedListScanner(
275275
return new AQInvertedListScannerLUT<false, AdditiveQuantizer::st>( \
276276
*this, store_pairs);
277277
A(ST_LUT_nonorm)
278-
// A(ST_norm_from_LUT)
278+
A(ST_norm_from_LUT)
279279
A(ST_norm_float)
280280
A(ST_norm_qint8)
281281
A(ST_norm_qint4)

faiss/impl/AdditiveQuantizer.cpp

+77-8
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,40 @@ void AdditiveQuantizer::train_norm(size_t n, const float* norms) {
152152
}
153153
}
154154

155+
void AdditiveQuantizer::compute_codebook_tables() {
156+
centroid_norms.resize(total_codebook_size);
157+
fvec_norms_L2sqr(
158+
centroid_norms.data(), codebooks.data(), d, total_codebook_size);
159+
size_t cross_table_size = 0;
160+
for (int m = 0; m < M; m++) {
161+
size_t K = (size_t)1 << nbits[m];
162+
cross_table_size += K * codebook_offsets[m];
163+
}
164+
codebook_cross_products.resize(cross_table_size);
165+
size_t ofs = 0;
166+
for (int m = 1; m < M; m++) {
167+
FINTEGER ki = (size_t)1 << nbits[m];
168+
FINTEGER kk = codebook_offsets[m];
169+
FINTEGER di = d;
170+
float zero = 0, one = 1;
171+
assert(ofs + ki * kk <= cross_table_size);
172+
sgemm_("Transposed",
173+
"Not transposed",
174+
&ki,
175+
&kk,
176+
&di,
177+
&one,
178+
codebooks.data() + d * kk,
179+
&di,
180+
codebooks.data(),
181+
&di,
182+
&zero,
183+
codebook_cross_products.data() + ofs,
184+
&ki);
185+
ofs += ki * kk;
186+
}
187+
}
188+
155189
namespace {
156190

157191
// TODO
@@ -471,7 +505,6 @@ namespace {
471505
float accumulate_IPs(
472506
const AdditiveQuantizer& aq,
473507
BitstringReader& bs,
474-
const uint8_t* codes,
475508
const float* LUT) {
476509
float accu = 0;
477510
for (int m = 0; m < aq.M; m++) {
@@ -483,6 +516,29 @@ float accumulate_IPs(
483516
return accu;
484517
}
485518

519+
float compute_norm_from_LUT(const AdditiveQuantizer& aq, BitstringReader& bs) {
520+
float accu = 0;
521+
std::vector<int> idx(aq.M);
522+
const float* c = aq.codebook_cross_products.data();
523+
for (int m = 0; m < aq.M; m++) {
524+
size_t nbit = aq.nbits[m];
525+
int i = bs.read(nbit);
526+
size_t K = 1 << nbit;
527+
idx[m] = i;
528+
529+
accu += aq.centroid_norms[aq.codebook_offsets[m] + i];
530+
531+
for (int l = 0; l < m; l++) {
532+
int j = idx[l];
533+
accu += 2 * c[j * K + i];
534+
c += (1 << aq.nbits[l]) * K;
535+
}
536+
}
537+
// FAISS_THROW_IF_NOT(c == aq.codebook_cross_products.data() +
538+
// aq.codebook_cross_products.size());
539+
return accu;
540+
}
541+
486542
} // anonymous namespace
487543

488544
template <>
@@ -491,7 +547,7 @@ float AdditiveQuantizer::
491547
const uint8_t* codes,
492548
const float* LUT) const {
493549
BitstringReader bs(codes, code_size);
494-
return accumulate_IPs(*this, bs, codes, LUT);
550+
return accumulate_IPs(*this, bs, LUT);
495551
}
496552

497553
template <>
@@ -500,7 +556,7 @@ float AdditiveQuantizer::
500556
const uint8_t* codes,
501557
const float* LUT) const {
502558
BitstringReader bs(codes, code_size);
503-
return -accumulate_IPs(*this, bs, codes, LUT);
559+
return -accumulate_IPs(*this, bs, LUT);
504560
}
505561

506562
template <>
@@ -509,7 +565,7 @@ float AdditiveQuantizer::
509565
const uint8_t* codes,
510566
const float* LUT) const {
511567
BitstringReader bs(codes, code_size);
512-
float accu = accumulate_IPs(*this, bs, codes, LUT);
568+
float accu = accumulate_IPs(*this, bs, LUT);
513569
uint32_t norm_i = bs.read(32);
514570
float norm2;
515571
memcpy(&norm2, &norm_i, 4);
@@ -522,7 +578,7 @@ float AdditiveQuantizer::
522578
const uint8_t* codes,
523579
const float* LUT) const {
524580
BitstringReader bs(codes, code_size);
525-
float accu = accumulate_IPs(*this, bs, codes, LUT);
581+
float accu = accumulate_IPs(*this, bs, LUT);
526582
uint32_t norm_i = bs.read(8);
527583
float norm2 = decode_qcint(norm_i);
528584
return norm2 - 2 * accu;
@@ -534,7 +590,7 @@ float AdditiveQuantizer::
534590
const uint8_t* codes,
535591
const float* LUT) const {
536592
BitstringReader bs(codes, code_size);
537-
float accu = accumulate_IPs(*this, bs, codes, LUT);
593+
float accu = accumulate_IPs(*this, bs, LUT);
538594
uint32_t norm_i = bs.read(4);
539595
float norm2 = decode_qcint(norm_i);
540596
return norm2 - 2 * accu;
@@ -546,7 +602,7 @@ float AdditiveQuantizer::
546602
const uint8_t* codes,
547603
const float* LUT) const {
548604
BitstringReader bs(codes, code_size);
549-
float accu = accumulate_IPs(*this, bs, codes, LUT);
605+
float accu = accumulate_IPs(*this, bs, LUT);
550606
uint32_t norm_i = bs.read(8);
551607
float norm2 = decode_qint8(norm_i, norm_min, norm_max);
552608
return norm2 - 2 * accu;
@@ -558,10 +614,23 @@ float AdditiveQuantizer::
558614
const uint8_t* codes,
559615
const float* LUT) const {
560616
BitstringReader bs(codes, code_size);
561-
float accu = accumulate_IPs(*this, bs, codes, LUT);
617+
float accu = accumulate_IPs(*this, bs, LUT);
562618
uint32_t norm_i = bs.read(4);
563619
float norm2 = decode_qint4(norm_i, norm_min, norm_max);
564620
return norm2 - 2 * accu;
565621
}
566622

623+
template <>
624+
float AdditiveQuantizer::
625+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_from_LUT>(
626+
const uint8_t* codes,
627+
const float* LUT) const {
628+
FAISS_THROW_IF_NOT(codebook_cross_products.size() > 0);
629+
BitstringReader bs(codes, code_size);
630+
float accu = accumulate_IPs(*this, bs, LUT);
631+
BitstringReader bs2(codes, code_size);
632+
float norm2 = compute_norm_from_LUT(*this, bs2);
633+
return norm2 - 2 * accu;
634+
}
635+
567636
} // namespace faiss

faiss/impl/AdditiveQuantizer.h

+15-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ struct AdditiveQuantizer : Quantizer {
2929
std::vector<float> codebooks; ///< codebooks
3030

3131
// derived values
32+
/// codebook #1 is stored in rows codebook_offsets[i]:codebook_offsets[i+1]
33+
/// in the codebooks table of size total_codebook_size by d
3234
std::vector<uint64_t> codebook_offsets;
3335
size_t tot_bits = 0; ///< total number of bits (indexes + norms)
3436
size_t norm_bits = 0; ///< bits allocated for the norms
@@ -38,9 +40,19 @@ struct AdditiveQuantizer : Quantizer {
3840
bool verbose = false; ///< verbose during training?
3941
bool is_trained = false; ///< is trained or not
4042

41-
IndexFlat1D qnorm; ///< store and search norms
42-
std::vector<float> norm_tabs; ///< store norms of codebook entries for 4-bit
43-
///< fastscan search
43+
/// auxiliary data for ST_norm_lsq2x4 and ST_norm_rq2x4
44+
/// store norms of codebook entries for 4-bit fastscan
45+
std::vector<float> norm_tabs;
46+
IndexFlat1D qnorm; ///< store and search norms
47+
48+
void compute_codebook_tables();
49+
50+
/// norms of all codebook entries (size total_codebook_size)
51+
std::vector<float> centroid_norms;
52+
53+
/// dot products of all codebook entries with the previous codebooks
54+
/// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1)
55+
std::vector<float> codebook_cross_products;
4456

4557
/// norms and distance matrixes with beam search can get large, so use this
4658
/// to control for the amount of memory that can be allocated

faiss/impl/ResidualQuantizer.cpp

-34
Original file line numberDiff line numberDiff line change
@@ -492,40 +492,6 @@ void ResidualQuantizer::refine_beam(
492492
* Functions using the dot products between codebook entries
493493
*******************************************************************/
494494

495-
void ResidualQuantizer::compute_codebook_tables() {
496-
cent_norms.resize(total_codebook_size);
497-
fvec_norms_L2sqr(
498-
cent_norms.data(), codebooks.data(), d, total_codebook_size);
499-
size_t cross_table_size = 0;
500-
for (int m = 0; m < M; m++) {
501-
size_t K = (size_t)1 << nbits[m];
502-
cross_table_size += K * codebook_offsets[m];
503-
}
504-
codebook_cross_products.resize(cross_table_size);
505-
size_t ofs = 0;
506-
for (int m = 1; m < M; m++) {
507-
FINTEGER ki = (size_t)1 << nbits[m];
508-
FINTEGER kk = codebook_offsets[m];
509-
FINTEGER di = d;
510-
float zero = 0, one = 1;
511-
assert(ofs + ki * kk <= cross_table_size);
512-
sgemm_("Transposed",
513-
"Not transposed",
514-
&ki,
515-
&kk,
516-
&di,
517-
&one,
518-
codebooks.data() + d * kk,
519-
&di,
520-
codebooks.data(),
521-
&di,
522-
&zero,
523-
codebook_cross_products.data() + ofs,
524-
&ki);
525-
ofs += ki * kk;
526-
}
527-
}
528-
529495
void ResidualQuantizer::refine_beam_LUT(
530496
size_t n,
531497
const float* query_norms, // size n

faiss/impl/ResidualQuantizer.h

-10
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,6 @@ struct ResidualQuantizer : AdditiveQuantizer {
143143
* @param beam_size if != -1, override the beam size
144144
*/
145145
size_t memory_per_point(int beam_size = -1) const;
146-
147-
/** Cross products used in codebook tables used for beam_LUT = 1
148-
*/
149-
void compute_codebook_tables();
150-
151-
/// dot products of all codebook entries with the previous codebooks
152-
/// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1)
153-
std::vector<float> codebook_cross_products;
154-
/// norms of all codebook entries (size total_codebook_size)
155-
std::vector<float> cent_norms;
156146
};
157147

158148
} // namespace faiss

faiss/impl/residual_quantizer_encode_steps.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@ void refine_beam_LUT_mp(
809809
rq.codebook_offsets.data(),
810810
query_cp + rq.codebook_offsets[m],
811811
rq.total_codebook_size,
812-
rq.cent_norms.data() + rq.codebook_offsets[m],
812+
rq.centroid_norms.data() + rq.codebook_offsets[m],
813813
m,
814814
codes_ptr,
815815
distances_ptr,

tests/test_residual_quantizer.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,37 @@ def test_search_decompress(self):
457457
# recalls are {1: 0.05, 10: 0.37, 100: 0.37}
458458
self.assertGreater(recalls[10], 0.35)
459459

460+
def do_exact_search_equiv(self, norm_type):
461+
""" searching with this normalization should yield
462+
exactly the same results as decompression (because the
463+
norms are exact) """
464+
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
465+
466+
# decompresses by default
467+
ir = faiss.IndexResidualQuantizer(ds.d, 3, 6)
468+
ir.rq.train_type = faiss.ResidualQuantizer.Train_default
469+
ir.train(ds.get_train())
470+
ir.add(ds.get_database())
471+
Dref, Iref = ir.search(ds.get_queries(), 10)
472+
473+
ir2 = faiss.IndexResidualQuantizer(
474+
ds.d, 3, 6, faiss.METRIC_L2, norm_type)
475+
476+
# assumes training is reproducible
477+
ir2.rq.train_type = faiss.ResidualQuantizer.Train_default
478+
ir2.train(ds.get_train())
479+
ir2.add(ds.get_database())
480+
D, I = ir2.search(ds.get_queries(), 10)
481+
482+
np.testing.assert_allclose(D, Dref, atol=1e-5)
483+
np.testing.assert_array_equal(I, Iref)
484+
485+
def test_exact_equiv_norm_float(self):
486+
self.do_exact_search_equiv(faiss.AdditiveQuantizer.ST_norm_float)
487+
488+
def test_exact_equiv_norm_from_LUT(self):
489+
self.do_exact_search_equiv(faiss.AdditiveQuantizer.ST_norm_from_LUT)
490+
460491
def test_reestimate_codebook(self):
461492
ds = datasets.SyntheticDataset(32, 1000, 1000, 100)
462493

@@ -858,6 +889,9 @@ def test_norm_cqint(self):
858889
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint8)
859890
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_cqint4)
860891

892+
def test_norm_from_LUT(self):
893+
self.do_test_accuracy(True, faiss.AdditiveQuantizer.ST_norm_from_LUT)
894+
861895
def test_factory(self):
862896
index = faiss.index_factory(12, "IVF1024,RQ8x8_Nfloat")
863897
self.assertEqual(index.nlist, 1024)
@@ -1105,7 +1139,7 @@ def test_precomp(self):
11051139
ofs += kk * K
11061140
np.testing.assert_allclose(py_table, cpp_table, atol=1e-5)
11071141

1108-
cent_norms = faiss.vector_to_array(rq.cent_norms)
1142+
cent_norms = faiss.vector_to_array(rq.centroid_norms)
11091143
np.testing.assert_array_almost_equal(
11101144
np.hstack(cent_norms_ref), cent_norms, decimal=5)
11111145

0 commit comments

Comments
 (0)