Skip to content

Commit 7d21c92

Browse files
mlomeli1facebook-github-bot
authored andcommitted
Dim reduction support in OIVFBBS (#3290)
Summary: This PR adds support for dimensionality reduction in OIVFBBS. I tested the code with an index `OPQ64_128,IVF4096,PQ64` using the ssnpp embeddings - this index string is added to the config_ssnpp.yaml to showcase this functionality. Pull Request resolved: #3290 Reviewed By: junjieqi Differential Revision: D54878345 Pulled By: mlomeli1 fbshipit-source-id: 98ecdeb2224ce0325e37720cc113d82f9c6c75d6
1 parent d5e4c79 commit 7d21c92

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

demos/offline_ivf/config_ssnpp.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ index:
66
non-prod:
77
- 'IVF16384,PQ128'
88
- 'IVF32768,PQ128'
9+
- 'OPQ64_128,IVF4096,PQ64'
910
nprobe:
1011
prod:
1112
- 512

demos/offline_ivf/offline_ivf.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def dedupe(self):
178178
idxs.append(np.empty((0,), dtype=np.uint32))
179179
bs = 1_000_000
180180
i = 0
181-
for buffer in tqdm(self.xb_ds.iterate(0, bs, np.float32)):
181+
for buffer in tqdm(self._iterate_transformed(self.xb_ds, 0, bs, np.float32)):
182182
for j in range(len(codecs)):
183183
codec, codeset, idx = codecs[j], codesets[j], idxs[j]
184184
uniq = codeset.insert(codec.sa_encode(buffer))
@@ -267,11 +267,18 @@ def index_shard(self):
267267
),
268268
file=sys.stdout,
269269
):
270-
assert xb_j.shape[1] == index.d
271-
index.add_with_ids(
272-
xb_j,
273-
np.arange(start + jj, start + jj + xb_j.shape[0]),
274-
)
270+
if is_pretransform_index(index):
271+
assert xb_j.shape[1] == index.chain.at(0).d_out
272+
index_ivf.add_with_ids(
273+
xb_j,
274+
np.arange(start + jj, start + jj + xb_j.shape[0]),
275+
)
276+
else:
277+
assert xb_j.shape[1] == index.d
278+
index.add_with_ids(
279+
xb_j,
280+
np.arange(start + jj, start + jj + xb_j.shape[0]),
281+
)
275282
jj += xb_j.shape[0]
276283
logging.info(jj)
277284
assert (
@@ -670,10 +677,14 @@ def search(self):
670677
os.remove(Ifn)
671678
os.remove(Dfn)
672679

673-
try: # TODO: modify shape for pretransform case
680+
try:
681+
if is_pretransform_index(index):
682+
d = index.chain.at(0).d_out
683+
else:
684+
d = self.input_d
674685
with open(Ifn, "xb") as f, open(Dfn, "xb") as g:
675686
xq_i = np.empty(
676-
shape=(self.xq_bs, self.input_d), dtype=np.float16
687+
shape=(self.xq_bs, d), dtype=np.float16
677688
)
678689
q_assign = np.empty(
679690
(self.xq_bs, self.nprobe), dtype=np.int32
@@ -835,8 +846,7 @@ def consistency_check(self):
835846
for j in range(SMALL_DATA_SAMPLE):
836847
assert np.where(I[j] == j + r)[0].size > 0, (
837848
f"I[j]: {I[j]}, j: {j}, i: {i}, shard_size:"
838-
f" {self.shard_size}"
839-
)
849+
f" {self.shard_size}")
840850

841851
logging.info("search results...")
842852
index_ivf.nprobe = self.nprobe

0 commit comments

Comments
 (0)