@@ -236,89 +236,84 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
236
236
raft::device_resources& handle = res->getRaftHandleCurrentDevice ();
237
237
auto stream = res->getDefaultStreamCurrentDevice ();
238
238
239
- idx_t dims = args.dims ;
240
- idx_t num_vectors = args.numVectors ;
241
- idx_t num_queries = args.numQueries ;
239
+ int64_t dims = args.dims ;
240
+ int64_t num_vectors = args.numVectors ;
241
+ int64_t num_queries = args.numQueries ;
242
242
int k = args.k ;
243
243
float metric_arg = args.metricArg ;
244
244
245
- auto inds = raft::make_writeback_temporary_device_buffer<idx_t , idx_t >(
246
- handle,
247
- reinterpret_cast <idx_t *>(args.outIndices ),
248
- raft::matrix_extent<idx_t >(num_queries, (idx_t )k));
249
- auto dists = raft::make_writeback_temporary_device_buffer<float , idx_t >(
250
- handle,
251
- reinterpret_cast <float *>(args.outDistances ),
252
- raft::matrix_extent<idx_t >(num_queries, (idx_t )k));
245
+ auto inds =
246
+ raft::make_writeback_temporary_device_buffer<idx_t , int64_t >(
247
+ handle,
248
+ reinterpret_cast <idx_t *>(args.outIndices ),
249
+ raft::matrix_extent<int64_t >(num_queries, (int64_t )k));
250
+ auto dists =
251
+ raft::make_writeback_temporary_device_buffer<float , int64_t >(
252
+ handle,
253
+ reinterpret_cast <float *>(args.outDistances ),
254
+ raft::matrix_extent<int64_t >(num_queries, (int64_t )k));
253
255
254
256
if (args.queriesRowMajor ) {
255
257
auto index = raft::make_readonly_temporary_device_buffer<
256
258
const float ,
257
- idx_t ,
259
+ int64_t ,
258
260
raft::row_major>(
259
261
handle,
260
262
const_cast <float *>(
261
263
reinterpret_cast <const float *>(args.vectors )),
262
- raft::matrix_extent<idx_t >(num_vectors, dims));
264
+ raft::matrix_extent<int64_t >(num_vectors, dims));
263
265
264
266
auto search = raft::make_readonly_temporary_device_buffer<
265
267
const float ,
266
- idx_t ,
268
+ int64_t ,
267
269
raft::row_major>(
268
270
handle,
269
271
const_cast <float *>(
270
272
reinterpret_cast <const float *>(args.queries )),
271
- raft::matrix_extent<idx_t >(num_queries, dims));
273
+ raft::matrix_extent<int64_t >(num_queries, dims));
272
274
273
- // For now, use RAFT's fused KNN when k <= 64 and L2 metric is used
274
- if (args.k <= 64 && args.metric == MetricType::METRIC_L2 &&
275
- args.numVectors > 0 ) {
276
- RAFT_LOG_INFO (" Invoking flat fused_l2_knn" );
277
- brute_force::fused_l2_knn (
278
- handle,
279
- index .view (),
280
- search.view (),
281
- inds.view (),
282
- dists.view (),
283
- distance);
284
- } else {
285
- std::vector<raft::device_matrix_view<
275
+ // get device_vector_view to the precalculate norms if available
276
+ std::optional<raft::temporary_device_buffer<
277
+ const float ,
278
+ raft::vector_extent<int64_t >>>
279
+ norms;
280
+ std::optional<raft::device_vector_view<const float , int64_t >>
281
+ norms_view;
282
+ if (args.vectorNorms ) {
283
+ norms = raft::make_readonly_temporary_device_buffer<
286
284
const float ,
287
- idx_t ,
288
- raft::row_major>>
289
- index_vec = {index .view ()};
290
- RAFT_LOG_INFO (" Invoking flat bfknn" );
291
- brute_force::knn (
285
+ int64_t >(
292
286
handle,
293
- index_vec,
294
- search.view (),
295
- inds.view (),
296
- dists.view (),
297
- distance,
298
- metric_arg);
287
+ args.vectorNorms ,
288
+ raft::vector_extent<int64_t >(num_queries));
289
+ norms_view = norms->view ();
299
290
}
291
+ raft::neighbors::brute_force::index idx (
292
+ handle, index .view (), norms_view, distance, metric_arg);
293
+ raft::neighbors::brute_force::search<float , idx_t >(
294
+ handle, idx, search.view (), inds.view (), dists.view ());
300
295
} else {
301
296
auto index = raft::make_readonly_temporary_device_buffer<
302
297
const float ,
303
- idx_t ,
298
+ int64_t ,
304
299
raft::col_major>(
305
300
handle,
306
301
const_cast <float *>(
307
302
reinterpret_cast <const float *>(args.vectors )),
308
- raft::matrix_extent<idx_t >(num_vectors, dims));
303
+ raft::matrix_extent<int64_t >(num_vectors, dims));
309
304
310
305
auto search = raft::make_readonly_temporary_device_buffer<
311
306
const float ,
312
- idx_t ,
307
+ int64_t ,
313
308
raft::col_major>(
314
309
handle,
315
310
const_cast <float *>(
316
311
reinterpret_cast <const float *>(args.queries )),
317
- raft::matrix_extent<idx_t >(num_queries, dims));
312
+ raft::matrix_extent<int64_t >(num_queries, dims));
318
313
319
314
std::vector<raft::device_matrix_view<
320
315
const float ,
321
- idx_t ,
316
+ int64_t ,
322
317
raft::col_major>>
323
318
index_vec = {index .view ()};
324
319
RAFT_LOG_INFO (" Invoking flat bfknn" );
0 commit comments