Skip to content

Commit c8323a0

Browse files
Add SIMD NEON Optimization for QT_FP16 in Scalar Quantizer
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
1 parent 5b6c4b4 commit c8323a0

File tree

4 files changed

+236
-37
lines changed

4 files changed

+236
-37
lines changed

faiss/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ set(FAISS_HEADERS
189189
utils/extra_distances.h
190190
utils/fp16-fp16c.h
191191
utils/fp16-inl.h
192+
utils/fp16-fp16.h
192193
utils/fp16.h
193194
utils/hamming-inl.h
194195
utils/hamming.h

faiss/impl/ScalarQuantizer.cpp

+204-37
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,20 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
397397

398398
#endif
399399

400+
#ifdef __aarch64__
401+
402+
template <>
403+
struct QuantizerFP16<8> : QuantizerFP16<1> {
404+
QuantizerFP16(size_t d, const std::vector<float>& trained)
405+
: QuantizerFP16<1>(d, trained) {}
406+
407+
FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const {
408+
uint16x4x2_t codei = vld2_u16((const uint16_t*)(code + 2 * i));
409+
return vzipq_f32(vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])), vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1])));
410+
}
411+
};
412+
#endif
413+
400414
/*******************************************************************
401415
* 8bit_direct quantizer
402416
*******************************************************************/
@@ -446,31 +460,32 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
446460

447461
#endif
448462

449-
template <int SIMDWIDTH>
463+
template <int SIMDWIDTH, int SIMDWIDTH_DEFAULT>
450464
ScalarQuantizer::SQuantizer* select_quantizer_1(
451465
QuantizerType qtype,
452466
size_t d,
453467
const std::vector<float>& trained) {
454468
switch (qtype) {
455469
case ScalarQuantizer::QT_8bit:
456-
return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(
470+
return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH_DEFAULT>(
457471
d, trained);
458472
case ScalarQuantizer::QT_6bit:
459-
return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(
473+
return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH_DEFAULT>(
460474
d, trained);
461475
case ScalarQuantizer::QT_4bit:
462-
return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(
476+
return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH_DEFAULT>(
463477
d, trained);
464478
case ScalarQuantizer::QT_8bit_uniform:
465-
return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(
479+
return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH_DEFAULT>(
466480
d, trained);
467481
case ScalarQuantizer::QT_4bit_uniform:
468-
return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(
482+
return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH_DEFAULT>(
469483
d, trained);
470484
case ScalarQuantizer::QT_fp16:
471485
return new QuantizerFP16<SIMDWIDTH>(d, trained);
472486
case ScalarQuantizer::QT_8bit_direct:
473-
return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
487+
return new Quantizer8bitDirect<SIMDWIDTH_DEFAULT>(d, trained);
488+
474489
}
475490
FAISS_THROW_MSG("unknown qtype");
476491
}
@@ -728,6 +743,57 @@ struct SimilarityL2<8> {
728743

729744
#endif
730745

746+
#ifdef __aarch64__
747+
template <>
748+
struct SimilarityL2<8> {
749+
static constexpr int simdwidth = 8;
750+
static constexpr MetricType metric_type = METRIC_L2;
751+
752+
const float *y, *yi;
753+
explicit SimilarityL2(const float* y) : y(y) {}
754+
float32x4x2_t accu8;
755+
756+
FAISS_ALWAYS_INLINE void begin_8() {
757+
accu8 = vzipq_f32(vdupq_n_f32(0.0f),vdupq_n_f32(0.0f));
758+
yi = y;
759+
}
760+
761+
FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
762+
float32x4x2_t yiv = vld1q_f32_x2(yi);
763+
yi += 8;
764+
765+
float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]);
766+
float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]);
767+
768+
float32x4_t accu8_0 = vaddq_f32(accu8.val[0], vmulq_f32(sub0,sub0));
769+
float32x4_t accu8_1 = vaddq_f32(accu8.val[1], vmulq_f32(sub1,sub1));
770+
771+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
772+
accu8 = vuzpq_f32(accu8_temp.val[0],accu8_temp.val[1]);
773+
}
774+
775+
FAISS_ALWAYS_INLINE void add_8_components_2(float32x4x2_t x, float32x4x2_t y) {
776+
float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]);
777+
float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]);
778+
779+
float32x4_t accu8_0 = vaddq_f32(accu8.val[0], vmulq_f32(sub0,sub0));
780+
float32x4_t accu8_1 = vaddq_f32(accu8.val[1], vmulq_f32(sub1,sub1));
781+
782+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
783+
accu8 = vuzpq_f32(accu8_temp.val[0],accu8_temp.val[1]);
784+
}
785+
786+
FAISS_ALWAYS_INLINE float result_8() {
787+
float32x4_t sum_0 = vpaddq_f32(accu8.val[0], accu8.val[0]);
788+
float32x4_t sum_1 = vpaddq_f32(accu8.val[1], accu8.val[1]);
789+
790+
float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0);
791+
float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1);
792+
return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0);
793+
}
794+
};
795+
#endif
796+
731797
template <int SIMDWIDTH>
732798
struct SimilarityIP {};
733799

