Skip to content

Commit 7b96c39

Browse files
kuarorafacebook-github-bot
authored andcommitted
Support for Remove ids from IVFPQFastScan index (#3354)
Summary: **Change was previously reverted because of build failure as change D55577576 removed the definition of FAISS_THROW_IF_MSG** **Context** [Issue 3128](#3128) is an enhancement request to support remove_ids for IVFPQFastScan. Existing mechanism use direct map and iterate over items in selector and use scopecodes and scopeIds to replace item to be removed. Given that codes are packed, it is hard to return single code how it is packed in CodePackerPQ4. Thus, we need a custom implementation to removed_ids. **In this diff**, 1. We have added custom implementation of remove_ids from BlockInvertedLists which unpack code as it iterate and repack in new position. DirectMap use this remove_id function in BlockInvertedLists for type NoMap in DirectMap. 2. Also, we are throwing exception for other map type in DirectMap i.e. HashTable Reviewed By: ramilbakhshyiev Differential Revision: D55858959
1 parent 366a814 commit 7b96c39

File tree

5 files changed

+81
-22
lines changed

5 files changed

+81
-22
lines changed

faiss/impl/FaissAssert.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,15 @@
9494
} \
9595
} while (false)
9696

97-
#define FAISS_THROW_IF_NOT_MSG(X, MSG) \
97+
#define FAISS_THROW_IF_MSG(X, MSG) \
9898
do { \
99-
if (!(X)) { \
99+
if (X) { \
100100
FAISS_THROW_FMT("Error: '%s' failed: " MSG, #X); \
101101
} \
102102
} while (false)
103103

104+
#define FAISS_THROW_IF_NOT_MSG(X, MSG) FAISS_THROW_IF_MSG(!(X), MSG)
105+
104106
#define FAISS_THROW_IF_NOT_FMT(X, FMT, ...) \
105107
do { \
106108
if (!(X)) { \

faiss/invlists/BlockInvertedLists.cpp

+27-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <faiss/impl/CodePacker.h>
1111
#include <faiss/impl/FaissAssert.h>
12+
#include <faiss/impl/IDSelector.h>
1213

1314
#include <faiss/impl/io.h>
1415
#include <faiss/impl/io_macros.h>
@@ -54,7 +55,9 @@ size_t BlockInvertedLists::add_entries(
5455
codes[list_no].resize(n_block * block_size);
5556
if (o % block_size == 0) {
5657
// copy whole blocks
57-
memcpy(&codes[list_no][o * code_size], code, n_block * block_size);
58+
memcpy(&codes[list_no][o * packer->code_size],
59+
code,
60+
n_block * block_size);
5861
} else {
5962
FAISS_THROW_IF_NOT_MSG(packer, "missing code packer");
6063
std::vector<uint8_t> buffer(packer->code_size);
@@ -76,6 +79,29 @@ const uint8_t* BlockInvertedLists::get_codes(size_t list_no) const {
7679
return codes[list_no].get();
7780
}
7881

82+
size_t BlockInvertedLists::remove_ids(const IDSelector& sel) {
83+
idx_t nremove = 0;
84+
#pragma omp parallel for
85+
for (idx_t i = 0; i < nlist; i++) {
86+
std::vector<uint8_t> buffer(packer->code_size);
87+
idx_t l = ids[i].size(), j = 0;
88+
while (j < l) {
89+
if (sel.is_member(ids[i][j])) {
90+
l--;
91+
ids[i][j] = ids[i][l];
92+
packer->unpack_1(codes[i].data(), l, buffer.data());
93+
packer->pack_1(buffer.data(), j, codes[i].data());
94+
} else {
95+
j++;
96+
}
97+
}
98+
resize(i, l);
99+
nremove += ids[i].size() - l;
100+
}
101+
102+
return nremove;
103+
}
104+
79105
const idx_t* BlockInvertedLists::get_ids(size_t list_no) const {
80106
assert(list_no < nlist);
81107
return ids[list_no].data();
@@ -102,12 +128,6 @@ void BlockInvertedLists::update_entries(
102128
const idx_t*,
103129
const uint8_t*) {
104130
FAISS_THROW_MSG("not impemented");
105-
/*
106-
assert (list_no < nlist);
107-
assert (n_entry + offset <= ids[list_no].size());
108-
memcpy (&ids[list_no][offset], ids_in, sizeof(ids_in[0]) * n_entry);
109-
memcpy (&codes[list_no][offset * code_size], codes_in, code_size * n_entry);
110-
*/
111131
}
112132

113133
BlockInvertedLists::~BlockInvertedLists() {

faiss/invlists/BlockInvertedLists.h

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
namespace faiss {
1616

1717
struct CodePacker;
18+
struct IDSelector;
1819

1920
/** Inverted Lists that are organized by blocks.
2021
*
@@ -47,6 +48,8 @@ struct BlockInvertedLists : InvertedLists {
4748
size_t list_size(size_t list_no) const override;
4849
const uint8_t* get_codes(size_t list_no) const override;
4950
const idx_t* get_ids(size_t list_no) const override;
51+
/// remove ids from the InvertedLists
52+
size_t remove_ids(const IDSelector& sel);
5053

5154
// works only on empty BlockInvertedLists
5255
// the codes should be of size ceil(n_entry / n_per_block) * block_size

faiss/invlists/DirectMap.cpp

+9-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <faiss/impl/AuxIndexStructures.h>
1616
#include <faiss/impl/FaissAssert.h>
1717
#include <faiss/impl/IDSelector.h>
18+
#include <faiss/invlists/BlockInvertedLists.h>
1819

1920
namespace faiss {
2021

@@ -148,8 +149,12 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
148149
std::vector<idx_t> toremove(nlist);
149150

150151
size_t nremove = 0;
151-
152+
BlockInvertedLists* block_invlists =
153+
dynamic_cast<BlockInvertedLists*>(invlists);
152154
if (type == NoMap) {
155+
if (block_invlists != nullptr) {
156+
return block_invlists->remove_ids(sel);
157+
}
153158
// exhaustive scan of IVF
154159
#pragma omp parallel for
155160
for (idx_t i = 0; i < nlist; i++) {
@@ -178,6 +183,9 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
178183
}
179184
}
180185
} else if (type == Hashtable) {
186+
FAISS_THROW_IF_MSG(
187+
block_invlists,
188+
"remove with hashtable is not supported with BlockInvertedLists");
181189
const IDSelectorArray* sela =
182190
dynamic_cast<const IDSelectorArray*>(&sel);
183191
FAISS_THROW_IF_NOT_MSG(

tests/test_merge_index.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -246,19 +246,45 @@ def test_merge_IDMap2(self):
246246

247247
class TestRemoveFastScan(unittest.TestCase):
248248

249-
def do_fast_scan_test(self, factory_key, size1):
249+
def do_fast_scan_test(self,
250+
factory_key,
251+
with_ids=False,
252+
direct_map_type=faiss.DirectMap.NoMap):
250253
ds = SyntheticDataset(110, 1000, 1000, 100)
251-
index1 = faiss.index_factory(ds.d, factory_key)
252-
index1.train(ds.get_train())
253-
index1.reset()
254+
index = faiss.index_factory(ds.d, factory_key)
255+
index.train(ds.get_train())
256+
257+
index.reset()
254258
tokeep = [i % 3 == 0 for i in range(ds.nb)]
255-
index1.add(ds.get_database()[tokeep])
256-
_, Iref = index1.search(ds.get_queries(), 5)
257-
index1.reset()
258-
index1.add(ds.get_database())
259-
index1.remove_ids(np.where(np.logical_not(tokeep))[0])
260-
_, Inew = index1.search(ds.get_queries(), 5)
259+
if with_ids:
260+
index.add_with_ids(ds.get_database()[tokeep], np.arange(ds.nb)[tokeep])
261+
faiss.extract_index_ivf(index).nprobe = 5
262+
else:
263+
index.add(ds.get_database()[tokeep])
264+
_, Iref = index.search(ds.get_queries(), 5)
265+
266+
index.reset()
267+
if with_ids:
268+
index.add_with_ids(ds.get_database(), np.arange(ds.nb))
269+
index.set_direct_map_type(direct_map_type)
270+
faiss.extract_index_ivf(index).nprobe = 5
271+
else:
272+
index.add(ds.get_database())
273+
index.remove_ids(np.where(np.logical_not(tokeep))[0])
274+
_, Inew = index.search(ds.get_queries(), 5)
261275
np.testing.assert_array_equal(Inew, Iref)
262276

263-
def test_remove(self):
264-
self.do_fast_scan_test("PQ5x4fs", 320)
277+
def test_remove_PQFastScan(self):
278+
# with_ids is not support for this type of index
279+
self.do_fast_scan_test("PQ5x4fs", False)
280+
281+
def test_remove_IVFPQFastScan(self):
282+
self.do_fast_scan_test("IVF20,PQ5x4fs", True)
283+
284+
def test_remove_IVFPQFastScan_2(self):
285+
self.assertRaisesRegex(Exception,
286+
".*not supported.*",
287+
self.do_fast_scan_test,
288+
"IVF20,PQ5x4fs",
289+
True,
290+
faiss.DirectMap.Hashtable)

0 commit comments

Comments
 (0)