@@ -384,11 +384,11 @@ void beam_search_encode_step_tab(
384
384
size_t n,
385
385
size_t beam_size, // input sizes
386
386
const float * codebook_cross_norms, // size K * ldc
387
- size_t ldc, // >= K
388
- const uint64_t * codebook_offsets, // m
389
- const float * query_cp, // size n * ldqc
390
- size_t ldqc, // >= K
391
- const float * cent_norms_i, // size K
387
+ size_t ldc,
388
+ const uint64_t * codebook_offsets, // m
389
+ const float * query_cp, // size n * ldqc
390
+ size_t ldqc, // >= K
391
+ const float * cent_norms_i, // size K
392
392
size_t m,
393
393
const int32_t * codes, // n * beam_size * m
394
394
const float * distances, // n * beam_size
@@ -412,35 +412,38 @@ void beam_search_encode_step_tab(
412
412
cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
413
413
}
414
414
415
- /*
415
+ bool use_baseline_implementation = false ;
416
+
416
417
// This is the baseline implementation. Its primary flaw
417
418
// that it writes way too many info to the temporary buffer
418
419
// called dp.
419
420
//
420
421
// This baseline code is kept intentionally because it is easy to
421
422
// understand what an optimized version optimizes exactly.
422
423
//
423
- for (size_t b = 0; b < beam_size; b++) {
424
- std::vector<float> dp(K);
425
-
426
- for (size_t m1 = 0; m1 < m; m1++) {
427
- size_t c = codes_i[b * m + m1];
428
- const float* cb =
429
- &codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
430
- fvec_add(K, cb, dp.data(), dp.data());
431
- }
424
+ if (use_baseline_implementation) {
425
+ for (size_t b = 0 ; b < beam_size; b++) {
426
+ std::vector<float > dp (K);
432
427
433
- for (size_t k = 0; k < K; k++) {
434
- cent_distances[b * K + k] =
435
- distances_i[b] + cd_common[k] + 2 * dp[k];
428
+ for (size_t m1 = 0 ; m1 < m; m1++) {
429
+ size_t c = codes_i[b * m + m1];
430
+ const float * cb =
431
+ &codebook_cross_norms
432
+ [(codebook_offsets[m1] + c) * ldc];
433
+ fvec_add (K, cb, dp.data (), dp.data ());
434
+ }
435
+
436
+ for (size_t k = 0 ; k < K; k++) {
437
+ cent_distances[b * K + k] =
438
+ distances_i[b] + cd_common[k] + 2 * dp[k];
439
+ }
436
440
}
437
- }
438
- */
439
441
440
- // An optimized implementation that avoids using a temporary buffer
441
- // and does the accumulation in registers.
442
+ } else {
443
+ // An optimized implementation that avoids using a temporary buffer
444
+ // and does the accumulation in registers.
442
445
443
- // Compute a sum of NK AQ codes.
446
+ // Compute a sum of NK AQ codes.
444
447
#define ACCUM_AND_FINALIZE_TAB (NK ) \
445
448
case NK: \
446
449
for (size_t b = 0 ; b < beam_size; b++) { \
@@ -457,51 +460,52 @@ void beam_search_encode_step_tab(
457
460
} \
458
461
break ;
459
462
460
- // this version contains many switch-case scenarios, but
461
- // they won't affect branch predictor.
462
- switch (m) {
463
- case 0 :
464
- // trivial case
465
- for (size_t b = 0 ; b < beam_size; b++) {
466
- for (size_t k = 0 ; k < K; k++) {
467
- cent_distances[b * K + k] =
468
- distances_i[b] + cd_common[k];
463
+ // this version contains many switch-case scenarios, but
464
+ // they won't affect branch predictor.
465
+ switch (m) {
466
+ case 0 :
467
+ // trivial case
468
+ for (size_t b = 0 ; b < beam_size; b++) {
469
+ for (size_t k = 0 ; k < K; k++) {
470
+ cent_distances[b * K + k] =
471
+ distances_i[b] + cd_common[k];
472
+ }
469
473
}
470
- }
471
- break ;
472
-
473
- ACCUM_AND_FINALIZE_TAB (1 )
474
- ACCUM_AND_FINALIZE_TAB (2 )
475
- ACCUM_AND_FINALIZE_TAB (3 )
476
- ACCUM_AND_FINALIZE_TAB (4 )
477
- ACCUM_AND_FINALIZE_TAB (5 )
478
- ACCUM_AND_FINALIZE_TAB (6 )
479
- ACCUM_AND_FINALIZE_TAB ( 7 )
480
-
481
- default : {
482
- // m >= 8 case.
483
-
484
- // A temporary buffer has to be used due to the lack of
485
- // registers. But we'll try to accumulate up to 8 AQ codes in
486
- // registers and issue a single write operation to the buffer,
487
- // while the baseline does no accumulation. So, the number of
488
- // write operations to the temporary buffer is reduced 8x.
489
-
490
- // allocate a temporary buffer
491
- std::vector<float > dp (K);
492
-
493
- for (size_t b = 0 ; b < beam_size; b++) {
494
- // Initialize it. Compute a sum of first 8 AQ codes
495
- // because m >= 8 .
496
- accum_and_store_tab<8 , 4 >(
497
- m,
498
- codebook_cross_norms,
499
- codebook_offsets,
500
- codes_i,
501
- b,
502
- ldc,
503
- K,
504
- dp.data ());
474
+ break ;
475
+
476
+ ACCUM_AND_FINALIZE_TAB ( 1 )
477
+ ACCUM_AND_FINALIZE_TAB (2 )
478
+ ACCUM_AND_FINALIZE_TAB (3 )
479
+ ACCUM_AND_FINALIZE_TAB (4 )
480
+ ACCUM_AND_FINALIZE_TAB (5 )
481
+ ACCUM_AND_FINALIZE_TAB (6 )
482
+ ACCUM_AND_FINALIZE_TAB (7 )
483
+
484
+ default : {
485
+ // m >= 8 case.
486
+
487
+ // A temporary buffer has to be used due to the lack of
488
+ // registers. But we'll try to accumulate up to 8 AQ codes
489
+ // in registers and issue a single write operation to the
490
+ // buffer, while the baseline does no accumulation. So, the
491
+ // number of write operations to the temporary buffer is
492
+ // reduced 8x.
493
+
494
+ // allocate a temporary buffer
495
+ std::vector<float > dp (K);
496
+
497
+ for (size_t b = 0 ; b < beam_size; b++) {
498
+ // Initialize it. Compute a sum of first 8 AQ codes
499
+ // because m >= 8 .
500
+ accum_and_store_tab<8 , 4 >(
501
+ m,
502
+ codebook_cross_norms,
503
+ codebook_offsets,
504
+ codes_i,
505
+ b,
506
+ ldc,
507
+ K,
508
+ dp.data ());
505
509
506
510
#define ACCUM_AND_ADD_TAB (NK ) \
507
511
case NK: \
@@ -516,37 +520,37 @@ void beam_search_encode_step_tab(
516
520
dp.data ()); \
517
521
break ;
518
522
519
- // accumulate up to 8 additional AQ codes into
520
- // a temporary buffer
521
- for (size_t im = 8 ; im < ((m + 7 ) / 8 ) * 8 ; im += 8 ) {
522
- size_t m_left = m - im;
523
- if (m_left > 8 ) {
524
- m_left = 8 ;
523
+ // accumulate up to 8 additional AQ codes into
524
+ // a temporary buffer
525
+ for (size_t im = 8 ; im < ((m + 7 ) / 8 ) * 8 ; im += 8 ) {
526
+ size_t m_left = m - im;
527
+ if (m_left > 8 ) {
528
+ m_left = 8 ;
529
+ }
530
+
531
+ switch (m_left) {
532
+ ACCUM_AND_ADD_TAB (1 )
533
+ ACCUM_AND_ADD_TAB (2 )
534
+ ACCUM_AND_ADD_TAB (3 )
535
+ ACCUM_AND_ADD_TAB (4 )
536
+ ACCUM_AND_ADD_TAB (5 )
537
+ ACCUM_AND_ADD_TAB (6 )
538
+ ACCUM_AND_ADD_TAB (7 )
539
+ ACCUM_AND_ADD_TAB (8 )
540
+ }
525
541
}
526
542
527
- switch (m_left) {
528
- ACCUM_AND_ADD_TAB (1 )
529
- ACCUM_AND_ADD_TAB (2 )
530
- ACCUM_AND_ADD_TAB (3 )
531
- ACCUM_AND_ADD_TAB (4 )
532
- ACCUM_AND_ADD_TAB (5 )
533
- ACCUM_AND_ADD_TAB (6 )
534
- ACCUM_AND_ADD_TAB (7 )
535
- ACCUM_AND_ADD_TAB (8 )
543
+ // done. finalize the result
544
+ for (size_t k = 0 ; k < K; k++) {
545
+ cent_distances[b * K + k] =
546
+ distances_i[b] + cd_common[k] + 2 * dp[k];
536
547
}
537
548
}
538
-
539
- // done. finalize the result
540
- for (size_t k = 0 ; k < K; k++) {
541
- cent_distances[b * K + k] =
542
- distances_i[b] + cd_common[k] + 2 * dp[k];
543
- }
544
549
}
545
550
}
546
- }
547
-
548
- // the optimized implementation ends here
549
551
552
+ // the optimized implementation ends here
553
+ }
550
554
using C = CMax<float , int >;
551
555
int32_t * new_codes_i = new_codes + i * (m + 1 ) * new_beam_size;
552
556
float * new_distances_i = new_distances + i * new_beam_size;
@@ -784,6 +788,7 @@ void refine_beam_LUT_mp(
784
788
// main loop
785
789
size_t codes_size = 0 ;
786
790
size_t distances_size = 0 ;
791
+ size_t cross_ofs = 0 ;
787
792
for (int m = 0 ; m < rq.M ; m++) {
788
793
int K = 1 << rq.nbits [m];
789
794
@@ -792,13 +797,15 @@ void refine_beam_LUT_mp(
792
797
793
798
codes_size = n * new_beam_size * (m + 1 );
794
799
distances_size = n * new_beam_size;
795
-
800
+ FAISS_THROW_IF_NOT (
801
+ cross_ofs + rq.codebook_offsets [m] * K <=
802
+ rq.codebook_cross_products .size ());
796
803
beam_search_encode_step_tab (
797
804
K,
798
805
n,
799
806
beam_size,
800
- rq.codebook_cross_products .data () + rq. codebook_offsets [m] ,
801
- rq. total_codebook_size ,
807
+ rq.codebook_cross_products .data () + cross_ofs ,
808
+ K ,
802
809
rq.codebook_offsets .data (),
803
810
query_cp + rq.codebook_offsets [m],
804
811
rq.total_codebook_size ,
@@ -810,7 +817,7 @@ void refine_beam_LUT_mp(
810
817
new_codes_ptr,
811
818
new_distances_ptr,
812
819
rq.approx_topk_mode );
813
-
820
+ cross_ofs += rq. codebook_offsets [m] * K;
814
821
std::swap (codes_ptr, new_codes_ptr);
815
822
std::swap (distances_ptr, new_distances_ptr);
816
823
@@ -830,7 +837,6 @@ void refine_beam_LUT_mp(
830
837
beam_size);
831
838
}
832
839
}
833
-
834
840
if (out_codes) {
835
841
memcpy (out_codes, codes_ptr, codes_size * sizeof (*codes_ptr));
836
842
}
@@ -903,8 +909,7 @@ void compute_codes_add_centroids_mp_lut1(
903
909
pool.distances .resize (rq.max_beam_size * n);
904
910
905
911
FAISS_THROW_IF_NOT_MSG (
906
- rq.codebook_cross_products .size () ==
907
- rq.total_codebook_size * rq.total_codebook_size ,
912
+ rq.M == 1 || rq.codebook_cross_products .size () > 0 ,
908
913
" call compute_codebook_tables first" );
909
914
910
915
pool.query_norms .resize (n);
0 commit comments