Skip to content

Commit 6800ebe

Browse files
mdouzefacebook-github-bot
authored andcommitted
Support independent IVF coarse quantizer
Summary: In the IndexIVFIndepenentQuantizer, the coarse quantizer is applied on the input vectors, but the encoding is performed on a vector-transformed version of the database elements. Reviewed By: alexanderguzhva Differential Revision: D45950970 fbshipit-source-id: 30f6cf46d44174b1d99a12384b7d5e2d475c1f88
1 parent a3296f4 commit 6800ebe

14 files changed

+448
-55
lines changed

contrib/inspect_tools.py

+14
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,20 @@ def get_LinearTransform_matrix(pca):
6868
return A, b
6969

7070

71+
def make_LinearTransform_matrix(A, b=None):
72+
""" make a linear transform from a matrix and a bias term (optional)"""
73+
d_out, d_in = A.shape
74+
if b is not None:
75+
assert b.shape == (d_out, )
76+
lt = faiss.LinearTransform(d_in, d_out, b is not None)
77+
faiss.copy_array_to_vector(A.ravel(), lt.A)
78+
if b is not None:
79+
faiss.copy_array_to_vector(b, lt.b)
80+
lt.is_trained = True
81+
lt.set_is_orthonormal()
82+
return lt
83+
84+
7185
def get_additive_quantizer_codebooks(aq):
7286
""" return to codebooks of an additive quantizer """
7387
codebooks = faiss.vector_to_array(aq.codebooks).reshape(-1, aq.d)

