Skip to content

Commit 07c9893

Browse files
✨ Add regen wikification (#44)
* ✨ Add regen wikification * 🐛 Fix Wikification * 🐛 Reduce tests complexity * 🐛 Reduce test resources * 🐛 Fix test * ➖ Remove test file * ✏️ Remove too expensive test
1 parent b2c2a27 commit 07c9893

File tree

4 files changed

+113
-15
lines changed

4 files changed

+113
-15
lines changed

zshot/linker/linker_regen/linker_regen.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,37 @@
1616

1717
class LinkerRegen(Linker):
1818
""" REGEN linker """
19-
def __init__(self, max_input_len=384, max_output_len=15, num_beams=10):
19+
def __init__(self, max_input_len=384, max_output_len=15, num_beams=10, trie=None):
2020
"""
2121
:param max_input_len: Max length of input
2222
:param max_output_len: Max length of output
2323
:param num_beams: Number of beans to use
24+
:param trie: If the trie is given the linker will use it to restrict the search space.
25+
Custom entities won't be used if the trie is given.
2426
"""
2527
super().__init__()
2628
self.model = None
2729
self.tokenizer = None
28-
self.trie = None
2930
self.max_input_len = max_input_len
3031
self.max_output_len = max_output_len
3132
self.num_beams = num_beams
33+
self.skip_set_kg = False if trie is None else True
34+
self.trie = trie
3235

3336
def set_kg(self, entities: Iterator[Entity]):
3437
""" Set new entities
3538
3639
:param entities: New entities to use
3740
"""
3841
super().set_kg(entities)
39-
self.load_tokenizer()
40-
self.trie = Trie(
41-
[
42-
self.tokenizer(e.name, return_tensors="pt")['input_ids'][0].tolist()
43-
for e in entities
44-
]
45-
)
42+
if not self.skip_set_kg:
43+
self.load_tokenizer()
44+
self.trie = Trie(
45+
[
46+
self.tokenizer(e.name, return_tensors="pt")['input_ids'][0].tolist()
47+
for e in entities
48+
]
49+
)
4650

4751
def load_models(self):
4852
""" Load Model """

zshot/linker/linker_regen/trie.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1-
from typing import List
1+
from typing import Collection
22

33

44
class Trie(object):
5-
def __init__(self, sequences: List[List[int]] = []):
5+
def __init__(self, sequences: Collection[Collection[int]] = []):
66
self.trie_dict = {}
77
for sequence in sequences:
88
self.add(sequence)
99

10-
def add(self, sequence: List[int]):
10+
def add(self, sequence: Collection[int]):
1111
trie = self.trie_dict
1212
for idx in sequence:
1313
if idx not in trie:
1414
trie[idx] = {}
1515
trie = trie[idx]
1616

17-
def postfix(self, prefix_sequence: List[int]):
17+
def postfix(self, prefix_sequence: Collection[int]):
1818
if len(prefix_sequence) == 1:
1919
return list(self.trie_dict.keys())
2020
trie = self.trie_dict

zshot/linker/linker_regen/utils.py

+57
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
import json
2+
import pickle
3+
from typing import Dict, List
4+
5+
import pytest
6+
from huggingface_hub import hf_hub_download
7+
8+
from zshot.linker.linker_regen.trie import Trie
9+
from zshot.utils.data_models import Span
10+
11+
REPO_ID = "ibm/regen-disambiguation"
12+
TRIE_FILE_NAME = "wikipedia_trie.pkl"
13+
WIKIPEDIA_MAP = "wikipedia_map_id.json"
14+
15+
116
def create_input(sentence, max_length, start_delimiter, end_delimiter):
217
sent_list = sentence.split(" ")
318
if len(sent_list) < max_length:
@@ -12,3 +27,45 @@ def create_input(sentence, max_length, start_delimiter, end_delimiter):
1227
left_index = left_index - max(0, (half_context - (right_index - end_delimiter_index)))
1328
print(len(sent_list[left_index:right_index]))
1429
return " ".join(sent_list[left_index:right_index])
30+
31+
32+
def load_wikipedia_trie() -> Trie:
33+
"""
34+
Load the wikipedia trie from the HB hub
35+
:return: The Wikipedia trie
36+
"""
37+
wikipedia_trie_file = hf_hub_download(repo_id=REPO_ID,
38+
repo_type='model',
39+
filename=TRIE_FILE_NAME)
40+
with open(wikipedia_trie_file, "rb") as f:
41+
wikipedia_trie = pickle.load(f)
42+
return wikipedia_trie
43+
44+
45+
@pytest.mark.skip(reason="Too expensive to run on every commit")
46+
def load_wikipedia_mapping() -> Dict[str, str]:
47+
"""
48+
Load the wikipedia trie from the HB hub
49+
:return: The Wikipedia trie
50+
"""
51+
wikipedia_map = hf_hub_download(repo_id=REPO_ID,
52+
repo_type='model',
53+
filename=WIKIPEDIA_MAP)
54+
with open(wikipedia_map, "r") as f:
55+
wikipedia_map = json.load(f)
56+
return wikipedia_map
57+
58+
59+
def spans_to_wikipedia(spans: List[Span]) -> List[str]:
60+
"""
61+
Generate wikipedia link for spans
62+
:return: The list of generated links
63+
"""
64+
links = []
65+
wikipedia_map = load_wikipedia_mapping()
66+
for s in spans:
67+
if s.label in wikipedia_map:
68+
links.append(f"https://en.wikipedia.org/wiki?curid={wikipedia_map[s.label]}")
69+
else:
70+
links.append(None)
71+
return links

zshot/tests/linker/test_regen_linker.py

+39-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88

99
from zshot import PipelineConfig
1010
from zshot.linker.linker_regen.linker_regen import LinkerRegen
11+
from zshot.linker.linker_regen.trie import Trie
12+
from zshot.linker.linker_regen.utils import load_wikipedia_trie, spans_to_wikipedia
1113
from zshot.mentions_extractor import MentionsExtractorSpacy
1214
from zshot.tests.config import EX_DOCS, EX_ENTITIES
15+
from zshot.tests.mentions_extractor.test_mention_extractor import DummyMentionsExtractor
16+
from zshot.utils.data_models import Span
1317

1418
logger = logging.getLogger(__name__)
1519

@@ -25,9 +29,9 @@ def teardown():
2529

2630

2731
def test_regen_linker():
28-
nlp = spacy.load("en_core_web_sm")
32+
nlp = spacy.blank("en")
2933
config = PipelineConfig(
30-
mentions_extractor=MentionsExtractorSpacy(),
34+
mentions_extractor=DummyMentionsExtractor(),
3135
linker=LinkerRegen(),
3236
entities=EX_ENTITIES
3337
)
@@ -60,3 +64,36 @@ def test_regen_linker_pipeline():
6064
nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker
6165
nlp.remove_pipe('zshot')
6266
del docs, nlp, config
67+
68+
69+
def test_regen_linker_wikification():
70+
nlp = spacy.blank("en")
71+
trie = Trie()
72+
trie.add([794, 536, 1])
73+
trie.add([794, 357, 1])
74+
config = PipelineConfig(
75+
mentions_extractor=DummyMentionsExtractor(),
76+
linker=LinkerRegen(trie=trie),
77+
)
78+
nlp.add_pipe("zshot", config=config, last=True)
79+
assert "zshot" in nlp.pipe_names
80+
81+
doc = nlp(EX_DOCS[1])
82+
assert len(doc.ents) > 0
83+
del nlp.get_pipe('zshot').mentions_extractor, nlp.get_pipe('zshot').entities, nlp.get_pipe('zshot').nlp
84+
del nlp.get_pipe('zshot').linker.tokenizer, nlp.get_pipe('zshot').linker.trie, \
85+
nlp.get_pipe('zshot').linker.model, nlp.get_pipe('zshot').linker
86+
nlp.remove_pipe('zshot')
87+
del doc, nlp, config
88+
89+
90+
def test_load_wikipedia_trie():
91+
trie = load_wikipedia_trie()
92+
assert len(list(trie.trie_dict.keys())) == 6952
93+
94+
95+
def test_span_to_wiki():
96+
s = Span(label="Surfing", start=0, end=10)
97+
wiki_links = spans_to_wikipedia([s])
98+
assert len(wiki_links) > 0
99+
assert wiki_links[0].startswith("https://en.wikipedia.org/wiki?curid=")

0 commit comments

Comments
 (0)