Skip to content

Commit 5a0fdac

Browse files
mdouzefacebook-github-bot
authored andcommitted
reduce cross table size (#3012)
Summary: Pull Request resolved: #3012 The cross-tables for codebook construction contained the dot products between codebook entries, which is not necessary (and caused OOMs in some cases). This diff computes only the off-diagonal blocks. Differential Revision: D48448615 fbshipit-source-id: 28d17ccff22f458f3fb96e2a51ce780c80621f3d
1 parent f40952c commit 5a0fdac

5 files changed

+145
-132
lines changed

faiss/impl/ResidualQuantizer.cpp

+19-12
Original file line numberDiff line numberDiff line change
@@ -493,29 +493,36 @@ void ResidualQuantizer::refine_beam(
493493
*******************************************************************/
494494

495495
void ResidualQuantizer::compute_codebook_tables() {
496-
codebook_cross_products.resize(total_codebook_size * total_codebook_size);
497496
cent_norms.resize(total_codebook_size);
498-
// stricly speaking we could use ssyrk
499-
{
500-
FINTEGER ni = total_codebook_size;
497+
fvec_norms_L2sqr(
498+
cent_norms.data(), codebooks.data(), d, total_codebook_size);
499+
size_t cross_table_size = 0;
500+
for (int m = 0; m < M; m++) {
501+
size_t K = (size_t)1 << nbits[m];
502+
cross_table_size += K * codebook_offsets[m];
503+
}
504+
codebook_cross_products.resize(cross_table_size);
505+
size_t ofs = 0;
506+
for (int m = 1; m < M; m++) {
507+
FINTEGER ki = (size_t)1 << nbits[m];
508+
FINTEGER kk = codebook_offsets[m];
501509
FINTEGER di = d;
502510
float zero = 0, one = 1;
511+
assert(ofs + ki * kk <= cross_table_size);
503512
sgemm_("Transposed",
504513
"Not transposed",
505-
&ni,
506-
&ni,
514+
&ki,
515+
&kk,
507516
&di,
508517
&one,
509-
codebooks.data(),
518+
codebooks.data() + d * kk,
510519
&di,
511520
codebooks.data(),
512521
&di,
513522
&zero,
514-
codebook_cross_products.data(),
515-
&ni);
516-
}
517-
for (size_t i = 0; i < total_codebook_size; i++) {
518-
cent_norms[i] = codebook_cross_products[i + i * total_codebook_size];
523+
codebook_cross_products.data() + ofs,
524+
&ki);
525+
ofs += ki * kk;
519526
}
520527
}
521528

faiss/impl/ResidualQuantizer.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,10 @@ struct ResidualQuantizer : AdditiveQuantizer {
148148
*/
149149
void compute_codebook_tables();
150150

151-
/// dot products of all codebook vectors with each other
152-
/// size total_codebook_size * total_codebook_size
151+
/// dot products of all codebook entries with the previous codebooks
152+
/// size sum(codebook_offsets[m] * 2^nbits[m], m=0..M-1)
153153
std::vector<float> codebook_cross_products;
154-
/// norms of all vectors
154+
/// norms of all codebook entries (size total_codebook_size)
155155
std::vector<float> cent_norms;
156156
};
157157

faiss/impl/residual_quantizer_encode_steps.cpp

+103-98
Original file line numberDiff line numberDiff line change
@@ -384,11 +384,11 @@ void beam_search_encode_step_tab(
384384
size_t n,
385385
size_t beam_size, // input sizes
386386
const float* codebook_cross_norms, // size K * ldc
387-
size_t ldc, // >= K
388-
const uint64_t* codebook_offsets, // m
389-
const float* query_cp, // size n * ldqc
390-
size_t ldqc, // >= K
391-
const float* cent_norms_i, // size K
387+
size_t ldc,
388+
const uint64_t* codebook_offsets, // m
389+
const float* query_cp, // size n * ldqc
390+
size_t ldqc, // >= K
391+
const float* cent_norms_i, // size K
392392
size_t m,
393393
const int32_t* codes, // n * beam_size * m
394394
const float* distances, // n * beam_size
@@ -412,35 +412,38 @@ void beam_search_encode_step_tab(
412412
cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
413413
}
414414

415-
/*
415+
bool use_baseline_implementation = false;
416+
416417
// This is the baseline implementation. Its primary flaw
417418
// that it writes way too many info to the temporary buffer
418419
// called dp.
419420
//
420421
// This baseline code is kept intentionally because it is easy to
421422
// understand what an optimized version optimizes exactly.
422423
//
423-
for (size_t b = 0; b < beam_size; b++) {
424-
std::vector<float> dp(K);
425-
426-
for (size_t m1 = 0; m1 < m; m1++) {
427-
size_t c = codes_i[b * m + m1];
428-
const float* cb =
429-
&codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
430-
fvec_add(K, cb, dp.data(), dp.data());
431-
}
424+
if (use_baseline_implementation) {
425+
for (size_t b = 0; b < beam_size; b++) {
426+
std::vector<float> dp(K);
432427

433-
for (size_t k = 0; k < K; k++) {
434-
cent_distances[b * K + k] =
435-
distances_i[b] + cd_common[k] + 2 * dp[k];
428+
for (size_t m1 = 0; m1 < m; m1++) {
429+
size_t c = codes_i[b * m + m1];
430+
const float* cb =
431+
&codebook_cross_norms
432+
[(codebook_offsets[m1] + c) * ldc];
433+
fvec_add(K, cb, dp.data(), dp.data());
434+
}
435+
436+
for (size_t k = 0; k < K; k++) {
437+
cent_distances[b * K + k] =
438+
distances_i[b] + cd_common[k] + 2 * dp[k];
439+
}
436440
}
437-
}
438-
*/
439441

440-
// An optimized implementation that avoids using a temporary buffer
441-
// and does the accumulation in registers.
442+
} else {
443+
// An optimized implementation that avoids using a temporary buffer
444+
// and does the accumulation in registers.
442445

443-
// Compute a sum of NK AQ codes.
446+
// Compute a sum of NK AQ codes.
444447
#define ACCUM_AND_FINALIZE_TAB(NK) \
445448
case NK: \
446449
for (size_t b = 0; b < beam_size; b++) { \
@@ -457,51 +460,52 @@ void beam_search_encode_step_tab(
457460
} \
458461
break;
459462

460-
// this version contains many switch-case scenarios, but
461-
// they won't affect branch predictor.
462-
switch (m) {
463-
case 0:
464-
// trivial case
465-
for (size_t b = 0; b < beam_size; b++) {
466-
for (size_t k = 0; k < K; k++) {
467-
cent_distances[b * K + k] =
468-
distances_i[b] + cd_common[k];
463+
// this version contains many switch-case scenarios, but
464+
// they won't affect branch predictor.
465+
switch (m) {
466+
case 0:
467+
// trivial case
468+
for (size_t b = 0; b < beam_size; b++) {
469+
for (size_t k = 0; k < K; k++) {
470+
cent_distances[b * K + k] =
471+
distances_i[b] + cd_common[k];
472+
}
469473
}
470-
}
471-
break;
472-
473-
ACCUM_AND_FINALIZE_TAB(1)
474-
ACCUM_AND_FINALIZE_TAB(2)
475-
ACCUM_AND_FINALIZE_TAB(3)
476-
ACCUM_AND_FINALIZE_TAB(4)
477-
ACCUM_AND_FINALIZE_TAB(5)
478-
ACCUM_AND_FINALIZE_TAB(6)
479-
ACCUM_AND_FINALIZE_TAB(7)
480-
481-
default: {
482-
// m >= 8 case.
483-
484-
// A temporary buffer has to be used due to the lack of
485-
// registers. But we'll try to accumulate up to 8 AQ codes in
486-
// registers and issue a single write operation to the buffer,
487-
// while the baseline does no accumulation. So, the number of
488-
// write operations to the temporary buffer is reduced 8x.
489-
490-
// allocate a temporary buffer
491-
std::vector<float> dp(K);
492-
493-
for (size_t b = 0; b < beam_size; b++) {
494-
// Initialize it. Compute a sum of first 8 AQ codes
495-
// because m >= 8 .
496-
accum_and_store_tab<8, 4>(
497-
m,
498-
codebook_cross_norms,
499-
codebook_offsets,
500-
codes_i,
501-
b,
502-
ldc,
503-
K,
504-
dp.data());
474+
break;
475+
476+
ACCUM_AND_FINALIZE_TAB(1)
477+
ACCUM_AND_FINALIZE_TAB(2)
478+
ACCUM_AND_FINALIZE_TAB(3)
479+
ACCUM_AND_FINALIZE_TAB(4)
480+
ACCUM_AND_FINALIZE_TAB(5)
481+
ACCUM_AND_FINALIZE_TAB(6)
482+
ACCUM_AND_FINALIZE_TAB(7)
483+
484+
default: {
485+
// m >= 8 case.
486+
487+
// A temporary buffer has to be used due to the lack of
488+
// registers. But we'll try to accumulate up to 8 AQ codes
489+
// in registers and issue a single write operation to the
490+
// buffer, while the baseline does no accumulation. So, the
491+
// number of write operations to the temporary buffer is
492+
// reduced 8x.
493+
494+
// allocate a temporary buffer
495+
std::vector<float> dp(K);
496+
497+
for (size_t b = 0; b < beam_size; b++) {
498+
// Initialize it. Compute a sum of first 8 AQ codes
499+
// because m >= 8 .
500+
accum_and_store_tab<8, 4>(
501+
m,
502+
codebook_cross_norms,
503+
codebook_offsets,
504+
codes_i,
505+
b,
506+
ldc,
507+
K,
508+
dp.data());
505509

506510
#define ACCUM_AND_ADD_TAB(NK) \
507511
case NK: \
@@ -516,37 +520,37 @@ void beam_search_encode_step_tab(
516520
dp.data()); \
517521
break;
518522

519-
// accumulate up to 8 additional AQ codes into
520-
// a temporary buffer
521-
for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
522-
size_t m_left = m - im;
523-
if (m_left > 8) {
524-
m_left = 8;
523+
// accumulate up to 8 additional AQ codes into
524+
// a temporary buffer
525+
for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
526+
size_t m_left = m - im;
527+
if (m_left > 8) {
528+
m_left = 8;
529+
}
530+
531+
switch (m_left) {
532+
ACCUM_AND_ADD_TAB(1)
533+
ACCUM_AND_ADD_TAB(2)
534+
ACCUM_AND_ADD_TAB(3)
535+
ACCUM_AND_ADD_TAB(4)
536+
ACCUM_AND_ADD_TAB(5)
537+
ACCUM_AND_ADD_TAB(6)
538+
ACCUM_AND_ADD_TAB(7)
539+
ACCUM_AND_ADD_TAB(8)
540+
}
525541
}
526542

527-
switch (m_left) {
528-
ACCUM_AND_ADD_TAB(1)
529-
ACCUM_AND_ADD_TAB(2)
530-
ACCUM_AND_ADD_TAB(3)
531-
ACCUM_AND_ADD_TAB(4)
532-
ACCUM_AND_ADD_TAB(5)
533-
ACCUM_AND_ADD_TAB(6)
534-
ACCUM_AND_ADD_TAB(7)
535-
ACCUM_AND_ADD_TAB(8)
543+
// done. finalize the result
544+
for (size_t k = 0; k < K; k++) {
545+
cent_distances[b * K + k] =
546+
distances_i[b] + cd_common[k] + 2 * dp[k];
536547
}
537548
}
538-
539-
// done. finalize the result
540-
for (size_t k = 0; k < K; k++) {
541-
cent_distances[b * K + k] =
542-
distances_i[b] + cd_common[k] + 2 * dp[k];
543-
}
544549
}
545550
}
546-
}
547-
548-
// the optimized implementation ends here
549551

552+
// the optimized implementation ends here
553+
}
550554
using C = CMax<float, int>;
551555
int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
552556
float* new_distances_i = new_distances + i * new_beam_size;
@@ -784,6 +788,7 @@ void refine_beam_LUT_mp(
784788
// main loop
785789
size_t codes_size = 0;
786790
size_t distances_size = 0;
791+
size_t cross_ofs = 0;
787792
for (int m = 0; m < rq.M; m++) {
788793
int K = 1 << rq.nbits[m];
789794

@@ -792,13 +797,15 @@ void refine_beam_LUT_mp(
792797

793798
codes_size = n * new_beam_size * (m + 1);
794799
distances_size = n * new_beam_size;
795-
800+
FAISS_THROW_IF_NOT(
801+
cross_ofs + rq.codebook_offsets[m] * K <=
802+
rq.codebook_cross_products.size());
796803
beam_search_encode_step_tab(
797804
K,
798805
n,
799806
beam_size,
800-
rq.codebook_cross_products.data() + rq.codebook_offsets[m],
801-
rq.total_codebook_size,
807+
rq.codebook_cross_products.data() + cross_ofs,
808+
K,
802809
rq.codebook_offsets.data(),
803810
query_cp + rq.codebook_offsets[m],
804811
rq.total_codebook_size,
@@ -810,7 +817,7 @@ void refine_beam_LUT_mp(
810817
new_codes_ptr,
811818
new_distances_ptr,
812819
rq.approx_topk_mode);
813-
820+
cross_ofs += rq.codebook_offsets[m] * K;
814821
std::swap(codes_ptr, new_codes_ptr);
815822
std::swap(distances_ptr, new_distances_ptr);
816823

@@ -830,7 +837,6 @@ void refine_beam_LUT_mp(
830837
beam_size);
831838
}
832839
}
833-
834840
if (out_codes) {
835841
memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
836842
}
@@ -903,8 +909,7 @@ void compute_codes_add_centroids_mp_lut1(
903909
pool.distances.resize(rq.max_beam_size * n);
904910

905911
FAISS_THROW_IF_NOT_MSG(
906-
rq.codebook_cross_products.size() ==
907-
rq.total_codebook_size * rq.total_codebook_size,
912+
rq.M == 1 || rq.codebook_cross_products.size() > 0,
908913
"call compute_codebook_tables first");
909914

910915
pool.query_norms.resize(n);

faiss/impl/residual_quantizer_encode_steps.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ void beam_search_encode_step_tab(
7070
size_t K,
7171
size_t n,
7272
size_t beam_size, // input sizes
73-
const float* codebook_cross_norms, // size K * ldc
74-
size_t ldc, // >= K
73+
const float* codebook_cross_norms, // size ldc * K
74+
size_t ldc, // >= codebook_offsets[m]
7575
const uint64_t* codebook_offsets, // m
7676
const float* query_cp, // size n * ldqc
7777
size_t ldqc, // >= K

0 commit comments

Comments
 (0)