Skip to content

Commit 933de3e

Browse files
Add SIMD NEON Optimization for QT_FP16 in Scalar Quantizer
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
1 parent 5b6c4b4 commit 933de3e

File tree

4 files changed

+257
-40
lines changed

4 files changed

+257
-40
lines changed

faiss/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ set(FAISS_HEADERS
189189
utils/extra_distances.h
190190
utils/fp16-fp16c.h
191191
utils/fp16-inl.h
192+
utils/fp16-fp16.h
192193
utils/fp16.h
193194
utils/hamming-inl.h
194195
utils/hamming.h

faiss/impl/ScalarQuantizer.cpp

+225-40
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,23 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
397397

398398
#endif
399399

400+
#ifdef __aarch64__
401+
402+
template <>
403+
struct QuantizerFP16<8> : QuantizerFP16<1> {
404+
QuantizerFP16(size_t d, const std::vector<float>& trained)
405+
: QuantizerFP16<1>(d, trained) {}
406+
407+
FAISS_ALWAYS_INLINE float32x4x2_t
408+
reconstruct_8_components(const uint8_t* code, int i) const {
409+
uint16x4x2_t codei = vld2_u16((const uint16_t*)(code + 2 * i));
410+
return vzipq_f32(
411+
vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
412+
vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1])));
413+
}
414+
};
415+
#endif
416+
400417
/*******************************************************************
401418
* 8bit_direct quantizer
402419
*******************************************************************/
@@ -446,31 +463,31 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
446463

447464
#endif
448465

449-
template <int SIMDWIDTH>
466+
template <int SIMDWIDTH, int SIMDWIDTH_DEFAULT>
450467
ScalarQuantizer::SQuantizer* select_quantizer_1(
451468
QuantizerType qtype,
452469
size_t d,
453470
const std::vector<float>& trained) {
454471
switch (qtype) {
455472
case ScalarQuantizer::QT_8bit:
456-
return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(
473+
return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH_DEFAULT>(
457474
d, trained);
458475
case ScalarQuantizer::QT_6bit:
459-
return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(
476+
return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH_DEFAULT>(
460477
d, trained);
461478
case ScalarQuantizer::QT_4bit:
462-
return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(
479+
return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH_DEFAULT>(
463480
d, trained);
464481
case ScalarQuantizer::QT_8bit_uniform:
465-
return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(
482+
return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH_DEFAULT>(
466483
d, trained);
467484
case ScalarQuantizer::QT_4bit_uniform:
468-
return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(
485+
return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH_DEFAULT>(
469486
d, trained);
470487
case ScalarQuantizer::QT_fp16:
471488
return new QuantizerFP16<SIMDWIDTH>(d, trained);
472489
case ScalarQuantizer::QT_8bit_direct:
473-
return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
490+
return new Quantizer8bitDirect<SIMDWIDTH_DEFAULT>(d, trained);
474491
}
475492
FAISS_THROW_MSG("unknown qtype");
476493
}
@@ -728,6 +745,59 @@ struct SimilarityL2<8> {
728745

729746
#endif
730747

748+
#ifdef __aarch64__
749+
template <>
750+
struct SimilarityL2<8> {
751+
static constexpr int simdwidth = 8;
752+
static constexpr MetricType metric_type = METRIC_L2;
753+
754+
const float *y, *yi;
755+
explicit SimilarityL2(const float* y) : y(y) {}
756+
float32x4x2_t accu8;
757+
758+
FAISS_ALWAYS_INLINE void begin_8() {
759+
accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
760+
yi = y;
761+
}
762+
763+
FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
764+
float32x4x2_t yiv = vld1q_f32_x2(yi);
765+
yi += 8;
766+
767+
float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]);
768+
float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]);
769+
770+
float32x4_t accu8_0 = vaddq_f32(accu8.val[0], vmulq_f32(sub0, sub0));
771+
float32x4_t accu8_1 = vaddq_f32(accu8.val[1], vmulq_f32(sub1, sub1));
772+
773+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
774+
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
775+
}
776+
777+
FAISS_ALWAYS_INLINE void add_8_components_2(
778+
float32x4x2_t x,
779+
float32x4x2_t y) {
780+
float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]);
781+
float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]);
782+
783+
float32x4_t accu8_0 = vaddq_f32(accu8.val[0], vmulq_f32(sub0, sub0));
784+
float32x4_t accu8_1 = vaddq_f32(accu8.val[1], vmulq_f32(sub1, sub1));
785+
786+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
787+
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
788+
}
789+
790+
FAISS_ALWAYS_INLINE float result_8() {
791+
float32x4_t sum_0 = vpaddq_f32(accu8.val[0], accu8.val[0]);
792+
float32x4_t sum_1 = vpaddq_f32(accu8.val[1], accu8.val[1]);
793+
794+
float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0);
795+
float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1);
796+
return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0);
797+
}
798+
};
799+
#endif
800+
731801
template <int SIMDWIDTH>
732802
struct SimilarityIP {};
733803

@@ -801,6 +871,60 @@ struct SimilarityIP<8> {
801871
};
802872
#endif
803873

