Skip to content

Commit c3ebd0c

Browse files
mdouzefacebook-github-bot
authored andcommittedJul 5, 2024
Add search functionality to FlatCodes (facebookresearch#3611)
Summary: Pull Request resolved: facebookresearch#3611 Using the new dispatcher functions, add search func to flat codes. To test it, make IndexLattice a subclass of FlatCodes and check the resonstruction there. Differential Revision: D59367989
1 parent 92594d1 commit c3ebd0c

6 files changed

+205
-50
lines changed
 

‎contrib/inspect_tools.py

+6
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ def get_flat_data(index):
9898
return xb.reshape(index.ntotal, index.d)
9999

100100

101+
def get_flat_codes(index_flat):
102+
""" get the codes from an indexFlatCodes as an array """
103+
return faiss.vector_to_array(index_flat.codes).reshape(
104+
index_flat.ntotal, index_flat.code_size)
105+
106+
101107
def get_NSG_neighbors(nsg):
102108
""" get the neighbor list for the vectors stored in the NSG structure, as
103109
a N-by-K matrix of indices """

‎faiss/IndexFlatCodes.cpp

+161-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <faiss/impl/DistanceComputer.h>
1313
#include <faiss/impl/FaissAssert.h>
1414
#include <faiss/impl/IDSelector.h>
15+
#include <faiss/impl/ResultHandler.h>
16+
#include <faiss/utils/extra_distances.h>
1517

1618
namespace faiss {
1719

@@ -70,11 +72,6 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
7072
reconstruct_n(key, 1, recons);
7173
}
7274

73-
FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
74-
const {
75-
FAISS_THROW_MSG("not implemented");
76-
}
77-
7875
void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
7976
// minimal sanity checks
8077
const IndexFlatCodes* other =
@@ -114,4 +111,163 @@ void IndexFlatCodes::permute_entries(const idx_t* perm) {
114111
std::swap(codes, new_codes);
115112
}
116113

114+
namespace {
115+
116+
template <class VD>
117+
struct GenericFlatCodesDistanceComputer : FlatCodesDistanceComputer {
118+
const IndexFlatCodes& codec;
119+
const VD vd;
120+
// temp buffers
121+
std::vector<uint8_t> code_buffer;
122+
std::vector<float> vec_buffer;
123+
const float* query = nullptr;
124+
125+
GenericFlatCodesDistanceComputer(const IndexFlatCodes* codec, const VD& vd)
126+
: FlatCodesDistanceComputer(codec->codes.data(), codec->code_size),
127+
codec(*codec),
128+
vd(vd),
129+
code_buffer(codec->code_size * 4),
130+
vec_buffer(codec->d * 4) {}
131+
132+
void set_query(const float* x) override {
133+
query = x;
134+
}
135+
136+
float operator()(idx_t i) override {
137+
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
138+
return vd(query, vec_buffer.data());
139+
}
140+
141+
float distance_to_code(const uint8_t* code) override {
142+
codec.sa_decode(1, code, vec_buffer.data());
143+
return vd(query, vec_buffer.data());
144+
}
145+
146+
float symmetric_dis(idx_t i, idx_t j) override {
147+
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
148+
codec.sa_decode(1, codes + j * code_size, vec_buffer.data() + vd.d);
149+
return vd(vec_buffer.data(), vec_buffer.data() + vd.d);
150+
}
151+
152+
void distances_batch_4(
153+
const idx_t idx0,
154+
const idx_t idx1,
155+
const idx_t idx2,
156+
const idx_t idx3,
157+
float& dis0,
158+
float& dis1,
159+
float& dis2,
160+
float& dis3) {
161+
uint8_t* cp = code_buffer.data();
162+
for (idx_t i : {idx0, idx1, idx2, idx3}) {
163+
memcpy(cp, codes + i * code_size, code_size);
164+
cp += code_size;
165+
}
166+
// potential benefit is if batch decoding is more efficient than 1 by 1
167+
// decoding
168+
codec.sa_decode(4, code_buffer.data(), vec_buffer.data());
169+
dis0 = vd(query, vec_buffer.data());
170+
dis1 = vd(query, vec_buffer.data() + vd.d);
171+
dis2 = vd(query, vec_buffer.data() + 2 * vd.d);
172+
dis3 = vd(query, vec_buffer.data() + 3 * vd.d);
173+
}
174+
};
175+
176+
struct Run_get_distance_computer {
177+
using T = FlatCodesDistanceComputer*;
178+
179+
template <class VD>
180+
FlatCodesDistanceComputer* f(const VD& vd, const IndexFlatCodes* codec) {
181+
return new GenericFlatCodesDistanceComputer<VD>(codec, vd);
182+
}
183+
};
184+
185+
template <class BlockResultHandler>
186+
struct Run_search_with_decompress {
187+
using T = void;
188+
189+
template <class VectorDistance>
190+
void f(VectorDistance& vd,
191+
const IndexFlatCodes* index_ptr,
192+
const float* xq,
193+
BlockResultHandler& res) {
194+
// Note that there seems to be a clang (?) bug that "sometimes" passes
195+
// the const Index & parameters by value, so to be on the safe side,
196+
// it's better to use pointers.
197+
const IndexFlatCodes& index = *index_ptr;
198+
const uint8_t* codes = index.codes.data();
199+
size_t ntotal = index.ntotal;
200+
size_t code_size = index.code_size;
201+
using SingleResultHandler =
202+
typename BlockResultHandler::SingleResultHandler;
203+
using DC = GenericFlatCodesDistanceComputer<VectorDistance>;
204+
#pragma omp parallel // if (res.nq > 100)
205+
{
206+
std::unique_ptr<DC> dc(new DC(&index, vd));
207+
SingleResultHandler resi(res);
208+
#pragma omp for
209+
for (int64_t q = 0; q < res.nq; q++) {
210+
resi.begin(q);
211+
dc->set_query(xq + vd.d * q);
212+
for (size_t i = 0; i < ntotal; i++) {
213+
if (res.is_in_selection(i)) {
214+
float dis = (*dc)(i);
215+
resi.add_result(dis, i);
216+
}
217+
}
218+
resi.end();
219+
}
220+
}
221+
}
222+
};
223+
224+
struct Run_search_with_decompress_res {
225+
using T = void;
226+
227+
template <class ResultHandler>
228+
void f(ResultHandler& res, const IndexFlatCodes* index, const float* xq) {
229+
Run_search_with_decompress<ResultHandler> r;
230+
dispatch_VectorDistance(
231+
index->d,
232+
index->metric_type,
233+
index->metric_arg,
234+
r,
235+
index,
236+
xq,
237+
res);
238+
}
239+
};
240+
241+
} // anonymous namespace
242+
243+
FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
244+
const {
245+
Run_get_distance_computer r;
246+
return dispatch_VectorDistance(d, metric_type, metric_arg, r, this);
247+
}
248+
249+
void IndexFlatCodes::search(
250+
idx_t n,
251+
const float* x,
252+
idx_t k,
253+
float* distances,
254+
idx_t* labels,
255+
const SearchParameters* params) const {
256+
Run_search_with_decompress_res r;
257+
const IDSelector* sel = params ? params->sel : nullptr;
258+
dispatch_knn_ResultHandler(
259+
n, distances, labels, k, metric_type, sel, r, this, x);
260+
}
261+
262+
void IndexFlatCodes::range_search(
263+
idx_t n,
264+
const float* x,
265+
float radius,
266+
RangeSearchResult* result,
267+
const SearchParameters* params) const {
268+
const IDSelector* sel = params ? params->sel : nullptr;
269+
Run_search_with_decompress_res r;
270+
dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
271+
}
272+
117273
} // namespace faiss

