Skip to content

Commit 86de47a

Browse files
junjieqiabhinavdangeti
authored andcommitted
Implement reconstruct_n for GPU IVFFlat indexes (facebookresearch#3338)
Summary: Pull Request resolved: facebookresearch#3338 add reconstruct_n for GPU IVFFlat Reviewed By: mdouze Differential Revision: D55577561 fbshipit-source-id: 47f8b939611e2df7dbcd087129538145f627293c
1 parent 07a189e commit 86de47a

11 files changed

+216
-3
lines changed

faiss/gpu/GpuIndexIVFFlat.cu

+22
Original file line numberDiff line numberDiff line change
@@ -356,5 +356,27 @@ void GpuIndexIVFFlat::setIndex_(
356356
}
357357
}
358358

359+
void GpuIndexIVFFlat::reconstruct_n(idx_t i0, idx_t ni, float* out) const {
360+
FAISS_ASSERT(index_);
361+
362+
if (ni == 0) {
363+
// nothing to do
364+
return;
365+
}
366+
367+
FAISS_THROW_IF_NOT_FMT(
368+
i0 < this->ntotal,
369+
"start index (%zu) out of bounds (ntotal %zu)",
370+
i0,
371+
this->ntotal);
372+
FAISS_THROW_IF_NOT_FMT(
373+
i0 + ni - 1 < this->ntotal,
374+
"max index requested (%zu) out of bounds (ntotal %zu)",
375+
i0 + ni - 1,
376+
this->ntotal);
377+
378+
index_->reconstruct_n(i0, ni, out);
379+
}
380+
359381
} // namespace gpu
360382
} // namespace faiss

faiss/gpu/GpuIndexIVFFlat.h

+2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
8787
/// Trains the coarse quantizer based on the given vector data
8888
void train(idx_t n, const float* x) override;
8989

90+
void reconstruct_n(idx_t i0, idx_t n, float* out) const override;
91+
9092
protected:
9193
/// Initialize appropriate index
9294
void setIndex_(

faiss/gpu/impl/IVFBase.cu

+4
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,10 @@ void IVFBase::copyInvertedListsTo(InvertedLists* ivf) {
340340
}
341341
}
342342