874+
#ifdef __aarch64__
875+
876+
template <>
877+
struct SimilarityIP<8> {
878+
static constexpr int simdwidth = 8;
879+
static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
880+
881+
const float *y, *yi;
882+
883+
explicit SimilarityIP(const float* y) : y(y) {}
884+
float32x4x2_t accu8;
885+
886+
FAISS_ALWAYS_INLINE void begin_8() {
887+
accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
888+
yi = y;
889+
}
890+
891+
FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
892+
float32x4x2_t yiv = vld1q_f32_x2(yi);
893+
yi += 8;
894+
895+
float32x4_t accu8_0 =
896+
vaddq_f32(accu8.val[0], vmulq_f32(yiv.val[0], x.val[0]));
897+
float32x4_t accu8_1 =
898+
vaddq_f32(accu8.val[1], vmulq_f32(yiv.val[1], x.val[1]));
899+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
900+
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
901+
}
902+
903+
FAISS_ALWAYS_INLINE void add_8_components_2(
904+
float32x4x2_t x1,
905+
float32x4x2_t x2) {
906+
float32x4_t accu8_0 =
907+
vaddq_f32(accu8.val[0], vmulq_f32(x1.val[0], x2.val[0]));
908+
float32x4_t accu8_1 =
909+
vaddq_f32(accu8.val[1], vmulq_f32(x1.val[1], x2.val[1]));
910+
float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
911+
accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
912+
}
913+
914+
FAISS_ALWAYS_INLINE float result_8() {
915+
float32x4x2_t sum_tmp = vzipq_f32(
916+
vpaddq_f32(accu8.val[0], accu8.val[0]),
917+
vpaddq_f32(accu8.val[1], accu8.val[1]));
918+
float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]);
919+
float32x4x2_t sum2_tmp = vzipq_f32(
920+
vpaddq_f32(sum.val[0], sum.val[0]),
921+
vpaddq_f32(sum.val[1], sum.val[1]));
922+
float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]);
923+
return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
924+
}
925+
};
926+
#endif
927+
804928
/*******************************************************************
805929
* DistanceComputer: combines a similarity and a quantizer to do
806930
* code-to-vector or code-to-code comparisons
@@ -903,6 +1027,53 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
9031027

9041028
#endif
9051029

1030+
#ifdef __aarch64__
1031+
1032+
template <class Quantizer, class Similarity>
1033+
struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
1034+
using Sim = Similarity;
1035+
1036+
Quantizer quant;
1037+
1038+
DCTemplate(size_t d, const std::vector<float>& trained)
1039+
: quant(d, trained) {}
1040+
float compute_distance(const float* x, const uint8_t* code) const {
1041+
Similarity sim(x);
1042+
sim.begin_8();
1043+
for (size_t i = 0; i < quant.d; i += 8) {
1044+
float32x4x2_t xi = quant.reconstruct_8_components(code, i);
1045+
sim.add_8_components(xi);
1046+
}
1047+
return sim.result_8();
1048+
}
1049+
1050+
float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1051+
const {
1052+
Similarity sim(nullptr);
1053+
sim.begin_8();
1054+
for (size_t i = 0; i < quant.d; i += 8) {
1055+
float32x4x2_t x1 = quant.reconstruct_8_components(code1, i);
1056+
float32x4x2_t x2 = quant.reconstruct_8_components(code2, i);
1057+
sim.add_8_components_2(x1, x2);
1058+
}
1059+
return sim.result_8();
1060+
}
1061+
1062+
void set_query(const float* x) final {
1063+
q = x;
1064+
}
1065+
1066+
float symmetric_dis(idx_t i, idx_t j) override {
1067+
return compute_code_distance(
1068+
codes + i * code_size, codes + j * code_size);
1069+
}
1070+
1071+
float query_to_code(const uint8_t* code) const final {
1072+
return compute_distance(q, code);
1073+
}
1074+
};
1075+
#endif
1076+
9061077
/*******************************************************************
9071078
* DistanceComputerByte: computes distances in the integer domain
9081079
*******************************************************************/
@@ -1024,55 +1195,57 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
10241195
* specialization
10251196
*******************************************************************/
10261197

