Skip to content

Commit 9c88422

Browse files
wx257osn2facebook-github-bot
authored andcommitted
Some changes to simdlib (facebookresearch#2885)
Summary: - Use elementwise operation and reduction once instead of across-vector comparing operation twice - Use already implemented supporting functions - Unify semantics of `operator==` as same as `simd16uint16` - `operator==` of `simd8uint32` and `simd8float32` had been implemented on facebookresearch#2568, but these has not same semantics as `simd16uint16` (which had been implemented in a long time ago). For getting the vector equality as `bool` , now we should use `is_same_as` member function. - Change `is_same_as` to accept any vector type as argument for `simdlib_neon` - `is_same_as` has supported any vector type on `simdlib_avx2` and `simdlib_emulated` already - Remove unused function `simd16uint16::is_same` on `simdlib_avx2` - Is it typo of `is_same_as` ? Anyway it seems to be used unlikely Pull Request resolved: facebookresearch#2885 Reviewed By: mdouze Differential Revision: D46330666 Pulled By: alexanderguzhva fbshipit-source-id: 0ea14f8e9a8bda78f24a655219dffe3e07fc110f
1 parent bbc95b1 commit 9c88422

File tree

2 files changed

+72
-83
lines changed

2 files changed

+72
-83
lines changed

faiss/utils/simdlib_avx2.h

-6
Original file line numberDiff line numberDiff line change
@@ -202,12 +202,6 @@ struct simd16uint16 : simd256bit {
202202
return simd16uint16(_mm256_cmpeq_epi16(lhs.i, rhs.i));
203203
}
204204

205-
bool is_same(simd16uint16 other) const {
206-
const __m256i pcmp = _mm256_cmpeq_epi16(i, other.i);
207-
unsigned bitmask = _mm256_movemask_epi8(pcmp);
208-
return (bitmask == 0xffffffffU);
209-
}
210-
211205
simd16uint16 operator~() const {
212206
return simd16uint16(_mm256_xor_si256(i, _mm256_set1_epi32(-1)));
213207
}

faiss/utils/simdlib_neon.h

+72-77
Original file line numberDiff line numberDiff line change
@@ -559,15 +559,13 @@ struct simd16uint16 {
559559
}
560560

561561
// Checks whether the other holds exactly the same bytes.
562-
bool is_same_as(simd16uint16 other) const {
563-
const bool equal0 =
564-
(vminvq_u16(vceqq_u16(data.val[0], other.data.val[0])) ==
565-
0xffff);
566-
const bool equal1 =
567-
(vminvq_u16(vceqq_u16(data.val[1], other.data.val[1])) ==
568-
0xffff);
569-
570-
return equal0 && equal1;
562+
template <typename T>
563+
bool is_same_as(T other) const {
564+
const auto o = detail::simdlib::reinterpret_u16(other.data);
565+
const auto equals = detail::simdlib::binary_func(data, o)
566+
.template call<&vceqq_u16>();
567+
const auto equal = vandq_u16(equals.val[0], equals.val[1]);
568+
return vminvq_u16(equal) == 0xffffu;
571569
}
572570

573571
simd16uint16 operator~() const {
@@ -689,13 +687,12 @@ inline void cmplt_min_max_fast(
689687
simd16uint16& minIndices,
690688
simd16uint16& maxValues,
691689
simd16uint16& maxIndices) {
692-
const uint16x8x2_t comparison = uint16x8x2_t{
693-
vcltq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
694-
vcltq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
690+
const uint16x8x2_t comparison =
691+
detail::simdlib::binary_func(
692+
candidateValues.data, currentValues.data)
693+
.call<&vcltq_u16>();
695694

696-
minValues.data = uint16x8x2_t{
697-
vminq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
698-
vminq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
695+
minValues = min(candidateValues, currentValues);
699696
minIndices.data = uint16x8x2_t{
700697
vbslq_u16(
701698
comparison.val[0],
@@ -706,9 +703,7 @@ inline void cmplt_min_max_fast(
706703
candidateIndices.data.val[1],
707704
currentIndices.data.val[1])};
708705

709-
maxValues.data = uint16x8x2_t{
710-
vmaxq_u16(candidateValues.data.val[0], currentValues.data.val[0]),
711-
vmaxq_u16(candidateValues.data.val[1], currentValues.data.val[1])};
706+
maxValues = max(candidateValues, currentValues);
712707
maxIndices.data = uint16x8x2_t{
713708
vbslq_u16(
714709
comparison.val[0],
@@ -869,13 +864,13 @@ struct simd32uint8 {
869864
}
870865

871866
// Checks whether the other holds exactly the same bytes.
872-
bool is_same_as(simd32uint8 other) const {
873-
const bool equal0 =
874-
(vminvq_u8(vceqq_u8(data.val[0], other.data.val[0])) == 0xff);
875-
const bool equal1 =
876-
(vminvq_u8(vceqq_u8(data.val[1], other.data.val[1])) == 0xff);
877-
878-
return equal0 && equal1;
867+
template <typename T>
868+
bool is_same_as(T other) const {
869+
const auto o = detail::simdlib::reinterpret_u8(other.data);
870+
const auto equals = detail::simdlib::binary_func(data, o)
871+
.template call<&vceqq_u8>();
872+
const auto equal = vandq_u8(equals.val[0], equals.val[1]);
873+
return vminvq_u8(equal) == 0xffu;
879874
}
880875
};
881876

@@ -960,27 +955,28 @@ struct simd8uint32 {
960955
return *this;
961956
}
962957

963-
bool operator==(simd8uint32 other) const {
964-
const auto equals = detail::simdlib::binary_func(data, other.data)
965-
.call<&vceqq_u32>();
966-
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
967-
return vminvq_u32(equal) == 0xffffffff;
958+
simd8uint32 operator==(simd8uint32 other) const {
959+
return simd8uint32{detail::simdlib::binary_func(data, other.data)
960+
.call<&vceqq_u32>()};
968961
}
969962

970-
bool operator!=(simd8uint32 other) const {
971-
return !(*this == other);
963+
simd8uint32 operator~() const {
964+
return simd8uint32{
965+
detail::simdlib::unary_func(data).call<&vmvnq_u32>()};
972966
}
973967

974-
// Checks whether the other holds exactly the same bytes.
975-
bool is_same_as(simd8uint32 other) const {
976-
const bool equal0 =
977-
(vminvq_u32(vceqq_u32(data.val[0], other.data.val[0])) ==
978-
0xffffffff);
979-
const bool equal1 =
980-
(vminvq_u32(vceqq_u32(data.val[1], other.data.val[1])) ==
981-
0xffffffff);
968+
simd8uint32 operator!=(simd8uint32 other) const {
969+
return ~(*this == other);
970+
}
982971

983-
return equal0 && equal1;
972+
// Checks whether the other holds exactly the same bytes.
973+
template <typename T>
974+
bool is_same_as(T other) const {
975+
const auto o = detail::simdlib::reinterpret_u32(other.data);
976+
const auto equals = detail::simdlib::binary_func(data, o)
977+
.template call<&vceqq_u32>();
978+
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
979+
return vminvq_u32(equal) == 0xffffffffu;
984980
}
985981

986982
void clear() {
@@ -1053,13 +1049,14 @@ inline void cmplt_min_max_fast(
10531049
simd8uint32& minIndices,
10541050
simd8uint32& maxValues,
10551051
simd8uint32& maxIndices) {
1056-
const uint32x4x2_t comparison = uint32x4x2_t{
1057-
vcltq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1058-
vcltq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1059-
1060-
minValues.data = uint32x4x2_t{
1061-
vminq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1062-
vminq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1052+
const uint32x4x2_t comparison =
1053+
detail::simdlib::binary_func(
1054+
candidateValues.data, currentValues.data)
1055+
.call<&vcltq_u32>();
1056+
1057+
minValues.data = detail::simdlib::binary_func(
1058+
candidateValues.data, currentValues.data)
1059+
.call<&vminq_u32>();
10631060
minIndices.data = uint32x4x2_t{
10641061
vbslq_u32(
10651062
comparison.val[0],
@@ -1070,9 +1067,9 @@ inline void cmplt_min_max_fast(
10701067
candidateIndices.data.val[1],
10711068
currentIndices.data.val[1])};
10721069

1073-
maxValues.data = uint32x4x2_t{
1074-
vmaxq_u32(candidateValues.data.val[0], currentValues.data.val[0]),
1075-
vmaxq_u32(candidateValues.data.val[1], currentValues.data.val[1])};
1070+
maxValues.data = detail::simdlib::binary_func(
1071+
candidateValues.data, currentValues.data)
1072+
.call<&vmaxq_u32>();
10761073
maxIndices.data = uint32x4x2_t{
10771074
vbslq_u32(
10781075
comparison.val[0],
@@ -1167,28 +1164,25 @@ struct simd8float32 {
11671164
return *this;
11681165
}
11691166

1170-
bool operator==(simd8float32 other) const {
1171-
const auto equals =
1167+
simd8uint32 operator==(simd8float32 other) const {
1168+
return simd8uint32{
11721169
detail::simdlib::binary_func<::uint32x4x2_t>(data, other.data)
1173-
.call<&vceqq_f32>();
1174-
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
1175-
return vminvq_u32(equal) == 0xffffffff;
1170+
.call<&vceqq_f32>()};
11761171
}
11771172

1178-
bool operator!=(simd8float32 other) const {
1179-
return !(*this == other);
1173+
simd8uint32 operator!=(simd8float32 other) const {
1174+
return ~(*this == other);
11801175
}
11811176

11821177
// Checks whether the other holds exactly the same bytes.
1183-
bool is_same_as(simd8float32 other) const {
1184-
const bool equal0 =
1185-
(vminvq_u32(vceqq_f32(data.val[0], other.data.val[0])) ==
1186-
0xffffffff);
1187-
const bool equal1 =
1188-
(vminvq_u32(vceqq_f32(data.val[1], other.data.val[1])) ==
1189-
0xffffffff);
1190-
1191-
return equal0 && equal1;
1178+
template <typename T>
1179+
bool is_same_as(T other) const {
1180+
const auto o = detail::simdlib::reinterpret_f32(other.data);
1181+
const auto equals =
1182+
detail::simdlib::binary_func<::uint32x4x2_t>(data, o)
1183+
.template call<&vceqq_f32>();
1184+
const auto equal = vandq_u32(equals.val[0], equals.val[1]);
1185+
return vminvq_u32(equal) == 0xffffffffu;
11921186
}
11931187

11941188
std::string tostring() const {
@@ -1302,13 +1296,14 @@ inline void cmplt_min_max_fast(
13021296
simd8uint32& minIndices,
13031297
simd8float32& maxValues,
13041298
simd8uint32& maxIndices) {
1305-
const uint32x4x2_t comparison = uint32x4x2_t{
1306-
vcltq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1307-
vcltq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1308-
1309-
minValues.data = float32x4x2_t{
1310-
vminq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1311-
vminq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1299+
const uint32x4x2_t comparison =
1300+
detail::simdlib::binary_func<::uint32x4x2_t>(
1301+
candidateValues.data, currentValues.data)
1302+
.call<&vcltq_f32>();
1303+
1304+
minValues.data = detail::simdlib::binary_func(
1305+
candidateValues.data, currentValues.data)
1306+
.call<&vminq_f32>();
13121307
minIndices.data = uint32x4x2_t{
13131308
vbslq_u32(
13141309
comparison.val[0],
@@ -1319,9 +1314,9 @@ inline void cmplt_min_max_fast(
13191314
candidateIndices.data.val[1],
13201315
currentIndices.data.val[1])};
13211316

1322-
maxValues.data = float32x4x2_t{
1323-
vmaxq_f32(candidateValues.data.val[0], currentValues.data.val[0]),
1324-
vmaxq_f32(candidateValues.data.val[1], currentValues.data.val[1])};
1317+
maxValues.data = detail::simdlib::binary_func(
1318+
candidateValues.data, currentValues.data)
1319+
.call<&vmaxq_f32>();
13251320
maxIndices.data = uint32x4x2_t{
13261321
vbslq_u32(
13271322
comparison.val[0],

0 commit comments

Comments
 (0)