Skip to content

Commit 0091af3

Browse files
Add Support for Other Quantizers in SQ
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
1 parent 6b13aba commit 0091af3

File tree

1 file changed

+201
-55
lines changed

1 file changed

+201
-55
lines changed

faiss/impl/ScalarQuantizer.cpp

+201-55
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,20 @@ struct Codec8bit {
9191
return _mm256_fmadd_ps(f8, one_255, half_one_255);
9292
}
9393
#endif
94+
95+
#ifdef __aarch64__
96+
static FAISS_ALWAYS_INLINE float32x4x2_t
97+
decode_8_components(const uint8_t* code, int i) {
98+
float32_t result[8] = {};
99+
for (size_t j = 0; j < 8; j++) {
100+
result[j] = decode_component(code, i + j);
101+
}
102+
float32x4_t res1 = vld1q_f32(result);
103+
float32x4_t res2 = vld1q_f32(result + 4);
104+
float32x4x2_t res = vzipq_f32(res1, res2);
105+
return vuzpq_f32(res.val[0], res.val[1]);
106+
}
107+
#endif
94108
};
95109

96110
struct Codec4bit {
@@ -129,6 +143,20 @@ struct Codec4bit {
129143
return _mm256_mul_ps(f8, one_255);
130144
}
131145
#endif
146+
147+
#ifdef __aarch64__
148+
static FAISS_ALWAYS_INLINE float32x4x2_t
149+
decode_8_components(const uint8_t* code, int i) {
150+
float32_t result[8] = {};
151+
for (size_t j = 0; j < 8; j++) {
152+
result[j] = decode_component(code, i + j);
153+
}
154+
float32x4_t res1 = vld1q_f32(result);
155+
float32x4_t res2 = vld1q_f32(result + 4);
156+
float32x4x2_t res = vzipq_f32(res1, res2);
157+
return vuzpq_f32(res.val[0], res.val[1]);
158+
}
159+
#endif
132160
};
133161

134162
struct Codec6bit {
@@ -228,6 +256,20 @@ struct Codec6bit {
228256
}
229257

230258
#endif
259+
260+
#ifdef __aarch64__
261+
static FAISS_ALWAYS_INLINE float32x4x2_t
262+
decode_8_components(const uint8_t* code, int i) {
263+
float32_t result[8] = {};
264+
for (size_t j = 0; j < 8; j++) {
265+
result[j] = decode_component(code, i + j);
266+
}
267+
float32x4_t res1 = vld1q_f32(result);
268+
float32x4_t res2 = vld1q_f32(result + 4);
269+
float32x4x2_t res = vzipq_f32(res1, res2);
270+
return vuzpq_f32(res.val[0], res.val[1]);
271+
}
272+
#endif
231273
};
232274

233275
/*******************************************************************
@@ -293,6 +335,31 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
293335

294336
#endif
295337

338+
#ifdef __aarch64__
339+
340+
template <class Codec>
341+
struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
342+
QuantizerTemplate(size_t d, const std::vector<float>& trained)
343+
: QuantizerTemplate<Codec, true, 1>(d, trained) {}
344+
345+
FAISS_ALWAYS_INLINE float32x4x2_t
346+
reconstruct_8_components(const uint8_t* code, int i) const {
347+
float32x4x2_t xi = Codec::decode_8_components(code, i);
348+
float32x4x2_t res = vzipq_f32(
349+
vfmaq_f32(
350+
vdupq_n_f32(this->vmin),
351+
xi.val[0],
352+
vdupq_n_f32(this->vdiff)),
353+
vfmaq_f32(
354+
vdupq_n_f32(this->vmin),
355+
xi.val[1],
356+
vdupq_n_f32(this->vdiff)));
357+
return vuzpq_f32(res.val[0], res.val[1]);
358+
}
359+
};
360+
361+
#endif
362+
296363
template <class Codec>
297364
struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
298365
const size_t d;
@@ -350,6 +417,29 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
350417

351418
#endif
352419

420+
#ifdef __aarch64__
421+
422+
template <class Codec>
423+
struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
424+
QuantizerTemplate(size_t d, const std::vector<float>& trained)
425+
: QuantizerTemplate<Codec, false, 1>(d, trained) {}
426+
427+
FAISS_ALWAYS_INLINE float32x4x2_t
428+
reconstruct_8_components(const uint8_t* code, int i) const {
429+
float32x4x2_t xi = Codec::decode_8_components(code, i);
430+
431+
float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i);
432+
float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i);
433+
434+
float32x4x2_t res = vzipq_f32(
435+
vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
436+
vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1]));
437+
return vuzpq_f32(res.val[0], res.val[1]);
438+
}
439+
};
440+
441+
#endif
442+
353443
/*******************************************************************
354444
* FP16 quantizer
355445
*******************************************************************/
@@ -463,31 +553,53 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
463553

464554
#endif
465555