‎faiss/IndexFlatCodes.h

+20-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// -*- c++ -*-
9-
108
#pragma once
119

1210
#include <faiss/Index.h>
@@ -45,13 +43,32 @@ struct IndexFlatCodes : Index {
4543
* different from the usual ones: the new ids are shifted */
4644
size_t remove_ids(const IDSelector& sel) override;
4745

48-
/** a FlatCodesDistanceComputer offers a distance_to_code method */
46+
/** a FlatCodesDistanceComputer offers a distance_to_code method
47+
*
48+
* The default implementation explicitly decodes the vector with sa_decode.
49+
*/
4950
virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;
5051

5152
DistanceComputer* get_distance_computer() const override {
5253
return get_FlatCodesDistanceComputer();
5354
}
5455

56+
/** Search implemented by decoding */
57+
void search(
58+
idx_t n,
59+
const float* x,
60+
idx_t k,
61+
float* distances,
62+
idx_t* labels,
63+
const SearchParameters* params = nullptr) const override;
64+
65+
void range_search(
66+
idx_t n,
67+
const float* x,
68+
float radius,
69+
RangeSearchResult* result,
70+
const SearchParameters* params = nullptr) const override;
71+
5572
// returns a new instance of a CodePacker
5673
CodePacker* get_CodePacker() const;
5774

‎faiss/IndexLattice.cpp

+1-19
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
namespace faiss {
1616

1717
IndexLattice::IndexLattice(idx_t d, int nsq, int scale_nbit, int r2)
18-
: Index(d),
18+
: IndexFlatCodes(0, d, METRIC_L2),
1919
nsq(nsq),
2020
dsq(d / nsq),
2121
zn_sphere_codec(dsq, r2),
@@ -114,22 +114,4 @@ void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
114114
}
115115
}
116116

