@@ -152,6 +152,40 @@ void AdditiveQuantizer::train_norm(size_t n, const float* norms) {
152
152
}
153
153
}
154
154
155
+ void AdditiveQuantizer::compute_codebook_tables () {
156
+ centroid_norms.resize (total_codebook_size);
157
+ fvec_norms_L2sqr (
158
+ centroid_norms.data (), codebooks.data (), d, total_codebook_size);
159
+ size_t cross_table_size = 0 ;
160
+ for (int m = 0 ; m < M; m++) {
161
+ size_t K = (size_t )1 << nbits[m];
162
+ cross_table_size += K * codebook_offsets[m];
163
+ }
164
+ codebook_cross_products.resize (cross_table_size);
165
+ size_t ofs = 0 ;
166
+ for (int m = 1 ; m < M; m++) {
167
+ FINTEGER ki = (size_t )1 << nbits[m];
168
+ FINTEGER kk = codebook_offsets[m];
169
+ FINTEGER di = d;
170
+ float zero = 0 , one = 1 ;
171
+ assert (ofs + ki * kk <= cross_table_size);
172
+ sgemm_ (" Transposed" ,
173
+ " Not transposed" ,
174
+ &ki,
175
+ &kk,
176
+ &di,
177
+ &one,
178
+ codebooks.data () + d * kk,
179
+ &di,
180
+ codebooks.data (),
181
+ &di,
182
+ &zero,
183
+ codebook_cross_products.data () + ofs,
184
+ &ki);
185
+ ofs += ki * kk;
186
+ }
187
+ }
188
+
155
189
namespace {
156
190
157
191
// TODO
@@ -471,7 +505,6 @@ namespace {
471
505
float accumulate_IPs (
472
506
const AdditiveQuantizer& aq,
473
507
BitstringReader& bs,
474
- const uint8_t * codes,
475
508
const float * LUT) {
476
509
float accu = 0 ;
477
510
for (int m = 0 ; m < aq.M ; m++) {
@@ -483,6 +516,29 @@ float accumulate_IPs(
483
516
return accu;
484
517
}
485
518
519
+ float compute_norm_from_LUT (const AdditiveQuantizer& aq, BitstringReader& bs) {
520
+ float accu = 0 ;
521
+ std::vector<int > idx (aq.M );
522
+ const float * c = aq.codebook_cross_products .data ();
523
+ for (int m = 0 ; m < aq.M ; m++) {
524
+ size_t nbit = aq.nbits [m];
525
+ int i = bs.read (nbit);
526
+ size_t K = 1 << nbit;
527
+ idx[m] = i;
528
+
529
+ accu += aq.centroid_norms [aq.codebook_offsets [m] + i];
530
+
531
+ for (int l = 0 ; l < m; l++) {
532
+ int j = idx[l];
533
+ accu += 2 * c[j * K + i];
534
+ c += (1 << aq.nbits [l]) * K;
535
+ }
536
+ }
537
+ // FAISS_THROW_IF_NOT(c == aq.codebook_cross_products.data() +
538
+ // aq.codebook_cross_products.size());
539
+ return accu;
540
+ }
541
+
486
542
} // anonymous namespace
487
543
488
544
template <>
@@ -491,7 +547,7 @@ float AdditiveQuantizer::
491
547
const uint8_t * codes,
492
548
const float * LUT) const {
493
549
BitstringReader bs (codes, code_size);
494
- return accumulate_IPs (*this , bs, codes, LUT);
550
+ return accumulate_IPs (*this , bs, LUT);
495
551
}
496
552
497
553
template <>
@@ -500,7 +556,7 @@ float AdditiveQuantizer::
500
556
const uint8_t * codes,
501
557
const float * LUT) const {
502
558
BitstringReader bs (codes, code_size);
503
- return -accumulate_IPs (*this , bs, codes, LUT);
559
+ return -accumulate_IPs (*this , bs, LUT);
504
560
}
505
561
506
562
template <>
@@ -509,7 +565,7 @@ float AdditiveQuantizer::
509
565
const uint8_t * codes,
510
566
const float * LUT) const {
511
567
BitstringReader bs (codes, code_size);
512
- float accu = accumulate_IPs (*this , bs, codes, LUT);
568
+ float accu = accumulate_IPs (*this , bs, LUT);
513
569
uint32_t norm_i = bs.read (32 );
514
570
float norm2;
515
571
memcpy (&norm2, &norm_i, 4 );
@@ -522,7 +578,7 @@ float AdditiveQuantizer::
522
578
const uint8_t * codes,
523
579
const float * LUT) const {
524
580
BitstringReader bs (codes, code_size);
525
- float accu = accumulate_IPs (*this , bs, codes, LUT);
581
+ float accu = accumulate_IPs (*this , bs, LUT);
526
582
uint32_t norm_i = bs.read (8 );
527
583
float norm2 = decode_qcint (norm_i);
528
584
return norm2 - 2 * accu;
@@ -534,7 +590,7 @@ float AdditiveQuantizer::
534
590
const uint8_t * codes,
535
591
const float * LUT) const {
536
592
BitstringReader bs (codes, code_size);
537
- float accu = accumulate_IPs (*this , bs, codes, LUT);
593
+ float accu = accumulate_IPs (*this , bs, LUT);
538
594
uint32_t norm_i = bs.read (4 );
539
595
float norm2 = decode_qcint (norm_i);
540
596
return norm2 - 2 * accu;
@@ -546,7 +602,7 @@ float AdditiveQuantizer::
546
602
const uint8_t * codes,
547
603
const float * LUT) const {
548
604
BitstringReader bs (codes, code_size);
549
- float accu = accumulate_IPs (*this , bs, codes, LUT);
605
+ float accu = accumulate_IPs (*this , bs, LUT);
550
606
uint32_t norm_i = bs.read (8 );
551
607
float norm2 = decode_qint8 (norm_i, norm_min, norm_max);
552
608
return norm2 - 2 * accu;
@@ -558,10 +614,23 @@ float AdditiveQuantizer::
558
614
const uint8_t * codes,
559
615
const float * LUT) const {
560
616
BitstringReader bs (codes, code_size);
561
- float accu = accumulate_IPs (*this , bs, codes, LUT);
617
+ float accu = accumulate_IPs (*this , bs, LUT);
562
618
uint32_t norm_i = bs.read (4 );
563
619
float norm2 = decode_qint4 (norm_i, norm_min, norm_max);
564
620
return norm2 - 2 * accu;
565
621
}
566
622
623
+ template <>
624
+ float AdditiveQuantizer::
625
+ compute_1_distance_LUT<false , AdditiveQuantizer::ST_norm_from_LUT>(
626
+ const uint8_t * codes,
627
+ const float * LUT) const {
628
+ FAISS_THROW_IF_NOT (codebook_cross_products.size () > 0 );
629
+ BitstringReader bs (codes, code_size);
630
+ float accu = accumulate_IPs (*this , bs, LUT);
631
+ BitstringReader bs2 (codes, code_size);
632
+ float norm2 = compute_norm_from_LUT (*this , bs2);
633
+ return norm2 - 2 * accu;
634
+ }
635
+
567
636
} // namespace faiss
0 commit comments