466-
template <int SIMDWIDTH, int SIMDWIDTH_DEFAULT>
556+
#ifdef __aarch64__
557+
558+
template <>
559+
struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
560+
Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
561+
: Quantizer8bitDirect<1>(d, trained) {}
562+
563+
FAISS_ALWAYS_INLINE float32x4x2_t
564+
reconstruct_8_components(const uint8_t* code, int i) const {
565+
float32_t result[8] = {};
566+
for (size_t j = 0; j < 8; j++) {
567+
result[j] = code[i + j];
568+
}
569+
float32x4_t res1 = vld1q_f32(result);
570+
float32x4_t res2 = vld1q_f32(result + 4);
571+
float32x4x2_t res = vzipq_f32(res1, res2);
572+
return vuzpq_f32(res.val[0], res.val[1]);
573+
}
574+
};
575+
576+
#endif
577+
578+
template <int SIMDWIDTH>
467579
ScalarQuantizer::SQuantizer* select_quantizer_1(
468580
QuantizerType qtype,
469581
size_t d,
470582
const std::vector<float>& trained) {
471583
switch (qtype) {
472584
case ScalarQuantizer::QT_8bit:
473-
return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH_DEFAULT>(
585+
return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(
474586
d, trained);
475587
case ScalarQuantizer::QT_6bit:
476-
return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH_DEFAULT>(
588+
return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(
477589
d, trained);
478590
case ScalarQuantizer::QT_4bit:
479-
return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH_DEFAULT>(
591+
return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(
480592
d, trained);
481593
case ScalarQuantizer::QT_8bit_uniform:
482-
return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH_DEFAULT>(
594+
return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(
483595
d, trained);
484596
case ScalarQuantizer::QT_4bit_uniform:
485-
return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH_DEFAULT>(
597+
return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(
486598
d, trained);
487599
case ScalarQuantizer::QT_fp16:
488600
return new QuantizerFP16<SIMDWIDTH>(d, trained);
489601
case ScalarQuantizer::QT_8bit_direct:
490-
return new Quantizer8bitDirect<SIMDWIDTH_DEFAULT>(d, trained);
602+
return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
491603
}
492604
FAISS_THROW_MSG("unknown qtype");
493605
}
@@ -1186,62 +1298,108 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
11861298

11871299
#endif
11881300

1301+
#ifdef __aarch64__
1302+
1303+
template <class Similarity>
1304+
struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1305+
using Sim = Similarity;
1306+
1307+
int d;
1308+
std::vector<uint8_t> tmp;
1309+
1310+
DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1311+
1312+
int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1313+
const {
1314+
int accu = 0;
1315+
for (int i = 0; i < d; i++) {
1316+
if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1317+
accu += int(code1[i]) * code2[i];
1318+
} else {
1319+
int diff = int(code1[i]) - code2[i];
1320+
accu += diff * diff;
1321+
}
1322+
}
1323+
return accu;
1324+
}
1325+
1326+
void set_query(const float* x) final {
1327+
for (int i = 0; i < d; i++) {
1328+
tmp[i] = int(x[i]);
1329+
}
1330+
}
1331+
1332+
int compute_distance(const float* x, const uint8_t* code) {
1333+
set_query(x);
1334+
return compute_code_distance(tmp.data(), code);
1335+
}
1336+
1337+
float symmetric_dis(idx_t i, idx_t j) override {
1338+
return compute_code_distance(
1339+
codes + i * code_size, codes + j * code_size);
1340+
}
1341+
1342+
float query_to_code(const uint8_t* code) const final {
1343+
return compute_code_distance(tmp.data(), code);
1344+
}
1345+
};
1346+
1347+
#endif
1348+
11891349
/*******************************************************************
11901350
* select_distance_computer: runtime selection of template
11911351
* specialization
11921352
*******************************************************************/
11931353

1194-
template <class Sim, class Sim_default>
1354+
template <class Sim>
11951355
SQDistanceComputer* select_distance_computer(
11961356
QuantizerType qtype,
11971357
size_t d,
11981358
const std::vector<float>& trained) {
11991359
constexpr int SIMDWIDTH = Sim::simdwidth;
1200-
constexpr int SIMDWIDTH_DEFAULT = Sim_default::simdwidth;
12011360
switch (qtype) {
12021361
case ScalarQuantizer::QT_8bit_uniform:
12031362
return new DCTemplate<
1204-
QuantizerTemplate<Codec8bit, true, SIMDWIDTH_DEFAULT>,
1205-
Sim_default,
1206-
SIMDWIDTH_DEFAULT>(d, trained);
1363+
QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1364+
Sim,
1365+
SIMDWIDTH>(d, trained);
12071366

12081367
case ScalarQuantizer::QT_4bit_uniform:
12091368
return new DCTemplate<
1210-
QuantizerTemplate<Codec4bit, true, SIMDWIDTH_DEFAULT>,
1211-
Sim_default,
1212-
SIMDWIDTH_DEFAULT>(d, trained);
1369+
QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1370+
Sim,
1371+
SIMDWIDTH>(d, trained);
12131372

12141373
case ScalarQuantizer::QT_8bit:
12151374
return new DCTemplate<
1216-
QuantizerTemplate<Codec8bit, false, SIMDWIDTH_DEFAULT>,
1217-
Sim_default,
1218-
SIMDWIDTH_DEFAULT>(d, trained);
1375+
QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1376+
Sim,
1377+
SIMDWIDTH>(d, trained);
12191378

12201379
case ScalarQuantizer::QT_6bit:
12211380
return new DCTemplate<
1222-
QuantizerTemplate<Codec6bit, false, SIMDWIDTH_DEFAULT>,
1223-
Sim_default,
1224-
SIMDWIDTH_DEFAULT>(d, trained);
1381+
QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1382+
Sim,
1383+
SIMDWIDTH>(d, trained);
12251384

12261385
case ScalarQuantizer::QT_4bit:
12271386
return new DCTemplate<
1228-
QuantizerTemplate<Codec4bit, false, SIMDWIDTH_DEFAULT>,
1229-
Sim_default,
1230-
SIMDWIDTH_DEFAULT>(d, trained);
1387+
QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1388+
Sim,
1389+
SIMDWIDTH>(d, trained);
12311390

12321391
case ScalarQuantizer::QT_fp16:
12331392
return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
12341393
d, trained);
12351394

12361395
case ScalarQuantizer::QT_8bit_direct:
12371396
if (d % 16 == 0) {
1238-
return new DistanceComputerByte<Sim_default, SIMDWIDTH_DEFAULT>(
1239-
d, trained);
1397+
return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
12401398
} else {
12411399
return new DCTemplate<
1242-
Quantizer8bitDirect<SIMDWIDTH_DEFAULT>,
1243-
Sim_default,
1244-
SIMDWIDTH_DEFAULT>(d, trained);
1400+
Quantizer8bitDirect<SIMDWIDTH>,
1401+
Sim,
1402+
SIMDWIDTH>(d, trained);
12451403
}
12461404
}
12471405
FAISS_THROW_MSG("unknown qtype");
@@ -1324,14 +1482,13 @@ void ScalarQuantizer::train(size_t n, const float* x) {
13241482
}
13251483

13261484
ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1485+
#if defined(USE_F16C) || defined(__aarch64__)
13271486
if (d % 8 == 0) {
1328-
#if defined(USE_F16C)
1329-
return select_quantizer_1<8, 8>(qtype, d, trained);
1330-
#elif defined(__aarch64__)
1331-
return select_quantizer_1<8, 1>(qtype, d, trained);
1487+
return select_quantizer_1<8>(qtype, d, trained);
1488+
} else
13321489
#endif
1333-
} else {
1334-
return select_quantizer_1<1, 1>(qtype, d, trained);
1490+
{
1491+
return select_quantizer_1<1>(qtype, d, trained);
13351492
}
13361493
}
13371494

@@ -1356,31 +1513,20 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
13561513
SQDistanceComputer* ScalarQuantizer::get_distance_computer(
13571514
MetricType metric) const {
13581515
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1516+
#if defined(USE_F16C) || defined(__aarch64__)
13591517
if (d % 8 == 0) {
13601518
if (metric == METRIC_L2) {
1361-
#if defined(USE_F16C)
1362-
return select_distance_computer<SimilarityL2<8>, SimilarityL2<8>>(
1363-
qtype, d, trained);
1364-
#elif defined(__aarch64__)
1365-
return select_distance_computer<SimilarityL2<8>, SimilarityL2<1>>(
1366-
qtype, d, trained);
1367-
#endif
1519+
return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
13681520
} else {
1369-
#if defined(USE_F16C)
1370-
return select_distance_computer<SimilarityIP<8>, SimilarityIP<8>>(
1371-
qtype, d, trained);
1372-
#elif defined(__aarch64__)
1373-
return select_distance_computer<SimilarityIP<8>, SimilarityIP<1>>(
1374-
qtype, d, trained);
1375-
#endif
1521+
return select_distance_computer<SimilarityIP<8>>(qtype, d, trained);
13761522
}
1377-
} else {
1523+
} else
1524+
#endif
1525+
{
13781526
if (metric == METRIC_L2) {
1379-
return select_distance_computer<SimilarityL2<1>, SimilarityL2<1>>(
1380-
qtype, d, trained);
1527+
return select_distance_computer<SimilarityL2<1>>(qtype, d, trained);
13811528
} else {
1382-
return select_distance_computer<SimilarityIP<1>, SimilarityIP<1>>(
1383-
qtype, d, trained);
1529+
return select_distance_computer<SimilarityIP<1>>(qtype, d, trained);
13841530
}
13851531
}
13861532
}
@@ -1703,7 +1849,7 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
17031849
bool store_pairs,
17041850
const IDSelector* sel,
17051851
bool by_residual) const {
1706-
#ifdef USE_F16C
1852+
#if defined(USE_F16C) || defined(__aarch64__)
17071853
if (d % 8 == 0) {
17081854
return sel0_InvertedListScanner<8>(
17091855
mt, this, quantizer, store_pairs, sel, by_residual);

0 commit comments

Comments
 (0)