@@ -409,18 +409,22 @@ void search_neighbors_to_add(
409
409
**************************************************************/
410
410
411
411
// / greedily update a nearest vector at a given level
412
- void greedy_update_nearest (
412
+ HNSWStats greedy_update_nearest (
413
413
const HNSW& hnsw,
414
414
DistanceComputer& qdis,
415
415
int level,
416
416
storage_idx_t & nearest,
417
417
float & d_nearest) {
418
+ HNSWStats stats;
419
+
418
420
for (;;) {
419
421
storage_idx_t prev_nearest = nearest;
420
422
421
423
size_t begin, end;
422
424
hnsw.neighbor_range (nearest, level, &begin, &end);
423
- for (size_t i = begin; i < end; i++) {
425
+
426
+ size_t ndis = 0 ;
427
+ for (size_t i = begin; i < end; i++, ndis++) {
424
428
storage_idx_t v = hnsw.neighbors [i];
425
429
if (v < 0 )
426
430
break ;
@@ -430,8 +434,13 @@ void greedy_update_nearest(
430
434
d_nearest = dis;
431
435
}
432
436
}
437
+
438
+ // update stats
439
+ stats.ndis += ndis;
440
+ stats.nhops += 1 ;
441
+
433
442
if (nearest == prev_nearest) {
434
- return ;
443
+ return stats ;
435
444
}
436
445
}
437
446
}
@@ -641,6 +650,7 @@ int search_from_candidates(
641
650
if (dis < threshold) {
642
651
if (res.add_result (dis, idx)) {
643
652
threshold = res.threshold ;
653
+ nres += 1 ;
644
654
}
645
655
}
646
656
}
@@ -692,6 +702,7 @@ int search_from_candidates(
692
702
stats.n2 ++;
693
703
}
694
704
stats.ndis += ndis;
705
+ stats.nhops += nstep;
695
706
}
696
707
697
708
return nres;
@@ -814,6 +825,8 @@ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
814
825
float dis = qdis (saved_j[icnt]);
815
826
add_to_heap (saved_j[icnt], dis);
816
827
}
828
+
829
+ stats.nhops += 1 ;
817
830
}
818
831
819
832
++stats.n1 ;
@@ -853,7 +866,9 @@ HNSWStats HNSW::search(
853
866
float d_nearest = qdis (nearest);
854
867
855
868
for (int level = max_level; level >= 1 ; level--) {
856
- greedy_update_nearest (*this , qdis, level, nearest, d_nearest);
869
+ HNSWStats local_stats = greedy_update_nearest (
870
+ *this , qdis, level, nearest, d_nearest);
871
+ stats.combine (local_stats);
857
872
}
858
873
859
874
int ef = std::max (params ? params->efSearch : efSearch, k);
@@ -916,11 +931,23 @@ HNSWStats HNSW::search(
916
931
if (level == 0 ) {
917
932
nres = search_from_candidates (
918
933
*this , qdis, res, candidates, vt, stats, 0 );
934
+ nres = std::min (nres, candidates_size);
919
935
} else {
936
+ const auto nres_prev = nres;
937
+
920
938
resh.begin (0 );
921
939
nres = search_from_candidates (
922
940
*this , qdis, resh, candidates, vt, stats, level);
941
+ nres = std::min (nres, candidates_size);
923
942
resh.end ();
943
+
944
+ // if the search on a particular level produces no improvements,
945
+ // then we need to repopulate candidates.
946
+ // search_from_candidates() will always damage candidates
947
+ // by doing 1 pop_min().
948
+ if (nres == 0 ) {
949
+ nres = nres_prev;
950
+ }
924
951
}
925
952
vt.advance ();
926
953
}
@@ -970,6 +997,7 @@ void HNSW::search_level_0(
970
997
0 ,
971
998
nres,
972
999
params);
1000
+ nres = std::min (nres, candidates_size);
973
1001
}
974
1002
} else if (search_type == 2 ) {
975
1003
int candidates_size = std::max (efSearch, int (k));
@@ -1051,7 +1079,99 @@ void HNSW::MinimaxHeap::clear() {
1051
1079
nvalid = k = 0 ;
1052
1080
}
1053
1081
1054
- #ifdef __AVX2__
1082
+ #ifdef __AVX512F__
1083
+
1084
+ int HNSW::MinimaxHeap::pop_min (float * vmin_out) {
1085
+ assert (k > 0 );
1086
+ static_assert (
1087
+ std::is_same<storage_idx_t , int32_t >::value,
1088
+ " This code expects storage_idx_t to be int32_t" );
1089
+
1090
+ int32_t min_idx = -1 ;
1091
+ float min_dis = std::numeric_limits<float >::infinity ();
1092
+
1093
+ __m512i min_indices = _mm512_set1_epi32 (-1 );
1094
+ __m512 min_distances =
1095
+ _mm512_set1_ps (std::numeric_limits<float >::infinity ());
1096
+ __m512i current_indices = _mm512_setr_epi32 (
1097
+ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 );
1098
+ __m512i offset = _mm512_set1_epi32 (16 );
1099
+
1100
+ // The following loop tracks the rightmost index with the min distance.
1101
+ // -1 index values are ignored.
1102
+ const int k16 = (k / 16 ) * 16 ;
1103
+ for (size_t iii = 0 ; iii < k16; iii += 16 ) {
1104
+ __m512i indices =
1105
+ _mm512_loadu_si512 ((const __m512i*)(ids.data () + iii));
1106
+ __m512 distances = _mm512_loadu_ps (dis.data () + iii);
1107
+
1108
+ // This mask filters out -1 values among indices.
1109
+ __mmask16 m1mask =
1110
+ _mm512_cmpgt_epi32_mask (_mm512_setzero_si512 (), indices);
1111
+
1112
+ __mmask16 dmask =
1113
+ _mm512_cmp_ps_mask (min_distances, distances, _CMP_LT_OS);
1114
+ __mmask16 finalmask = m1mask | dmask;
1115
+
1116
+ const __m512i min_indices_new = _mm512_mask_blend_epi32 (
1117
+ finalmask, current_indices, min_indices);
1118
+ const __m512 min_distances_new =
1119
+ _mm512_mask_blend_ps (finalmask, distances, min_distances);
1120
+
1121
+ min_indices = min_indices_new;
1122
+ min_distances = min_distances_new;
1123
+
1124
+ current_indices = _mm512_add_epi32 (current_indices, offset);
1125
+ }
1126
+
1127
+ // leftovers
1128
+ if (k16 != k) {
1129
+ const __mmask16 kmask = (1 << (k - k16)) - 1 ;
1130
+
1131
+ __m512i indices = _mm512_mask_loadu_epi32 (
1132
+ _mm512_set1_epi32 (-1 ), kmask, ids.data () + k16);
1133
+ __m512 distances = _mm512_maskz_loadu_ps (kmask, dis.data () + k16);
1134
+
1135
+ // This mask filters out -1 values among indices.
1136
+ __mmask16 m1mask =
1137
+ _mm512_cmpgt_epi32_mask (_mm512_setzero_si512 (), indices);
1138
+
1139
+ __mmask16 dmask =
1140
+ _mm512_cmp_ps_mask (min_distances, distances, _CMP_LT_OS);
1141
+ __mmask16 finalmask = m1mask | dmask;
1142
+
1143
+ const __m512i min_indices_new = _mm512_mask_blend_epi32 (
1144
+ finalmask, current_indices, min_indices);
1145
+ const __m512 min_distances_new =
1146
+ _mm512_mask_blend_ps (finalmask, distances, min_distances);
1147
+
1148
+ min_indices = min_indices_new;
1149
+ min_distances = min_distances_new;
1150
+ }
1151
+
1152
+ // grab min distance
1153
+ min_dis = _mm512_reduce_min_ps (min_distances);
1154
+ // blend
1155
+ __mmask16 mindmask =
1156
+ _mm512_cmpeq_ps_mask (min_distances, _mm512_set1_ps (min_dis));
1157
+ // pick the max one
1158
+ min_idx = _mm512_mask_reduce_max_epi32 (mindmask, min_indices);
1159
+
1160
+ if (min_idx == -1 ) {
1161
+ return -1 ;
1162
+ }
1163
+
1164
+ if (vmin_out) {
1165
+ *vmin_out = min_dis;
1166
+ }
1167
+ int ret = ids[min_idx];
1168
+ ids[min_idx] = -1 ;
1169
+ --nvalid;
1170
+ return ret;
1171
+ }
1172
+
1173
+ #elif __AVX2__
1174
+
1055
1175
int HNSW::MinimaxHeap::pop_min (float * vmin_out) {
1056
1176
assert (k > 0 );
1057
1177
static_assert (
0 commit comments