Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Commit d284f40

Browse files
author
Sewon Min
committed
improve speed when using union of corpora
1 parent 91dec78 commit d284f40

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

npm/dstore.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ def get_embs(self, indices):
547547
assert len(indices)==len(self.dstores)
548548
embs = []
549549
for dstore, _indices in zip(self.dstores, indices):
550-
embs.append(np.array(dstore.get_embs(_indices).tolist()))
551-
return np.concatenate(embs, -2)
550+
embs.append(dstore.get_embs(_indices))
551+
return embs
552552

553553

npm/npm_single.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,15 @@ def decode(self, ids):
3434

3535
def get_scores(self, queries, x):
3636
if type(queries)==np.ndarray:
37-
all_scores = np.inner(queries, x).squeeze(1) / np.sqrt(self.dstore.dimension)
37+
if type(x)==list:
38+
all_scores = np.concatenate([self.get_scores(queries, xi) for xi in x], -1)
39+
else:
40+
all_scores = np.inner(queries, x).squeeze(1) / np.sqrt(self.dstore.dimension)
3841
else:
39-
all_scores = torch.inner(queries, x).squeeze(1) / np.sqrt(self.dstore.dimension)
42+
if type(x)==list:
43+
all_scores = torch.cat([self.get_scores(queries, xi) for xi in x], -1)
44+
else:
45+
all_scores = torch.inner(queries, x).squeeze(1) / np.sqrt(self.dstore.dimension)
4046
return all_scores
4147

4248
def get_all_scores(self, queries):

0 commit comments

Comments
 (0)