Skip to content

Commit 81a74f4

Browse files
mdouzeketor
authored andcommitted
Add search functionality to FlatCodes (facebookresearch#3611)
Summary: Pull Request resolved: facebookresearch#3611 Using the new dispatcher functions, add search func to flat codes. To test it, make IndexLattice a subclass of FlatCodes and check the resonstruction there. Reviewed By: asadoughi Differential Revision: D59367989 fbshipit-source-id: 405dab4358fe34b2e38ac8bcc222b19f58643229
1 parent 7303cf6 commit 81a74f4

6 files changed

+203
-49
lines changed

contrib/inspect_tools.py

+6
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ def get_flat_data(index):
9898
return xb.reshape(index.ntotal, index.d)
9999

100100

101+
def get_flat_codes(index_flat):
102+
""" get the codes from an indexFlatCodes as an array """
103+
return faiss.vector_to_array(index_flat.codes).reshape(
104+
index_flat.ntotal, index_flat.code_size)
105+
106+
101107
def get_NSG_neighbors(nsg):
102108
""" get the neighbor list for the vectors stored in the NSG structure, as
103109
a N-by-K matrix of indices """

faiss/IndexFlatCodes.cpp

+159-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <faiss/impl/DistanceComputer.h>
1313
#include <faiss/impl/FaissAssert.h>
1414
#include <faiss/impl/IDSelector.h>
15+
#include <faiss/impl/ResultHandler.h>
16+
#include <faiss/utils/extra_distances.h>
1517

1618
namespace faiss {
1719

@@ -70,11 +72,6 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
7072
reconstruct_n(key, 1, recons);
7173
}
7274

73-
FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
74-
const {
75-
FAISS_THROW_MSG("not implemented");
76-
}
77-
7875
void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
7976
// minimal sanity checks
8077
const IndexFlatCodes* other =
@@ -114,4 +111,161 @@ void IndexFlatCodes::permute_entries(const idx_t* perm) {
114111
std::swap(codes, new_codes);
115112
}
116113

114+
namespace {
115+
116+
template <class VD>
117+
struct GenericFlatCodesDistanceComputer : FlatCodesDistanceComputer {
118+
const IndexFlatCodes& codec;
119+
const VD vd;
120+
// temp buffers
121+
std::vector<uint8_t> code_buffer;
122+
std::vector<float> vec_buffer;
123+
const float* query = nullptr;
124+
125+
GenericFlatCodesDistanceComputer(const IndexFlatCodes* codec, const VD& vd)
126+
: FlatCodesDistanceComputer(codec->codes.data(), codec->code_size),
127+
codec(*codec),
128+
vd(vd),
129+
code_buffer(codec->code_size * 4),
130+
vec_buffer(codec->d * 4) {}
131+
132+
void set_query(const float* x) override {
133+
query = x;
134+
}
135+
136+
float operator()(idx_t i) override {
137+
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
138+
return vd(query, vec_buffer.data());
139+
}
140+
141+
float distance_to_code(const uint8_t* code) override {
142+
codec.sa_decode(1, code, vec_buffer.data());
143+
return vd(query, vec_buffer.data());
144+
}
145+
146+
float symmetric_dis(idx_t i, idx_t j) override {
147+
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
148+
codec.sa_decode(1, codes + j * code_size, vec_buffer.data() + vd.d);
149+
return vd(vec_buffer.data(), vec_buffer.data() + vd.d);
150+
}
151+
152+
void distances_batch_4(
153+
const idx_t idx0,
154+
const idx_t idx1,
155+
const idx_t idx2,
156+
const idx_t idx3,
157+
float& dis0,
158+
float& dis1,
159+
float& dis2,
160+
float& dis3) override {
161+
uint8_t* cp = code_buffer.data();
162+
for (idx_t i : {idx0, idx1, idx2, idx3}) {
163+
memcpy(cp, codes + i * code_size, code_size);
164+
cp += code_size;
165+
}
166+
// potential benefit is if batch decoding is more efficient than 1 by 1
167+
// decoding
168+
codec.sa_decode(4, code_buffer.data(), vec_buffer.data());
169+
dis0 = vd(query, vec_buffer.data());
170+
dis1 = vd(query, vec_buffer.data() + vd.d);
171+
dis2 = vd(query, vec_buffer.data() + 2 * vd.d);
172+
dis3 = vd(query, vec_buffer.data() + 3 * vd.d);
173+
}
174+
};
175+
176+
struct Run_get_distance_computer {
177+
using T = FlatCodesDistanceComputer*;
178+
179+
template <class VD>
180+
FlatCodesDistanceComputer* f(const VD& vd, const IndexFlatCodes* codec) {
181+
return new GenericFlatCodesDistanceComputer<VD>(codec, vd);
182+
}
183+
};
184+
185+
template <class BlockResultHandler>
186+
struct Run_search_with_decompress {
187+
using T = void;
188+
189+
template <class VectorDistance>
190+
void f(VectorDistance& vd,
191+
const IndexFlatCodes* index_ptr,
192+
const float* xq,
193+
BlockResultHandler& res) {
194+
// Note that there seems to be a clang (?) bug that "sometimes" passes
195+
// the const Index & parameters by value, so to be on the safe side,
196+
// it's better to use pointers.
197+
const IndexFlatCodes& index = *index_ptr;
198+
size_t ntotal = index.ntotal;
199+
using SingleResultHandler =
200+
typename BlockResultHandler::SingleResultHandler;
201+
using DC = GenericFlatCodesDistanceComputer<VectorDistance>;
202+
#pragma omp parallel // if (res.nq > 100)
203+
{
204+
std::unique_ptr<DC> dc(new DC(&index, vd));
205+
SingleResultHandler resi(res);
206+
#pragma omp for
207+
for (int64_t q = 0; q < res.nq; q++) {
208+
resi.begin(q);
209+
dc->set_query(xq + vd.d * q);
210+
for (size_t i = 0; i < ntotal; i++) {
211+
if (res.is_in_selection(i)) {
212+
float dis = (*dc)(i);
213+
resi.add_result(dis, i);
214+
}
215+
}
216+
resi.end();
217+
}
218+
}
219+
}
220+
};
221+
222+
struct Run_search_with_decompress_res {
223+
using T = void;
224+
225+
template <class ResultHandler>
226+
void f(ResultHandler& res, const IndexFlatCodes* index, const float* xq) {
227+
Run_search_with_decompress<ResultHandler> r;
228+
dispatch_VectorDistance(
229+
index->d,
230+
index->metric_type,
231+
index->metric_arg,
232+
r,
233+
index,
234+
xq,
235+
res);
236+
}
237+
};
238+
239+
} // anonymous namespace
240+
241+
FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
242+
const {
243+
Run_get_distance_computer r;
244+
return dispatch_VectorDistance(d, metric_type, metric_arg, r, this);
245+
}
246+
247+
void IndexFlatCodes::search(
248+
idx_t n,
249+
const float* x,
250+
idx_t k,
251+
float* distances,
252+
idx_t* labels,
253+
const SearchParameters* params) const {
254+
Run_search_with_decompress_res r;
255+
const IDSelector* sel = params ? params->sel : nullptr;
256+
dispatch_knn_ResultHandler(
257+
n, distances, labels, k, metric_type, sel, r, this, x);
258+
}
259+
260+
void IndexFlatCodes::range_search(
261+
idx_t n,
262+
const float* x,
263+
float radius,
264+
RangeSearchResult* result,
265+
const SearchParameters* params) const {
266+
const IDSelector* sel = params ? params->sel : nullptr;
267+
Run_search_with_decompress_res r;
268+
dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
269+
}
270+
117271
} // namespace faiss

