Skip to content

Commit e1f156e

Browse files
ramilbakhshyievabhinavdangeti
authored andcommitted
Throw when attempting to move IndexPQ to GPU (facebookresearch#3328)
Summary: Pull Request resolved: facebookresearch#3328 Reviewed By: junjieqi Differential Revision: D55476917 fbshipit-source-id: e7f64adefa07650fda32ad2300a1b933cedc9c79
1 parent 01abe5b commit e1f156e

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

faiss/gpu/GpuCloner.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ faiss::Index* index_cpu_to_gpu(
224224
int device,
225225
const faiss::Index* index,
226226
const GpuClonerOptions* options) {
227+
auto index_pq = dynamic_cast<const faiss::IndexPQ*>(index);
228+
FAISS_THROW_IF_MSG(index_pq, "This index type is not implemented on GPU.");
227229
GpuClonerOptions defaults;
228230
ToGpuCloner cl(provider, device, options ? *options : defaults);
229231
return cl.clone_Index(index);
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
import unittest
3+
import faiss
4+
5+
6+
class TestMoveToGpu(unittest.TestCase):
7+
def test_index_cpu_to_gpu(self):
8+
dimension = 128
9+
n = 2500
10+
db_vectors = np.random.random((n, dimension)).astype('float32')
11+
code_size = 16
12+
res = faiss.StandardGpuResources()
13+
index_pq = faiss.IndexPQ(dimension, code_size, 6)
14+
index_pq.train(db_vectors)
15+
index_pq.add(db_vectors)
16+
self.assertRaisesRegex(Exception, ".*not implemented.*",
17+
faiss.index_cpu_to_gpu, res, 0, index_pq)
18+
19+
def test_index_cpu_to_gpu_does_not_throw_with_index_flat(self):
20+
dimension = 128
21+
n = 100
22+
db_vectors = np.random.random((n, dimension)).astype('float32')
23+
res = faiss.StandardGpuResources()
24+
index_flat = faiss.IndexFlatL2(dimension)
25+
index_flat.add(db_vectors)
26+
try:
27+
faiss.index_cpu_to_gpu(res, 0, index_flat)
28+
except Exception:
29+
self.fail("index_cpu_to_gpu() threw an unexpected exception.")

faiss/impl/FaissAssert.h

+7
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@
9494
} \
9595
} while (false)
9696

97+
#define FAISS_THROW_IF_MSG(X, MSG) \
98+
do { \
99+
if (X) { \
100+
FAISS_THROW_FMT("Error: '%s' failed: " MSG, #X); \
101+
} \
102+
} while (false)
103+
97104
#define FAISS_THROW_IF_NOT_MSG(X, MSG) \
98105
do { \
99106
if (!(X)) { \

0 commit comments

Comments
 (0)