Skip to content

Commit 875c168

Browse files
algoriddlefacebook-github-bot
authored andcommitted
tiling bfKnn
Summary: Adding tiling support for bfKnn, breaking up both queries and vectors into tiles of size vectorsMemoryLimit and queriesMemoryLimit. Differential Revision: D45944524 fbshipit-source-id: 9dfab73338601c6278171a37282694273473ace7
1 parent 48d48a3 commit 875c168

File tree

5 files changed

+174
-5
lines changed

5 files changed

+174
-5
lines changed

faiss/gpu/GpuDistance.cu

+124-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <faiss/gpu/GpuResources.h>
2525
#include <faiss/gpu/utils/DeviceUtils.h>
2626
#include <faiss/impl/FaissAssert.h>
27+
#include <faiss/utils/Heap.h>
2728
#include <faiss/gpu/impl/Distance.cuh>
2829
#include <faiss/gpu/utils/ConversionOperators.cuh>
2930
#include <faiss/gpu/utils/CopyUtils.cuh>
@@ -218,7 +219,9 @@ void bfKnnConvert(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
218219
fromDevice<float, 2>(tOutDistances, args.outDistances, stream);
219220
}
220221

221-
void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
222+
void bfKnn_single_tile(
223+
GpuResourcesProvider* prov,
224+
const GpuDistanceParams& args) {
222225
// For now, both vectors and queries must be of the same data type
223226
FAISS_THROW_IF_NOT_MSG(
224227
args.vectorType == args.queryType,
@@ -368,6 +371,126 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
368371
}
369372
}
370373