faiss/IndexFlatCodes.h

+20-3
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// -*- c++ -*-
9-
108
#pragma once
119

1210
#include <faiss/Index.h>
@@ -45,13 +43,32 @@ struct IndexFlatCodes : Index {
4543
* different from the usual ones: the new ids are shifted */
4644
size_t remove_ids(const IDSelector& sel) override;
4745

48-
/** a FlatCodesDistanceComputer offers a distance_to_code method */
46+
/** a FlatCodesDistanceComputer offers a distance_to_code method
47+
*
48+
* The default implementation explicitly decodes the vector with sa_decode.
49+
*/
4950
virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;
5051

5152
DistanceComputer* get_distance_computer() const override {
5253
return get_FlatCodesDistanceComputer();
5354
}
5455

56+
/** Search implemented by decoding */
57+
void search(
58+
idx_t n,
59+
const float* x,
60+
idx_t k,
61+
float* distances,
62+
idx_t* labels,
63+
const SearchParameters* params = nullptr) const override;
64+
65+
void range_search(
66+
idx_t n,
67+
const float* x,
68+
float radius,
69+
RangeSearchResult* result,
70+
const SearchParameters* params = nullptr) const override;
71+
5572
// returns a new instance of a CodePacker
5673
CodePacker* get_CodePacker() const;
5774

faiss/IndexLattice.cpp

+1-19
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
namespace faiss {
1616

1717
IndexLattice::IndexLattice(idx_t d, int nsq, int scale_nbit, int r2)
18-
: Index(d),
18+
: IndexFlatCodes(0, d, METRIC_L2),
1919
nsq(nsq),
2020
dsq(d / nsq),
2121
zn_sphere_codec(dsq, r2),
@@ -114,22 +114,4 @@ void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
114114
}
115115
}
116116