@@ -801,6 +867,50 @@ struct SimilarityIP<8> {
801867
};
802868
#endif
803869

870+
#ifdef __aarch64__
871+
872+
template <>
873+
struct SimilarityIP<8> {
874+
static constexpr int simdwidth = 8;
875+
static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
876+
877+
const float *y, *yi;
878+
879+
explicit SimilarityIP(const float* y) : y(y) {}
880+
float32x4x2_t accu8;
881+
882+
FAISS_ALWAYS_INLINE void begin_8() {
883+
accu8 = vzipq_f32(vdupq_n_f32(0.0f),vdupq_n_f32(0.0f));
884+
yi = y;
885+
}
886+
887+
FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
888+
float32x4x2_t yiv = vld1q_f32_x2(yi);
889+
yi += 8;
890+
891+
float32x4_t accu8_0 = vaddq_f32(accu8.val[0], vmulq_f32(yiv.val[0], x.val[0]));
892+
float32x4_t accu8_1 = vaddq_f32(accu8.val[1], vmulq_f32(yiv.val[1], x.val[1]));
893+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
894+
accu8 = vuzpq_f32(accu8_temp.val[0],accu8_temp.val[1]);
895+
}
896+
897+
FAISS_ALWAYS_INLINE void add_8_components_2(float32x4x2_t x1, float32x4x2_t x2) {
898+
float32x4_t accu8_0 = vaddq_f32(accu8.val[0], vmulq_f32(x1.val[0], x2.val[0]));
899+
float32x4_t accu8_1 = vaddq_f32(accu8.val[1], vmulq_f32(x1.val[1], x2.val[1]));
900+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
901+
accu8 = vuzpq_f32(accu8_temp.val[0],accu8_temp.val[1]);
902+
}
903+
904+
FAISS_ALWAYS_INLINE float result_8() {
905+
float32x4x2_t sum_tmp = vzipq_f32(vpaddq_f32(accu8.val[0], accu8.val[0]), vpaddq_f32(accu8.val[1], accu8.val[1]));
906+
float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]);
907+
float32x4x2_t sum2_tmp = vzipq_f32(vpaddq_f32(sum.val[0], sum.val[0]), vpaddq_f32(sum.val[1], sum.val[1]));
908+
float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]);
909+
return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
910+
}
911+
};
912+
#endif
913+
804914
/*******************************************************************
805915
* DistanceComputer: combines a similarity and a quantizer to do
806916
* code-to-vector or code-to-code comparisons
@@ -903,6 +1013,53 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
9031013

9041014
#endif
9051015

1016+
#ifdef __aarch64__
1017+
1018+
template <class Quantizer, class Similarity>
1019+
struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
1020+
using Sim = Similarity;
1021+
1022+
Quantizer quant;
1023+
1024+
DCTemplate(size_t d, const std::vector<float>& trained)
1025+
: quant(d, trained) {}
1026+
float compute_distance(const float* x, const uint8_t* code) const {
1027+
Similarity sim(x);
1028+
sim.begin_8();
1029+
for (size_t i = 0; i < quant.d; i += 8) {
1030+
float32x4x2_t xi = quant.reconstruct_8_components(code, i);
1031+
sim.add_8_components(xi);
1032+
}
1033+
return sim.result_8();
1034+
}
1035+
1036+
float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1037+
const {
1038+
Similarity sim(nullptr);
1039+
sim.begin_8();
1040+
for (size_t i = 0; i < quant.d; i += 8) {
1041+
float32x4x2_t x1 = quant.reconstruct_8_components(code1, i);
1042+
float32x4x2_t x2 = quant.reconstruct_8_components(code2, i);
1043+
sim.add_8_components_2(x1, x2);
1044+
}
1045+
return sim.result_8();
1046+
}
1047+
1048+
void set_query(const float* x) final {
1049+
q = x;
1050+
}
1051+
1052+
float symmetric_dis(idx_t i, idx_t j) override {
1053+
return compute_code_distance(
1054+
codes + i * code_size, codes + j * code_size);
1055+
}
1056+
1057+
float query_to_code(const uint8_t* code) const final {
1058+
return compute_distance(q, code);
1059+
}
1060+
};
1061+
#endif
1062+
9061063
/*******************************************************************
9071064
* DistanceComputerByte: computes distances in the integer domain
9081065
*******************************************************************/
@@ -1024,55 +1181,57 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
10241181
* specialization
10251182
*******************************************************************/
10261183