1027-
template <class Sim>
1198+
template <class Sim, class Sim_default>
10281199
SQDistanceComputer* select_distance_computer(
10291200
QuantizerType qtype,
10301201
size_t d,
10311202
const std::vector<float>& trained) {
10321203
constexpr int SIMDWIDTH = Sim::simdwidth;
1204+
constexpr int SIMDWIDTH_DEFAULT = Sim_default::simdwidth;
10331205
switch (qtype) {
10341206
case ScalarQuantizer::QT_8bit_uniform:
10351207
return new DCTemplate<
1036-
QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1037-
Sim,
1038-
SIMDWIDTH>(d, trained);
1208+
QuantizerTemplate<Codec8bit, true, SIMDWIDTH_DEFAULT>,
1209+
Sim_default,
1210+
SIMDWIDTH_DEFAULT>(d, trained);
10391211

10401212
case ScalarQuantizer::QT_4bit_uniform:
10411213
return new DCTemplate<
1042-
QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1043-
Sim,
1044-
SIMDWIDTH>(d, trained);
1214+
QuantizerTemplate<Codec4bit, true, SIMDWIDTH_DEFAULT>,
1215+
Sim_default,
1216+
SIMDWIDTH_DEFAULT>(d, trained);
10451217

10461218
case ScalarQuantizer::QT_8bit:
10471219
return new DCTemplate<
1048-
QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1049-
Sim,
1050-
SIMDWIDTH>(d, trained);
1220+
QuantizerTemplate<Codec8bit, false, SIMDWIDTH_DEFAULT>,
1221+
Sim_default,
1222+
SIMDWIDTH_DEFAULT>(d, trained);
10511223

10521224
case ScalarQuantizer::QT_6bit:
10531225
return new DCTemplate<
1054-
QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1055-
Sim,
1056-
SIMDWIDTH>(d, trained);
1226+
QuantizerTemplate<Codec6bit, false, SIMDWIDTH_DEFAULT>,
1227+
Sim_default,
1228+
SIMDWIDTH_DEFAULT>(d, trained);
10571229

10581230
case ScalarQuantizer::QT_4bit:
10591231
return new DCTemplate<
1060-
QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1061-
Sim,
1062-
SIMDWIDTH>(d, trained);
1232+
QuantizerTemplate<Codec4bit, false, SIMDWIDTH_DEFAULT>,
1233+
Sim_default,
1234+
SIMDWIDTH_DEFAULT>(d, trained);
10631235

10641236
case ScalarQuantizer::QT_fp16:
10651237
return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
10661238
d, trained);
10671239

10681240
case ScalarQuantizer::QT_8bit_direct:
10691241
if (d % 16 == 0) {
1070-
return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1242+
return new DistanceComputerByte<Sim_default, SIMDWIDTH_DEFAULT>(
1243+
d, trained);
10711244
} else {
10721245
return new DCTemplate<
1073-
Quantizer8bitDirect<SIMDWIDTH>,
1074-
Sim,
1075-
SIMDWIDTH>(d, trained);
1246+
Quantizer8bitDirect<SIMDWIDTH_DEFAULT>,
1247+
Sim_default,
1248+
SIMDWIDTH_DEFAULT>(d, trained);
10761249
}
10771250
}
10781251
FAISS_THROW_MSG("unknown qtype");
@@ -1155,13 +1328,14 @@ void ScalarQuantizer::train(size_t n, const float* x) {
11551328
}
11561329

11571330
ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1158-
#ifdef USE_F16C
11591331
if (d % 8 == 0) {
1160-
return select_quantizer_1<8>(qtype, d, trained);
1161-
} else
1332+
#if defined(USE_F16C)
1333+
return select_quantizer_1<8, 8>(qtype, d, trained);
1334+
#elif defined(__aarch64__)
1335+
return select_quantizer_1<8, 1>(qtype, d, trained);
11621336
#endif
1163-
{
1164-
return select_quantizer_1<1>(qtype, d, trained);
1337+
} else {
1338+
return select_quantizer_1<1, 1>(qtype, d, trained);
11651339
}
11661340
}
11671341

@@ -1186,20 +1360,31 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
11861360
SQDistanceComputer* ScalarQuantizer::get_distance_computer(
11871361
MetricType metric) const {
11881362
FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1189-
#ifdef USE_F16C
11901363
if (d % 8 == 0) {
11911364
if (metric == METRIC_L2) {
1192-
return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
1365+
#if defined(USE_F16C)
1366+
return select_distance_computer<SimilarityL2<8>, SimilarityL2<8>>(
1367+
qtype, d, trained);
1368+
#elif defined(__aarch64__)
1369+
return select_distance_computer<SimilarityL2<8>, SimilarityL2<1>>(
1370+
qtype, d, trained);
1371+
#endif
11931372
} else {
1194-
return select_distance_computer<SimilarityIP<8>>(qtype, d, trained);
1195-
}
1196-
} else
1373+
#if defined(USE_F16C)
1374+
return select_distance_computer<SimilarityIP<8>, SimilarityIP<8>>(
1375+
qtype, d, trained);
1376+
#elif defined(__aarch64__)
1377+
return select_distance_computer<SimilarityIP<8>, SimilarityIP<1>>(
1378+
qtype, d, trained);
11971379
#endif
1198-
{
1380+
}
1381+
} else {
11991382
if (metric == METRIC_L2) {
1200-
return select_distance_computer<SimilarityL2<1>>(qtype, d, trained);
1383+
return select_distance_computer<SimilarityL2<1>, SimilarityL2<1>>(
1384+
qtype, d, trained);
12011385
} else {
1202-
return select_distance_computer<SimilarityIP<1>>(qtype, d, trained);
1386+
return select_distance_computer<SimilarityIP<1>, SimilarityIP<1>>(
1387+
qtype, d, trained);
12031388
}
12041389
}
12051390
}

0 commit comments

Comments
 (0)