343+
void IVFBase::reconstruct_n(idx_t i0, idx_t n, float* out) {
344+
FAISS_THROW_MSG("not implemented");
345+
}
346+
343347
void IVFBase::addEncodedVectorsToList_(
344348
idx_t listId,
345349
const void* codes,

faiss/gpu/impl/IVFBase.cuh

+11-2
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,18 @@ class IVFBase {
109109
Tensor<idx_t, 2, true>& outIndices,
110110
bool storePairs) = 0;
111111

112+
/* It is used to reconstruct a given number of vectors in an Inverted File
113+
* (IVF) index
114+
* @param i0 index of the first vector to reconstruct
115+
* @param n number of vectors to reconstruct
116+
* @param out This is a pointer to a buffer where the reconstructed
117+
* vectors will be stored.
118+
*/
119+
virtual void reconstruct_n(idx_t i0, idx_t n, float* out);
120+
112121
protected:
113-
/// Adds a set of codes and indices to a list, with the representation
114-
/// coming from the CPU equivalent
122+
/// Adds a set of codes and indices to a list, with the
123+
/// representation coming from the CPU equivalent
115124
virtual void addEncodedVectorsToList_(
116125
idx_t listId,
117126
// resident on the host

faiss/gpu/impl/IVFFlat.cu

+47
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,53 @@ void IVFFlat::searchPreassigned(
283283
storePairs);
284284
}
285285

286+
void IVFFlat::reconstruct_n(idx_t i0, idx_t ni, float* out) {
287+
if (ni == 0) {
288+
// nothing to do
289+
return;
290+
}
291+
292+
auto stream = resources_->getDefaultStreamCurrentDevice();
293+
294+
for (idx_t list_no = 0; list_no < numLists_; list_no++) {
295+
size_t list_size = deviceListData_[list_no]->numVecs;
296+
297+
auto idlist = getListIndices(list_no);
298+
299+
for (idx_t offset = 0; offset < list_size; offset++) {
300+
idx_t id = idlist[offset];
301+
if (!(id >= i0 && id < i0 + ni)) {
302+
continue;
303+
}
304+
305+
// vector data in the non-interleaved format is laid out like:
306+
// v0d0 v0d1 ... v0d(dim-1) v1d0 v1d1 ... v1d(dim-1)
307+
308+
// vector data in the interleaved format is laid out like:
309+
// (v0d0 v1d0 ... v31d0) (v0d1 v1d1 ... v31d1)
310+
// (v0d(dim - 1) ... v31d(dim-1))
311+
// (v32d0 v33d0 ... v63d0) (... v63d(dim-1)) (v64d0 ...)
312+
313+
// where vectors are chunked into groups of 32, and each dimension
314+
// for each of the 32 vectors is contiguous
315+
316+
auto vectorChunk = offset / 32;
317+
auto vectorWithinChunk = offset % 32;
318+
319+
auto listDataPtr = (float*)deviceListData_[list_no]->data.data();
320+
listDataPtr += vectorChunk * 32 * dim_ + vectorWithinChunk;
321+
322+
for (int d = 0; d < dim_; ++d) {
323+
fromDevice<float>(
324+
listDataPtr + 32 * d,
325+
out + (id - i0) * dim_ + d,
326+
1,
327+
stream);
328+
}
329+
}
330+
}
331+
}
332+
286333
void IVFFlat::searchImpl_(
287334
Tensor<float, 2, true>& queries,
288335
Tensor<float, 2, true>& coarseDistances,

faiss/gpu/impl/IVFFlat.cuh

+2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class IVFFlat : public IVFBase {
5151
Tensor<idx_t, 2, true>& outIndices,
5252
bool storePairs) override;
5353

54+
void reconstruct_n(idx_t i0, idx_t n, float* out) override;
55+
5456
protected:
5557
/// Returns the number of bytes in which an IVF list containing numVecs
5658
/// vectors is encoded on the device. Note that due to padding this is not

faiss/gpu/test/TestGpuIndexIVFFlat.cpp

+65
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,71 @@ TEST(TestGpuIndexIVFFlat, LongIVFList) {
842842
#endif
843843
}
844844

845+
TEST(TestGpuIndexIVFFlat, Reconstruct_n) {
846+
Options opt;
847+
848+
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);
849+
std::vector<float> addVecs = faiss::gpu::randVecs(opt.numAdd, opt.dim);
850+
851+
faiss::IndexFlatL2 cpuQuantizer(opt.dim);
852+
faiss::IndexIVFFlat cpuIndex(
853+
&cpuQuantizer, opt.dim, opt.numCentroids, faiss::METRIC_L2);
854+
cpuIndex.nprobe = opt.nprobe;
855+
cpuIndex.train(opt.numTrain, trainVecs.data());
856+
cpuIndex.add(opt.numAdd, addVecs.data());
857+
858+
faiss::gpu::StandardGpuResources res;
859+
res.noTempMemory();
860+
861+
faiss::gpu::GpuIndexIVFFlatConfig config;
862+
config.device = opt.device;
863+
config.indicesOptions = faiss::gpu::INDICES_64_BIT;
864+
config.use_raft = false;
865+
866+
faiss::gpu::GpuIndexIVFFlat gpuIndex(
867+
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
868+
gpuIndex.nprobe = opt.nprobe;
869+
870+
gpuIndex.train(opt.numTrain, trainVecs.data());
871+
gpuIndex.add(opt.numAdd, addVecs.data());
872+
873+
std::vector<float> gpuVals(opt.numAdd * opt.dim);
874+
875+
gpuIndex.reconstruct_n(0, gpuIndex.ntotal, gpuVals.data());
876+
877+
std::vector<float> cpuVals(opt.numAdd * opt.dim);
878+
879+
cpuIndex.reconstruct_n(0, cpuIndex.ntotal, cpuVals.data());
880+
881+
EXPECT_EQ(gpuVals, cpuVals);
882+
883+
config.indicesOptions = faiss::gpu::INDICES_32_BIT;
884+
885+
faiss::gpu::GpuIndexIVFFlat gpuIndex1(
886+
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
887+
gpuIndex1.nprobe = opt.nprobe;
888+
889+
gpuIndex1.train(opt.numTrain, trainVecs.data());
890+
gpuIndex1.add(opt.numAdd, addVecs.data());
891+
892+
gpuIndex1.reconstruct_n(0, gpuIndex1.ntotal, gpuVals.data());
893+
894+
EXPECT_EQ(gpuVals, cpuVals);
895+
896+
config.indicesOptions = faiss::gpu::INDICES_CPU;
897+
898+
faiss::gpu::GpuIndexIVFFlat gpuIndex2(
899+
&res, opt.dim, opt.numCentroids, faiss::METRIC_L2, config);
900+
gpuIndex2.nprobe = opt.nprobe;
901+
902+
gpuIndex2.train(opt.numTrain, trainVecs.data());
903+
gpuIndex2.add(opt.numAdd, addVecs.data());
904+
905+
gpuIndex2.reconstruct_n(0, gpuIndex2.ntotal, gpuVals.data());
906+
907+
EXPECT_EQ(gpuVals, cpuVals);
908+
}
909+
845910
int main(int argc, char** argv) {
846911
testing::InitGoogleTest(&argc, argv);
847912

faiss/gpu/test/test_gpu_basics.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import random
1212
from common_faiss_tests import get_dataset_2
1313

14+
1415
class ReferencedObject(unittest.TestCase):
1516

1617
d = 16
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import unittest
7+
8+
import faiss
9+
import numpy as np
10+
11+
12+
class TestGpuIndexIvfflat(unittest.TestCase):
13+
def test_reconstruct_n(self):
14+
index = faiss.index_factory(4, "IVF10,Flat")
15+
x = np.random.RandomState(123).rand(10, 4).astype('float32')
16+
index.train(x)
17+
index.add(x)
18+
res = faiss.StandardGpuResources()
19+
res.noTempMemory()
20+
config = faiss.GpuIndexIVFFlatConfig()
21+
config.use_raft = False
22+
index2 = faiss.GpuIndexIVFFlat(res, index, config)
23+
recons = index2.reconstruct_n(0, 10)
24+
25+
np.testing.assert_array_equal(recons, x)

faiss/gpu/test/torch_test_contrib_gpu.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_train_add_with_ids(self):
108108
self.assertTrue(np.array_equal(I.reshape(10), ids_np[10:20]))
109109

110110
# tests reconstruct, reconstruct_n
111-
def test_reconstruct(self):
111+
def test_flat_reconstruct(self):
112112
d = 32
113113
res = faiss.StandardGpuResources()
114114
res.noTempMemory()
@@ -157,6 +157,40 @@ def test_reconstruct(self):
157157
index.reconstruct_n(50, 10, y)
158158
self.assertTrue(torch.equal(xb[50:60], y))
159159

160+
def test_ivfflat_reconstruct(self):
161+
d = 32
162+
nlist = 5
163+
res = faiss.StandardGpuResources()
164+
res.noTempMemory()
165+
config = faiss.GpuIndexIVFFlatConfig()
166+
config.use_raft = False
167+
168+
index = faiss.GpuIndexIVFFlat(res, d, nlist, faiss.METRIC_L2, config)
169+
170+
xb = torch.rand(100, d, device=torch.device('cuda', 0), dtype=torch.float32)
171+
index.train(xb)
172+
index.add(xb)
173+
174+
# Test reconstruct_n with torch gpu (native return)
175+
y = index.reconstruct_n(10, 10)
176+
self.assertTrue(y.is_cuda)
177+
self.assertTrue(torch.equal(xb[10:20], y))
178+
179+
# Test reconstruct with numpy output provided
180+
y = np.empty((10, d), dtype='float32')
181+
index.reconstruct_n(20, 10, y)
182+
self.assertTrue(np.array_equal(xb.cpu().numpy()[20:30], y))
183+
184+
# Test reconstruct_n with torch cpu output provided
185+
y = torch.empty(10, d, dtype=torch.float32)
186+
index.reconstruct_n(40, 10, y)
187+
self.assertTrue(torch.equal(xb[40:50].cpu(), y))
188+
189+
# Test reconstruct_n with torch gpu output provided
190+
y = torch.empty(10, d, device=torch.device('cuda', 0), dtype=torch.float32)
191+
index.reconstruct_n(50, 10, y)
192+
self.assertTrue(torch.equal(xb[50:60], y))
193+
160194
# tests assign
161195
def test_assign(self):
162196
d = 32

faiss/gpu/utils/DeviceVector.cuh

+2
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ class DeviceVector {
169169
T out;
170170
CUDA_VERIFY(cudaMemcpyAsync(
171171
&out, data() + idx, sizeof(T), cudaMemcpyDeviceToHost, stream));
172+
173+
return out;
172174
}
173175

174176
// Clean up after oversized allocations, while leaving some space to

0 commit comments

Comments
 (0)