diff --git a/INSTALL.md b/INSTALL.md index 26b51a80b1..515ba7c788 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -109,6 +109,8 @@ section of the wiki](https://github.com/facebookresearch/faiss/wiki/Troubleshoot ### Building with NVIDIA cuVS +[cuVS](https://docs.rapids.ai/api/cuvs/nightly/) contains state-of-the-art implementations of several algorithms for running approximate nearest neighbors and clustering on the GPU. It is built on top of the [RAPIDS RAFT](https://github.com/rapidsai/raft) library of high performance machine learning primitives. Building FAISS with cuVS enabled allows a user to choose between regular GPU implementations in FAISS and cuVS implementations for specific algorithms. + The libcuvs dependency should be installed via conda: 1. With CUDA 12.0 - 12.5: ``` diff --git a/benchs/bench_ivfflat_cuvs.py b/benchs/bench_ivfflat_cuvs.py index 3628ec7422..0e3f74207f 100644 --- a/benchs/bench_ivfflat_cuvs.py +++ b/benchs/bench_ivfflat_cuvs.py @@ -4,7 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +25,28 @@ import argparse import rmm +try: + from faiss.contrib.datasets_fb import \ + DatasetSIFT1M, DatasetDeep1B, DatasetBigANN +except ImportError: + from faiss.contrib.datasets import \ + DatasetSIFT1M, DatasetDeep1B, DatasetBigANN + + +# ds = DatasetDeep1B(10**6) +# ds = DatasetBigANN(nb_M=1) +ds = DatasetSIFT1M() + +xq = ds.get_queries() +xb = ds.get_database() +gt = ds.get_groundtruth() + +xt = ds.get_train() + +nb, d = xb.shape +nq, d = xq.shape +nt, d = xt.shape + ###################################################### # Command-line parsing ###################################################### @@ -38,25 +60,23 @@ def aa(*args, **kwargs): group = parser.add_argument_group('benchmarking options') -aa('--bm_train', default=False, action='store_true', +aa('--bm_train', default=True, help='whether to benchmark train operation on GPU index') -aa('--bm_add', default=False, action='store_true', +aa('--bm_add', default=True, help='whether to benchmark add operation on GPU index') aa('--bm_search', default=True, help='whether to benchmark search operation on GPU index') -aa('--cuvs_only', default=False, action='store_true', - help='whether to only produce cuVS enabled benchmarks') group = parser.add_argument_group('IVF options') -aa('--n_centroids', default=256, type=int, +aa('--nlist', default=1024, type=int, help="number of IVF centroids") group = parser.add_argument_group('searching') -aa('--k', default=100, type=int, help='nb of nearest neighbors') -aa('--nprobe', default=50, help='nb of IVF lists to probe') +aa('--k', default=10, type=int, help='nb of nearest neighbors') +aa('--nprobe', default=10, help='nb of IVF lists to probe') args = parser.parse_args() @@ -70,42 +90,38 @@ def aa(*args, **kwargs): mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaMemoryResource()) rmm.mr.set_current_device_resource(mr) -def bench_train_milliseconds(index, trainVecs, use_cuvs): - co = faiss.GpuMultipleClonerOptions() - co.use_cuvs = use_cuvs - index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) + +def bench_train_milliseconds(trainVecs, ncols, nlist, use_cuvs): + config = faiss.GpuIndexIVFFlatConfig() + config.use_cuvs = use_cuvs + index = faiss.GpuIndexIVFFlat(res, ncols, nlist, faiss.METRIC_L2, config) t0 = time.time() - index_gpu.train(trainVecs) + index.train(trainVecs) return 1000*(time.time() - t0) +#warmup +xw = rs.rand(nt, d) +bench_train_milliseconds(xw, d, args.nlist, True) + + if args.bm_train: print("=" * 40) print("GPU Train Benchmarks") print("=" * 40) - trainset_sizes = [5000, 10000, 100000, 1000000, 5000000] - dataset_dims = [128, 256, 1024] - for n_rows in trainset_sizes: - for n_cols in dataset_dims: - index = faiss.index_factory(n_cols, "IVF{},Flat".format(args.n_centroids)) - trainVecs = rs.rand(n_rows, n_cols).astype('float32') - cuvs_gpu_train_time = bench_train_milliseconds( - index, trainVecs, True) - if args.cuvs_only: - print("Method: IVFFlat, Operation: TRAIN, dim: %d, n_centroids %d, numTrain: %d, cuVS enabled GPU train time: %.3f milliseconds" % ( - n_cols, args.n_centroids, n_rows, cuvs_gpu_train_time)) - else: - classical_gpu_train_time = bench_train_milliseconds( - index, trainVecs, False) - print("Method: IVFFlat, Operation: TRAIN, dim: %d, n_centroids %d, numTrain: %d, classical GPU train time: %.3f milliseconds, cuVS enabled GPU train time: %.3f milliseconds" % ( - n_cols, args.n_centroids, n_rows, classical_gpu_train_time, cuvs_gpu_train_time)) - - -def bench_add_milliseconds(index, addVecs, use_cuvs): - co = faiss.GpuMultipleClonerOptions() - co.use_cuvs = use_cuvs - index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) - index_gpu.copyFrom(index) + + cuvs_gpu_train_time = bench_train_milliseconds(xt, d, args.nlist, True) + classical_gpu_train_time = bench_train_milliseconds(xt, d, args.nlist, False) + print("Method: IVFFlat, Operation: TRAIN, dim: %d, nlist %d, numTrain: %d, classical GPU train time: %.3f milliseconds, cuVS enabled GPU train time: %.3f milliseconds" % ( + d, args.nlist, nt, classical_gpu_train_time, cuvs_gpu_train_time)) + + +def bench_add_milliseconds(addVecs, q, use_cuvs): + # construct a GPU index using the same trained coarse quantizer + config = faiss.GpuIndexIVFFlatConfig() + config.use_cuvs = use_cuvs + index_gpu = faiss.GpuIndexIVFFlat(res, q, d, args.nlist, faiss.METRIC_L2, config) + assert(index_gpu.is_trained) t0 = time.time() index_gpu.add(addVecs) return 1000*(time.time() - t0) @@ -115,33 +131,19 @@ def bench_add_milliseconds(index, addVecs, use_cuvs): print("=" * 40) print("GPU Add Benchmarks") print("=" * 40) - addset_sizes = [5000, 10000, 100000, 1000000] - dataset_dims = [128, 256, 1024] - n_train = 10000 - trainVecs = rs.rand(n_train, n_cols).astype('float32') - index = faiss.index_factory( - n_cols, "IVF" + str(args.n_centroids) + ",Flat") - index.train(trainVecs) - for n_rows in addset_sizes: - for n_cols in dataset_dims: - addVecs = rs.rand(n_rows, n_cols).astype('float32') - cuvs_gpu_add_time = bench_add_milliseconds(index, addVecs, True) - if args.cuvs_only: - print("Method: IVFFlat, Operation: ADD, dim: %d, n_centroids %d, numAdd: %d, cuVS enabled GPU add time: %.3f milliseconds" % ( - n_train, n_rows, n_cols, args.n_centroids, cuvs_gpu_add_time)) - else: - classical_gpu_add_time = bench_add_milliseconds( - index, addVecs, False) - print("Method: IVFFlat, Operation: ADD, dim: %d, n_centroids %d, numAdd: %d, classical GPU add time: %.3f milliseconds, cuVS enabled GPU add time: %.3f milliseconds" % ( - n_train, n_rows, n_cols, args.n_centroids, classical_gpu_add_time, cuvs_gpu_add_time)) - - -def bench_search_milliseconds(index, addVecs, queryVecs, nprobe, k, use_cuvs): - co = faiss.GpuMultipleClonerOptions() + quantizer = faiss.IndexFlatL2(d) + idx_cpu = faiss.IndexIVFFlat(quantizer, d, args.nlist) + idx_cpu.train(xt) + cuvs_gpu_add_time = bench_add_milliseconds(xb, quantizer, True) + classical_gpu_add_time = bench_add_milliseconds(xb, quantizer, False) + print("Method: IVFFlat, Operation: ADD, dim: %d, nlist %d, numAdd: %d, classical GPU add time: %.3f milliseconds, cuVS enabled GPU add time: %.3f milliseconds" % ( + d, args.nlist, nb, classical_gpu_add_time, cuvs_gpu_add_time)) + + +def bench_search_milliseconds(index, queryVecs, nprobe, k, use_cuvs): + co = faiss.GpuClonerOptions() co.use_cuvs = use_cuvs index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) - index_gpu.copyFrom(index) - index_gpu.add(addVecs) index_gpu.nprobe = nprobe t0 = time.time() index_gpu.search(queryVecs, k) @@ -152,43 +154,14 @@ def bench_search_milliseconds(index, addVecs, queryVecs, nprobe, k, use_cuvs): print("=" * 40) print("GPU Search Benchmarks") print("=" * 40) - queryset_sizes = [5000, 10000, 100000, 500000] - n_train = 10000 - n_add = 100000 - search_bm_dims = [8, 16, 32] - for n_cols in search_bm_dims: - index = faiss.index_factory(n_cols, "IVF{},Flat".format(args.n_centroids)) - trainVecs = rs.rand(n_train, n_cols).astype('float32') - index.train(trainVecs) - addVecs = rs.rand(n_add, n_cols).astype('float32') - for n_rows in queryset_sizes: - queryVecs = rs.rand(n_rows, n_cols).astype('float32') - cuvs_gpu_search_time = bench_search_milliseconds( - index, addVecs, queryVecs, args.nprobe, args.k, True) - if args.cuvs_only: - print("Method: IVFFlat, Operation: SEARCH, dim: %d, n_centroids: %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, cuVS enabled GPU search time: %.3f milliseconds" % ( - n_cols, args.n_centroids, n_add, n_rows, args.nprobe, args.k, cuvs_gpu_search_time)) - else: - classical_gpu_search_time = bench_search_milliseconds( - index, addVecs, queryVecs, args.nprobe, args.k, False) - print("Method: IVFFlat, Operation: SEARCH, dim: %d, n_centroids: %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, classical GPU search time: %.3f milliseconds, cuVS enabled GPU search time: %.3f milliseconds" % ( - n_cols, args.n_centroids, n_add, n_rows, args.nprobe, args.k, classical_gpu_search_time, cuvs_gpu_search_time)) - - print("=" * 40) - print("Large cuVS Enabled Benchmarks") - print("=" * 40) - # Avoid classical GPU Benchmarks for large datasets because of OOM for more than 500000 queries and/or large dims as well as for large k - queryset_sizes = [100000, 500000, 1000000] - large_search_bm_dims = [128, 256, 1024] - for n_cols in large_search_bm_dims: - trainVecs = rs.rand(n_train, n_cols).astype('float32') - index = faiss.index_factory( - n_cols, "IVF" + str(args.n_centroids) + ",Flat") - index.train(trainVecs) - addVecs = rs.rand(n_add, n_cols).astype('float32') - for n_rows in queryset_sizes: - queryVecs = rs.rand(n_rows, n_cols).astype('float32') - cuvs_gpu_search_time = bench_search_milliseconds( - index, addVecs, queryVecs, args.nprobe, args.k, True) - print("Method: IVFFlat, Operation: SEARCH, numTrain: %d, dim: %d, n_centroids: %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, cuVS enabled GPU search time: %.3f milliseconds" % ( - n_cols, args.n_centroids, n_add, n_rows, args.nprobe, args.k, cuvs_gpu_search_time)) + idx_cpu = faiss.IndexIVFFlat( + faiss.IndexFlatL2(d), d, args.nlist) + idx_cpu.train(xt) + idx_cpu.add(xb) + + cuvs_gpu_search_time = bench_search_milliseconds( + idx_cpu, xq, args.nprobe, args.k, True) + classical_gpu_search_time = bench_search_milliseconds( + idx_cpu, xq, args.nprobe, args.k, False) + print("Method: IVFFlat, Operation: SEARCH, dim: %d, nlist: %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, classical GPU search time: %.3f milliseconds, cuVS enabled GPU search time: %.3f milliseconds" % ( + d, args.nlist, nb, nq, args.nprobe, args.k, classical_gpu_search_time, cuvs_gpu_search_time)) diff --git a/benchs/bench_ivfpq_cuvs.py b/benchs/bench_ivfpq_cuvs.py index 7668afffea..924f24038a 100644 --- a/benchs/bench_ivfpq_cuvs.py +++ b/benchs/bench_ivfpq_cuvs.py @@ -1,9 +1,10 @@ -# Copyright (c) Facebook, Inc. and its affiliates. +# @lint-ignore-every LICENSELINT +# Copyright (c) Meta Platforms, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,6 +24,31 @@ import time import argparse import rmm +import ctypes + +try: + from faiss.contrib.datasets_fb import \ + DatasetSIFT1M, DatasetDeep1B, DatasetBigANN +except ImportError: + from faiss.contrib.datasets import \ + DatasetSIFT1M, DatasetDeep1B, DatasetBigANN + + +# ds = DatasetDeep1B(10**6) +# ds = DatasetBigANN(nb_M=1) +ds = DatasetSIFT1M() + +xq = ds.get_queries() +xb = ds.get_database() +gt = ds.get_groundtruth() + +xt = ds.get_train() + +nb, d = xb.shape +nq, d = xq.shape +nt, d = xt.shape + +M = d / 2 ###################################################### # Command-line parsing @@ -30,33 +56,40 @@ parser = argparse.ArgumentParser() -from datasets import load_sift1M, evaluate - - -print("load data") -xb, xq, xt, gt = load_sift1M() def aa(*args, **kwargs): group.add_argument(*args, **kwargs) group = parser.add_argument_group('benchmarking options') -aa('--cuvs_only', default=False, action='store_true', - help='whether to only produce cuVS enabled benchmarks') + +aa('--bm_train', default=True, + help='whether to benchmark train operation on GPU index') +aa('--bm_add', default=True, + help='whether to benchmark add operation on GPU index') +aa('--bm_search', default=True, + help='whether to benchmark search operation on GPU index') + group = parser.add_argument_group('IVF options') -aa('--bits_per_code', default=8, type=int, help='bits per code. Note that < 8 is only supported when cuVS is enabled') -aa('--pq_len', default=2, type=int, help='number of vector elements represented by one PQ code') -aa('--use_precomputed', default=True, type=bool, help='use precomputed codes (not with cuVS enabled)') +aa('--nlist', default=1024, type=np.int64, + help="number of IVF centroids") +aa('--bits_per_code', default=8, type=np.int64, help='bits per code. Note that < 8 is only supported when cuVS is enabled') + group = parser.add_argument_group('searching') + aa('--k', default=10, type=int, help='nb of nearest neighbors') -aa('--nprobe', default=50, type=int, help='nb of IVF lists to probe') +aa('--nprobe', default=10, help='nb of IVF lists to probe') args = parser.parse_args() print("args:", args) +gt = gt[:, :args.k] +nlist = args.nlist +bits_per_code = args.bits_per_code + rs = np.random.RandomState(123) res = faiss.StandardGpuResources() @@ -65,104 +98,90 @@ def aa(*args, **kwargs): mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaMemoryResource()) rmm.mr.set_current_device_resource(mr) -# A heuristic to select a suitable number of lists -def compute_nlist(numVecs): - nlist = np.sqrt(numVecs) - if (numVecs / nlist < 1000): - nlist = numVecs / 1000 - return int(nlist) +def eval_recall(neighbors, t): + speed = t * 1000 / nq + qps = 1000 / speed -def bench_train_milliseconds(index, trainVecs, use_cuvs): - co = faiss.GpuMultipleClonerOptions() - # use float 16 lookup tables to save space - co.useFloat16LookupTables = True - co.use_cuvs = use_cuvs - index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) + corrects = (gt == neighbors).sum() + recall = corrects / (nq * args.k) + + return recall, qps + + +def bench_train_milliseconds(trainVecs, use_cuvs): + config = faiss.GpuIndexIVFPQConfig() + config.use_cuvs = use_cuvs + index = faiss.GpuIndexIVFPQ(res, d, 1024, 32, 8, faiss.METRIC_L2, config) t0 = time.time() - index_gpu.train(trainVecs) + index.train(trainVecs) return 1000*(time.time() - t0) -n_rows, n_cols = xb.shape -n_train, _ = xt.shape -M = n_cols // args.pq_len -nlist = compute_nlist(n_rows) -index = faiss.index_factory(n_cols, "IVF{},PQ{}x{}np".format(nlist, M, args.bits_per_code)) - -print("=" * 40) -print("GPU Train Benchmarks") -print("=" * 40) -cuvs_gpu_train_time = bench_train_milliseconds(index, xt, True) -if args.cuvs_only: - print("Method: IVFPQ, Operation: TRAIN, dim: %d, n_centroids %d, numSubQuantizers %d, bitsPerCode %d, numTrain: %d, cuVS enabled GPU train time: %.3f milliseconds" % ( - n_cols, nlist, M, args.bits_per_code, n_train, cuvs_gpu_train_time)) -else: - classical_gpu_train_time = bench_train_milliseconds( - index, xt, False) - print("Method: IVFPQ, Operation: TRAIN, dim: %d, n_centroids %d, numSubQuantizers %d, bitsPerCode %d, numTrain: %d, classical GPU train time: %.3f milliseconds, cuVS enabled GPU train time: %.3f milliseconds" % ( - n_cols, nlist, M, args.bits_per_code, n_train, classical_gpu_train_time, cuvs_gpu_train_time)) - - -def bench_add_milliseconds(index, addVecs, use_cuvs): - co = faiss.GpuMultipleClonerOptions() - # use float 16 lookup tables to save space - co.useFloat16LookupTables = True - co.use_cuvs = use_cuvs - index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) - index_gpu.copyFrom(index) + +#warmup +xw = rs.rand(nt, d) +bench_train_milliseconds(xw, True) + + +if args.bm_train: + print("=" * 40) + print("GPU Train Benchmarks") + print("=" * 40) + + cuvs_gpu_train_time = bench_train_milliseconds(xt, True) + classical_gpu_train_time = bench_train_milliseconds(xt, False) + print("TRAIN, dim: %d, nlist %d, numTrain: %d, classical GPU train time: %.3f milliseconds, cuVS enabled GPU train time: %.3f milliseconds" % ( + d, nlist, nt, classical_gpu_train_time, cuvs_gpu_train_time)) + + +def bench_add_milliseconds(addVecs, index_cpu, use_cuvs): + # construct a GPU index using the same trained coarse quantizer + config = faiss.GpuClonerOptions() + config.use_cuvs = use_cuvs + index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu, config) + assert(index_gpu.is_trained) t0 = time.time() index_gpu.add(addVecs) return 1000*(time.time() - t0) -print("=" * 40) -print("GPU Add Benchmarks") -print("=" * 40) -index.train(xt) -cuvs_gpu_add_time = bench_add_milliseconds(index, xb, True) -if args.cuvs_only: - print("Method: IVFPQ, Operation: ADD, dim: %d, n_centroids %d numSubQuantizers %d, bitsPerCode %d, numAdd %d, cuVS enabled GPU add time: %.3f milliseconds" % ( - n_cols, nlist, M, args.bits_per_code, n_rows, cuvs_gpu_add_time)) -else: - classical_gpu_add_time = bench_add_milliseconds( - index, xb, False) - print("Method: IVFFPQ, Operation: ADD, dim: %d, n_centroids %d, numSubQuantizers %d, bitsPerCode %d, numAdd %d, classical GPU add time: %.3f milliseconds, cuVS enabled GPU add time: %.3f milliseconds" % ( - n_cols, nlist, M, args.bits_per_code, n_rows, classical_gpu_add_time, cuvs_gpu_add_time)) - - -def bench_search_milliseconds(index, addVecs, queryVecs, nprobe, k, use_cuvs): - co = faiss.GpuMultipleClonerOptions() + +if args.bm_add: + print("=" * 40) + print("GPU Add Benchmarks") + print("=" * 40) + quantizer = faiss.IndexFlatL2(d) + index_cpu = faiss.IndexIVFPQ(quantizer, d, 1024, 32, 8, faiss.METRIC_L2) + index_cpu.train(xt) + cuvs_gpu_add_time = bench_add_milliseconds(xb, index_cpu, True) + classical_gpu_add_time = bench_add_milliseconds(xb, index_cpu, False) + print("ADD, dim: %d, nlist %d, numAdd: %d, classical GPU add time: %.3f milliseconds, cuVS enabled GPU add time: %.3f milliseconds" % ( + d, nlist, nb, classical_gpu_add_time, cuvs_gpu_add_time)) + + +def bench_search_milliseconds(index, queryVecs, nprobe, k, use_cuvs): + co = faiss.GpuClonerOptions() co.use_cuvs = use_cuvs - co.useFloat16LookupTables = True index_gpu = faiss.index_cpu_to_gpu(res, 0, index, co) - index_gpu.copyFrom(index) - index_gpu.add(addVecs) index_gpu.nprobe = nprobe t0 = time.time() - index_gpu.search(queryVecs, k) - return 1000*(time.time() - t0) + _, I = index_gpu.search(queryVecs, k) + return I, 1000*(time.time() - t0) +# Search benchmarks: both indexes have identical IVF centroids and lists. if args.bm_search: print("=" * 40) print("GPU Search Benchmarks") print("=" * 40) - queryset_sizes = [1, 10, 100, 1000, 10000] - n_train, n_cols = xt.shape - n_add, _ = xb.shape - print(xq.shape) - M = n_cols // args.pq_len - nlist = compute_nlist(n_add) - index = faiss.index_factory(n_cols, "IVF{},PQ{}x{}np".format(nlist, M, args.bits_per_code)) - index.train(xt) - for n_rows in queryset_sizes: - queryVecs = xq[np.random.choice(xq.shape[0], n_rows, replace=False)] - cuvs_gpu_search_time = bench_search_milliseconds( - index, xb, queryVecs, args.nprobe, args.k, True) - if args.cuvs_only: - print("Method: IVFPQ, Operation: SEARCH, dim: %d, n_centroids: %d, numSubQuantizers %d, bitsPerCode %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, cuVS enabled GPU search time: %.3f milliseconds" % ( - n_cols, nlist, M, args.bits_per_code, n_add, n_rows, args.nprobe, args.k, cuvs_gpu_search_time)) - else: - classical_gpu_search_time = bench_search_milliseconds( - index, xb, queryVecs, args.nprobe, args.k, False) - print("Method: IVFPQ, Operation: SEARCH, dim: %d, n_centroids: %d, numSubQuantizers %d, bitsPerCode %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, classical GPU search time: %.3f milliseconds, cuVS enabled GPU search time: %.3f milliseconds" % ( - n_cols, nlist, M, args.bits_per_code, n_add, n_rows, args.nprobe, args.k, classical_gpu_search_time, cuvs_gpu_search_time)) \ No newline at end of file + index_cpu = faiss.IndexIVFPQ(quantizer, d, 1024, 32, 8, faiss.METRIC_L2) + index_cpu.train(xt) + index_cpu.add(xb) + + cuvs_indices, cuvs_gpu_search_time = bench_search_milliseconds( + index_cpu, xq, args.nprobe, args.k, True) + classical_gpu_indices, classical_gpu_search_time = bench_search_milliseconds( + index_cpu, xq, args.nprobe, args.k, False) + cuvs_recall, cuvs_qps = eval_recall(cuvs_indices, cuvs_gpu_search_time) + classical_recall, classical_qps = eval_recall(classical_gpu_indices, classical_gpu_search_time) + print("SEARCH, dim: %d, nlist: %d, numVecs: %d, numQuery: %d, nprobe: %d, k: %d, classical GPU qps: %.3f, cuVS enabled GPU qps: %.3f" % ( + d, nlist, nb, nq, args.nprobe, args.k, classical_qps, cuvs_qps)) diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index 16574aab61..04d28907d1 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -273,7 +273,7 @@ if(FAISS_ENABLE_CUVS) target_compile_definitions(faiss_avx512_spr PUBLIC USE_NVIDIA_CUVS=1) # Mark all functions as hidden so that we don't generate - # global 'public' functions that also exist in libraft.so + # global 'public' functions that also exist in libcuvs.so # # This ensures that faiss functions will call the local version # inside libfaiss.so . This is needed to ensure that things @@ -285,8 +285,13 @@ if(FAISS_ENABLE_CUVS) # respective classes/types in the headers are explicitly marked # as 'public' so they can be used by consumers set_source_files_properties( + GpuIndexCagra.cu GpuDistance.cu + GpuIndexIVFFlat.cu + GpuIndexIVFPQ.cu + GpuIndexFlat.cu StandardGpuResources.cpp + impl/CuvsCagra.cu impl/CuvsFlatIndex.cu impl/CuvsIVFFlat.cu impl/CuvsIVFPQ.cu diff --git a/faiss/gpu/GpuIndexIVFFlat.cu b/faiss/gpu/GpuIndexIVFFlat.cu index ceeb2dda76..eb5dacc1cd 100644 --- a/faiss/gpu/GpuIndexIVFFlat.cu +++ b/faiss/gpu/GpuIndexIVFFlat.cu @@ -72,9 +72,6 @@ GpuIndexIVFFlat::GpuIndexIVFFlat( config), ivfFlatConfig_(config), reserveMemoryVecs_(0) { - FAISS_THROW_IF_NOT_MSG( - !should_use_cuvs(config), - "GpuIndexIVFFlat: cuVS does not support separate coarseQuantizer"); // We could have been passed an already trained coarse quantizer. There is // no other quantizer that we need to train, so this is sufficient if (this->is_trained) { diff --git a/faiss/gpu/test/test_gpu_index.py b/faiss/gpu/test/test_gpu_index.py index d3892e190d..287f27d958 100755 --- a/faiss/gpu/test/test_gpu_index.py +++ b/faiss/gpu/test/test_gpu_index.py @@ -141,7 +141,6 @@ def test_ivfflat_cpu_coarse(self): # construct a GPU index using the same trained coarse quantizer # from the CPU index config = faiss.GpuIndexIVFFlatConfig() - config.use_cuvs = False idx_gpu = faiss.GpuIndexIVFFlat(res, q, d, nlist, faiss.METRIC_L2, config) assert(idx_gpu.is_trained) idx_gpu.add(xb) @@ -156,6 +155,7 @@ def test_ivfflat_cpu_coarse(self): self.assertGreaterEqual((i_g == i_c).sum(), i_g.size * 0.9) self.assertTrue(np.allclose(d_g, d_c, rtol=5e-5, atol=5e-5)) + def test_ivfsq_pu_coarse(self): res = faiss.StandardGpuResources() d = 128