1027-
template <class Sim>
1184+
template <class Sim, class Sim_default>
10281185
SQDistanceComputer* select_distance_computer(
10291186
QuantizerType qtype,
10301187
size_t d,
10311188
const std::vector<float>& trained) {
10321189
constexpr int SIMDWIDTH = Sim::simdwidth;
1190+
constexpr int SIMDWIDTH_DEFAULT = Sim_default::simdwidth;
10331191
switch (qtype) {
10341192
case ScalarQuantizer::QT_8bit_uniform:
10351193
return new DCTemplate<
1036-
QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1037-
Sim,
1038-
SIMDWIDTH>(d, trained);
1194+
QuantizerTemplate<Codec8bit, true, SIMDWIDTH_DEFAULT>,
1195+
Sim_default,
1196+
SIMDWIDTH_DEFAULT>(d, trained);
10391197

10401198
case ScalarQuantizer::QT_4bit_uniform:
10411199
return new DCTemplate<
1042-
QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1043-
Sim,
1044-
SIMDWIDTH>(d, trained);
1200+
QuantizerTemplate<Codec4bit, true, SIMDWIDTH_DEFAULT>,
1201+
Sim_default,
1202+
SIMDWIDTH_DEFAULT>(d, trained);
10451203

10461204
case ScalarQuantizer::QT_8bit:
10471205
return new DCTemplate<
1048-
QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1049-
Sim,
1050-
SIMDWIDTH>(d, trained);
1206+
QuantizerTemplate<Codec8bit, false, SIMDWIDTH_DEFAULT>,
1207+
Sim_default,
1208+
SIMDWIDTH_DEFAULT>(d, trained);
10511209

10521210
case ScalarQuantizer::QT_6bit:
10531211
return new DCTemplate<
1054-
QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1055-
Sim,
1056-
SIMDWIDTH>(d, trained);
1212+
QuantizerTemplate<Codec6bit, false, SIMDWIDTH_DEFAULT>,
1213+
Sim_default,
1214+
SIMDWIDTH_DEFAULT>(d, trained);
10571215

10581216
case ScalarQuantizer::QT_4bit:
10591217
return new DCTemplate<
1060-
QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1061-
Sim,
1062-
SIMDWIDTH>(d, trained);
1218+
QuantizerTemplate<Codec4bit, false, SIMDWIDTH_DEFAULT>,
1219+
Sim_default,
1220+
SIMDWIDTH_DEFAULT>(d, trained);
1221+
10631222

10641223
case ScalarQuantizer::QT_fp16:
10651224
return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
10661225
d, trained);
10671226

10681227
case ScalarQuantizer::QT_8bit_direct:
10691228
if (d % 16 == 0) {
1070-
return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1229+
return new DistanceComputerByte<Sim_default, SIMDWIDTH_DEFAULT>(d, trained);
10711230
} else {
10721231
return new DCTemplate<
1073-
Quantizer8bitDirect<SIMDWIDTH>,
1074-
Sim,
1075-
SIMDWIDTH>(d, trained);
1232+
Quantizer8bitDirect<SIMDWIDTH_DEFAULT>,
1233+
Sim_default,
1234+
SIMDWIDTH_DEFAULT>(d, trained);
10761235
}
10771236
}
10781237
FAISS_THROW_MSG("unknown qtype");
@@ -1155,13 +1314,15 @@ void ScalarQuantizer::train(size_t n, const float* x) {
11551314
}
11561315

