|
24 | 24 | #include <faiss/gpu/GpuResources.h>
|
25 | 25 | #include <faiss/gpu/utils/DeviceUtils.h>
|
26 | 26 | #include <faiss/impl/FaissAssert.h>
|
| 27 | +#include <faiss/utils/Heap.h> |
27 | 28 | #include <faiss/gpu/impl/Distance.cuh>
|
28 | 29 | #include <faiss/gpu/utils/ConversionOperators.cuh>
|
29 | 30 | #include <faiss/gpu/utils/CopyUtils.cuh>
|
@@ -218,7 +219,9 @@ void bfKnnConvert(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
|
218 | 219 | fromDevice<float, 2>(tOutDistances, args.outDistances, stream);
|
219 | 220 | }
|
220 | 221 |
|
221 |
| -void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) { |
| 222 | +void bfKnn_single_tile( |
| 223 | + GpuResourcesProvider* prov, |
| 224 | + const GpuDistanceParams& args) { |
222 | 225 | // For now, both vectors and queries must be of the same data type
|
223 | 226 | FAISS_THROW_IF_NOT_MSG(
|
224 | 227 | args.vectorType == args.queryType,
|
@@ -368,6 +371,126 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
|
368 | 371 | }
|
369 | 372 | }
|
370 | 373 |
|
| 374 | +template <class C> |
| 375 | +void bfKnn_shard_database( |
| 376 | + GpuResourcesProvider* prov, |
| 377 | + const GpuDistanceParams& args, |
| 378 | + idx_t shard_size, |
| 379 | + idx_t distance_size) { |
| 380 | + std::vector<typename C::T> heaps_distances; |
| 381 | + if (args.ignoreOutDistances) { |
| 382 | + heaps_distances.resize(args.numQueries * args.k, 0); |
| 383 | + } |
| 384 | + HeapArray<C> heaps = { |
| 385 | + (size_t)args.numQueries, |
| 386 | + (size_t)args.k, |
| 387 | + (typename C::TI*)args.outIndices, |
| 388 | + args.ignoreOutDistances ? heaps_distances.data() |
| 389 | + : args.outDistances}; |
| 390 | + heaps.heapify(); |
| 391 | + std::vector<typename C::TI> labels(args.numQueries * args.k, -1); |
| 392 | + std::vector<typename C::T> distances(args.numQueries * args.k, 0); |
| 393 | + GpuDistanceParams args_batch = args; |
| 394 | + args_batch.outDistances = distances.data(); |
| 395 | + args_batch.ignoreOutDistances = false; |
| 396 | + args_batch.outIndices = labels.data(); |
| 397 | + for (idx_t i = 0; i < args.numVectors; i += shard_size) { |
| 398 | + args_batch.numVectors = min(shard_size, args.numVectors - i); |
| 399 | + args_batch.vectors = |
| 400 | + (char*)args.vectors + distance_size * args.dims * i; |
| 401 | + args_batch.vectorNorms = |
| 402 | + args.vectorNorms ? args.vectorNorms + i : nullptr; |
| 403 | + bfKnn_single_tile(prov, args_batch); |
| 404 | + for (auto& label : labels) { |
| 405 | + label += i; |
| 406 | + } |
| 407 | + heaps.addn_with_ids(args.k, distances.data(), labels.data(), args.k); |
| 408 | + } |
| 409 | + heaps.reorder(); |
| 410 | +} |
| 411 | + |
| 412 | +void bfKnn_single_query_shard( |
| 413 | + GpuResourcesProvider* prov, |
| 414 | + const GpuDistanceParams& args) { |
| 415 | + if (args.vectorsMemoryLimit == 0) { |
| 416 | + bfKnn_single_tile(prov, args); |
| 417 | + return; |
| 418 | + } |
| 419 | + FAISS_THROW_IF_NOT_MSG( |
| 420 | + args.vectorsRowMajor, |
| 421 | + "sharding vectors is only supported in row major mode"); |
| 422 | + FAISS_THROW_IF_NOT_MSG( |
| 423 | + args.k > 0, "sharding vectors is only supported for k > 0"); |
| 424 | + idx_t distance_size = args.vectorType == DistanceDataType::F32 ? 4 |
| 425 | + : args.vectorType == DistanceDataType::F16 ? 2 |
| 426 | + : 0; |
| 427 | + FAISS_THROW_IF_NOT_MSG(distance_size > 0, "unknown vectorType"); |
| 428 | + idx_t shard_size = args.vectorsMemoryLimit / (args.dims * distance_size); |
| 429 | + FAISS_THROW_IF_NOT_MSG( |
| 430 | + shard_size > 0, |
| 431 | + "vectorsMemoryLimit is too low, shard size would be zero"); |
| 432 | + if (args.numVectors <= shard_size) { |
| 433 | + bfKnn_single_tile(prov, args); |
| 434 | + return; |
| 435 | + } |
| 436 | + if (is_similarity_metric(args.metric)) { |
| 437 | + if (args.outIndicesType == IndicesDataType::I64) { |
| 438 | + bfKnn_shard_database<CMin<float, int64_t>>( |
| 439 | + prov, args, shard_size, distance_size); |
| 440 | + } else if (args.outIndicesType == IndicesDataType::I32) { |
| 441 | + bfKnn_shard_database<CMin<float, int32_t>>( |
| 442 | + prov, args, shard_size, distance_size); |
| 443 | + } else { |
| 444 | + FAISS_THROW_MSG("unknown outIndicesType"); |
| 445 | + } |
| 446 | + } else { |
| 447 | + if (args.outIndicesType == IndicesDataType::I64) { |
| 448 | + bfKnn_shard_database<CMax<float, int64_t>>( |
| 449 | + prov, args, shard_size, distance_size); |
| 450 | + } else if (args.outIndicesType == IndicesDataType::I32) { |
| 451 | + bfKnn_shard_database<CMax<float, int32_t>>( |
| 452 | + prov, args, shard_size, distance_size); |
| 453 | + } else { |
| 454 | + FAISS_THROW_MSG("unknown outIndicesType"); |
| 455 | + } |
| 456 | + } |
| 457 | +} |
| 458 | + |
| 459 | +void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) { |
| 460 | + if (args.queriesMemoryLimit == 0) { |
| 461 | + bfKnn_single_query_shard(prov, args); |
| 462 | + return; |
| 463 | + } |
| 464 | + FAISS_THROW_IF_NOT_MSG( |
| 465 | + args.queriesRowMajor, |
| 466 | + "sharding queries is only supported in row major mode"); |
| 467 | + FAISS_THROW_IF_NOT_MSG( |
| 468 | + args.k > 0, "sharding queries is only supported for k > 0"); |
| 469 | + idx_t distance_size = args.queryType == DistanceDataType::F32 ? 4 |
| 470 | + : args.queryType == DistanceDataType::F16 ? 2 |
| 471 | + : 0; |
| 472 | + FAISS_THROW_IF_NOT_MSG(distance_size > 0, "unknown queryType"); |
| 473 | + idx_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8 |
| 474 | + : args.outIndicesType == IndicesDataType::I32 ? 4 |
| 475 | + : 0; |
| 476 | + FAISS_THROW_IF_NOT_MSG(distance_size > 0, "unknown outIndicesType"); |
| 477 | + idx_t shard_size = args.queriesMemoryLimit / |
| 478 | + (args.k * (distance_size + label_size) + args.dims * distance_size); |
| 479 | + FAISS_THROW_IF_NOT_MSG(shard_size > 0, "queriesMemoryLimit is too low"); |
| 480 | + for (idx_t i = 0; i < args.numQueries; i += shard_size) { |
| 481 | + GpuDistanceParams args_batch = args; |
| 482 | + args_batch.numQueries = min(shard_size, args.numQueries - i); |
| 483 | + args_batch.queries = |
| 484 | + (char*)args.queries + distance_size * args.dims * i; |
| 485 | + if (!args_batch.ignoreOutDistances) { |
| 486 | + args_batch.outDistances = args.outDistances + args.k * i; |
| 487 | + } |
| 488 | + args_batch.outIndices = |
| 489 | + (char*)args.outIndices + args.k * label_size * i; |
| 490 | + bfKnn_single_query_shard(prov, args_batch); |
| 491 | + } |
| 492 | +} |
| 493 | + |
371 | 494 | // legacy version
|
372 | 495 | void bruteForceKnn(
|
373 | 496 | GpuResourcesProvider* res,
|
|
0 commit comments