117-
void IndexLattice::add(idx_t, const float*) {
118-
FAISS_THROW_MSG("not implemented");
119-
}
120-
121-
void IndexLattice::search(
122-
idx_t,
123-
const float*,
124-
idx_t,
125-
float*,
126-
idx_t*,
127-
const SearchParameters*) const {
128-
FAISS_THROW_MSG("not implemented");
129-
}
130-
131-
void IndexLattice::reset() {
132-
FAISS_THROW_MSG("not implemented");
133-
}
134-
135117
} // namespace faiss

faiss/IndexLattice.h

+3-22
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,18 @@
55
* LICENSE file in the root directory of this source tree.
66
*/
77

8-
// -*- c++ -*-
9-
10-
#ifndef FAISS_INDEX_LATTICE_H
11-
#define FAISS_INDEX_LATTICE_H
8+
#pragma once
129

1310
#include <vector>
1411

15-
#include <faiss/IndexIVF.h>
12+
#include <faiss/IndexFlatCodes.h>
1613
#include <faiss/impl/lattice_Zn.h>
1714

1815
namespace faiss {
1916

2017
/** Index that encodes a vector with a series of Zn lattice quantizers
2118
*/
22-
struct IndexLattice : Index {
19+
struct IndexLattice : IndexFlatCodes {
2320
/// number of sub-vectors
2421
int nsq;
2522
/// dimension of sub-vectors
@@ -30,8 +27,6 @@ struct IndexLattice : Index {
3027

3128
/// nb bits used to encode the scale, per subvector
3229
int scale_nbit, lattice_nbit;
33-
/// total, in bytes
34-
size_t code_size;
3530

3631
/// mins and maxes of the vector norms, per subquantizer
3732
std::vector<float> trained;
@@ -46,20 +41,6 @@ struct IndexLattice : Index {
4641
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
4742

4843
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
49-
50-
/// not implemented
51-
void add(idx_t n, const float* x) override;
52-
void search(
53-
idx_t n,
54-
const float* x,
55-
idx_t k,
56-
float* distances,
57-
idx_t* labels,
58-
const SearchParameters* params = nullptr) const override;
59-
60-
void reset() override;
6144
};
6245

6346
} // namespace faiss
64-
65-
#endif

tests/test_standalone_codec.py

+14
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from faiss.contrib.datasets import SyntheticDataset
1515
from faiss.contrib.inspect_tools import get_additive_quantizer_codebooks
1616

17+
1718
class TestEncodeDecode(unittest.TestCase):
1819

1920
def do_encode_twice(self, factory_key):
@@ -263,6 +264,19 @@ def test_ZnSphereCodecAlt32(self):
263264
def test_ZnSphereCodecAlt24(self):
264265
self.run_ZnSphereCodecAlt(24, 14)
265266

267+
def test_lattice_index(self):
268+
index = faiss.index_factory(96, "ZnLattice3x10_4")
269+
rs = np.random.RandomState(123)
270+
xq = rs.randn(10, 96).astype('float32')
271+
xb = rs.randn(20, 96).astype('float32')
272+
index.train(xb)
273+
index.add(xb)
274+
D, I = index.search(xq, 5)
275+
for i in range(10):
276+
recons = index.reconstruct_batch(I[i, :])
277+
ref_dis = ((recons - xq[i]) ** 2).sum(1)
278+
np.testing.assert_allclose(D[i, :], ref_dis, atol=1e-4)
279+
266280

267281
class TestBitstring(unittest.TestCase):
268282

0 commit comments

Comments
 (0)