@@ -397,6 +397,20 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
397
397
398
398
#endif
399
399
400
+ #ifdef __aarch64__
401
+
402
+ template <>
403
+ struct QuantizerFP16 <8 > : QuantizerFP16<1 > {
404
+ QuantizerFP16 (size_t d, const std::vector<float >& trained)
405
+ : QuantizerFP16<1 >(d, trained) {}
406
+
407
+ FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components (const uint8_t * code, int i) const {
408
+ uint16x4x2_t codei = vld2_u16 ((const uint16_t *)(code + 2 * i));
409
+ return vzipq_f32 (vcvt_f32_f16 (vreinterpret_f16_u16 (codei.val [0 ])), vcvt_f32_f16 (vreinterpret_f16_u16 (codei.val [1 ])));
410
+ }
411
+ };
412
+ #endif
413
+
400
414
/* ******************************************************************
401
415
* 8bit_direct quantizer
402
416
*******************************************************************/
@@ -446,31 +460,32 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
446
460
447
461
#endif
448
462
449
- template <int SIMDWIDTH>
463
+ template <int SIMDWIDTH, int SIMDWIDTH_DEFAULT >
450
464
ScalarQuantizer::SQuantizer* select_quantizer_1 (
451
465
QuantizerType qtype,
452
466
size_t d,
453
467
const std::vector<float >& trained) {
454
468
switch (qtype) {
455
469
case ScalarQuantizer::QT_8bit:
456
- return new QuantizerTemplate<Codec8bit, false , SIMDWIDTH >(
470
+ return new QuantizerTemplate<Codec8bit, false , SIMDWIDTH_DEFAULT >(
457
471
d, trained);
458
472
case ScalarQuantizer::QT_6bit:
459
- return new QuantizerTemplate<Codec6bit, false , SIMDWIDTH >(
473
+ return new QuantizerTemplate<Codec6bit, false , SIMDWIDTH_DEFAULT >(
460
474
d, trained);
461
475
case ScalarQuantizer::QT_4bit:
462
- return new QuantizerTemplate<Codec4bit, false , SIMDWIDTH >(
476
+ return new QuantizerTemplate<Codec4bit, false , SIMDWIDTH_DEFAULT >(
463
477
d, trained);
464
478
case ScalarQuantizer::QT_8bit_uniform:
465
- return new QuantizerTemplate<Codec8bit, true , SIMDWIDTH >(
479
+ return new QuantizerTemplate<Codec8bit, true , SIMDWIDTH_DEFAULT >(
466
480
d, trained);
467
481
case ScalarQuantizer::QT_4bit_uniform:
468
- return new QuantizerTemplate<Codec4bit, true , SIMDWIDTH >(
482
+ return new QuantizerTemplate<Codec4bit, true , SIMDWIDTH_DEFAULT >(
469
483
d, trained);
470
484
case ScalarQuantizer::QT_fp16:
471
485
return new QuantizerFP16<SIMDWIDTH>(d, trained);
472
486
case ScalarQuantizer::QT_8bit_direct:
473
- return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
487
+ return new Quantizer8bitDirect<SIMDWIDTH_DEFAULT>(d, trained);
488
+
474
489
}
475
490
FAISS_THROW_MSG (" unknown qtype" );
476
491
}
@@ -728,6 +743,57 @@ struct SimilarityL2<8> {
728
743
729
744
#endif
730
745
746
+ #ifdef __aarch64__
747
+ template <>
748
+ struct SimilarityL2 <8 > {
749
+ static constexpr int simdwidth = 8 ;
750
+ static constexpr MetricType metric_type = METRIC_L2;
751
+
752
+ const float *y, *yi;
753
+ explicit SimilarityL2 (const float * y) : y(y) {}
754
+ float32x4x2_t accu8;
755
+
756
+ FAISS_ALWAYS_INLINE void begin_8 () {
757
+ accu8 = vzipq_f32 (vdupq_n_f32 (0 .0f ),vdupq_n_f32 (0 .0f ));
758
+ yi = y;
759
+ }
760
+
761
+ FAISS_ALWAYS_INLINE void add_8_components (float32x4x2_t x) {
762
+ float32x4x2_t yiv = vld1q_f32_x2 (yi);
763
+ yi += 8 ;
764
+
765
+ float32x4_t sub0 = vsubq_f32 (yiv.val [0 ], x.val [0 ]);
766
+ float32x4_t sub1 = vsubq_f32 (yiv.val [1 ], x.val [1 ]);
767
+
768
+ float32x4_t accu8_0 = vaddq_f32 (accu8.val [0 ], vmulq_f32 (sub0,sub0));
769
+ float32x4_t accu8_1 = vaddq_f32 (accu8.val [1 ], vmulq_f32 (sub1,sub1));
770
+
771
+ float32x4x2_t accu8_temp = vzipq_f32 (accu8_0, accu8_1);
772
+ accu8 = vuzpq_f32 (accu8_temp.val [0 ],accu8_temp.val [1 ]);
773
+ }
774
+
775
+ FAISS_ALWAYS_INLINE void add_8_components_2 (float32x4x2_t x, float32x4x2_t y) {
776
+ float32x4_t sub0 = vsubq_f32 (y.val [0 ], x.val [0 ]);
777
+ float32x4_t sub1 = vsubq_f32 (y.val [1 ], x.val [1 ]);
778
+
779
+ float32x4_t accu8_0 = vaddq_f32 (accu8.val [0 ], vmulq_f32 (sub0,sub0));
780
+ float32x4_t accu8_1 = vaddq_f32 (accu8.val [1 ], vmulq_f32 (sub1,sub1));
781
+
782
+ float32x4x2_t accu8_temp = vzipq_f32 (accu8_0, accu8_1);
783
+ accu8 = vuzpq_f32 (accu8_temp.val [0 ],accu8_temp.val [1 ]);
784
+ }
785
+
786
+ FAISS_ALWAYS_INLINE float result_8 () {
787
+ float32x4_t sum_0 = vpaddq_f32 (accu8.val [0 ], accu8.val [0 ]);
788
+ float32x4_t sum_1 = vpaddq_f32 (accu8.val [1 ], accu8.val [1 ]);
789
+
790
+ float32x4_t sum2_0 = vpaddq_f32 (sum_0, sum_0);
791
+ float32x4_t sum2_1 = vpaddq_f32 (sum_1, sum_1);
792
+ return vgetq_lane_f32 (sum2_0, 0 ) + vgetq_lane_f32 (sum2_1, 0 );
793
+ }
794
+ };
795
+ #endif
796
+
731
797
template <int SIMDWIDTH>
732
798
struct SimilarityIP {};
733
799
@@ -801,6 +867,50 @@ struct SimilarityIP<8> {
801
867
};
802
868
#endif
803
869
870
+ #ifdef __aarch64__
871
+
872
+ template <>
873
+ struct SimilarityIP <8 > {
874
+ static constexpr int simdwidth = 8 ;
875
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
876
+
877
+ const float *y, *yi;
878
+
879
+ explicit SimilarityIP (const float * y) : y(y) {}
880
+ float32x4x2_t accu8;
881
+
882
+ FAISS_ALWAYS_INLINE void begin_8 () {
883
+ accu8 = vzipq_f32 (vdupq_n_f32 (0 .0f ),vdupq_n_f32 (0 .0f ));
884
+ yi = y;
885
+ }
886
+
887
+ FAISS_ALWAYS_INLINE void add_8_components (float32x4x2_t x) {
888
+ float32x4x2_t yiv = vld1q_f32_x2 (yi);
889
+ yi += 8 ;
890
+
891
+ float32x4_t accu8_0 = vaddq_f32 (accu8.val [0 ], vmulq_f32 (yiv.val [0 ], x.val [0 ]));
892
+ float32x4_t accu8_1 = vaddq_f32 (accu8.val [1 ], vmulq_f32 (yiv.val [1 ], x.val [1 ]));
893
+ float32x4x2_t accu8_temp = vzipq_f32 (accu8_0, accu8_1);
894
+ accu8 = vuzpq_f32 (accu8_temp.val [0 ],accu8_temp.val [1 ]);
895
+ }
896
+
897
+ FAISS_ALWAYS_INLINE void add_8_components_2 (float32x4x2_t x1, float32x4x2_t x2) {
898
+ float32x4_t accu8_0 = vaddq_f32 (accu8.val [0 ], vmulq_f32 (x1.val [0 ], x2.val [0 ]));
899
+ float32x4_t accu8_1 = vaddq_f32 (accu8.val [1 ], vmulq_f32 (x1.val [1 ], x2.val [1 ]));
900
+ float32x4x2_t accu8_temp = vzipq_f32 (accu8_0, accu8_1);
901
+ accu8 = vuzpq_f32 (accu8_temp.val [0 ],accu8_temp.val [1 ]);
902
+ }
903
+
904
+ FAISS_ALWAYS_INLINE float result_8 () {
905
+ float32x4x2_t sum_tmp = vzipq_f32 (vpaddq_f32 (accu8.val [0 ], accu8.val [0 ]), vpaddq_f32 (accu8.val [1 ], accu8.val [1 ]));
906
+ float32x4x2_t sum = vuzpq_f32 (sum_tmp.val [0 ], sum_tmp.val [1 ]);
907
+ float32x4x2_t sum2_tmp = vzipq_f32 (vpaddq_f32 (sum.val [0 ], sum.val [0 ]), vpaddq_f32 (sum.val [1 ], sum.val [1 ]));
908
+ float32x4x2_t sum2 = vuzpq_f32 (sum2_tmp.val [0 ], sum2_tmp.val [1 ]);
909
+ return vgetq_lane_f32 (sum2.val [0 ], 0 ) + vgetq_lane_f32 (sum2.val [1 ], 0 );
910
+ }
911
+ };
912
+ #endif
913
+
804
914
/* ******************************************************************
805
915
* DistanceComputer: combines a similarity and a quantizer to do
806
916
* code-to-vector or code-to-code comparisons
@@ -903,6 +1013,53 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
903
1013
904
1014
#endif
905
1015
1016
+ #ifdef __aarch64__
1017
+
1018
+ template <class Quantizer , class Similarity >
1019
+ struct DCTemplate <Quantizer, Similarity, 8 > : SQDistanceComputer {
1020
+ using Sim = Similarity;
1021
+
1022
+ Quantizer quant;
1023
+
1024
+ DCTemplate (size_t d, const std::vector<float >& trained)
1025
+ : quant(d, trained) {}
1026
+ float compute_distance (const float * x, const uint8_t * code) const {
1027
+ Similarity sim (x);
1028
+ sim.begin_8 ();
1029
+ for (size_t i = 0 ; i < quant.d ; i += 8 ) {
1030
+ float32x4x2_t xi = quant.reconstruct_8_components (code, i);
1031
+ sim.add_8_components (xi);
1032
+ }
1033
+ return sim.result_8 ();
1034
+ }
1035
+
1036
+ float compute_code_distance (const uint8_t * code1, const uint8_t * code2)
1037
+ const {
1038
+ Similarity sim (nullptr );
1039
+ sim.begin_8 ();
1040
+ for (size_t i = 0 ; i < quant.d ; i += 8 ) {
1041
+ float32x4x2_t x1 = quant.reconstruct_8_components (code1, i);
1042
+ float32x4x2_t x2 = quant.reconstruct_8_components (code2, i);
1043
+ sim.add_8_components_2 (x1, x2);
1044
+ }
1045
+ return sim.result_8 ();
1046
+ }
1047
+
1048
+ void set_query (const float * x) final {
1049
+ q = x;
1050
+ }
1051
+
1052
+ float symmetric_dis (idx_t i, idx_t j) override {
1053
+ return compute_code_distance (
1054
+ codes + i * code_size, codes + j * code_size);
1055
+ }
1056
+
1057
+ float query_to_code (const uint8_t * code) const final {
1058
+ return compute_distance (q, code);
1059
+ }
1060
+ };
1061
+ #endif
1062
+
906
1063
/* ******************************************************************
907
1064
* DistanceComputerByte: computes distances in the integer domain
908
1065
*******************************************************************/
@@ -1024,55 +1181,57 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1024
1181
* specialization
1025
1182
*******************************************************************/
1026
1183
1027
- template <class Sim >
1184
+ template <class Sim , class Sim_default >
1028
1185
SQDistanceComputer* select_distance_computer (
1029
1186
QuantizerType qtype,
1030
1187
size_t d,
1031
1188
const std::vector<float >& trained) {
1032
1189
constexpr int SIMDWIDTH = Sim::simdwidth;
1190
+ constexpr int SIMDWIDTH_DEFAULT = Sim_default::simdwidth;
1033
1191
switch (qtype) {
1034
1192
case ScalarQuantizer::QT_8bit_uniform:
1035
1193
return new DCTemplate<
1036
- QuantizerTemplate<Codec8bit, true , SIMDWIDTH >,
1037
- Sim ,
1038
- SIMDWIDTH >(d, trained);
1194
+ QuantizerTemplate<Codec8bit, true , SIMDWIDTH_DEFAULT >,
1195
+ Sim_default ,
1196
+ SIMDWIDTH_DEFAULT >(d, trained);
1039
1197
1040
1198
case ScalarQuantizer::QT_4bit_uniform:
1041
1199
return new DCTemplate<
1042
- QuantizerTemplate<Codec4bit, true , SIMDWIDTH >,
1043
- Sim ,
1044
- SIMDWIDTH >(d, trained);
1200
+ QuantizerTemplate<Codec4bit, true , SIMDWIDTH_DEFAULT >,
1201
+ Sim_default ,
1202
+ SIMDWIDTH_DEFAULT >(d, trained);
1045
1203
1046
1204
case ScalarQuantizer::QT_8bit:
1047
1205
return new DCTemplate<
1048
- QuantizerTemplate<Codec8bit, false , SIMDWIDTH >,
1049
- Sim ,
1050
- SIMDWIDTH >(d, trained);
1206
+ QuantizerTemplate<Codec8bit, false , SIMDWIDTH_DEFAULT >,
1207
+ Sim_default ,
1208
+ SIMDWIDTH_DEFAULT >(d, trained);
1051
1209
1052
1210
case ScalarQuantizer::QT_6bit:
1053
1211
return new DCTemplate<
1054
- QuantizerTemplate<Codec6bit, false , SIMDWIDTH >,
1055
- Sim ,
1056
- SIMDWIDTH >(d, trained);
1212
+ QuantizerTemplate<Codec6bit, false , SIMDWIDTH_DEFAULT >,
1213
+ Sim_default ,
1214
+ SIMDWIDTH_DEFAULT >(d, trained);
1057
1215
1058
1216
case ScalarQuantizer::QT_4bit:
1059
1217
return new DCTemplate<
1060
- QuantizerTemplate<Codec4bit, false , SIMDWIDTH>,
1061
- Sim,
1062
- SIMDWIDTH>(d, trained);
1218
+ QuantizerTemplate<Codec4bit, false , SIMDWIDTH_DEFAULT>,
1219
+ Sim_default,
1220
+ SIMDWIDTH_DEFAULT>(d, trained);
1221
+
1063
1222
1064
1223
case ScalarQuantizer::QT_fp16:
1065
1224
return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1066
1225
d, trained);
1067
1226
1068
1227
case ScalarQuantizer::QT_8bit_direct:
1069
1228
if (d % 16 == 0 ) {
1070
- return new DistanceComputerByte<Sim, SIMDWIDTH >(d, trained);
1229
+ return new DistanceComputerByte<Sim_default, SIMDWIDTH_DEFAULT >(d, trained);
1071
1230
} else {
1072
1231
return new DCTemplate<
1073
- Quantizer8bitDirect<SIMDWIDTH >,
1074
- Sim ,
1075
- SIMDWIDTH >(d, trained);
1232
+ Quantizer8bitDirect<SIMDWIDTH_DEFAULT >,
1233
+ Sim_default ,
1234
+ SIMDWIDTH_DEFAULT >(d, trained);
1076
1235
}
1077
1236
}
1078
1237
FAISS_THROW_MSG (" unknown qtype" );
@@ -1155,13 +1314,15 @@ void ScalarQuantizer::train(size_t n, const float* x) {
1155
1314
}
1156
1315
1157
1316
ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer () const {
1158
- #ifdef USE_F16C
1159
1317
if (d % 8 == 0 ) {
1160
- return select_quantizer_1<8 >(qtype, d, trained);
1161
- } else
1318
+ #if defined(USE_F16C)
1319
+ return select_quantizer_1<8 ,8 >(qtype, d, trained);
1320
+ #elif defined(__aarch64__)
1321
+ return select_quantizer_1<8 ,1 >(qtype, d, trained);
1162
1322
#endif
1323
+ } else
1163
1324
{
1164
- return select_quantizer_1<1 >(qtype, d, trained);
1325
+ return select_quantizer_1<1 , 1 >(qtype, d, trained);
1165
1326
}
1166
1327
}
1167
1328
@@ -1186,20 +1347,26 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1186
1347
SQDistanceComputer* ScalarQuantizer::get_distance_computer (
1187
1348
MetricType metric) const {
1188
1349
FAISS_THROW_IF_NOT (metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1189
- #ifdef USE_F16C
1190
1350
if (d % 8 == 0 ) {
1191
1351
if (metric == METRIC_L2) {
1192
- return select_distance_computer<SimilarityL2<8 >>(qtype, d, trained);
1352
+ #if defined(USE_F16C)
1353
+ return select_distance_computer<SimilarityL2<8 >,SimilarityL2<8 >>(qtype, d, trained);
1354
+ #elif defined(__aarch64__)
1355
+ return select_distance_computer<SimilarityL2<8 >,SimilarityL2<1 >>(qtype, d, trained);
1356
+ #endif
1193
1357
} else {
1194
- return select_distance_computer<SimilarityIP<8 >>(qtype, d, trained);
1358
+ #if defined(USE_F16C)
1359
+ return select_distance_computer<SimilarityIP<8 >,SimilarityIP<8 >>(qtype, d, trained);
1360
+ #elif defined(__aarch64__)
1361
+ return select_distance_computer<SimilarityIP<8 >,SimilarityIP<1 >>(qtype, d, trained);
1362
+ #endif
1195
1363
}
1196
1364
} else
1197
- #endif
1198
1365
{
1199
1366
if (metric == METRIC_L2) {
1200
- return select_distance_computer<SimilarityL2<1 >>(qtype, d, trained);
1367
+ return select_distance_computer<SimilarityL2<1 >,SimilarityL2< 1 > >(qtype, d, trained);
1201
1368
} else {
1202
- return select_distance_computer<SimilarityIP<1 >>(qtype, d, trained);
1369
+ return select_distance_computer<SimilarityIP<1 >,SimilarityIP< 1 > >(qtype, d, trained);
1203
1370
}
1204
1371
}
1205
1372
}
0 commit comments