diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index e18e203208..6a1186ca6a 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -255,7 +255,7 @@ void hnsw_search( FAISS_THROW_IF_NOT_MSG(params, "params type invalid"); efSearch = params->efSearch; } - size_t n1 = 0, n2 = 0, ndis = 0; + size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; idx_t check_period = InterruptCallback::get_period_hint( hnsw.max_level * index->d * efSearch); @@ -263,7 +263,7 @@ void hnsw_search( for (idx_t i0 = 0; i0 < n; i0 += check_period) { idx_t i1 = std::min(i0 + check_period, n); -#pragma omp parallel +#pragma omp parallel if (i1 - i0 > 1) { VisitedTable vt(index->ntotal); typename BlockResultHandler::SingleResultHandler res(bres); @@ -271,7 +271,7 @@ void hnsw_search( std::unique_ptr dis( storage_distance_computer(index->storage)); -#pragma omp for reduction(+ : n1, n2, ndis) schedule(guided) +#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided) for (idx_t i = i0; i < i1; i++) { res.begin(i); dis->set_query(x + i * index->d); @@ -280,13 +280,14 @@ void hnsw_search( n1 += stats.n1; n2 += stats.n2; ndis += stats.ndis; + nhops += stats.nhops; res.end(); } } InterruptCallback::check(); } - hnsw_stats.combine({n1, n2, ndis}); + hnsw_stats.combine({n1, n2, ndis, nhops}); } } // anonymous namespace @@ -612,6 +613,10 @@ void IndexHNSW::permute_entries(const idx_t* perm) { hnsw.permute_entries(perm); } +DistanceComputer* IndexHNSW::get_distance_computer() const { + return storage->get_distance_computer(); +} + /************************************************************** * IndexHNSWFlat implementation **************************************************************/ @@ -635,8 +640,13 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric) IndexHNSWPQ::IndexHNSWPQ() = default; -IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits) - : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) { +IndexHNSWPQ::IndexHNSWPQ( + int d, + int pq_m, + int M, + int pq_nbits, + MetricType metric) + : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits, metric), M) { own_fields = true; is_trained = false; } @@ -762,7 +772,7 @@ void IndexHNSW2Level::search( IndexHNSW::search(n, x, k, distances, labels); } else { // "mixed" search - size_t n1 = 0, n2 = 0, ndis = 0; + size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0; const IndexIVFPQ* index_ivfpq = dynamic_cast(storage); @@ -791,10 +801,10 @@ void IndexHNSW2Level::search( std::unique_ptr dis( storage_distance_computer(storage)); - int candidates_size = hnsw.upper_beam; + constexpr int candidates_size = 1; MinimaxHeap candidates(candidates_size); -#pragma omp for reduction(+ : n1, n2, ndis) +#pragma omp for reduction(+ : n1, n2, ndis, nhops) for (idx_t i = 0; i < n; i++) { idx_t* idxi = labels + i * k; float* simi = distances + i * k; @@ -816,7 +826,7 @@ void IndexHNSW2Level::search( candidates.clear(); - for (int j = 0; j < hnsw.upper_beam && j < k; j++) { + for (int j = 0; j < k; j++) { if (idxi[j] < 0) break; candidates.push(idxi[j], simi[j]); @@ -840,6 +850,7 @@ void IndexHNSW2Level::search( n1 += search_stats.n1; n2 += search_stats.n2; ndis += search_stats.ndis; + nhops += search_stats.nhops; vt.advance(); vt.advance(); @@ -848,7 +859,7 @@ void IndexHNSW2Level::search( } } - hnsw_stats.combine({n1, n2, ndis}); + hnsw_stats.combine({n1, n2, ndis, nhops}); } } diff --git a/faiss/IndexHNSW.h b/faiss/IndexHNSW.h index 71807c6537..0768eb88b9 100644 --- a/faiss/IndexHNSW.h +++ b/faiss/IndexHNSW.h @@ -27,7 +27,7 @@ struct IndexHNSW; struct IndexHNSW : Index { typedef HNSW::storage_idx_t storage_idx_t; - // the link strcuture + // the link structure HNSW hnsw; // the sequential storage @@ -111,6 +111,8 @@ struct IndexHNSW : Index { void link_singletons(); void permute_entries(const idx_t* perm); + + DistanceComputer* get_distance_computer() const override; }; /** Flat index topped with with a HNSW structure to access elements @@ -127,7 +129,12 @@ struct IndexHNSWFlat : IndexHNSW { */ struct IndexHNSWPQ : IndexHNSW { IndexHNSWPQ(); - IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8); + IndexHNSWPQ( + int d, + int pq_m, + int M, + int pq_nbits = 8, + MetricType metric = METRIC_L2); void train(idx_t n, const float* x) override; }; diff --git a/faiss/IndexRefine.cpp b/faiss/IndexRefine.cpp index 8fb0ea80bb..4f1d34d5bf 100644 --- a/faiss/IndexRefine.cpp +++ b/faiss/IndexRefine.cpp @@ -68,12 +68,12 @@ template static void reorder_2_heaps( idx_t n, idx_t k, - idx_t* labels, - float* distances, + idx_t* __restrict labels, + float* __restrict distances, idx_t k_base, - const idx_t* base_labels, - const float* base_distances) { -#pragma omp parallel for + const idx_t* __restrict base_labels, + const float* __restrict base_distances) { +#pragma omp parallel for if (n > 1) for (idx_t i = 0; i < n; i++) { idx_t* idxo = labels + i * k; float* diso = distances + i * k; diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index 277e194b3e..3eb2f5a76b 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -409,18 +409,22 @@ void search_neighbors_to_add( **************************************************************/ /// greedily update a nearest vector at a given level -void greedy_update_nearest( +HNSWStats greedy_update_nearest( const HNSW& hnsw, DistanceComputer& qdis, int level, storage_idx_t& nearest, float& d_nearest) { + HNSWStats stats; + for (;;) { storage_idx_t prev_nearest = nearest; size_t begin, end; hnsw.neighbor_range(nearest, level, &begin, &end); - for (size_t i = begin; i < end; i++) { + + size_t ndis = 0; + for (size_t i = begin; i < end; i++, ndis++) { storage_idx_t v = hnsw.neighbors[i]; if (v < 0) break; @@ -430,8 +434,13 @@ void greedy_update_nearest( d_nearest = dis; } } + + // update stats + stats.ndis += ndis; + stats.nhops += 1; + if (nearest == prev_nearest) { - return; + return stats; } } } @@ -641,6 +650,7 @@ int search_from_candidates( if (dis < threshold) { if (res.add_result(dis, idx)) { threshold = res.threshold; + nres += 1; } } } @@ -692,6 +702,7 @@ int search_from_candidates( stats.n2++; } stats.ndis += ndis; + stats.nhops += nstep; } return nres; @@ -814,6 +825,8 @@ std::priority_queue search_from_candidate_unbounded( float dis = qdis(saved_j[icnt]); add_to_heap(saved_j[icnt], dis); } + + stats.nhops += 1; } ++stats.n1; @@ -850,85 +863,44 @@ HNSWStats HNSW::search( bool bounded_queue = params ? params->bounded_queue : this->search_bounded_queue; - if (upper_beam == 1) { - // greedy search on upper levels - storage_idx_t nearest = entry_point; - float d_nearest = qdis(nearest); - - for (int level = max_level; level >= 1; level--) { - greedy_update_nearest(*this, qdis, level, nearest, d_nearest); - } + // greedy search on upper levels + storage_idx_t nearest = entry_point; + float d_nearest = qdis(nearest); - int ef = std::max(params ? params->efSearch : efSearch, k); - if (bounded_queue) { // this is the most common branch - MinimaxHeap candidates(ef); - - candidates.push(nearest, d_nearest); - - search_from_candidates( - *this, qdis, res, candidates, vt, stats, 0, 0, params); - } else { - std::priority_queue top_candidates = - search_from_candidate_unbounded( - *this, - Node(d_nearest, nearest), - qdis, - ef, - &vt, - stats); - - while (top_candidates.size() > k) { - top_candidates.pop(); - } + for (int level = max_level; level >= 1; level--) { + HNSWStats local_stats = + greedy_update_nearest(*this, qdis, level, nearest, d_nearest); + stats.combine(local_stats); + } - while (!top_candidates.empty()) { - float d; - storage_idx_t label; - std::tie(d, label) = top_candidates.top(); - res.add_result(d, label); - top_candidates.pop(); - } - } + int ef = std::max(params ? params->efSearch : efSearch, k); + if (bounded_queue) { // this is the most common branch + MinimaxHeap candidates(ef); - vt.advance(); + candidates.push(nearest, d_nearest); + search_from_candidates( + *this, qdis, res, candidates, vt, stats, 0, 0, params); } else { - int candidates_size = upper_beam; - MinimaxHeap candidates(candidates_size); - - std::vector I_to_next(candidates_size); - std::vector D_to_next(candidates_size); - - HeapBlockResultHandler block_resh( - 1, D_to_next.data(), I_to_next.data(), candidates_size); - HeapBlockResultHandler::SingleResultHandler resh(block_resh); + std::priority_queue top_candidates = + search_from_candidate_unbounded( + *this, Node(d_nearest, nearest), qdis, ef, &vt, stats); - int nres = 1; - I_to_next[0] = entry_point; - D_to_next[0] = qdis(entry_point); - - for (int level = max_level; level >= 0; level--) { - // copy I, D -> candidates - - candidates.clear(); - - for (int i = 0; i < nres; i++) { - candidates.push(I_to_next[i], D_to_next[i]); - } + while (top_candidates.size() > k) { + top_candidates.pop(); + } - if (level == 0) { - nres = search_from_candidates( - *this, qdis, res, candidates, vt, stats, 0); - } else { - resh.begin(0); - nres = search_from_candidates( - *this, qdis, resh, candidates, vt, stats, level); - resh.end(); - } - vt.advance(); + while (!top_candidates.empty()) { + float d; + storage_idx_t label; + std::tie(d, label) = top_candidates.top(); + res.add_result(d, label); + top_candidates.pop(); } } + vt.advance(); + return stats; } @@ -973,6 +945,7 @@ void HNSW::search_level_0( 0, nres, params); + nres = std::min(nres, candidates_size); } } else if (search_type == 2) { int candidates_size = std::max(efSearch, int(k)); @@ -1054,7 +1027,99 @@ void HNSW::MinimaxHeap::clear() { nvalid = k = 0; } -#ifdef __AVX2__ +#ifdef __AVX512F__ + +int HNSW::MinimaxHeap::pop_min(float* vmin_out) { + assert(k > 0); + static_assert( + std::is_same::value, + "This code expects storage_idx_t to be int32_t"); + + int32_t min_idx = -1; + float min_dis = std::numeric_limits::infinity(); + + __m512i min_indices = _mm512_set1_epi32(-1); + __m512 min_distances = + _mm512_set1_ps(std::numeric_limits::infinity()); + __m512i current_indices = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + __m512i offset = _mm512_set1_epi32(16); + + // The following loop tracks the rightmost index with the min distance. + // -1 index values are ignored. + const int k16 = (k / 16) * 16; + for (size_t iii = 0; iii < k16; iii += 16) { + __m512i indices = + _mm512_loadu_si512((const __m512i*)(ids.data() + iii)); + __m512 distances = _mm512_loadu_ps(dis.data() + iii); + + // This mask filters out -1 values among indices. + __mmask16 m1mask = + _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices); + + __mmask16 dmask = + _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); + __mmask16 finalmask = m1mask | dmask; + + const __m512i min_indices_new = _mm512_mask_blend_epi32( + finalmask, current_indices, min_indices); + const __m512 min_distances_new = + _mm512_mask_blend_ps(finalmask, distances, min_distances); + + min_indices = min_indices_new; + min_distances = min_distances_new; + + current_indices = _mm512_add_epi32(current_indices, offset); + } + + // leftovers + if (k16 != k) { + const __mmask16 kmask = (1 << (k - k16)) - 1; + + __m512i indices = _mm512_mask_loadu_epi32( + _mm512_set1_epi32(-1), kmask, ids.data() + k16); + __m512 distances = _mm512_maskz_loadu_ps(kmask, dis.data() + k16); + + // This mask filters out -1 values among indices. + __mmask16 m1mask = + _mm512_cmpgt_epi32_mask(_mm512_setzero_si512(), indices); + + __mmask16 dmask = + _mm512_cmp_ps_mask(min_distances, distances, _CMP_LT_OS); + __mmask16 finalmask = m1mask | dmask; + + const __m512i min_indices_new = _mm512_mask_blend_epi32( + finalmask, current_indices, min_indices); + const __m512 min_distances_new = + _mm512_mask_blend_ps(finalmask, distances, min_distances); + + min_indices = min_indices_new; + min_distances = min_distances_new; + } + + // grab min distance + min_dis = _mm512_reduce_min_ps(min_distances); + // blend + __mmask16 mindmask = + _mm512_cmpeq_ps_mask(min_distances, _mm512_set1_ps(min_dis)); + // pick the max one + min_idx = _mm512_mask_reduce_max_epi32(mindmask, min_indices); + + if (min_idx == -1) { + return -1; + } + + if (vmin_out) { + *vmin_out = min_dis; + } + int ret = ids[min_idx]; + ids[min_idx] = -1; + --nvalid; + return ret; +} + +#elif __AVX2__ + int HNSW::MinimaxHeap::pop_min(float* vmin_out) { assert(k > 0); static_assert( diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index 5916922360..d2c974f384 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -142,9 +142,6 @@ struct HNSW { /// enough? bool check_relative_distance = true; - /// number of entry points in levels > 0. - int upper_beam = 1; - /// use bounded queue during exploration bool search_bounded_queue = true; @@ -235,20 +232,23 @@ struct HNSW { }; struct HNSWStats { - size_t n1 = 0; /// numbner of vectors searched + size_t n1 = 0; /// number of vectors searched size_t n2 = - 0; /// number of queries for which the candidate list is exhasted - size_t ndis = 0; /// number of distances computed + 0; /// number of queries for which the candidate list is exhausted + size_t ndis = 0; /// number of distances computed + size_t nhops = 0; /// number of hops aka number of edges traversed void reset() { n1 = n2 = 0; ndis = 0; + nhops = 0; } void combine(const HNSWStats& other) { n1 += other.n1; n2 += other.n2; ndis += other.ndis; + nhops += other.nhops; } }; diff --git a/faiss/impl/code_distance/code_distance-avx512.h b/faiss/impl/code_distance/code_distance-avx512.h new file mode 100644 index 0000000000..6c6afc7e0a --- /dev/null +++ b/faiss/impl/code_distance/code_distance-avx512.h @@ -0,0 +1,248 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#ifdef __AVX512F__ + +#include + +#include + +#include +#include + +namespace faiss { + +// According to experiments, the AVX-512 version may be SLOWER than +// the AVX2 version, which is somewhat unexpected. +// This version is not used for now, but it may be used later. +// +// TODO: test for AMD CPUs. + +template +typename std::enable_if::value, float>:: + type inline distance_single_code_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code) { + // default implementation + return distance_single_code_generic(M, nbits, sim_table, code); +} + +template +typename std::enable_if::value, float>:: + type inline distance_single_code_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + const uint8_t* code0) { + float result0 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 1; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + tab += ksub; + } + } + + return result0; +} + +template +typename std::enable_if::value, void>:: + type + distance_four_codes_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + distance_four_codes_generic( + M, + nbits, + sim_table, + code0, + code1, + code2, + code3, + result0, + result1, + result2, + result3); +} + +// Combines 4 operations of distance_single_code() +template +typename std::enable_if::value, void>::type +distance_four_codes_avx512( + // number of subquantizers + const size_t M, + // number of bits per quantization index + const size_t nbits, + // precomputed distances, layout (M, ksub) + const float* sim_table, + // codes + const uint8_t* __restrict code0, + const uint8_t* __restrict code1, + const uint8_t* __restrict code2, + const uint8_t* __restrict code3, + // computed distances + float& result0, + float& result1, + float& result2, + float& result3) { + result0 = 0; + result1 = 0; + result2 = 0; + result3 = 0; + constexpr size_t ksub = 1 << 8; + + size_t m = 0; + const size_t pqM16 = M / 16; + + constexpr intptr_t N = 4; + + const float* tab = sim_table; + + if (pqM16 > 0) { + // process 16 values per loop + const __m512i vksub = _mm512_set1_epi32(ksub); + __m512i offsets_0 = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + offsets_0 = _mm512_mullo_epi32(offsets_0, vksub); + + // accumulators of partial sums + __m512 partialSums[N]; + for (intptr_t j = 0; j < N; j++) { + partialSums[j] = _mm512_setzero_ps(); + } + + // loop + for (m = 0; m < pqM16 * 16; m += 16) { + // load 16 uint8 values + __m128i mm1[N]; + mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); + mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); + mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); + mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); + + // process first 8 codes + for (intptr_t j = 0; j < N; j++) { + const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); + + // add offsets + const __m512i indices_to_read_from = + _mm512_add_epi32(idx1, offsets_0); + + // gather 16 values, similar to 16 operations of tab[idx] + __m512 collected = _mm512_i32gather_ps( + indices_to_read_from, tab, sizeof(float)); + + // collect partial sums + partialSums[j] = _mm512_add_ps(partialSums[j], collected); + } + tab += ksub * 16; + } + + // horizontal sum for partialSum + result0 += _mm512_reduce_add_ps(partialSums[0]); + result1 += _mm512_reduce_add_ps(partialSums[1]); + result2 += _mm512_reduce_add_ps(partialSums[2]); + result3 += _mm512_reduce_add_ps(partialSums[3]); + } + + // + if (m < M) { + // process leftovers + PQDecoder8 decoder0(code0 + m, nbits); + PQDecoder8 decoder1(code1 + m, nbits); + PQDecoder8 decoder2(code2 + m, nbits); + PQDecoder8 decoder3(code3 + m, nbits); + for (; m < M; m++) { + result0 += tab[decoder0.decode()]; + result1 += tab[decoder1.decode()]; + result2 += tab[decoder2.decode()]; + result3 += tab[decoder3.decode()]; + tab += ksub; + } + } +} + +} // namespace faiss + +#endif diff --git a/faiss/impl/code_distance/code_distance_avx512.h b/faiss/impl/code_distance/code_distance_avx512.h deleted file mode 100644 index 296e0df1b6..0000000000 --- a/faiss/impl/code_distance/code_distance_avx512.h +++ /dev/null @@ -1,102 +0,0 @@ -/** - * Copyright (c) Facebook, Inc. and its affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -// // // AVX-512 version. It is not used, but let it be for the future -// // // needs. -// // template -// // typename std::enable_if<(std::is_same::value), void>:: -// // type distance_four_codes( -// // const uint8_t* __restrict code0, -// // const uint8_t* __restrict code1, -// // const uint8_t* __restrict code2, -// // const uint8_t* __restrict code3, -// // float& result0, -// // float& result1, -// // float& result2, -// // float& result3 -// // ) const { -// // result0 = 0; -// // result1 = 0; -// // result2 = 0; -// // result3 = 0; - -// // size_t m = 0; -// // const size_t pqM16 = pq.M / 16; - -// // constexpr intptr_t N = 4; - -// // const float* tab = sim_table; - -// // if (pqM16 > 0) { -// // // process 16 values per loop -// // const __m512i ksub = _mm512_set1_epi32(pq.ksub); -// // __m512i offsets_0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, -// // 8, 9, 10, 11, 12, 13, 14, 15); -// // offsets_0 = _mm512_mullo_epi32(offsets_0, ksub); - -// // // accumulators of partial sums -// // __m512 partialSums[N]; -// // for (intptr_t j = 0; j < N; j++) { -// // partialSums[j] = _mm512_setzero_ps(); -// // } - -// // // loop -// // for (m = 0; m < pqM16 * 16; m += 16) { -// // // load 16 uint8 values -// // __m128i mm1[N]; -// // mm1[0] = _mm_loadu_si128((const __m128i_u*)(code0 + m)); -// // mm1[1] = _mm_loadu_si128((const __m128i_u*)(code1 + m)); -// // mm1[2] = _mm_loadu_si128((const __m128i_u*)(code2 + m)); -// // mm1[3] = _mm_loadu_si128((const __m128i_u*)(code3 + m)); - -// // // process first 8 codes -// // for (intptr_t j = 0; j < N; j++) { -// // // convert uint8 values (low part of __m128i) to int32 -// // // values -// // const __m512i idx1 = _mm512_cvtepu8_epi32(mm1[j]); - -// // // add offsets -// // const __m512i indices_to_read_from = -// // _mm512_add_epi32(idx1, offsets_0); - -// // // gather 8 values, similar to 8 operations of -// // // tab[idx] -// // __m512 collected = -// // _mm512_i32gather_ps( -// // indices_to_read_from, tab, sizeof(float)); - -// // // collect partial sums -// // partialSums[j] = _mm512_add_ps(partialSums[j], -// // collected); -// // } -// // tab += pq.ksub * 16; - -// // } - -// // // horizontal sum for partialSum -// // result0 += _mm512_reduce_add_ps(partialSums[0]); -// // result1 += _mm512_reduce_add_ps(partialSums[1]); -// // result2 += _mm512_reduce_add_ps(partialSums[2]); -// // result3 += _mm512_reduce_add_ps(partialSums[3]); -// // } - -// // // -// // if (m < pq.M) { -// // // process leftovers -// // PQDecoder decoder0(code0 + m, pq.nbits); -// // PQDecoder decoder1(code1 + m, pq.nbits); -// // PQDecoder decoder2(code2 + m, pq.nbits); -// // PQDecoder decoder3(code3 + m, pq.nbits); -// // for (; m < pq.M; m++) { -// // result0 += tab[decoder0.decode()]; -// // result1 += tab[decoder1.decode()]; -// // result2 += tab[decoder2.decode()]; -// // result3 += tab[decoder3.decode()]; -// // tab += pq.ksub; -// // } -// // } -// // } diff --git a/faiss/impl/index_read.cpp b/faiss/impl/index_read.cpp index aa041c0fac..f0aff59481 100644 --- a/faiss/impl/index_read.cpp +++ b/faiss/impl/index_read.cpp @@ -373,7 +373,10 @@ static void read_HNSW(HNSW* hnsw, IOReader* f) { READ1(hnsw->max_level); READ1(hnsw->efConstruction); READ1(hnsw->efSearch); - READ1(hnsw->upper_beam); + + // // deprecated field + // READ1(hnsw->upper_beam); + READ1_DUMMY(int) } static void read_NSG(NSG* nsg, IOReader* f) { diff --git a/faiss/impl/index_write.cpp b/faiss/impl/index_write.cpp index 0a924d0225..6e787aed44 100644 --- a/faiss/impl/index_write.cpp +++ b/faiss/impl/index_write.cpp @@ -312,7 +312,11 @@ static void write_HNSW(const HNSW* hnsw, IOWriter* f) { WRITE1(hnsw->max_level); WRITE1(hnsw->efConstruction); WRITE1(hnsw->efSearch); - WRITE1(hnsw->upper_beam); + + // // deprecated field + // WRITE1(hnsw->upper_beam); + constexpr int tmp_upper_beam = 1; + WRITE1(tmp_upper_beam); } static void write_NSG(const NSG* nsg, IOWriter* f) { diff --git a/faiss/impl/io_macros.h b/faiss/impl/io_macros.h index 4b974b7e2e..4cdcade554 100644 --- a/faiss/impl/io_macros.h +++ b/faiss/impl/io_macros.h @@ -29,6 +29,12 @@ #define READ1(x) READANDCHECK(&(x), 1) +#define READ1_DUMMY(x_type) \ + { \ + x_type x = {}; \ + READ1(x); \ + } + // will fail if we write 256G of data at once... #define READVECTOR(vec) \ { \ diff --git a/faiss/utils/Heap.h b/faiss/utils/Heap.h index cdb714f4d6..b67707ecb1 100644 --- a/faiss/utils/Heap.h +++ b/faiss/utils/Heap.h @@ -30,6 +30,7 @@ #include #include +#include #include @@ -200,6 +201,110 @@ inline void maxheap_replace_top( heap_replace_top>(k, bh_val, bh_ids, val, ids); } +/******************************************************************* + * Basic heap> ops: push and pop + *******************************************************************/ + +// This section contains a heap implementation that works with +// std::pair elements. + +/** Pops the top element from the heap defined by bh_val[0..k-1] and + * bh_ids[0..k-1]. on output the element at k-1 is undefined. + */ +template +inline void heap_pop(size_t k, std::pair* bh) { + bh--; /* Use 1-based indexing for easier node->child translation */ + typename C::T val = bh[k].first; + typename C::TI id = bh[k].second; + size_t i = 1, i1, i2; + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > k) + break; + if ((i2 == k + 1) || + C::cmp2(bh[i1].first, bh[i2].first, bh[i1].second, bh[i2].second)) { + if (C::cmp2(val, bh[i1].first, id, bh[i1].second)) { + break; + } + bh[i] = bh[i1]; + i = i1; + } else { + if (C::cmp2(val, bh[i2].first, id, bh[i2].second)) { + break; + } + bh[i] = bh[i2]; + i = i2; + } + } + bh[i] = bh[k]; +} + +/** Pushes the element (val, ids) into the heap bh_val[0..k-2] and + * bh_ids[0..k-2]. on output the element at k-1 is defined. + */ +template +inline void heap_push( + size_t k, + std::pair* bh, + typename C::T val, + typename C::TI id) { + bh--; /* Use 1-based indexing for easier node->child translation */ + size_t i = k, i_father; + while (i > 1) { + i_father = i >> 1; + auto bh_v = bh[i_father]; + if (!C::cmp2(val, bh_v.first, id, bh_v.second)) { + /* the heap structure is ok */ + break; + } + bh[i] = bh_v; + i = i_father; + } + bh[i] = std::make_pair(val, id); +} + +/** + * Replaces the top element from the heap defined by bh_val[0..k-1] and + * bh_ids[0..k-1], and for identical bh_val[] values also sorts by bh_ids[] + * values. + */ +template +inline void heap_replace_top( + size_t k, + std::pair* bh, + typename C::T val, + typename C::TI id) { + bh--; /* Use 1-based indexing for easier node->child translation */ + size_t i = 1, i1, i2; + while (1) { + i1 = i << 1; + i2 = i1 + 1; + if (i1 > k) { + break; + } + + // Note that C::cmp2() is a bool function answering + // `(a1 > b1) || ((a1 == b1) && (a2 > b2))` for max + // heap and same with the `<` sign for min heap. + if ((i2 == k + 1) || + C::cmp2(bh[i1].first, bh[i2].first, bh[i1].second, bh[i2].second)) { + if (C::cmp2(val, bh[i1].first, id, bh[i1].second)) { + break; + } + bh[i] = bh[i1]; + i = i1; + } else { + if (C::cmp2(val, bh[i2].first, id, bh[i2].second)) { + break; + } + bh[i] = bh[i2]; + i = i2; + } + } + bh[i] = std::make_pair(val, id); +} + /******************************************************************* * Heap initialization *******************************************************************/