374+
template <class C>
375+
void bfKnn_shard_database(
376+
GpuResourcesProvider* prov,
377+
const GpuDistanceParams& args,
378+
idx_t shard_size,
379+
idx_t distance_size) {
380+
std::vector<typename C::T> heaps_distances;
381+
if (args.ignoreOutDistances) {
382+
heaps_distances.resize(args.numQueries * args.k, 0);
383+
}
384+
HeapArray<C> heaps = {
385+
(size_t)args.numQueries,
386+
(size_t)args.k,
387+
(typename C::TI*)args.outIndices,
388+
args.ignoreOutDistances ? heaps_distances.data()
389+
: args.outDistances};
390+
heaps.heapify();
391+
std::vector<typename C::TI> labels(args.numQueries * args.k, -1);
392+
std::vector<typename C::T> distances(args.numQueries * args.k, 0);
393+
GpuDistanceParams args_batch = args;
394+
args_batch.outDistances = distances.data();
395+
args_batch.ignoreOutDistances = false;
396+
args_batch.outIndices = labels.data();
397+
for (idx_t i = 0; i < args.numVectors; i += shard_size) {
398+
args_batch.numVectors = min(shard_size, args.numVectors - i);
399+
args_batch.vectors =
400+
(char*)args.vectors + distance_size * args.dims * i;
401+
args_batch.vectorNorms =
402+
args.vectorNorms ? args.vectorNorms + i : nullptr;
403+
bfKnn_single_tile(prov, args_batch);
404+
for (auto& label : labels) {
405+
label += i;
406+
}
407+
heaps.addn_with_ids(args.k, distances.data(), labels.data(), args.k);
408+
}
409+
heaps.reorder();
410+
}
411+
412+
void bfKnn_single_query_shard(
413+
GpuResourcesProvider* prov,
414+
const GpuDistanceParams& args) {
415+
if (args.vectorsMemoryLimit == 0) {
416+
bfKnn_single_tile(prov, args);
417+
return;
418+
}
419+
FAISS_THROW_IF_NOT_MSG(
420+
args.vectorsRowMajor,
421+
"sharding vectors is only supported in row major mode");
422+
FAISS_THROW_IF_NOT_MSG(
423+
args.k > 0, "sharding vectors is only supported for k > 0");
424+
idx_t distance_size = args.vectorType == DistanceDataType::F32 ? 4
425+
: args.vectorType == DistanceDataType::F16 ? 2
426+
: 0;
427+
FAISS_THROW_IF_NOT_MSG(distance_size > 0, "unknown vectorType");
428+
idx_t shard_size = args.vectorsMemoryLimit / (args.dims * distance_size);
429+
FAISS_THROW_IF_NOT_MSG(
430+
shard_size > 0,
431+
"vectorsMemoryLimit is too low, shard size would be zero");
432+
if (args.numVectors <= shard_size) {
433+
bfKnn_single_tile(prov, args);
434+
return;
435+
}
436+
if (is_similarity_metric(args.metric)) {
437+
if (args.outIndicesType == IndicesDataType::I64) {
438+
bfKnn_shard_database<CMin<float, int64_t>>(
439+
prov, args, shard_size, distance_size);
440+
} else if (args.outIndicesType == IndicesDataType::I32) {
441+
bfKnn_shard_database<CMin<float, int32_t>>(
442+
prov, args, shard_size, distance_size);
443+
} else {
444+
FAISS_THROW_MSG("unknown outIndicesType");
445+
}
446+
} else {
447+
if (args.outIndicesType == IndicesDataType::I64) {
448+
bfKnn_shard_database<CMax<float, int64_t>>(
449+
prov, args, shard_size, distance_size);
450+
} else if (args.outIndicesType == IndicesDataType::I32) {
451+
bfKnn_shard_database<CMax<float, int32_t>>(
452+
prov, args, shard_size, distance_size);
453+
} else {
454+
FAISS_THROW_MSG("unknown outIndicesType");
455+
}
456+
}
457+
}
458+
459+
void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
460+
if (args.queriesMemoryLimit == 0) {
461+
bfKnn_single_query_shard(prov, args);
462+
return;
463+
}
464+
FAISS_THROW_IF_NOT_MSG(
465+
args.queriesRowMajor,
466+
"sharding queries is only supported in row major mode");
467+
FAISS_THROW_IF_NOT_MSG(
468+
args.k > 0, "sharding queries is only supported for k > 0");
469+
idx_t distance_size = args.queryType == DistanceDataType::F32 ? 4
470+
: args.queryType == DistanceDataType::F16 ? 2
471+
: 0;
472+
FAISS_THROW_IF_NOT_MSG(distance_size > 0, "unknown queryType");
473+
idx_t label_size = args.outIndicesType == IndicesDataType::I64 ? 8
474+
: args.outIndicesType == IndicesDataType::I32 ? 4
475+
: 0;
476+
FAISS_THROW_IF_NOT_MSG(distance_size > 0, "unknown outIndicesType");
477+
idx_t shard_size = args.queriesMemoryLimit /
478+
(args.k * (distance_size + label_size) + args.dims * distance_size);
479+
FAISS_THROW_IF_NOT_MSG(shard_size > 0, "queriesMemoryLimit is too low");
480+
for (idx_t i = 0; i < args.numQueries; i += shard_size) {
481+
GpuDistanceParams args_batch = args;
482+
args_batch.numQueries = min(shard_size, args.numQueries - i);
483+
args_batch.queries =
484+
(char*)args.queries + distance_size * args.dims * i;
485+
if (!args_batch.ignoreOutDistances) {
486+
args_batch.outDistances = args.outDistances + args.k * i;
487+
}
488+
args_batch.outIndices =
489+
(char*)args.outIndices + args.k * label_size * i;
490+
bfKnn_single_query_shard(prov, args_batch);
491+
}
492+
}
493+
371494
// legacy version
372495
void bruteForceKnn(
373496
GpuResourcesProvider* res,

faiss/gpu/GpuDistance.h

+15-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ struct GpuDistanceParams {
4646
ignoreOutDistances(false),
4747
outIndicesType(IndicesDataType::I64),
4848
outIndices(nullptr),
49-
device(-1) {}
49+
device(-1),
50+
vectorsMemoryLimit(0),
51+
queriesMemoryLimit(0),
52+
use_raft(false) {}
5053

5154
//
5255
// Search parameters
@@ -125,8 +128,18 @@ struct GpuDistanceParams {
125128
/// execution
126129
int device;
127130

131+
// Memory limits for vectors and queries.
132+
// If not 0, the GPU will use at most this amount of memory
133+
// for vectors and queries respectively.
134+
// Vectors are broken up into chunks of size vectorsMemoryLimit,
135+
// and queries are broken up into chunks of size queriesMemoryLimit,
136+
// including the memory required for the results.
137+
// Only supported for row major matrices.
138+
uint64_t vectorsMemoryLimit;
139+
uint64_t queriesMemoryLimit;
140+
128141
/// Should the index dispatch down to RAFT?
129-
bool use_raft = false;
142+
bool use_raft;
130143
};
131144

132145
/// A wrapper for gpu/impl/Distance.cuh to expose direct brute-force k-nearest

faiss/gpu/test/test_gpu_basics.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,14 @@ def make_t(num, d, clamp=False, seed=None):
225225

226226
class TestKnn(unittest.TestCase):
227227
def test_input_types(self):
228+
self.do_test_input_types(0, 0)
229+
230+
def test_input_types_tiling(self):
231+
self.do_test_input_types(0, 500)
232+
self.do_test_input_types(1000, 0)
233+
self.do_test_input_types(1000, 500)
234+
235+
def do_test_input_types(self, vectorsMemoryLimit, queriesMemoryLimit):
228236
d = 33
229237
k = 5
230238
nb = 1000
@@ -243,6 +251,8 @@ def test_input_types(self):
243251
out_d = np.empty((nq, k), dtype=np.float32)
244252
out_i = np.empty((nq, k), dtype=np.int64)
245253

254+
gpu_id = random.randrange(0, faiss.get_num_gpus())
255+
246256
# Try f32 data/queries, i64 out indices
247257
params = faiss.GpuDistanceParams()
248258
params.k = k
@@ -253,19 +263,30 @@ def test_input_types(self):
253263
params.numQueries = nq
254264
params.outDistances = faiss.swig_ptr(out_d)
255265
params.outIndices = faiss.swig_ptr(out_i)
256-
params.device = random.randrange(0, faiss.get_num_gpus())
266+
params.device = gpu_id
267+
params.vectorsMemoryLimit = vectorsMemoryLimit
268+
params.queriesMemoryLimit = queriesMemoryLimit
257269

258270
faiss.bfKnn(res, params)
259271

260272
self.assertTrue(np.allclose(ref_d, out_d, atol=1e-5))
261273
self.assertGreaterEqual((out_i == ref_i).sum(), ref_i.size)
262274

275+
out_d, out_i = faiss.knn_gpu(
276+
res, qs, xs, k, device=gpu_id,
277+
vectorsMemoryLimit=vectorsMemoryLimit,
278+
queriesMemoryLimit=queriesMemoryLimit)
279+
280+
self.assertTrue(np.allclose(ref_d, out_d, atol=1e-5))
281+
self.assertGreaterEqual((out_i == ref_i).sum(), ref_i.size)
282+
263283
# Try int32 out indices
264284
out_i32 = np.empty((nq, k), dtype=np.int32)
265285
params.outIndices = faiss.swig_ptr(out_i32)
266286
params.outIndicesType = faiss.IndicesDataType_I32
267287

268288
faiss.bfKnn(res, params)
289+
269290
self.assertEqual((out_i32 == ref_i).sum(), ref_i.size)
270291

271292
# Try float16 data/queries, i64 out indices

faiss/python/gpu_wrappers.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def index_cpu_to_gpus_list(index, co=None, gpus=None, ngpu=-1):
5454
# allows numpy ndarray usage with bfKnn
5555

5656

57-
def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2, device=-1):
57+
def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2, device=-1, vectorsMemoryLimit=0, queriesMemoryLimit=0):
5858
"""
5959
Compute the k nearest neighbors of a vector on one GPU without constructing an index
6060
@@ -82,6 +82,14 @@ def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2, device=-1):
8282
(can also be set via torch.cuda.set_device in PyTorch)
8383
Otherwise, an integer 0 <= device < numDevices indicates the GPU on which
8484
the computation should be run
85+
vectorsMemoryLimit: int, optional
86+
queriesMemoryLimit: int, optional
87+
Memory limits for vectors and queries.
88+
If not 0, the GPU will use at most this amount of memory
89+
for vectors and queries respectively.
90+
Vectors are broken up into chunks of size vectorsMemoryLimit,
91+
and queries are broken up into chunks of size queriesMemoryLimit,
92+
including the memory required for the results.
8593
8694
Returns
8795
-------
@@ -168,6 +176,8 @@ def knn_gpu(res, xq, xb, k, D=None, I=None, metric=METRIC_L2, device=-1):
168176
args.outIndices = I_ptr
169177
args.outIndicesType = I_type
170178
args.device = device
179+
args.vectorsMemoryLimit = vectorsMemoryLimit
180+
args.queriesMemoryLimit = queriesMemoryLimit
171181

172182
# no stream synchronization needed, inputs and outputs are guaranteed to
173183
# be on the CPU (numpy arrays)

faiss/utils/Heap.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ void HeapArray<C>::per_line_extrema(T* out_val, TI* out_ids) const {
136136

137137
template struct HeapArray<CMin<float, int64_t>>;
138138
template struct HeapArray<CMax<float, int64_t>>;
139+
template struct HeapArray<CMin<float, int32_t>>;
140+
template struct HeapArray<CMax<float, int32_t>>;
139141
template struct HeapArray<CMin<int, int64_t>>;
140142
template struct HeapArray<CMax<int, int64_t>>;
141143

0 commit comments

Comments
 (0)