Skip to content

Commit 67d8727

Browse files
mdouzefacebook-github-bot
authored andcommitted
Clean up batch comments + obey IO_FLAG_SKIP_PRECOMPUTE_TABLE (facebookresearch#3013)
Summary: Pull Request resolved: facebookresearch#3013 To avoid OOM when loading some RCQs, don't precompute cross product tables when io_flags contains bit IO_FLAG_SKIP_PRECOMPUTE_TABLE Reviewed By: pemazare Differential Revision: D48448616 fbshipit-source-id: a261259f1fb583aa358d6b6c42d9b851e9729247
1 parent 82352dd commit 67d8727

8 files changed

+232
-227
lines changed

faiss/IndexAdditiveQuantizer.cpp

+155-157
Large diffs are not rendered by default.

faiss/impl/AdditiveQuantizer.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,8 @@ void AdditiveQuantizer::compute_LUT(
370370

371371
namespace {
372372

373+
/* compute inner products of one query with all centroids, given a look-up
374+
* table of all inner producst with codebook entries */
373375
void compute_inner_prod_with_LUT(
374376
const AdditiveQuantizer& aq,
375377
const float* LUT,

faiss/impl/AdditiveQuantizer.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,13 @@ struct AdditiveQuantizer : Quantizer {
4949
/// encode a norm into norm_bits bits
5050
uint64_t encode_norm(float norm) const;
5151

52+
/// encode norm by non-uniform scalar quantization
5253
uint32_t encode_qcint(
53-
float x) const; ///< encode norm by non-uniform scalar quantization
54+
float x) const;
5455

56+
/// decode norm by non-uniform scalar quantization
5557
float decode_qcint(uint32_t c)
56-
const; ///< decode norm by non-uniform scalar quantization
58+
const;
5759

5860
/// Encodes how search is performed and how vectors are encoded
5961
enum Search_type_t {

faiss/impl/ResidualQuantizer.cpp

+26-41
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ void ResidualQuantizer::initialize_from(
125125
}
126126
}
127127

128+
/****************************************************************
129+
* Encoding steps, used both for training and search
130+
*/
131+
128132
void beam_search_encode_step(
129133
size_t d,
130134
size_t K,
@@ -277,6 +281,10 @@ void beam_search_encode_step(
277281
}
278282
}
279283

284+
/****************************************************************
285+
* Training
286+
****************************************************************/
287+
280288
void ResidualQuantizer::train(size_t n, const float* x) {
281289
codebooks.resize(d * codebook_offsets.back());
282290

@@ -568,7 +576,12 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
568576
return mem;
569577
}
570578

571-
// a namespace full of preallocated buffers
579+
/****************************************************************
580+
* Encoding
581+
****************************************************************/
582+
583+
// a namespace full of preallocated buffers. This speeds up
584+
// computations, instead of re-allocating them at every encoing step
572585
namespace {
573586

574587
// Preallocated memory chunk for refine_beam_mp() call
@@ -609,8 +622,6 @@ struct ComputeCodesAddCentroidsLUT1MemoryPool {
609622
RefineBeamLUTMemoryPool refine_beam_lut_pool;
610623
};
611624

612-
} // namespace
613-
614625
// forward declaration
615626
void refine_beam_mp(
616627
const ResidualQuantizer& rq,
@@ -743,6 +754,8 @@ void compute_codes_add_centroids_mp_lut1(
743754
centroids);
744755
}
745756

757+
} // namespace
758+
746759
void ResidualQuantizer::compute_codes_add_centroids(
747760
const float* x,
748761
uint8_t* codes_out,
@@ -769,11 +782,6 @@ void ResidualQuantizer::compute_codes_add_centroids(
769782
cent = centroids + i0 * d;
770783
}
771784

772-
// compute_codes_add_centroids(
773-
// x + i0 * d,
774-
// codes_out + i0 * code_size,
775-
// i1 - i0,
776-
// cent);
777785
if (use_beam_LUT == 0) {
778786
compute_codes_add_centroids_mp_lut0(
779787
*this,
@@ -794,6 +802,8 @@ void ResidualQuantizer::compute_codes_add_centroids(
794802
}
795803
}
796804

805+
namespace {
806+
797807
void refine_beam_mp(
798808
const ResidualQuantizer& rq,
799809
size_t n,
@@ -873,15 +883,11 @@ void refine_beam_mp(
873883
codebooks_m,
874884
n,
875885
cur_beam_size,
876-
// residuals.data(),
877886
residuals_ptr,
878887
m,
879-
// codes.data(),
880888
codes_ptr,
881889
new_beam_size,
882-
// new_codes.data(),
883890
new_codes_ptr,
884-
// new_residuals.data(),
885891
new_residuals_ptr,
886892
pool.distances.data(),
887893
assign_index.get(),
@@ -896,9 +902,6 @@ void refine_beam_mp(
896902

897903
if (rq.verbose) {
898904
float sum_distances = 0;
899-
// for (int j = 0; j < distances.size(); j++) {
900-
// sum_distances += distances[j];
901-
// }
902905
for (int j = 0; j < distances_size; j++) {
903906
sum_distances += pool.distances[j];
904907
}
@@ -914,27 +917,22 @@ void refine_beam_mp(
914917
}
915918

916919
if (out_codes) {
917-
// memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
918920
memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
919921
}
920922
if (out_residuals) {
921-
// memcpy(out_residuals,
922-
// residuals.data(),
923-
// residuals.size() * sizeof(residuals[0]));
924923
memcpy(out_residuals,
925924
residuals_ptr,
926925
residuals_size * sizeof(*residuals_ptr));
927926
}
928927
if (out_distances) {
929-
// memcpy(out_distances,
930-
// distances.data(),
931-
// distances.size() * sizeof(distances[0]));
932928
memcpy(out_distances,
933929
pool.distances.data(),
934930
distances_size * sizeof(pool.distances[0]));
935931
}
936932
}
937933

934+
} // anonymous namespace
935+
938936
void ResidualQuantizer::refine_beam(
939937
size_t n,
940938
size_t beam_size,
@@ -1165,7 +1163,7 @@ void accum_and_finalize_tab(
11651163
}
11661164
}
11671165

1168-
} // namespace
1166+
} // anonymous namespace
11691167

11701168
void beam_search_encode_step_tab(
11711169
size_t K,
@@ -1390,6 +1388,8 @@ void beam_search_encode_step_tab(
13901388
}
13911389
}
13921390

1391+
namespace {
1392+
13931393
//
13941394
void refine_beam_LUT_mp(
13951395
const ResidualQuantizer& rq,
@@ -1443,13 +1443,9 @@ void refine_beam_LUT_mp(
14431443
for (int m = 0; m < rq.M; m++) {
14441444
int K = 1 << rq.nbits[m];
14451445

1446-
// it is guaranteed that (new_beam_size <= than max_beam_size) ==
1447-
// true
1446+
// it is guaranteed that (new_beam_size <= max_beam_size)
14481447
int new_beam_size = std::min(beam_size * K, out_beam_size);
14491448

1450-
// std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
1451-
// std::vector<float> new_distances(n * new_beam_size);
1452-
14531449
codes_size = n * new_beam_size * (m + 1);
14541450
distances_size = n * new_beam_size;
14551451

@@ -1464,29 +1460,20 @@ void refine_beam_LUT_mp(
14641460
rq.total_codebook_size,
14651461
rq.cent_norms.data() + rq.codebook_offsets[m],
14661462
m,
1467-
// codes.data(),
14681463
codes_ptr,
1469-
// distances.data(),
14701464
distances_ptr,
14711465
new_beam_size,
1472-
// new_codes.data(),
14731466
new_codes_ptr,
1474-
// new_distances.data()
14751467
new_distances_ptr,
14761468
rq.approx_topk_mode);
14771469

1478-
// codes.swap(new_codes);
14791470
std::swap(codes_ptr, new_codes_ptr);
1480-
// distances.swap(new_distances);
14811471
std::swap(distances_ptr, new_distances_ptr);
14821472

14831473
beam_size = new_beam_size;
14841474

14851475
if (rq.verbose) {
14861476
float sum_distances = 0;
1487-
// for (int j = 0; j < distances.size(); j++) {
1488-
// sum_distances += distances[j];
1489-
// }
14901477
for (int j = 0; j < distances_size; j++) {
14911478
sum_distances += distances_ptr[j];
14921479
}
@@ -1501,19 +1488,17 @@ void refine_beam_LUT_mp(
15011488
}
15021489

15031490
if (out_codes) {
1504-
// memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
15051491
memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
15061492
}
15071493
if (out_distances) {
1508-
// memcpy(out_distances,
1509-
// distances.data(),
1510-
// distances.size() * sizeof(distances[0]));
15111494
memcpy(out_distances,
15121495
distances_ptr,
15131496
distances_size * sizeof(*distances_ptr));
15141497
}
15151498
}
15161499

1500+
} // namespace
1501+
15171502
void ResidualQuantizer::refine_beam_LUT(
15181503
size_t n,
15191504
const float* query_norms, // size n

faiss/impl/ResidualQuantizer.h

+10-3
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,7 @@ struct ResidualQuantizer : AdditiveQuantizer {
144144
*/
145145
size_t memory_per_point(int beam_size = -1) const;
146146

147-
/** Cross products used in codebook tables
148-
*
149-
* These are used to keep trak of norms of centroids.
147+
/** Cross products used in codebook tables used for beam_LUT = 1
150148
*/
151149
void compute_codebook_tables();
152150

@@ -194,6 +192,15 @@ void beam_search_encode_step(
194192

195193
/** Encode a set of vectors using their dot products with the codebooks
196194
*
195+
* @param K number of vectors in the codebook
196+
* @param n nb of vectors to encode
197+
* @param beam_size input beam size
198+
* @param codebook_cross_norms inner product of this codebook with the m
199+
* previously encoded codebooks
200+
* @param codebook_offsets offsets into codebook_cross_norms for each
201+
* previous codebook
202+
* @param query_cp dot products of query vectors with ???
203+
* @param cent_norms_i norms of centroids
197204
*/
198205
void beam_search_encode_step_tab(
199206
size_t K,

faiss/impl/index_read.cpp

+26-13
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,17 @@ static void read_AdditiveQuantizer(AdditiveQuantizer* aq, IOReader* f) {
292292
aq->set_derived_values();
293293
}
294294

295-
static void read_ResidualQuantizer(ResidualQuantizer* rq, IOReader* f) {
295+
static void read_ResidualQuantizer(
296+
ResidualQuantizer* rq,
297+
IOReader* f,
298+
int io_flags) {
296299
read_AdditiveQuantizer(rq, f);
297300
READ1(rq->train_type);
298301
READ1(rq->max_beam_size);
299-
if (!(rq->train_type & ResidualQuantizer::Skip_codebook_tables)) {
302+
if ((rq->train_type & ResidualQuantizer::Skip_codebook_tables) ||
303+
(io_flags & IO_FLAG_SKIP_PRECOMPUTE_TABLE)) {
304+
// don't precompute the tables
305+
} else {
300306
rq->compute_codebook_tables();
301307
}
302308
}
@@ -325,12 +331,13 @@ static void read_ProductAdditiveQuantizer(
325331

326332
static void read_ProductResidualQuantizer(
327333
ProductResidualQuantizer* prq,
328-
IOReader* f) {
334+
IOReader* f,
335+
int io_flags) {
329336
read_ProductAdditiveQuantizer(prq, f);
330337

331338
for (size_t i = 0; i < prq->nsplits; i++) {
332339
auto rq = new ResidualQuantizer();
333-
read_ResidualQuantizer(rq, f);
340+
read_ResidualQuantizer(rq, f, io_flags);
334341
prq->quantizers.push_back(rq);
335342
}
336343
}
@@ -601,7 +608,7 @@ Index* read_index(IOReader* f, int io_flags) {
601608
if (h == fourcc("IxRQ")) {
602609
read_ResidualQuantizer_old(&idxr->rq, f);
603610
} else {
604-
read_ResidualQuantizer(&idxr->rq, f);
611+
read_ResidualQuantizer(&idxr->rq, f, io_flags);
605612
}
606613
READ1(idxr->code_size);
607614
READVECTOR(idxr->codes);
@@ -616,7 +623,7 @@ Index* read_index(IOReader* f, int io_flags) {
616623
} else if (h == fourcc("IxPR")) {
617624
auto idxpr = new IndexProductResidualQuantizer();
618625
read_index_header(idxpr, f);
619-
read_ProductResidualQuantizer(&idxpr->prq, f);
626+
read_ProductResidualQuantizer(&idxpr->prq, f, io_flags);
620627
READ1(idxpr->code_size);
621628
READVECTOR(idxpr->codes);
622629
idx = idxpr;
@@ -630,8 +637,13 @@ Index* read_index(IOReader* f, int io_flags) {
630637
} else if (h == fourcc("ImRQ")) {
631638
ResidualCoarseQuantizer* idxr = new ResidualCoarseQuantizer();
632639
read_index_header(idxr, f);
633-
read_ResidualQuantizer(&idxr->rq, f);
640+
read_ResidualQuantizer(&idxr->rq, f, io_flags);
634641
READ1(idxr->beam_factor);
642+
if (io_flags & IO_FLAG_SKIP_PRECOMPUTE_TABLE) {
643+
// then we force the beam factor to -1
644+
// which skips the table precomputation.
645+
idxr->beam_factor = -1;
646+
}
635647
idxr->set_beam_factor(idxr->beam_factor);
636648
idx = idxr;
637649
} else if (
@@ -656,13 +668,14 @@ Index* read_index(IOReader* f, int io_flags) {
656668
if (is_LSQ) {
657669
read_LocalSearchQuantizer((LocalSearchQuantizer*)idxaqfs->aq, f);
658670
} else if (is_RQ) {
659-
read_ResidualQuantizer((ResidualQuantizer*)idxaqfs->aq, f);
671+
read_ResidualQuantizer(
672+
(ResidualQuantizer*)idxaqfs->aq, f, io_flags);
660673
} else if (is_PLSQ) {
661674
read_ProductLocalSearchQuantizer(
662675
(ProductLocalSearchQuantizer*)idxaqfs->aq, f);
663676
} else {
664677
read_ProductResidualQuantizer(
665-
(ProductResidualQuantizer*)idxaqfs->aq, f);
678+
(ProductResidualQuantizer*)idxaqfs->aq, f, io_flags);
666679
}
667680

668681
READ1(idxaqfs->implem);
@@ -704,13 +717,13 @@ Index* read_index(IOReader* f, int io_flags) {
704717
if (is_LSQ) {
705718
read_LocalSearchQuantizer((LocalSearchQuantizer*)ivaqfs->aq, f);
706719
} else if (is_RQ) {
707-
read_ResidualQuantizer((ResidualQuantizer*)ivaqfs->aq, f);
720+
read_ResidualQuantizer((ResidualQuantizer*)ivaqfs->aq, f, io_flags);
708721
} else if (is_PLSQ) {
709722
read_ProductLocalSearchQuantizer(
710723
(ProductLocalSearchQuantizer*)ivaqfs->aq, f);
711724
} else {
712725
read_ProductResidualQuantizer(
713-
(ProductResidualQuantizer*)ivaqfs->aq, f);
726+
(ProductResidualQuantizer*)ivaqfs->aq, f, io_flags);
714727
}
715728

716729
READ1(ivaqfs->by_residual);
@@ -832,13 +845,13 @@ Index* read_index(IOReader* f, int io_flags) {
832845
if (is_LSQ) {
833846
read_LocalSearchQuantizer((LocalSearchQuantizer*)iva->aq, f);
834847
} else if (is_RQ) {
835-
read_ResidualQuantizer((ResidualQuantizer*)iva->aq, f);
848+
read_ResidualQuantizer((ResidualQuantizer*)iva->aq, f, io_flags);
836849
} else if (is_PLSQ) {
837850
read_ProductLocalSearchQuantizer(
838851
(ProductLocalSearchQuantizer*)iva->aq, f);
839852
} else {
840853
read_ProductResidualQuantizer(
841-
(ProductResidualQuantizer*)iva->aq, f);
854+
(ProductResidualQuantizer*)iva->aq, f, io_flags);
842855
}
843856
READ1(iva->by_residual);
844857
READ1(iva->use_precomputed_table);

faiss/python/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,10 @@ def serialize_index(index):
298298
return vector_to_array(writer.data)
299299

300300

301-
def deserialize_index(data):
301+
def deserialize_index(data, io_flags=0):
302302
reader = VectorIOReader()
303303
copy_array_to_vector(data, reader.data)
304-
return read_index(reader)
304+
return read_index(reader, io_flags)
305305

306306

307307
def serialize_index_binary(index):

0 commit comments

Comments
 (0)