11571316
ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1158-
#ifdef USE_F16C
11591317
if (d % 8 == 0) {
1160-
return select_quantizer_1<8>(qtype, d, trained);
1161-
} else
1318+
#if defined(USE_F16C)
1319+
return select_quantizer_1<8,8>(qtype, d, trained);
1320+
#elif defined(__aarch64__)
1321+
return select_quantizer_1<8,1>(qtype, d, trained);
11621322
#endif
1323+
} else
11631324
{
1164-
return select_quantizer_1<1>(qtype, d, trained);
1325+
return select_quantizer_1<1,1>(qtype, d, trained);
11651326
}
11661327
}
11671328

@@ -1186,20 +1347,26 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
11861347
SQDistanceComputer* ScalarQuantizer::get_distance_computer(
11871348
MetricType metric) const {
11881349
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1189-
#ifdef USE_F16C
11901350
if (d % 8 == 0) {
11911351
if (metric == METRIC_L2) {
1192-
return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
1352+
#if defined(USE_F16C)
1353+
return select_distance_computer<SimilarityL2<8>,SimilarityL2<8>>(qtype, d, trained);
1354+
#elif defined(__aarch64__)
1355+
return select_distance_computer<SimilarityL2<8>,SimilarityL2<1>>(qtype, d, trained);
1356+
#endif
11931357
} else {
1194-
return select_distance_computer<SimilarityIP<8>>(qtype, d, trained);
1358+
#if defined(USE_F16C)
1359+
return select_distance_computer<SimilarityIP<8>,SimilarityIP<8>>(qtype, d, trained);
1360+
#elif defined(__aarch64__)
1361+
return select_distance_computer<SimilarityIP<8>,SimilarityIP<1>>(qtype, d, trained);
1362+
#endif
11951363
}
11961364
} else
1197-
#endif
11981365
{
11991366
if (metric == METRIC_L2) {
1200-
return select_distance_computer<SimilarityL2<1>>(qtype, d, trained);
1367+
return select_distance_computer<SimilarityL2<1>,SimilarityL2<1>>(qtype, d, trained);
12011368
} else {
1202-
return select_distance_computer<SimilarityIP<1>>(qtype, d, trained);
1369+
return select_distance_computer<SimilarityIP<1>,SimilarityIP<1>>(qtype, d, trained);
12031370
}
12041371
}
12051372
}

faiss/utils/fp16-fp16.h

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
#include <cstdint>
11+
#include <arm_neon.h>
12+
13+
namespace faiss {
14+
15+
inline uint16_t encode_fp16(float x) {
16+
float32x4_t fx4 = vdupq_n_f32(x);
17+
float16x4_t f16x4 = vcvt_f16_f32(fx4);
18+
uint16x4_t ui16x4 = vreinterpret_u16_f16(f16x4);
19+
return vduph_lane_u16(ui16x4, 3);
20+
}
21+
22+
inline float decode_fp16(uint16_t x) {
23+
uint16x4_t ui16x4 = vdup_n_u16(x);
24+
float16x4_t f16x4 = vreinterpret_f16_u16(ui16x4);
25+
float32x4_t fx4 = vcvt_f32_f16(f16x4);
26+
return vdups_laneq_f32(fx4, 3);
27+
}
28+
29+
}

0 commit comments

Comments
 (0)