Skip to content

Commit d643c41

Browse files
benfredfacebook-github-bot
authored andcommitted
use precomputed norms for raft brute_force knn calls (facebookresearch#3089)
Summary: Pull Request resolved: facebookresearch#3089 Reviewed By: algoriddle Differential Revision: D50933982 Pulled By: mdouze fbshipit-source-id: dd0d00cf71ac490f75b8c2f152e7ae4cc28019ef
1 parent b109d08 commit d643c41

File tree

2 files changed

+52
-68
lines changed

2 files changed

+52
-68
lines changed

faiss/gpu/GpuDistance.cu

+39-44
Original file line numberDiff line numberDiff line change
@@ -236,89 +236,84 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
236236
raft::device_resources& handle = res->getRaftHandleCurrentDevice();
237237
auto stream = res->getDefaultStreamCurrentDevice();
238238

239-
idx_t dims = args.dims;
240-
idx_t num_vectors = args.numVectors;
241-
idx_t num_queries = args.numQueries;
239+
int64_t dims = args.dims;
240+
int64_t num_vectors = args.numVectors;
241+
int64_t num_queries = args.numQueries;
242242
int k = args.k;
243243
float metric_arg = args.metricArg;
244244

245-
auto inds = raft::make_writeback_temporary_device_buffer<idx_t, idx_t>(
246-
handle,
247-
reinterpret_cast<idx_t*>(args.outIndices),
248-
raft::matrix_extent<idx_t>(num_queries, (idx_t)k));
249-
auto dists = raft::make_writeback_temporary_device_buffer<float, idx_t>(
250-
handle,
251-
reinterpret_cast<float*>(args.outDistances),
252-
raft::matrix_extent<idx_t>(num_queries, (idx_t)k));
245+
auto inds =
246+
raft::make_writeback_temporary_device_buffer<idx_t, int64_t>(
247+
handle,
248+
reinterpret_cast<idx_t*>(args.outIndices),
249+
raft::matrix_extent<int64_t>(num_queries, (int64_t)k));
250+
auto dists =
251+
raft::make_writeback_temporary_device_buffer<float, int64_t>(
252+
handle,
253+
reinterpret_cast<float*>(args.outDistances),
254+
raft::matrix_extent<int64_t>(num_queries, (int64_t)k));
253255

254256
if (args.queriesRowMajor) {
255257
auto index = raft::make_readonly_temporary_device_buffer<
256258
const float,
257-
idx_t,
259+
int64_t,
258260
raft::row_major>(
259261
handle,
260262
const_cast<float*>(
261263
reinterpret_cast<const float*>(args.vectors)),
262-
raft::matrix_extent<idx_t>(num_vectors, dims));
264+
raft::matrix_extent<int64_t>(num_vectors, dims));
263265

264266
auto search = raft::make_readonly_temporary_device_buffer<
265267
const float,
266-
idx_t,
268+
int64_t,
267269
raft::row_major>(
268270
handle,
269271
const_cast<float*>(
270272
reinterpret_cast<const float*>(args.queries)),
271-
raft::matrix_extent<idx_t>(num_queries, dims));
273+
raft::matrix_extent<int64_t>(num_queries, dims));
272274

273-
// For now, use RAFT's fused KNN when k <= 64 and L2 metric is used
274-
if (args.k <= 64 && args.metric == MetricType::METRIC_L2 &&
275-
args.numVectors > 0) {
276-
RAFT_LOG_INFO("Invoking flat fused_l2_knn");
277-
brute_force::fused_l2_knn(
278-
handle,
279-
index.view(),
280-
search.view(),
281-
inds.view(),
282-
dists.view(),
283-
distance);
284-
} else {
285-
std::vector<raft::device_matrix_view<
275+
// get device_vector_view to the precalculate norms if available
276+
std::optional<raft::temporary_device_buffer<
277+
const float,
278+
raft::vector_extent<int64_t>>>
279+
norms;
280+
std::optional<raft::device_vector_view<const float, int64_t>>
281+
norms_view;
282+
if (args.vectorNorms) {
283+
norms = raft::make_readonly_temporary_device_buffer<
286284
const float,
287-
idx_t,
288-
raft::row_major>>
289-
index_vec = {index.view()};
290-
RAFT_LOG_INFO("Invoking flat bfknn");
291-
brute_force::knn(
285+
int64_t>(
292286
handle,
293-
index_vec,
294-
search.view(),
295-
inds.view(),
296-
dists.view(),
297-
distance,
298-
metric_arg);
287+
args.vectorNorms,
288+
raft::vector_extent<int64_t>(num_queries));
289+
norms_view = norms->view();
299290
}
291+
raft::neighbors::brute_force::index idx(
292+
handle, index.view(), norms_view, distance, metric_arg);
293+
raft::neighbors::brute_force::search<float, idx_t>(
294+
handle, idx, search.view(), inds.view(), dists.view());
300295
} else {
301296
auto index = raft::make_readonly_temporary_device_buffer<
302297
const float,
303-
idx_t,
298+
int64_t,
304299
raft::col_major>(
305300
handle,
306301
const_cast<float*>(
307302
reinterpret_cast<const float*>(args.vectors)),
308-
raft::matrix_extent<idx_t>(num_vectors, dims));
303+
raft::matrix_extent<int64_t>(num_vectors, dims));
309304

