@@ -397,6 +397,23 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
397
397
398
398
#endif
399
399
400
+ #ifdef __aarch64__
401
+
402
+ template <>
403
+ struct QuantizerFP16 <8 > : QuantizerFP16<1 > {
404
+ QuantizerFP16 (size_t d, const std::vector<float >& trained)
405
+ : QuantizerFP16<1 >(d, trained) {}
406
+
407
+ FAISS_ALWAYS_INLINE float32x4x2_t
408
+ reconstruct_8_components (const uint8_t * code, int i) const {
409
+ uint16x4x2_t codei = vld2_u16 ((const uint16_t *)(code + 2 * i));
410
+ return vzipq_f32 (
411
+ vcvt_f32_f16 (vreinterpret_f16_u16 (codei.val [0 ])),
412
+ vcvt_f32_f16 (vreinterpret_f16_u16 (codei.val [1 ])));
413
+ }
414
+ };
415
+ #endif
416
+
400
417
/* ******************************************************************
401
418
* 8bit_direct quantizer
402
419
*******************************************************************/
@@ -446,31 +463,31 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
446
463
447
464
#endif
448
465
449
- template <int SIMDWIDTH>
466
+ template <int SIMDWIDTH, int SIMDWIDTH_DEFAULT >
450
467
ScalarQuantizer::SQuantizer* select_quantizer_1 (
451
468
QuantizerType qtype,
452
469
size_t d,
453
470
const std::vector<float >& trained) {
454
471
switch (qtype) {
455
472
case ScalarQuantizer::QT_8bit:
456
- return new QuantizerTemplate<Codec8bit, false , SIMDWIDTH >(
473
+ return new QuantizerTemplate<Codec8bit, false , SIMDWIDTH_DEFAULT >(
457
474
d, trained);
458
475
case ScalarQuantizer::QT_6bit:
459
- return new QuantizerTemplate<Codec6bit, false , SIMDWIDTH >(
476
+ return new QuantizerTemplate<Codec6bit, false , SIMDWIDTH_DEFAULT >(
460
477
d, trained);
461
478
case ScalarQuantizer::QT_4bit:
462
- return new QuantizerTemplate<Codec4bit, false , SIMDWIDTH >(
479
+ return new QuantizerTemplate<Codec4bit, false , SIMDWIDTH_DEFAULT >(
463
480
d, trained);
464
481
case ScalarQuantizer::QT_8bit_uniform:
465
- return new QuantizerTemplate<Codec8bit, true , SIMDWIDTH >(
482
+ return new QuantizerTemplate<Codec8bit, true , SIMDWIDTH_DEFAULT >(
466
483
d, trained);
467
484
case ScalarQuantizer::QT_4bit_uniform:
468
- return new QuantizerTemplate<Codec4bit, true , SIMDWIDTH >(
485
+ return new QuantizerTemplate<Codec4bit, true , SIMDWIDTH_DEFAULT >(
469
486
d, trained);
470
487
case ScalarQuantizer::QT_fp16:
471
488
return new QuantizerFP16<SIMDWIDTH>(d, trained);
472
489
case ScalarQuantizer::QT_8bit_direct:
473
- return new Quantizer8bitDirect<SIMDWIDTH >(d, trained);
490
+ return new Quantizer8bitDirect<SIMDWIDTH_DEFAULT >(d, trained);
474
491
}
475
492
FAISS_THROW_MSG (" unknown qtype" );
476
493
}
@@ -728,6 +745,59 @@ struct SimilarityL2<8> {
728
745
729
746
#endif
730
747
748
+ #ifdef __aarch64__
749
+ template <>
750
+ struct SimilarityL2 <8 > {
751
+ static constexpr int simdwidth = 8 ;
752
+ static constexpr MetricType metric_type = METRIC_L2;
753
+
754
+ const float *y, *yi;
755
+ explicit SimilarityL2 (const float * y) : y(y) {}
756
+ float32x4x2_t accu8;
757
+
758
+ FAISS_ALWAYS_INLINE void begin_8 () {
759
+ accu8 = vzipq_f32 (vdupq_n_f32 (0 .0f ), vdupq_n_f32 (0 .0f ));
760
+ yi = y;
761
+ }
762
+
763
+ FAISS_ALWAYS_INLINE void add_8_components (float32x4x2_t x) {
764
+ float32x4x2_t yiv = vld1q_f32_x2 (yi);
765
+ yi += 8 ;
766
+
767
+ float32x4_t sub0 = vsubq_f32 (yiv.val [0 ], x.val [0 ]);
768
+ float32x4_t sub1 = vsubq_f32 (yiv.val [1 ], x.val [1 ]);
769
+
770
+ float32x4_t accu8_0 = vaddq_f32 (accu8.val [0 ], vmulq_f32 (sub0, sub0));
771
+ float32x4_t accu8_1 = vaddq_f32 (accu8.val [1 ], vmulq_f32 (sub1, sub1));
772
+
773
+ float32x4x2_t accu8_temp = vzipq_f32 (accu8_0, accu8_1);
774
+ accu8 = vuzpq_f32 (accu8_temp.val [0 ], accu8_temp.val [1 ]);
775
+ }
776
+
777
+ FAISS_ALWAYS_INLINE void add_8_components_2 (
778
+ float32x4x2_t x,
779
+ float32x4x2_t y) {
780
+ float32x4_t sub0 = vsubq_f32 (y.val [0 ], x.val [0 ]);
781
+ float32x4_t sub1 = vsubq_f32 (y.val [1 ], x.val [1 ]);
782
+
783
+ float32x4_t accu8_0 = vaddq_f32 (accu8.val [0 ], vmulq_f32 (sub0, sub0));
784
+ float32x4_t accu8_1 = vaddq_f32 (accu8.val [1 ], vmulq_f32 (sub1, sub1));
785
+
786
+ float32x4x2_t accu8_temp = vzipq_f32 (accu8_0, accu8_1);
787
+ accu8 = vuzpq_f32 (accu8_temp.val [0 ], accu8_temp.val [1 ]);
788
+ }
789
+
790
+ FAISS_ALWAYS_INLINE float result_8 () {
791
+ float32x4_t sum_0 = vpaddq_f32 (accu8.val [0 ], accu8.val [0 ]);
792
+ float32x4_t sum_1 = vpaddq_f32 (accu8.val [1 ], accu8.val [1 ]);
793
+
794
+ float32x4_t sum2_0 = vpaddq_f32 (sum_0, sum_0);
795
+ float32x4_t sum2_1 = vpaddq_f32 (sum_1, sum_1);
796
+ return vgetq_lane_f32 (sum2_0, 0 ) + vgetq_lane_f32 (sum2_1, 0 );
797
+ }
798
+ };
799
+ #endif
800
+
731
801
template <int SIMDWIDTH>
732
802
struct SimilarityIP {};
733
803
@@ -801,6 +871,60 @@ struct SimilarityIP<8> {
801
871
};
802
872
#endif
803
873
874
+ #ifdef __aarch64__
875
+
876
+ template <>
877
+ struct SimilarityIP <8 > {
878
+ static constexpr int simdwidth = 8 ;
879
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
880
+
881
+ const float *y, *yi;
882
+
883
+ explicit SimilarityIP (const float * y) : y(y) {}
884
+ float32x4x2_t accu8;
885
+
886
+ FAISS_ALWAYS_INLINE void begin_8 () {
887
+ accu8 = vzipq_f32 (vdupq_n_f32 (0 .0f ), vdupq_n_f32 (0 .0f ));
888
+ yi = y;
889
+ }
890
+
891
+ FAISS_ALWAYS_INLINE void add_8_components (float32x4x2_t x) {
892
+ float32x4x2_t yiv = vld1q_f32_x2 (yi);
893
+ yi += 8 ;
894
+
895
+ float32x4_t accu8_0 =
896
+ vaddq_f32 (accu8.val [0 ], vmulq_f32 (yiv.val [0 ], x.val [0 ]));
897
+ float32x4_t accu8_1 =
898
+ vaddq_f32 (accu8.val [1 ], vmulq_f32 (yiv.val [1 ], x.val [1 ]));
899
+ float32x4x2_t accu8_temp = vzipq_f32 (accu8_0, accu8_1);
900
+ accu8 = vuzpq_f32 (accu8_temp.val [0 ], accu8_temp.val [1 ]);
901
+ }
902
+
903
+ FAISS_ALWAYS_INLINE void add_8_components_2 (
904
+ float32x4x2_t x1,
905
+ float32x4x2_t x2) {
906
+ float32x4_t accu8_0 =
907
+ vaddq_f32 (accu8.val [0 ], vmulq_f32 (x1.val [0 ], x2.val [0 ]));
908
+ float32x4_t accu8_1 =
909
+ vaddq_f32 (accu8.val [1 ], vmulq_f32 (x1.val [1 ], x2.val [1 ]));
910
+ float32x4x2_t accu8_temp = vzipq_f32 (accu8_0, accu8_1);
911
+ accu8 = vuzpq_f32 (accu8_temp.val [0 ], accu8_temp.val [1 ]);
912
+ }
913
+
914
+ FAISS_ALWAYS_INLINE float result_8 () {
915
+ float32x4x2_t sum_tmp = vzipq_f32 (
916
+ vpaddq_f32 (accu8.val [0 ], accu8.val [0 ]),
917
+ vpaddq_f32 (accu8.val [1 ], accu8.val [1 ]));
918
+ float32x4x2_t sum = vuzpq_f32 (sum_tmp.val [0 ], sum_tmp.val [1 ]);
919
+ float32x4x2_t sum2_tmp = vzipq_f32 (
920
+ vpaddq_f32 (sum.val [0 ], sum.val [0 ]),
921
+ vpaddq_f32 (sum.val [1 ], sum.val [1 ]));
922
+ float32x4x2_t sum2 = vuzpq_f32 (sum2_tmp.val [0 ], sum2_tmp.val [1 ]);
923
+ return vgetq_lane_f32 (sum2.val [0 ], 0 ) + vgetq_lane_f32 (sum2.val [1 ], 0 );
924
+ }
925
+ };
926
+ #endif
927
+
804
928
/* ******************************************************************
805
929
* DistanceComputer: combines a similarity and a quantizer to do
806
930
* code-to-vector or code-to-code comparisons
@@ -903,6 +1027,53 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
903
1027
904
1028
#endif
905
1029
1030
+ #ifdef __aarch64__
1031
+
1032
+ template <class Quantizer , class Similarity >
1033
+ struct DCTemplate <Quantizer, Similarity, 8 > : SQDistanceComputer {
1034
+ using Sim = Similarity;
1035
+
1036
+ Quantizer quant;
1037
+
1038
+ DCTemplate (size_t d, const std::vector<float >& trained)
1039
+ : quant(d, trained) {}
1040
+ float compute_distance (const float * x, const uint8_t * code) const {
1041
+ Similarity sim (x);
1042
+ sim.begin_8 ();
1043
+ for (size_t i = 0 ; i < quant.d ; i += 8 ) {
1044
+ float32x4x2_t xi = quant.reconstruct_8_components (code, i);
1045
+ sim.add_8_components (xi);
1046
+ }
1047
+ return sim.result_8 ();
1048
+ }
1049
+
1050
+ float compute_code_distance (const uint8_t * code1, const uint8_t * code2)
1051
+ const {
1052
+ Similarity sim (nullptr );
1053
+ sim.begin_8 ();
1054
+ for (size_t i = 0 ; i < quant.d ; i += 8 ) {
1055
+ float32x4x2_t x1 = quant.reconstruct_8_components (code1, i);
1056
+ float32x4x2_t x2 = quant.reconstruct_8_components (code2, i);
1057
+ sim.add_8_components_2 (x1, x2);
1058
+ }
1059
+ return sim.result_8 ();
1060
+ }
1061
+
1062
+ void set_query (const float * x) final {
1063
+ q = x;
1064
+ }
1065
+
1066
+ float symmetric_dis (idx_t i, idx_t j) override {
1067
+ return compute_code_distance (
1068
+ codes + i * code_size, codes + j * code_size);
1069
+ }
1070
+
1071
+ float query_to_code (const uint8_t * code) const final {
1072
+ return compute_distance (q, code);
1073
+ }
1074
+ };
1075
+ #endif
1076
+
906
1077
/* ******************************************************************
907
1078
* DistanceComputerByte: computes distances in the integer domain
908
1079
*******************************************************************/
@@ -1024,55 +1195,57 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1024
1195
* specialization
1025
1196
*******************************************************************/
1026
1197
1027
- template <class Sim >
1198
+ template <class Sim , class Sim_default >
1028
1199
SQDistanceComputer* select_distance_computer (
1029
1200
QuantizerType qtype,
1030
1201
size_t d,
1031
1202
const std::vector<float >& trained) {
1032
1203
constexpr int SIMDWIDTH = Sim::simdwidth;
1204
+ constexpr int SIMDWIDTH_DEFAULT = Sim_default::simdwidth;
1033
1205
switch (qtype) {
1034
1206
case ScalarQuantizer::QT_8bit_uniform:
1035
1207
return new DCTemplate<
1036
- QuantizerTemplate<Codec8bit, true , SIMDWIDTH >,
1037
- Sim ,
1038
- SIMDWIDTH >(d, trained);
1208
+ QuantizerTemplate<Codec8bit, true , SIMDWIDTH_DEFAULT >,
1209
+ Sim_default ,
1210
+ SIMDWIDTH_DEFAULT >(d, trained);
1039
1211
1040
1212
case ScalarQuantizer::QT_4bit_uniform:
1041
1213
return new DCTemplate<
1042
- QuantizerTemplate<Codec4bit, true , SIMDWIDTH >,
1043
- Sim ,
1044
- SIMDWIDTH >(d, trained);
1214
+ QuantizerTemplate<Codec4bit, true , SIMDWIDTH_DEFAULT >,
1215
+ Sim_default ,
1216
+ SIMDWIDTH_DEFAULT >(d, trained);
1045
1217
1046
1218
case ScalarQuantizer::QT_8bit:
1047
1219
return new DCTemplate<
1048
- QuantizerTemplate<Codec8bit, false , SIMDWIDTH >,
1049
- Sim ,
1050
- SIMDWIDTH >(d, trained);
1220
+ QuantizerTemplate<Codec8bit, false , SIMDWIDTH_DEFAULT >,
1221
+ Sim_default ,
1222
+ SIMDWIDTH_DEFAULT >(d, trained);
1051
1223
1052
1224
case ScalarQuantizer::QT_6bit:
1053
1225
return new DCTemplate<
1054
- QuantizerTemplate<Codec6bit, false , SIMDWIDTH >,
1055
- Sim ,
1056
- SIMDWIDTH >(d, trained);
1226
+ QuantizerTemplate<Codec6bit, false , SIMDWIDTH_DEFAULT >,
1227
+ Sim_default ,
1228
+ SIMDWIDTH_DEFAULT >(d, trained);
1057
1229
1058
1230
case ScalarQuantizer::QT_4bit:
1059
1231
return new DCTemplate<
1060
- QuantizerTemplate<Codec4bit, false , SIMDWIDTH >,
1061
- Sim ,
1062
- SIMDWIDTH >(d, trained);
1232
+ QuantizerTemplate<Codec4bit, false , SIMDWIDTH_DEFAULT >,
1233
+ Sim_default ,
1234
+ SIMDWIDTH_DEFAULT >(d, trained);
1063
1235
1064
1236
case ScalarQuantizer::QT_fp16:
1065
1237
return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1066
1238
d, trained);
1067
1239
1068
1240
case ScalarQuantizer::QT_8bit_direct:
1069
1241
if (d % 16 == 0 ) {
1070
- return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1242
+ return new DistanceComputerByte<Sim_default, SIMDWIDTH_DEFAULT>(
1243
+ d, trained);
1071
1244
} else {
1072
1245
return new DCTemplate<
1073
- Quantizer8bitDirect<SIMDWIDTH >,
1074
- Sim ,
1075
- SIMDWIDTH >(d, trained);
1246
+ Quantizer8bitDirect<SIMDWIDTH_DEFAULT >,
1247
+ Sim_default ,
1248
+ SIMDWIDTH_DEFAULT >(d, trained);
1076
1249
}
1077
1250
}
1078
1251
FAISS_THROW_MSG (" unknown qtype" );
@@ -1155,13 +1328,14 @@ void ScalarQuantizer::train(size_t n, const float* x) {
1155
1328
}
1156
1329
1157
1330
ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer () const {
1158
- #ifdef USE_F16C
1159
1331
if (d % 8 == 0 ) {
1160
- return select_quantizer_1<8 >(qtype, d, trained);
1161
- } else
1332
+ #if defined(USE_F16C)
1333
+ return select_quantizer_1<8 , 8 >(qtype, d, trained);
1334
+ #elif defined(__aarch64__)
1335
+ return select_quantizer_1<8 , 1 >(qtype, d, trained);
1162
1336
#endif
1163
- {
1164
- return select_quantizer_1<1 >(qtype, d, trained);
1337
+ } else {
1338
+ return select_quantizer_1<1 , 1 >(qtype, d, trained);
1165
1339
}
1166
1340
}
1167
1341
@@ -1186,20 +1360,31 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1186
1360
SQDistanceComputer* ScalarQuantizer::get_distance_computer (
1187
1361
MetricType metric) const {
1188
1362
FAISS_THROW_IF_NOT (metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1189
- #ifdef USE_F16C
1190
1363
if (d % 8 == 0 ) {
1191
1364
if (metric == METRIC_L2) {
1192
- return select_distance_computer<SimilarityL2<8 >>(qtype, d, trained);
1365
+ #if defined(USE_F16C)
1366
+ return select_distance_computer<SimilarityL2<8 >, SimilarityL2<8 >>(
1367
+ qtype, d, trained);
1368
+ #elif defined(__aarch64__)
1369
+ return select_distance_computer<SimilarityL2<8 >, SimilarityL2<1 >>(
1370
+ qtype, d, trained);
1371
+ #endif
1193
1372
} else {
1194
- return select_distance_computer<SimilarityIP<8 >>(qtype, d, trained);
1195
- }
1196
- } else
1373
+ #if defined(USE_F16C)
1374
+ return select_distance_computer<SimilarityIP<8 >, SimilarityIP<8 >>(
1375
+ qtype, d, trained);
1376
+ #elif defined(__aarch64__)
1377
+ return select_distance_computer<SimilarityIP<8 >, SimilarityIP<1 >>(
1378
+ qtype, d, trained);
1197
1379
#endif
1198
- {
1380
+ }
1381
+ } else {
1199
1382
if (metric == METRIC_L2) {
1200
- return select_distance_computer<SimilarityL2<1 >>(qtype, d, trained);
1383
+ return select_distance_computer<SimilarityL2<1 >, SimilarityL2<1 >>(
1384
+ qtype, d, trained);
1201
1385
} else {
1202
- return select_distance_computer<SimilarityIP<1 >>(qtype, d, trained);
1386
+ return select_distance_computer<SimilarityIP<1 >, SimilarityIP<1 >>(
1387
+ qtype, d, trained);
1203
1388
}
1204
1389
}
1205
1390
}
0 commit comments