|
24 | 24 | #include <faiss/gpu/GpuResources.h>
|
25 | 25 | #include <faiss/gpu/utils/DeviceUtils.h>
|
26 | 26 | #include <faiss/impl/FaissAssert.h>
|
| 27 | +#include <faiss/utils/Heap.h> |
27 | 28 | #include <faiss/gpu/impl/Distance.cuh>
|
28 | 29 | #include <faiss/gpu/utils/ConversionOperators.cuh>
|
29 | 30 | #include <faiss/gpu/utils/CopyUtils.cuh>
|
@@ -368,6 +369,155 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
|
368 | 369 | }
|
369 | 370 | }
|
370 | 371 |
|
| 372 | +template <class C> |
| 373 | +void bfKnn_shard_database( |
| 374 | + GpuResourcesProvider* prov, |
| 375 | + const GpuDistanceParams& args, |
| 376 | + size_t shard_size, |
| 377 | + size_t distance_size) { |
| 378 | + std::vector<typename C::T> heaps_distances; |
| 379 | + if (args.ignoreOutDistances) { |
| 380 | + heaps_distances.resize(args.numQueries * args.k, 0); |
| 381 | + } |
| 382 | + HeapArray<C> heaps = { |
| 383 | + (size_t)args.numQueries, |
| 384 | + (size_t)args.k, |
| 385 | + (typename C::TI*)args.outIndices, |
| 386 | + args.ignoreOutDistances ? heaps_distances.data() |
| 387 | + : args.outDistances}; |
| 388 | + heaps.heapify(); |
| 389 | + std::vector<typename C::TI> labels(args.numQueries * args.k); |
| 390 | + std::vector<typename C::T> distances(args.numQueries * args.k); |
| 391 | + GpuDistanceParams args_batch = args; |
| 392 | + args_batch.outDistances = distances.data(); |
| 393 | + args_batch.ignoreOutDistances = false; |
| 394 | + args_batch.outIndices = labels.data(); |
| 395 | + for (idx_t i = 0; i < args.numVectors; i += shard_size) { |
| 396 | + args_batch.numVectors = min(shard_size, args.numVectors - i); |
| 397 | + args_batch.vectors = |
| 398 | + (char*)args.vectors + distance_size * args.dims * i; |
| 399 | + args_batch.vectorNorms = |
| 400 | + args.vectorNorms ? args.vectorNorms + i : nullptr; |
| 401 | + bfKnn(prov, args_batch); |
| 402 | + for (auto& label : labels) { |
| 403 | + label += i; |
| 404 | + } |
| 405 | + heaps.addn_with_ids(args.k, distances.data(), labels.data(), args.k); |
| 406 | + } |
| 407 | + heaps.reorder(); |
| 408 | +} |
| 409 | + |
| 410 | +void bfKnn_single_query_shard( |
| 411 | + GpuResourcesProvider* prov, |
| 412 | + const GpuDistanceParams& args, |
| 413 | + size_t vectorsMemoryLimit) { |
| 414 | + if (vectorsMemoryLimit == 0) { |
| 415 | + bfKnn(prov, args); |
| 416 | + return; |
| 417 | + } |
| 418 | + FAISS_THROW_IF_NOT_MSG( |
| 419 | + args.numVectors > 0, "bfKnn_tiling: numVectors must be > 0"); |
| 420 | + FAISS_THROW_IF_NOT_MSG( |
| 421 | + args.vectors, |
| 422 | + "bfKnn_tiling: vectors must be provided (passed null)"); |
| 423 | + FAISS_THROW_IF_NOT_MSG( |
| 424 | + getDeviceForAddress(args.vectors) == -1, |
| 425 | + "bfKnn_tiling: vectors should be in CPU memory when vectorsMemoryLimit > 0"); |
| 426 | + FAISS_THROW_IF_NOT_MSG( |
| 427 | + args.vectorsRowMajor, |
| 428 | + "bfKnn_tiling: tiling vectors is only supported in row major mode"); |
| 429 | + FAISS_THROW_IF_NOT_MSG( |
| 430 | + args.k > 0, |
| 431 | + "bfKnn_tiling: tiling vectors is only supported for k > 0"); |
| 432 | + size_t distance_size = args.vectorType == DistanceDataType::F32 ? 4 |
| 433 | + : args.vectorType == DistanceDataType::F16 ? 2 |
| 434 | + : 0; |
| 435 | + FAISS_THROW_IF_NOT_MSG( |
| 436 | + distance_size > 0, "bfKnn_tiling: unknown vectorType"); |
| 437 | + size_t shard_size = vectorsMemoryLimit / (args.dims * distance_size); |
| 438 | + FAISS_THROW_IF_NOT_MSG( |
| 439 | + shard_size > 0, "bfKnn_tiling: vectorsMemoryLimit is too low"); |
| 440 | + if (args.numVectors <= shard_size) { |
| 441 | + bfKnn(prov, args); |
| 442 | + return; |
| 443 | + } |
| 444 | + if (is_similarity_metric(args.metric)) { |
| 445 | + if (args.outIndicesType == IndicesDataType::I64) { |
| 446 | + bfKnn_shard_database<CMin<float, int64_t>>( |
| 447 | + prov, args, shard_size, distance_size); |
| 448 | + } else if (args.outIndicesType == IndicesDataType::I32) { |
| 449 | + bfKnn_shard_database<CMin<float, int32_t>>( |
| 450 | + prov, args, shard_size, distance_size); |
| 451 | + } else { |
| 452 | + FAISS_THROW_MSG("bfKnn_tiling: unknown outIndicesType"); |
| 453 | + } |
| 454 | + } else { |
| 455 | + if (args.outIndicesType == IndicesDataType::I64) { |
| 456 | + bfKnn_shard_database<CMax<float, int64_t>>( |
| 457 | + prov, args, shard_size, distance_size); |
| 458 | + } else if (args.outIndicesType == IndicesDataType::I32) { |
| 459 | + bfKnn_shard_database<CMax<float, int32_t>>( |
| 460 | + prov, args, shard_size, distance_size); |
| 461 | + } else { |
| 462 | + FAISS_THROW_MSG("bfKnn_tiling: unknown outIndicesType"); |
| 463 | + } |
| 464 | + } |
| 465 | +} |
| 466 | + |
| 467 | +void bfKnn_tiling( |
| 468 | + GpuResourcesProvider* prov, |
| 469 | + const GpuDistanceParams& args, |
| 470 | + size_t vectorsMemoryLimit, |
| 471 | + size_t queriesMemoryLimit) { |
| 472 | + if (queriesMemoryLimit == 0) { |
| 473 | + bfKnn_single_query_shard(prov, args, vectorsMemoryLimit); |
| 474 | + return; |
| 475 | + } |
| 476 | + FAISS_THROW_IF_NOT_MSG( |
| 477 | + args.numQueries > 0, "bfKnn_tiling: numQueries must be > 0"); |
| 478 | + FAISS_THROW_IF_NOT_MSG( |
| 479 | + args.queries, |
| 480 | + "bfKnn_tiling: queries must be provided (passed null)"); |
| 481 | + FAISS_THROW_IF_NOT_MSG( |
| 482 | + getDeviceForAddress(args.queries) == -1, |
| 483 | + "bfKnn_tiling: queries should be in CPU memory when queriesMemoryLimit > 0"); |
| 484 | + FAISS_THROW_IF_NOT_MSG( |
| 485 | + args.queriesRowMajor, |
| 486 | + "bfKnn_tiling: tiling queries is only supported in row major mode"); |
| 487 | + FAISS_THROW_IF_NOT_MSG( |
| 488 | + args.k > 0, |
| 489 | + "bfKnn_tiling: tiling queries is only supported for k > 0"); |
| 490 | + size_t distance_size = args.queryType == DistanceDataType::F32 ? 4 |
| 491 | + : args.queryType == DistanceDataType::F16 ? 2 |
| 492 | + : 0; |
| 493 | + FAISS_THROW_IF_NOT_MSG( |
| 494 | + distance_size > 0, "bfKnn_tiling: unknown queryType"); |
| 495 | + size_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8 |
| 496 | + : args.outIndicesType == IndicesDataType::I32 ? 4 |
| 497 | + : 0; |
| 498 | + FAISS_THROW_IF_NOT_MSG( |
| 499 | + distance_size > 0, "bfKnn_tiling: unknown outIndicesType"); |
| 500 | + size_t shard_size = queriesMemoryLimit / |
| 501 | + (args.k * (distance_size + label_size) + args.dims * distance_size); |
| 502 | + FAISS_THROW_IF_NOT_MSG( |
| 503 | + shard_size > 0, "bfKnn_tiling: queriesMemoryLimit is too low"); |
| 504 | + FAISS_THROW_IF_NOT_MSG( |
| 505 | + args.outIndices, |
| 506 | + "bfKnn: outIndices must be provided (passed null)"); |
| 507 | + for (idx_t i = 0; i < args.numQueries; i += shard_size) { |
| 508 | + GpuDistanceParams args_batch = args; |
| 509 | + args_batch.numQueries = min(shard_size, args.numQueries - i); |
| 510 | + args_batch.queries = |
| 511 | + (char*)args.queries + distance_size * args.dims * i; |
| 512 | + if (!args_batch.ignoreOutDistances) { |
| 513 | + args_batch.outDistances = args.outDistances + args.k * i; |
| 514 | + } |
| 515 | + args_batch.outIndices = |
| 516 | + (char*)args.outIndices + args.k * label_size * i; |
| 517 | + bfKnn_single_query_shard(prov, args_batch, vectorsMemoryLimit); |
| 518 | + } |
| 519 | +} |
| 520 | + |
371 | 521 | // legacy version
|
372 | 522 | void bruteForceKnn(
|
373 | 523 | GpuResourcesProvider* res,
|
|
0 commit comments