Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some small improvements. #3692

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions faiss/IndexHNSW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,23 +255,23 @@ void hnsw_search(
FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
efSearch = params->efSearch;
}
size_t n1 = 0, n2 = 0, ndis = 0;
size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;

idx_t check_period = InterruptCallback::get_period_hint(
hnsw.max_level * index->d * efSearch);

for (idx_t i0 = 0; i0 < n; i0 += check_period) {
idx_t i1 = std::min(i0 + check_period, n);

#pragma omp parallel
#pragma omp parallel if (i1 - i0 > 1)
{
VisitedTable vt(index->ntotal);
typename BlockResultHandler::SingleResultHandler res(bres);

std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(index->storage));

#pragma omp for reduction(+ : n1, n2, ndis) schedule(guided)
#pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
for (idx_t i = i0; i < i1; i++) {
res.begin(i);
dis->set_query(x + i * index->d);
Expand All @@ -280,13 +280,14 @@ void hnsw_search(
n1 += stats.n1;
n2 += stats.n2;
ndis += stats.ndis;
nhops += stats.nhops;
res.end();
}
}
InterruptCallback::check();
}

hnsw_stats.combine({n1, n2, ndis});
hnsw_stats.combine({n1, n2, ndis, nhops});
}

} // anonymous namespace
Expand Down Expand Up @@ -612,6 +613,10 @@ void IndexHNSW::permute_entries(const idx_t* perm) {
hnsw.permute_entries(perm);
}

DistanceComputer* IndexHNSW::get_distance_computer() const {
return storage->get_distance_computer();
}

/**************************************************************
* IndexHNSWFlat implementation
**************************************************************/
Expand All @@ -635,8 +640,13 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)

IndexHNSWPQ::IndexHNSWPQ() = default;

IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits)
: IndexHNSW(new IndexPQ(d, pq_m, pq_nbits), M) {
IndexHNSWPQ::IndexHNSWPQ(
int d,
int pq_m,
int M,
int pq_nbits,
MetricType metric)
: IndexHNSW(new IndexPQ(d, pq_m, pq_nbits, metric), M) {
own_fields = true;
is_trained = false;
}
Expand Down Expand Up @@ -762,7 +772,7 @@ void IndexHNSW2Level::search(
IndexHNSW::search(n, x, k, distances, labels);

} else { // "mixed" search
size_t n1 = 0, n2 = 0, ndis = 0;
size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;

const IndexIVFPQ* index_ivfpq =
dynamic_cast<const IndexIVFPQ*>(storage);
Expand Down Expand Up @@ -791,10 +801,10 @@ void IndexHNSW2Level::search(
std::unique_ptr<DistanceComputer> dis(
storage_distance_computer(storage));

int candidates_size = hnsw.upper_beam;
constexpr int candidates_size = 1;
MinimaxHeap candidates(candidates_size);

#pragma omp for reduction(+ : n1, n2, ndis)
#pragma omp for reduction(+ : n1, n2, ndis, nhops)
for (idx_t i = 0; i < n; i++) {
idx_t* idxi = labels + i * k;
float* simi = distances + i * k;
Expand All @@ -816,7 +826,7 @@ void IndexHNSW2Level::search(

candidates.clear();

for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
for (int j = 0; j < k; j++) {
if (idxi[j] < 0)
break;
candidates.push(idxi[j], simi[j]);
Expand All @@ -840,6 +850,7 @@ void IndexHNSW2Level::search(
n1 += search_stats.n1;
n2 += search_stats.n2;
ndis += search_stats.ndis;
nhops += search_stats.nhops;

vt.advance();
vt.advance();
Expand All @@ -848,7 +859,7 @@ void IndexHNSW2Level::search(
}
}

hnsw_stats.combine({n1, n2, ndis});
hnsw_stats.combine({n1, n2, ndis, nhops});
}
}

Expand Down
11 changes: 9 additions & 2 deletions faiss/IndexHNSW.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct IndexHNSW;
struct IndexHNSW : Index {
typedef HNSW::storage_idx_t storage_idx_t;

// the link strcuture
// the link structure
HNSW hnsw;

// the sequential storage
Expand Down Expand Up @@ -111,6 +111,8 @@ struct IndexHNSW : Index {
void link_singletons();

void permute_entries(const idx_t* perm);

DistanceComputer* get_distance_computer() const override;
};

/** Flat index topped with with a HNSW structure to access elements
Expand All @@ -127,7 +129,12 @@ struct IndexHNSWFlat : IndexHNSW {
*/
struct IndexHNSWPQ : IndexHNSW {
IndexHNSWPQ();
IndexHNSWPQ(int d, int pq_m, int M, int pq_nbits = 8);
IndexHNSWPQ(
int d,
int pq_m,
int M,
int pq_nbits = 8,
MetricType metric = METRIC_L2);
void train(idx_t n, const float* x) override;
};

Expand Down
10 changes: 5 additions & 5 deletions faiss/IndexRefine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ template <class C>
static void reorder_2_heaps(
idx_t n,
idx_t k,
idx_t* labels,
float* distances,
idx_t* __restrict labels,
float* __restrict distances,
idx_t k_base,
const idx_t* base_labels,
const float* base_distances) {
#pragma omp parallel for
const idx_t* __restrict base_labels,
const float* __restrict base_distances) {
#pragma omp parallel for if (n > 1)
for (idx_t i = 0; i < n; i++) {
idx_t* idxo = labels + i * k;
float* diso = distances + i * k;
Expand Down
Loading
Loading