Skip to content

Commit dd72e41

Browse files
mdouzefacebook-github-bot
authored andcommitted
QINCo implementation in CPU Faiss (facebookresearch#3608)
Summary: Pull Request resolved: facebookresearch#3608 This is a straightforward implementation of QINCo in CPU Faiss, with encoding and decoding capabilities (not training). For this, we translate a simplified version of some torch classes: - tensors, restricted to 2D and int32 + float32 - Linear and Embedding layer Then the QINCoStep and QINCo can just be defined as C++ objects that are copy-constructable. There is some plumbing required in the wrapping layers to support the integration. Pytroch tensors are converted to numpy for getting / setting them in C++. Reviewed By: asadoughi Differential Revision: D59132952 fbshipit-source-id: eea4856507a5b7c5f219efcf8d19fe56944df088
1 parent ab109c2 commit dd72e41

11 files changed

+1212
-1
lines changed

demos/demo_qinco.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""
7+
This demonstrates how to reproduce the QINCo paper results using the Faiss
8+
QINCo implementation. The code loads the reference model because training
9+
is not implemented in Faiss.
10+
11+
Prepare the data with
12+
13+
cd /tmp
14+
15+
# get the reference qinco code
16+
git clone https://github.com/facebookresearch/Qinco.git
17+
18+
# get the data
19+
wget https://dl.fbaipublicfiles.com/QINCo/datasets/bigann/bigann1M.bvecs
20+
21+
# get the model
22+
wget https://dl.fbaipublicfiles.com/QINCo/models/bigann_8x8_L2.pt
23+
24+
"""
25+
26+
import numpy as np
27+
from faiss.contrib.vecs_io import bvecs_mmap
28+
import sys
29+
import time
30+
import torch
31+
import faiss
32+
33+
# make sure pickle deserialization will work
34+
sys.path.append("/tmp/Qinco")
35+
import model_qinco
36+
37+
with torch.no_grad():
38+
39+
qinco = torch.load("/tmp/bigann_8x8_L2.pt")
40+
qinco.eval()
41+
# print(qinco)
42+
if True:
43+
torch.set_num_threads(1)
44+
faiss.omp_set_num_threads(1)
45+
46+
x_base = bvecs_mmap("/tmp/bigann1M.bvecs")[:1000].astype('float32')
47+
x_scaled = torch.from_numpy(x_base) / qinco.db_scale
48+
49+
t0 = time.time()
50+
codes, _ = qinco.encode(x_scaled)
51+
x_decoded_scaled = qinco.decode(codes)
52+
print(f"Pytorch encode {time.time() - t0:.3f} s")
53+
# multi-thread: 1.13s, single-thread: 7.744
54+
55+
x_decoded = x_decoded_scaled.numpy() * qinco.db_scale
56+
57+
err = ((x_decoded - x_base) ** 2).sum(1).mean()
58+
print("MSE=", err) # = 14211.956, near the L=2 result in Fig 4 of the paper
59+
60+
qinco2 = faiss.QINCo(qinco)
61+
t0 = time.time()
62+
codes2 = qinco2.encode(faiss.Tensor2D(x_scaled))
63+
x_decoded2 = qinco2.decode(codes2).numpy() * qinco.db_scale
64+
print(f"Faiss encode {time.time() - t0:.3f} s")
65+
# multi-thread: 3.2s, single thread: 7.019
66+
67+
# these tests don't work because there are outlier encodings
68+
# np.testing.assert_array_equal(codes.numpy(), codes2.numpy())
69+
# np.testing.assert_allclose(x_decoded, x_decoded2)
70+
71+
ndiff = (codes.numpy() != codes2.numpy()).sum() / codes.numel()
72+
assert ndiff < 0.01
73+
ndiff = (((x_decoded - x_decoded2) ** 2).sum(1) > 1e-5).sum()
74+
assert ndiff / len(x_base) < 0.01
75+
76+
err = ((x_decoded2 - x_base) ** 2).sum(1).mean()
77+
print("MSE=", err) # = 14213.551

faiss/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ set(FAISS_SRC
4646
IndexScalarQuantizer.cpp
4747
IndexShards.cpp
4848
IndexShardsIVF.cpp
49+
IndexNeuralNetCodec.cpp
4950
MatrixStats.cpp
5051
MetaIndexes.cpp
5152
VectorTransform.cpp
@@ -81,6 +82,7 @@ set(FAISS_SRC
8182
invlists/InvertedLists.cpp
8283
invlists/InvertedListsIOHook.cpp
8384
utils/Heap.cpp
85+
utils/NeuralNet.cpp
8486
utils/WorkerThread.cpp
8587
utils/distances.cpp
8688
utils/distances_simd.cpp

faiss/IndexNeuralNetCodec.cpp

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <faiss/IndexNeuralNetCodec.h>
9+
#include <faiss/impl/FaissAssert.h>
10+
#include <faiss/utils/hamming.h>
11+
12+
namespace faiss {
13+
14+
/*********************************************************
15+
* IndexNeuralNetCodec implementation
16+
*********************************************************/
17+
18+
IndexNeuralNetCodec::IndexNeuralNetCodec(
19+
int d,
20+
int M,
21+
int nbits,
22+
MetricType metric)
23+
: IndexFlatCodes((M * nbits + 7) / 8, d, metric), M(M), nbits(nbits) {
24+
is_trained = false;
25+
}
26+
27+
void IndexNeuralNetCodec::train(idx_t n, const float* x) {
28+
FAISS_THROW_MSG("Training not implemented in C++, use Pytorch");
29+
}
30+
31+
void IndexNeuralNetCodec::sa_encode(idx_t n, const float* x, uint8_t* codes)
32+
const {
33+
nn::Tensor2D x_tensor(n, d, x);
34+
nn::Int32Tensor2D codes_tensor = net->encode(x_tensor);
35+
pack_bitstrings(n, M, nbits, codes_tensor.data(), codes, code_size);
36+
}
37+
38+
void IndexNeuralNetCodec::sa_decode(idx_t n, const uint8_t* codes, float* x)
39+
const {
40+
nn::Int32Tensor2D codes_tensor(n, M);
41+
unpack_bitstrings(n, M, nbits, codes, code_size, codes_tensor.data());
42+
nn::Tensor2D x_tensor = net->decode(codes_tensor);
43+
memcpy(x, x_tensor.data(), d * n * sizeof(float));
44+
}
45+
46+
/*********************************************************
47+
* IndexQINeuralNetCodec implementation
48+
*********************************************************/
49+
50+
IndexQINCo::IndexQINCo(int d, int M, int nbits, int L, int h, MetricType metric)
51+
: IndexNeuralNetCodec(d, M, nbits, metric),
52+
qinco(d, 1 << nbits, L, M, h) {
53+
net = &qinco;
54+
}
55+
56+
} // namespace faiss

faiss/IndexNeuralNetCodec.h

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
#include <vector>
11+
12+
#include <faiss/IndexFlatCodes.h>
13+
#include <faiss/utils/NeuralNet.h>
14+
15+
namespace faiss {
16+
17+
struct IndexNeuralNetCodec : IndexFlatCodes {
18+
NeuralNetCodec* net = nullptr;
19+
size_t M, nbits;
20+
21+
explicit IndexNeuralNetCodec(
22+
int d = 0,
23+
int M = 0,
24+
int nbits = 0,
25+
MetricType metric = METRIC_L2);
26+
27+
void train(idx_t n, const float* x) override;
28+
29+
void sa_encode(idx_t n, const float* x, uint8_t* codes) const override;
30+
void sa_decode(idx_t n, const uint8_t* codes, float* x) const override;
31+
32+
~IndexNeuralNetCodec() {}
33+
};
34+
35+
struct IndexQINCo : IndexNeuralNetCodec {
36+
QINCo qinco;
37+
38+
IndexQINCo(
39+
int d,
40+
int M,
41+
int nbits,
42+
int L,
43+
int h,
44+
MetricType metric = METRIC_L2);
45+
46+
~IndexQINCo() {}
47+
};
48+
49+
} // namespace faiss

faiss/impl/ResultHandler.h

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <faiss/impl/IDSelector.h>
1717
#include <faiss/utils/Heap.h>
1818
#include <faiss/utils/partitioning.h>
19+
1920
#include <algorithm>
2021
#include <iostream>
2122

faiss/python/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,14 @@
4444
class_wrappers.handle_IDSelectorSubset(IDSelectorBitmap, class_owns=False, force_int64=False)
4545
class_wrappers.handle_CodeSet(CodeSet)
4646

47+
class_wrappers.handle_Tensor2D(Tensor2D)
48+
class_wrappers.handle_Tensor2D(Int32Tensor2D)
49+
class_wrappers.handle_Embedding(Embedding)
50+
class_wrappers.handle_Linear(Linear)
51+
class_wrappers.handle_QINCo(QINCo)
52+
class_wrappers.handle_QINCoStep(QINCoStep)
53+
54+
4755
this_module = sys.modules[__name__]
4856

4957
# handle sub-classes

faiss/python/class_wrappers.py

+149
Original file line numberDiff line numberDiff line change
@@ -1247,3 +1247,152 @@ def replacement_insert(self, codes, inserted=None):
12471247
return inserted
12481248

12491249
replace_method(the_class, 'insert', replacement_insert)
1250+
1251+
######################################################
1252+
# Syntatic sugar for NeuralNet classes
1253+
######################################################
1254+
1255+
1256+
def handle_Tensor2D(the_class):
1257+
the_class.original_init = the_class.__init__
1258+
1259+
def replacement_init(self, *args):
1260+
if len(args) == 1:
1261+
array, = args
1262+
n, d = array.shape
1263+
self.original_init(n, d)
1264+
faiss.copy_array_to_vector(
1265+
np.ascontiguousarray(array).ravel(), self.v)
1266+
else:
1267+
self.original_init(*args)
1268+
1269+
def numpy(self):
1270+
shape = np.zeros(2, dtype=np.int64)
1271+
faiss.memcpy(faiss.swig_ptr(shape), self.shape, shape.nbytes)
1272+
return faiss.vector_to_array(self.v).reshape(shape[0], shape[1])
1273+
1274+
the_class.__init__ = replacement_init
1275+
the_class.numpy = numpy
1276+
1277+
1278+
def handle_Embedding(the_class):
1279+
the_class.original_init = the_class.__init__
1280+
1281+
def replacement_init(self, *args):
1282+
if len(args) != 1 or args[0].__class__ == the_class:
1283+
self.original_init(*args)
1284+
return
1285+
# assume it's a torch.Embedding
1286+
emb = args[0]
1287+
self.original_init(emb.num_embeddings, emb.embedding_dim)
1288+
self.from_torch(emb)
1289+
1290+
def from_torch(self, emb):
1291+
""" copy weights from torch.Embedding """
1292+
assert emb.weight.shape == (self.num_embeddings, self.embedding_dim)
1293+
faiss.copy_array_to_vector(
1294+
np.ascontiguousarray(emb.weight.data).ravel(), self.weight)
1295+
1296+
def from_array(self, array):
1297+
""" copy weights from numpy array """
1298+
assert array.shape == (self.num_embeddings, self.embedding_dim)
1299+
faiss.copy_array_to_vector(
1300+
np.ascontiguousarray(array).ravel(), self.weight)
1301+
1302+
the_class.from_array = from_array
1303+
the_class.from_torch = from_torch
1304+
the_class.__init__ = replacement_init
1305+
1306+
1307+
def handle_Linear(the_class):
1308+
the_class.original_init = the_class.__init__
1309+
1310+
def replacement_init(self, *args):
1311+
if len(args) != 1 or args[0].__class__ == the_class:
1312+
self.original_init(*args)
1313+
return
1314+
# assume it's a torch.Linear
1315+
linear = args[0]
1316+
bias = linear.bias is not None
1317+
self.original_init(linear.in_features, linear.out_features, bias)
1318+
self.from_torch(linear)
1319+
1320+
def from_torch(self, linear):
1321+
""" copy weights from torch.Linear """
1322+
assert linear.weight.shape == (self.out_features, self.in_features)
1323+
faiss.copy_array_to_vector(
1324+
linear.weight.data.numpy().ravel(), self.weight)
1325+
if linear.bias is not None:
1326+
assert linear.bias.shape == (self.out_features,)
1327+
faiss.copy_array_to_vector(linear.bias.data.numpy(), self.bias)
1328+
1329+
def from_array(self, array, bias=None):
1330+
""" copy weights from numpy array """
1331+
assert array.shape == (self.out_features, self.in_features)
1332+
faiss.copy_array_to_vector(
1333+
np.ascontiguousarray(array).ravel(), self.weight)
1334+
if bias is not None:
1335+
assert bias.shape == (self.out_features,)
1336+
faiss.copy_array_to_vector(bias, self.bias)
1337+
1338+
the_class.__init__ = replacement_init
1339+
the_class.from_array = from_array
1340+
the_class.from_torch = from_torch
1341+
1342+
######################################################
1343+
# Syntatic sugar for QINCo and QINCoStep
1344+
######################################################
1345+
1346+
def handle_QINCoStep(the_class):
1347+
the_class.original_init = the_class.__init__
1348+
1349+
def replacement_init(self, *args):
1350+
if len(args) != 1 or args[0].__class__ == the_class:
1351+
self.original_init(*args)
1352+
return
1353+
step = args[0]
1354+
# assume it's a Torch QINCoStep
1355+
self.original_init(step.d, step.K, step.L, step.h)
1356+
self.from_torch(step)
1357+
1358+
def from_torch(self, step):
1359+
""" copy weights from torch.QINCoStep """
1360+
assert (step.d, step.K, step.L, step.h) == (self.d, self.K, self.L, self.h)
1361+
self.codebook.from_torch(step.codebook)
1362+
self.MLPconcat.from_torch(step.MLPconcat)
1363+
1364+
for l in range(step.L):
1365+
src = step.residual_blocks[l]
1366+
dest = self.get_residual_block(l)
1367+
dest.linear1.from_torch(src[0])
1368+
dest.linear2.from_torch(src[2])
1369+
1370+
the_class.__init__ = replacement_init
1371+
the_class.from_torch = from_torch
1372+
1373+
1374+
def handle_QINCo(the_class):
1375+
the_class.original_init = the_class.__init__
1376+
1377+
def replacement_init(self, *args):
1378+
if len(args) != 1 or args[0].__class__ == the_class:
1379+
self.original_init(*args)
1380+
return
1381+
1382+
# assume it's a Torch QINCo
1383+
qinco = args[0]
1384+
self.original_init(qinco.d, qinco.K, qinco.L, qinco.M, qinco.h)
1385+
self.from_torch(qinco)
1386+
1387+
def from_torch(self, qinco):
1388+
""" copy weights from torch.QINCo """
1389+
assert (
1390+
(qinco.d, qinco.K, qinco.L, qinco.M, qinco.h) ==
1391+
(self.d, self.K, self.L, self.M, self.h)
1392+
)
1393+
self.codebook0.from_torch(qinco.codebook0)
1394+
for m in range(qinco.M - 1):
1395+
self.get_step(m).from_torch(qinco.steps[m])
1396+
1397+
the_class.__init__ = replacement_init
1398+
the_class.from_torch = from_torch

0 commit comments

Comments
 (0)