faiss/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ set(FAISS_SRC
3737
IndexPQ.cpp
3838
IndexFastScan.cpp
3939
IndexAdditiveQuantizerFastScan.cpp
40+
IndexIVFIndependentQuantizer.cpp
4041
IndexPQFastScan.cpp
4142
IndexPreTransform.cpp
4243
IndexRefine.cpp
@@ -113,6 +114,7 @@ set(FAISS_HEADERS
113114
IndexIDMap.h
114115
IndexIVF.h
115116
IndexIVFAdditiveQuantizer.h
117+
IndexIVFIndependentQuantizer.h
116118
IndexIVFFlat.h
117119
IndexIVFPQ.h
118120
IndexIVFFastScan.h

faiss/IVFlib.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <faiss/IndexAdditiveQuantizer.h>
1414
#include <faiss/IndexIVFAdditiveQuantizer.h>
15+
#include <faiss/IndexIVFIndependentQuantizer.h>
1516
#include <faiss/IndexPreTransform.h>
1617
#include <faiss/MetaIndexes.h>
1718
#include <faiss/impl/FaissAssert.h>
@@ -67,6 +68,10 @@ const IndexIVF* try_extract_index_ivf(const Index* index) {
6768
if (auto* idmap = dynamic_cast<const IndexIDMap2*>(index)) {
6869
index = idmap->index;
6970
}
71+
if (auto* indep =
72+
dynamic_cast<const IndexIVFIndependentQuantizer*>(index)) {
73+
index = indep->index_ivf;
74+
}
7075

7176
auto* ivf = dynamic_cast<const IndexIVF*>(index);
7277

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <faiss/IndexIVFIndependentQuantizer.h>
9+
#include <faiss/IndexIVFPQ.h>
10+
#include <faiss/impl/FaissAssert.h>
11+
#include <faiss/utils/utils.h>
12+
13+
namespace faiss {
14+
15+
IndexIVFIndependentQuantizer::IndexIVFIndependentQuantizer(
16+
Index* quantizer,
17+
IndexIVF* index_ivf,
18+
VectorTransform* vt)
19+
: Index(quantizer->d, index_ivf->metric_type),
20+
quantizer(quantizer),
21+
vt(vt),
22+
index_ivf(index_ivf) {
23+
if (vt) {
24+
FAISS_THROW_IF_NOT_MSG(
25+
vt->d_in == d && vt->d_out == index_ivf->d,
26+
"invalid vector dimensions");
27+
} else {
28+
FAISS_THROW_IF_NOT_MSG(index_ivf->d == d, "invalid vector dimensions");
29+
}
30+
31+
if (quantizer->is_trained && quantizer->ntotal != 0) {
32+
FAISS_THROW_IF_NOT(quantizer->ntotal == index_ivf->nlist);
33+
}
34+
if (index_ivf->is_trained && vt) {
35+
FAISS_THROW_IF_NOT(vt->is_trained);
36+
}
37+
ntotal = index_ivf->ntotal;
38+
is_trained =
39+
(quantizer->is_trained && quantizer->ntotal == index_ivf->nlist &&
40+
(!vt || vt->is_trained) && index_ivf->is_trained);
41+
42+
// disable precomputed tables because they use the distances that are
43+
// provided by the coarse quantizer (that are out of sync with the IVFPQ)
44+
if (auto index_ivfpq = dynamic_cast<IndexIVFPQ*>(index_ivf)) {
45+
index_ivfpq->use_precomputed_table = -1;
46+
}
47+
}
48+
49+
IndexIVFIndependentQuantizer::~IndexIVFIndependentQuantizer() {
50+
if (own_fields) {
51+
delete quantizer;
52+
delete index_ivf;
53+
delete vt;
54+
}
55+
}
56+
57+
namespace {
58+
59+
struct VTransformedVectors : TransformedVectors {
60+
VTransformedVectors(const VectorTransform* vt, idx_t n, const float* x)
61+
: TransformedVectors(x, vt ? vt->apply(n, x) : x) {}
62+
};
63+
64+
struct SubsampledVectors : TransformedVectors {
65+
SubsampledVectors(int d, idx_t* n, idx_t max_n, const float* x)
66+
: TransformedVectors(
67+
x,
68+
fvecs_maybe_subsample(d, (size_t*)n, max_n, x, true)) {}
69+
};
70+
71+
} // anonymous namespace
72+
73+
void IndexIVFIndependentQuantizer::add(idx_t n, const float* x) {
74+
std::vector<float> D(n);
75+
std::vector<idx_t> I(n);
76+
quantizer->search(n, x, 1, D.data(), I.data());
77+
78+
VTransformedVectors tv(vt, n, x);
79+
80+
index_ivf->add_core(n, tv.x, nullptr, I.data());
81+
}
82+
83+
void IndexIVFIndependentQuantizer::search(
84+
idx_t n,
85+
const float* x,
86+
idx_t k,
87+
float* distances,
88+
idx_t* labels,
89+
const SearchParameters* params) const {
90+
FAISS_THROW_IF_NOT_MSG(!params, "search parameters not supported");
91+
int nprobe = index_ivf->nprobe;
92+
std::vector<float> D(n * nprobe);
93+
std::vector<idx_t> I(n * nprobe);
94+
quantizer->search(n, x, nprobe, D.data(), I.data());
95+
96+
VTransformedVectors tv(vt, n, x);
97+
98+
index_ivf->search_preassigned(
99+
n, tv.x, k, I.data(), D.data(), distances, labels, false);
100+
}
101+
102+
void IndexIVFIndependentQuantizer::reset() {
103+
index_ivf->reset();
104+
ntotal = 0;
105+
}
106+
107+
void IndexIVFIndependentQuantizer::train(idx_t n, const float* x) {
108+
// quantizer training
109+
size_t nlist = index_ivf->nlist;
110+
Level1Quantizer l1(quantizer, nlist);
111+
l1.train_q1(n, x, verbose, metric_type);
112+
113+
// train the VectorTransform
114+
if (vt && !vt->is_trained) {
115+
if (verbose) {
116+
printf("IndexIVFIndependentQuantizer: train the VectorTransform\n");
117+
}
118+
vt->train(n, x);
119+
}
120+
121+
// get the centroids from the quantizer, transform them and
122+
// add them to the index_ivf's quantizer
123+
if (verbose) {
124+
printf("IndexIVFIndependentQuantizer: extract the main quantizer centroids\n");
125+
}
126+
std::vector<float> centroids(nlist * d);
127+
quantizer->reconstruct_n(0, nlist, centroids.data());
128+
VTransformedVectors tcent(vt, nlist, centroids.data());
129+
130+
if (verbose) {
131+
printf("IndexIVFIndependentQuantizer: add centroids to the secondary quantizer\n");
132+
}
133+
if (!index_ivf->quantizer->is_trained) {
134+
index_ivf->quantizer->train(nlist, tcent.x);
135+
}
136+
index_ivf->quantizer->add(nlist, tcent.x);
137+
138+
// train the payload
139+
140+
// optional subsampling
141+
idx_t max_nt = index_ivf->train_encoder_num_vectors();
142+
if (max_nt <= 0) {
143+
max_nt = (size_t)1 << 35;
144+
}
145+
SubsampledVectors sv(index_ivf->d, &n, max_nt, x);
146+
147+
// transform subsampled vectors
148+
VTransformedVectors tv(vt, n, sv.x);
149+
150+
if (verbose) {
151+
printf("IndexIVFIndependentQuantizer: train encoder\n");
152+
}
153+
154+
if (index_ivf->by_residual) {
155+
// assign with quantizer
156+
std::vector<idx_t> assign(n);
157+
quantizer->assign(n, sv.x, assign.data());
158+
159+
// compute residual with IVF quantizer
160+
std::vector<float> residuals(n * index_ivf->d);
161+
index_ivf->quantizer->compute_residual_n(
162+
n, tv.x, residuals.data(), assign.data());
163+
164+
index_ivf->train_encoder(n, residuals.data(), assign.data());
165+
} else {
166+
index_ivf->train_encoder(n, tv.x, nullptr);
167+
}
168+
index_ivf->is_trained = true;
169+
is_trained = true;
170+
}
171+
172+
} // namespace faiss

faiss/IndexIVFIndependentQuantizer.h

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
#include <faiss/IndexIVF.h>
11+
#include <faiss/VectorTransform.h>
12+
13+
namespace faiss {
14+
15+
/** An IVF index with a quantizer that has a different input dimension from the
16+
* payload size. The vectors to encode are obtained from the input vectors by a
17+
* VectorTransform.
18+
*/
19+
struct IndexIVFIndependentQuantizer : Index {
20+
/// quantizer is fed directly with the input vectors
21+
Index* quantizer = nullptr;
22+
23+
/// transform before the IVF vectors are applied
24+
VectorTransform* vt = nullptr;
25+
26+
/// the IVF index, controls nlist and nprobe
27+
IndexIVF* index_ivf = nullptr;
28+
29+
/// whether *this owns the 3 fields
30+
bool own_fields = false;
31+
32+
IndexIVFIndependentQuantizer(
33+
Index* quantizer,
34+
IndexIVF* index_ivf,
35+
VectorTransform* vt = nullptr);
36+
37+
IndexIVFIndependentQuantizer() {}
38+
39+
void train(idx_t n, const float* x) override;
40+
41+
void add(idx_t n, const float* x) override;
42+
43+
void search(
44+
idx_t n,
45+
const float* x,
46+
idx_t k,
47+
float* distances,
48+
idx_t* labels,
49+
const SearchParameters* params = nullptr) const override;
50+
51+
void reset() override;
52+
53+
~IndexIVFIndependentQuantizer() override;
54+
};
55+
56+
} // namespace faiss

faiss/IndexPreTransform.cpp

+10-19
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,8 @@ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
141141

142142
void IndexPreTransform::add(idx_t n, const float* x) {
143143
FAISS_THROW_IF_NOT(is_trained);
144-
const float* xt = apply_chain(n, x);
145-
ScopeDeleter<float> del(xt == x ? nullptr : xt);
146-
index->add(n, xt);
144+
TransformedVectors tv(x, apply_chain(n, x));
145+
index->add(n, tv.x);
147146
ntotal = index->ntotal;
148147
}
149148

@@ -152,9 +151,8 @@ void IndexPreTransform::add_with_ids(
152151
const float* x,
153152
const idx_t* xids) {
154153
FAISS_THROW_IF_NOT(is_trained);
155-
const float* xt = apply_chain(n, x);
156-
ScopeDeleter<float> del(xt == x ? nullptr : xt);
157-
index->add_with_ids(n, xt, xids);
154+
TransformedVectors tv(x, apply_chain(n, x));
155+
index->add_with_ids(n, tv.x, xids);
158156
ntotal = index->ntotal;
159157
}
160158

@@ -190,10 +188,9 @@ void IndexPreTransform::range_search(
190188
RangeSearchResult* result,
191189
const SearchParameters* params) const {
192190
FAISS_THROW_IF_NOT(is_trained);
193-
const float* xt = apply_chain(n, x);
194-
ScopeDeleter<float> del(xt == x ? nullptr : xt);
191+
TransformedVectors tv(x, apply_chain(n, x));
195192
index->range_search(
196-
n, xt, radius, result, extract_index_search_params(params));
193+
n, tv.x, radius, result, extract_index_search_params(params));
197194
}
198195

199196
void IndexPreTransform::reset() {
@@ -238,14 +235,13 @@ void IndexPreTransform::search_and_reconstruct(
238235
FAISS_THROW_IF_NOT(k > 0);
239236
FAISS_THROW_IF_NOT(is_trained);
240237

241-
const float* xt = apply_chain(n, x);
242-
ScopeDeleter<float> del((xt == x) ? nullptr : xt);
238+
TransformedVectors trans(x, apply_chain(n, x));
243239

244240
float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
245241
ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
246242
index->search_and_reconstruct(
247243
n,
248-
xt,
244+
trans.x,
249245
k,
250246
distances,
251247
labels,
@@ -262,13 +258,8 @@ size_t IndexPreTransform::sa_code_size() const {
262258

263259
void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
264260
const {
265-
if (chain.empty()) {
266-
index->sa_encode(n, x, bytes);
267-
} else {
268-
const float* xt = apply_chain(n, x);
269-
ScopeDeleter<float> del(xt == x ? nullptr : xt);
270-
index->sa_encode(n, xt, bytes);
271-
}
261+
TransformedVectors tv(x, apply_chain(n, x));
262+
index->sa_encode(n, tv.x, bytes);
272263
}
273264

274265
void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)

0 commit comments

Comments
 (0)