@@ -91,6 +91,20 @@ struct Codec8bit {
91
91
return _mm256_fmadd_ps (f8, one_255, half_one_255);
92
92
}
93
93
#endif
94
+
95
+ #ifdef __aarch64__
96
+ static FAISS_ALWAYS_INLINE float32x4x2_t
97
+ decode_8_components (const uint8_t * code, int i) {
98
+ float32_t result[8 ] = {};
99
+ for (size_t j = 0 ; j < 8 ; j++) {
100
+ result[j] = decode_component (code, i + j);
101
+ }
102
+ float32x4_t res1 = vld1q_f32 (result);
103
+ float32x4_t res2 = vld1q_f32 (result + 4 );
104
+ float32x4x2_t res = vzipq_f32 (res1, res2);
105
+ return vuzpq_f32 (res.val [0 ], res.val [1 ]);
106
+ }
107
+ #endif
94
108
};
95
109
96
110
struct Codec4bit {
@@ -129,6 +143,20 @@ struct Codec4bit {
129
143
return _mm256_mul_ps (f8, one_255);
130
144
}
131
145
#endif
146
+
147
+ #ifdef __aarch64__
148
+ static FAISS_ALWAYS_INLINE float32x4x2_t
149
+ decode_8_components (const uint8_t * code, int i) {
150
+ float32_t result[8 ] = {};
151
+ for (size_t j = 0 ; j < 8 ; j++) {
152
+ result[j] = decode_component (code, i + j);
153
+ }
154
+ float32x4_t res1 = vld1q_f32 (result);
155
+ float32x4_t res2 = vld1q_f32 (result + 4 );
156
+ float32x4x2_t res = vzipq_f32 (res1, res2);
157
+ return vuzpq_f32 (res.val [0 ], res.val [1 ]);
158
+ }
159
+ #endif
132
160
};
133
161
134
162
struct Codec6bit {
@@ -228,6 +256,20 @@ struct Codec6bit {
228
256
}
229
257
230
258
#endif
259
+
260
+ #ifdef __aarch64__
261
+ static FAISS_ALWAYS_INLINE float32x4x2_t
262
+ decode_8_components (const uint8_t * code, int i) {
263
+ float32_t result[8 ] = {};
264
+ for (size_t j = 0 ; j < 8 ; j++) {
265
+ result[j] = decode_component (code, i + j);
266
+ }
267
+ float32x4_t res1 = vld1q_f32 (result);
268
+ float32x4_t res2 = vld1q_f32 (result + 4 );
269
+ float32x4x2_t res = vzipq_f32 (res1, res2);
270
+ return vuzpq_f32 (res.val [0 ], res.val [1 ]);
271
+ }
272
+ #endif
231
273
};
232
274
233
275
/* ******************************************************************
@@ -293,6 +335,31 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
293
335
294
336
#endif
295
337
338
+ #ifdef __aarch64__
339
+
340
+ template <class Codec >
341
+ struct QuantizerTemplate <Codec, true , 8 > : QuantizerTemplate<Codec, true , 1 > {
342
+ QuantizerTemplate (size_t d, const std::vector<float >& trained)
343
+ : QuantizerTemplate<Codec, true , 1 >(d, trained) {}
344
+
345
+ FAISS_ALWAYS_INLINE float32x4x2_t
346
+ reconstruct_8_components (const uint8_t * code, int i) const {
347
+ float32x4x2_t xi = Codec::decode_8_components (code, i);
348
+ float32x4x2_t res = vzipq_f32 (
349
+ vfmaq_f32 (
350
+ vdupq_n_f32 (this ->vmin ),
351
+ xi.val [0 ],
352
+ vdupq_n_f32 (this ->vdiff )),
353
+ vfmaq_f32 (
354
+ vdupq_n_f32 (this ->vmin ),
355
+ xi.val [1 ],
356
+ vdupq_n_f32 (this ->vdiff )));
357
+ return vuzpq_f32 (res.val [0 ], res.val [1 ]);
358
+ }
359
+ };
360
+
361
+ #endif
362
+
296
363
template <class Codec >
297
364
struct QuantizerTemplate <Codec, false , 1 > : ScalarQuantizer::SQuantizer {
298
365
const size_t d;
@@ -350,6 +417,29 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
350
417
351
418
#endif
352
419
420
+ #ifdef __aarch64__
421
+
422
+ template <class Codec >
423
+ struct QuantizerTemplate <Codec, false , 8 > : QuantizerTemplate<Codec, false , 1 > {
424
+ QuantizerTemplate (size_t d, const std::vector<float >& trained)
425
+ : QuantizerTemplate<Codec, false , 1 >(d, trained) {}
426
+
427
+ FAISS_ALWAYS_INLINE float32x4x2_t
428
+ reconstruct_8_components (const uint8_t * code, int i) const {
429
+ float32x4x2_t xi = Codec::decode_8_components (code, i);
430
+
431
+ float32x4x2_t vmin_8 = vld1q_f32_x2 (this ->vmin + i);
432
+ float32x4x2_t vdiff_8 = vld1q_f32_x2 (this ->vdiff + i);
433
+
434
+ float32x4x2_t res = vzipq_f32 (
435
+ vfmaq_f32 (vmin_8.val [0 ], xi.val [0 ], vdiff_8.val [0 ]),
436
+ vfmaq_f32 (vmin_8.val [1 ], xi.val [1 ], vdiff_8.val [1 ]));
437
+ return vuzpq_f32 (res.val [0 ], res.val [1 ]);
438
+ }
439
+ };
440
+
441
+ #endif
442
+
353
443
/* ******************************************************************
354
444
* FP16 quantizer
355
445
*******************************************************************/
@@ -463,31 +553,53 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
463
553
464
554
#endif
465
555
466
- template <int SIMDWIDTH, int SIMDWIDTH_DEFAULT>
556
+ #ifdef __aarch64__
557
+
558
+ template <>
559
+ struct Quantizer8bitDirect <8 > : Quantizer8bitDirect<1 > {
560
+ Quantizer8bitDirect (size_t d, const std::vector<float >& trained)
561
+ : Quantizer8bitDirect<1 >(d, trained) {}
562
+
563
+ FAISS_ALWAYS_INLINE float32x4x2_t
564
+ reconstruct_8_components (const uint8_t * code, int i) const {
565
+ float32_t result[8 ] = {};
566
+ for (size_t j = 0 ; j < 8 ; j++) {
567
+ result[j] = code[i + j];
568
+ }
569
+ float32x4_t res1 = vld1q_f32 (result);
570
+ float32x4_t res2 = vld1q_f32 (result + 4 );
571
+ float32x4x2_t res = vzipq_f32 (res1, res2);
572
+ return vuzpq_f32 (res.val [0 ], res.val [1 ]);
573
+ }
574
+ };
575
+
576
+ #endif
577
+
578
+ template <int SIMDWIDTH>
467
579
ScalarQuantizer::SQuantizer* select_quantizer_1 (
468
580
QuantizerType qtype,
469
581
size_t d,
470
582
const std::vector<float >& trained) {
471
583
switch (qtype) {
472
584
case ScalarQuantizer::QT_8bit:
473
- return new QuantizerTemplate<Codec8bit, false , SIMDWIDTH_DEFAULT >(
585
+ return new QuantizerTemplate<Codec8bit, false , SIMDWIDTH >(
474
586
d, trained);
475
587
case ScalarQuantizer::QT_6bit:
476
- return new QuantizerTemplate<Codec6bit, false , SIMDWIDTH_DEFAULT >(
588
+ return new QuantizerTemplate<Codec6bit, false , SIMDWIDTH >(
477
589
d, trained);
478
590
case ScalarQuantizer::QT_4bit:
479
- return new QuantizerTemplate<Codec4bit, false , SIMDWIDTH_DEFAULT >(
591
+ return new QuantizerTemplate<Codec4bit, false , SIMDWIDTH >(
480
592
d, trained);
481
593
case ScalarQuantizer::QT_8bit_uniform:
482
- return new QuantizerTemplate<Codec8bit, true , SIMDWIDTH_DEFAULT >(
594
+ return new QuantizerTemplate<Codec8bit, true , SIMDWIDTH >(
483
595
d, trained);
484
596
case ScalarQuantizer::QT_4bit_uniform:
485
- return new QuantizerTemplate<Codec4bit, true , SIMDWIDTH_DEFAULT >(
597
+ return new QuantizerTemplate<Codec4bit, true , SIMDWIDTH >(
486
598
d, trained);
487
599
case ScalarQuantizer::QT_fp16:
488
600
return new QuantizerFP16<SIMDWIDTH>(d, trained);
489
601
case ScalarQuantizer::QT_8bit_direct:
490
- return new Quantizer8bitDirect<SIMDWIDTH_DEFAULT >(d, trained);
602
+ return new Quantizer8bitDirect<SIMDWIDTH >(d, trained);
491
603
}
492
604
FAISS_THROW_MSG (" unknown qtype" );
493
605
}
@@ -1186,62 +1298,108 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1186
1298
1187
1299
#endif
1188
1300
1301
+ #ifdef __aarch64__
1302
+
1303
+ template <class Similarity >
1304
+ struct DistanceComputerByte <Similarity, 8 > : SQDistanceComputer {
1305
+ using Sim = Similarity;
1306
+
1307
+ int d;
1308
+ std::vector<uint8_t > tmp;
1309
+
1310
+ DistanceComputerByte (int d, const std::vector<float >&) : d(d), tmp(d) {}
1311
+
1312
+ int compute_code_distance (const uint8_t * code1, const uint8_t * code2)
1313
+ const {
1314
+ int accu = 0 ;
1315
+ for (int i = 0 ; i < d; i++) {
1316
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1317
+ accu += int (code1[i]) * code2[i];
1318
+ } else {
1319
+ int diff = int (code1[i]) - code2[i];
1320
+ accu += diff * diff;
1321
+ }
1322
+ }
1323
+ return accu;
1324
+ }
1325
+
1326
+ void set_query (const float * x) final {
1327
+ for (int i = 0 ; i < d; i++) {
1328
+ tmp[i] = int (x[i]);
1329
+ }
1330
+ }
1331
+
1332
+ int compute_distance (const float * x, const uint8_t * code) {
1333
+ set_query (x);
1334
+ return compute_code_distance (tmp.data (), code);
1335
+ }
1336
+
1337
+ float symmetric_dis (idx_t i, idx_t j) override {
1338
+ return compute_code_distance (
1339
+ codes + i * code_size, codes + j * code_size);
1340
+ }
1341
+
1342
+ float query_to_code (const uint8_t * code) const final {
1343
+ return compute_code_distance (tmp.data (), code);
1344
+ }
1345
+ };
1346
+
1347
+ #endif
1348
+
1189
1349
/* ******************************************************************
1190
1350
* select_distance_computer: runtime selection of template
1191
1351
* specialization
1192
1352
*******************************************************************/
1193
1353
1194
- template <class Sim , class Sim_default >
1354
+ template <class Sim >
1195
1355
SQDistanceComputer* select_distance_computer (
1196
1356
QuantizerType qtype,
1197
1357
size_t d,
1198
1358
const std::vector<float >& trained) {
1199
1359
constexpr int SIMDWIDTH = Sim::simdwidth;
1200
- constexpr int SIMDWIDTH_DEFAULT = Sim_default::simdwidth;
1201
1360
switch (qtype) {
1202
1361
case ScalarQuantizer::QT_8bit_uniform:
1203
1362
return new DCTemplate<
1204
- QuantizerTemplate<Codec8bit, true , SIMDWIDTH_DEFAULT >,
1205
- Sim_default ,
1206
- SIMDWIDTH_DEFAULT >(d, trained);
1363
+ QuantizerTemplate<Codec8bit, true , SIMDWIDTH >,
1364
+ Sim ,
1365
+ SIMDWIDTH >(d, trained);
1207
1366
1208
1367
case ScalarQuantizer::QT_4bit_uniform:
1209
1368
return new DCTemplate<
1210
- QuantizerTemplate<Codec4bit, true , SIMDWIDTH_DEFAULT >,
1211
- Sim_default ,
1212
- SIMDWIDTH_DEFAULT >(d, trained);
1369
+ QuantizerTemplate<Codec4bit, true , SIMDWIDTH >,
1370
+ Sim ,
1371
+ SIMDWIDTH >(d, trained);
1213
1372
1214
1373
case ScalarQuantizer::QT_8bit:
1215
1374
return new DCTemplate<
1216
- QuantizerTemplate<Codec8bit, false , SIMDWIDTH_DEFAULT >,
1217
- Sim_default ,
1218
- SIMDWIDTH_DEFAULT >(d, trained);
1375
+ QuantizerTemplate<Codec8bit, false , SIMDWIDTH >,
1376
+ Sim ,
1377
+ SIMDWIDTH >(d, trained);
1219
1378
1220
1379
case ScalarQuantizer::QT_6bit:
1221
1380
return new DCTemplate<
1222
- QuantizerTemplate<Codec6bit, false , SIMDWIDTH_DEFAULT >,
1223
- Sim_default ,
1224
- SIMDWIDTH_DEFAULT >(d, trained);
1381
+ QuantizerTemplate<Codec6bit, false , SIMDWIDTH >,
1382
+ Sim ,
1383
+ SIMDWIDTH >(d, trained);
1225
1384
1226
1385
case ScalarQuantizer::QT_4bit:
1227
1386
return new DCTemplate<
1228
- QuantizerTemplate<Codec4bit, false , SIMDWIDTH_DEFAULT >,
1229
- Sim_default ,
1230
- SIMDWIDTH_DEFAULT >(d, trained);
1387
+ QuantizerTemplate<Codec4bit, false , SIMDWIDTH >,
1388
+ Sim ,
1389
+ SIMDWIDTH >(d, trained);
1231
1390
1232
1391
case ScalarQuantizer::QT_fp16:
1233
1392
return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1234
1393
d, trained);
1235
1394
1236
1395
case ScalarQuantizer::QT_8bit_direct:
1237
1396
if (d % 16 == 0 ) {
1238
- return new DistanceComputerByte<Sim_default, SIMDWIDTH_DEFAULT>(
1239
- d, trained);
1397
+ return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1240
1398
} else {
1241
1399
return new DCTemplate<
1242
- Quantizer8bitDirect<SIMDWIDTH_DEFAULT >,
1243
- Sim_default ,
1244
- SIMDWIDTH_DEFAULT >(d, trained);
1400
+ Quantizer8bitDirect<SIMDWIDTH >,
1401
+ Sim ,
1402
+ SIMDWIDTH >(d, trained);
1245
1403
}
1246
1404
}
1247
1405
FAISS_THROW_MSG (" unknown qtype" );
@@ -1324,14 +1482,13 @@ void ScalarQuantizer::train(size_t n, const float* x) {
1324
1482
}
1325
1483
1326
1484
ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer () const {
1485
+ #if defined(USE_F16C) || defined(__aarch64__)
1327
1486
if (d % 8 == 0 ) {
1328
- #if defined(USE_F16C)
1329
- return select_quantizer_1<8 , 8 >(qtype, d, trained);
1330
- #elif defined(__aarch64__)
1331
- return select_quantizer_1<8 , 1 >(qtype, d, trained);
1487
+ return select_quantizer_1<8 >(qtype, d, trained);
1488
+ } else
1332
1489
#endif
1333
- } else {
1334
- return select_quantizer_1<1 , 1 >(qtype, d, trained);
1490
+ {
1491
+ return select_quantizer_1<1 >(qtype, d, trained);
1335
1492
}
1336
1493
}
1337
1494
@@ -1356,31 +1513,20 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1356
1513
SQDistanceComputer* ScalarQuantizer::get_distance_computer (
1357
1514
MetricType metric) const {
1358
1515
FAISS_THROW_IF_NOT (metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1516
+ #if defined(USE_F16C) || defined(__aarch64__)
1359
1517
if (d % 8 == 0 ) {
1360
1518
if (metric == METRIC_L2) {
1361
- #if defined(USE_F16C)
1362
- return select_distance_computer<SimilarityL2<8 >, SimilarityL2<8 >>(
1363
- qtype, d, trained);
1364
- #elif defined(__aarch64__)
1365
- return select_distance_computer<SimilarityL2<8 >, SimilarityL2<1 >>(
1366
- qtype, d, trained);
1367
- #endif
1519
+ return select_distance_computer<SimilarityL2<8 >>(qtype, d, trained);
1368
1520
} else {
1369
- #if defined(USE_F16C)
1370
- return select_distance_computer<SimilarityIP<8 >, SimilarityIP<8 >>(
1371
- qtype, d, trained);
1372
- #elif defined(__aarch64__)
1373
- return select_distance_computer<SimilarityIP<8 >, SimilarityIP<1 >>(
1374
- qtype, d, trained);
1375
- #endif
1521
+ return select_distance_computer<SimilarityIP<8 >>(qtype, d, trained);
1376
1522
}
1377
- } else {
1523
+ } else
1524
+ #endif
1525
+ {
1378
1526
if (metric == METRIC_L2) {
1379
- return select_distance_computer<SimilarityL2<1 >, SimilarityL2<1 >>(
1380
- qtype, d, trained);
1527
+ return select_distance_computer<SimilarityL2<1 >>(qtype, d, trained);
1381
1528
} else {
1382
- return select_distance_computer<SimilarityIP<1 >, SimilarityIP<1 >>(
1383
- qtype, d, trained);
1529
+ return select_distance_computer<SimilarityIP<1 >>(qtype, d, trained);
1384
1530
}
1385
1531
}
1386
1532
}
@@ -1702,7 +1848,7 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1702
1848
bool store_pairs,
1703
1849
const IDSelector* sel,
1704
1850
bool by_residual) const {
1705
- #ifdef USE_F16C
1851
+ #if defined( USE_F16C) || defined(__aarch64__)
1706
1852
if (d % 8 == 0 ) {
1707
1853
return sel0_InvertedListScanner<8 >(
1708
1854
mt, this , quantizer, store_pairs, sel, by_residual);
0 commit comments