Skip to content

Commit 79db790

Browse files
GabrielePiccoGabriele Picco
and
Gabriele Picco
authored
🐛 Fix displacy render function (#10)
* 🐛 Fix displacy render Signed-off-by: Gabriele Picco <gabriele.picco@ibm.comm> * ✅ Add displacy test Signed-off-by: Gabriele Picco <gabriele.picco@ibm.comm> Signed-off-by: Gabriele Picco <gabriele.picco@ibm.comm> Co-authored-by: Gabriele Picco <gabriele.picco@ibm.comm>
1 parent 89b8d5b commit 79db790

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

zshot/tests/utils/test_utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
import spacy
2+
3+
from zshot import PipelineConfig, displacy
4+
from zshot.tests.config import EX_ENTITIES, EX_DOCS
5+
from zshot.tests.linker.test_linker import DummyLinkerEnd2End
6+
from zshot.tests.mentions_extractor.test_mention_extractor import DummyMentionsExtractor
17
from zshot.utils.data_models import Span
28
from zshot.utils.alignment_utils import align_spans, AlignmentMode, filter_overlapping_spans
39

@@ -164,3 +170,17 @@ def test_alignment_expand_overlaps_no_score():
164170
assert filtered_spans[0].label == "A"
165171
assert filtered_spans[1].start == 3 and filtered_spans[1].end == 8
166172
assert filtered_spans[1].label == "C"
173+
174+
175+
def test_displacy_render():
176+
nlp = spacy.blank("en")
177+
178+
nlp.add_pipe("zshot", config=PipelineConfig(
179+
mentions_extractor=DummyMentionsExtractor(),
180+
linker=DummyLinkerEnd2End(),
181+
entities=EX_ENTITIES), last=True)
182+
doc = nlp(EX_DOCS[1])
183+
assert len(doc.ents) > 0
184+
assert len(doc._.spans) > 0
185+
res = displacy.render(doc, style="ent", jupyter=False)
186+
assert res is not None

zshot/utils/displacy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ def color_from_label(label: str):
2727
class displacy:
2828

2929
@staticmethod
30-
def render(doc, options: Dict = None, **kwargs):
30+
def render(doc, options: Dict = None, **kwargs) -> str:
3131
if options:
3232
options['colors'] = ents_colors(doc)
3333
else:
3434
options = {'colors': ents_colors(doc)}
35-
s_displacy.render(doc, options=options, **kwargs)
35+
return s_displacy.render(doc, options=options, **kwargs)
3636

3737
@staticmethod
3838
def serve(doc, options: Dict = None, **kwargs):

0 commit comments

Comments
 (0)