310305
auto search = raft::make_readonly_temporary_device_buffer<
311306
const float,
312-
idx_t,
307+
int64_t,
313308
raft::col_major>(
314309
handle,
315310
const_cast<float*>(
316311
reinterpret_cast<const float*>(args.queries)),
317-
raft::matrix_extent<idx_t>(num_queries, dims));
312+
raft::matrix_extent<int64_t>(num_queries, dims));
318313

319314
std::vector<raft::device_matrix_view<
320315
const float,
321-
idx_t,
316+
int64_t,
322317
raft::col_major>>
323318
index_vec = {index.view()};
324319
RAFT_LOG_INFO("Invoking flat bfknn");

faiss/gpu/impl/RaftFlatIndex.cu

+13-24
Original file line numberDiff line numberDiff line change
@@ -77,41 +77,30 @@ void RaftFlatIndex::query(
7777
raft::device_resources& handle =
7878
resources_->getRaftHandleCurrentDevice();
7979

80-
auto index = raft::make_device_matrix_view<const float, idx_t>(
80+
auto index = raft::make_device_matrix_view<const float, int64_t>(
8181
vectors_.data(), vectors_.getSize(0), vectors_.getSize(1));
82-
auto search = raft::make_device_matrix_view<const float, idx_t>(
82+
auto search = raft::make_device_matrix_view<const float, int64_t>(
8383
input.data(), input.getSize(0), input.getSize(1));
84-
auto inds = raft::make_device_matrix_view<idx_t, idx_t>(
84+
85+
auto inds = raft::make_device_matrix_view<idx_t, int64_t>(
8586
outIndices.data(),
8687
outIndices.getSize(0),
8788
outIndices.getSize(1));
88-
auto dists = raft::make_device_matrix_view<float, idx_t>(
89+
auto dists = raft::make_device_matrix_view<float, int64_t>(
8990
outDistances.data(),
9091
outDistances.getSize(0),
9192
outDistances.getSize(1));
9293

9394
DistanceType distance = faiss_to_raft(metric, exactDistance);
9495

95-
std::vector<raft::device_matrix_view<const float, idx_t>> index_vec = {
96-
index};
97-
98-
// For now, use RAFT's fused KNN when k <= 64 and L2 metric is used
99-
if (k <= 64 && metric == MetricType::METRIC_L2 &&
100-
vectors_.getSize(0) > 0) {
101-
RAFT_LOG_INFO("Invoking flat fused_l2_knn");
102-
brute_force::fused_l2_knn(
103-
handle, index, search, inds, dists, distance);
104-
} else {
105-
RAFT_LOG_INFO("Invoking flat bfknn");
106-
brute_force::knn(
107-
handle,
108-
index_vec,
109-
search,
110-
inds,
111-
dists,
112-
distance,
113-
metricArg);
114-
}
96+
std::optional<raft::device_vector_view<const float, int64_t>>
97+
norms_view = raft::make_device_vector_view(
98+
norms_.data(), norms_.getSize(0));
99+
100+
raft::neighbors::brute_force::index idx(
101+
handle, index, norms_view, distance, metricArg);
102+
raft::neighbors::brute_force::search<float, int64_t>(
103+
handle, idx, search, inds, dists);
115104

116105
if (metric == MetricType::METRIC_Lp) {
117106
raft::linalg::unary_op(

0 commit comments

Comments
 (0)