117-
void IndexLattice::add(idx_t, const float*) {
118-
FAISS_THROW_MSG("not implemented");
119-
}
120-
121-
void IndexLattice::search(
122-
idx_t,
123-
const float*,
124-
idx_t,
125-
float*,
126-
idx_t*,
127-
const SearchParameters*) const {
128-
FAISS_THROW_MSG("not implemented");
129-
}
130-
131-
void IndexLattice::reset() {
132-
FAISS_THROW_MSG("not implemented");
133-
}
134-
135117
} // namespace faiss

‎faiss/IndexLattice.h

+3-22
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,18 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// -*- c++ -*-
9-
10-
#ifndef FAISS_INDEX_LATTICE_H
11-
#define FAISS_INDEX_LATTICE_H
8+
#pragma once
129

1310
#include <vector>
1411

15-
#include <faiss/IndexIVF.h>
12+
#include <faiss/IndexFlatCodes.h>
1613
#include <faiss/impl/lattice_Zn.h>
1714

1815
namespace faiss {
1916

2017
/** Index that encodes a vector with a series of Zn lattice quantizers
2118
*/
22-
struct IndexLattice : Index {
19+
struct IndexLattice : IndexFlatCodes {
2320
/// number of sub-vectors
2421
int nsq;
2522
/// dimension of sub-vectors
@@ -30,8 +27,6 @@ struct IndexLattice : Index {
3027

3128
/// nb bits used to encode the scale, per subvector
3229
int scale_nbit, lattice_nbit;
33-
/// total, in bytes
34-
size_t code_size;
3530

3631
/// mins and maxes of the vector norms, per subquantizer
3732
std::vector<float> trained;
@@ -46,20 +41,6 @@ struct IndexLattice : Index {
4641
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
4742

4843
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
49-
50-
/// not implemented
51-
void add(idx_t n, const float* x) override;
52-
void search(
53-
idx_t n,
54-
const float* x,
55-
idx_t k,
56-
float* distances,
57-
idx_t* labels,
58-
const SearchParameters* params = nullptr) const override;
59-
60-
void reset() override;
6144
};
6245

6346
} // namespace faiss
64-
65-
#endif

‎tests/test_standalone_codec.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from common_faiss_tests import get_dataset_2
1414
from faiss.contrib.datasets import SyntheticDataset
15-
from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks
15+
from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks, get_flat_codes
1616

1717
class TestEncodeDecode(unittest.TestCase):
1818

@@ -263,6 +263,19 @@ def test_ZnSphereCodecAlt32(self):
263263
def test_ZnSphereCodecAlt24(self):
264264
self.run_ZnSphereCodecAlt(24, 14)
265265

266+
def test_lattice_index(self):
267+
index = faiss.index_factory(96, "ZnLattice3x10_4")
268+
rs = np.random.RandomState(123)
269+
xq = rs.randn(10, 96).astype('float32')
270+
xb = rs.randn(20, 96).astype('float32')
271+
index.train(xb)
272+
index.add(xb)
273+
D, I = index.search(xq, 5)
274+
for i in range(10):
275+
recons = index.reconstruct_batch(I[i, :])
276+
ref_dis = ((recons - xq[i]) ** 2).sum(1)
277+
np.testing.assert_allclose(D[i, :], ref_dis, atol=5e-5)
278+
266279

267280
class TestBitstring(unittest.TestCase):
268281

0 commit comments

Comments
 (0)