This repository was archived by the owner on Oct 31, 2023. It is now read-only.
File tree 2 files changed +10
-4
lines changed
2 files changed +10
-4
lines changed Original file line number Diff line number Diff line change @@ -547,7 +547,7 @@ def get_embs(self, indices):
547
547
assert len (indices )== len (self .dstores )
548
548
embs = []
549
549
for dstore , _indices in zip (self .dstores , indices ):
550
- embs .append (np . array ( dstore .get_embs (_indices ). tolist () ))
551
- return np . concatenate ( embs , - 2 )
550
+ embs .append (dstore .get_embs (_indices ))
551
+ return embs
552
552
553
553
Original file line number Diff line number Diff line change @@ -34,9 +34,15 @@ def decode(self, ids):
34
34
35
35
def get_scores (self , queries , x ):
36
36
if type (queries )== np .ndarray :
37
- all_scores = np .inner (queries , x ).squeeze (1 ) / np .sqrt (self .dstore .dimension )
37
+ if type (x )== list :
38
+ all_scores = np .concatenate ([self .get_scores (queries , xi ) for xi in x ], - 1 )
39
+ else :
40
+ all_scores = np .inner (queries , x ).squeeze (1 ) / np .sqrt (self .dstore .dimension )
38
41
else :
39
- all_scores = torch .inner (queries , x ).squeeze (1 ) / np .sqrt (self .dstore .dimension )
42
+ if type (x )== list :
43
+ all_scores = torch .cat ([self .get_scores (queries , xi ) for xi in x ], - 1 )
44
+ else :
45
+ all_scores = torch .inner (queries , x ).squeeze (1 ) / np .sqrt (self .dstore .dimension )
40
46
return all_scores
41
47
42
48
def get_all_scores (self , queries ):
You can’t perform that action at this time.
0 commit comments