Skip to content

Commit eebbbb1

Browse files
authored
🐛 Fix bug in evaluation with both mentions extractors and linkers (#34)
* 🐛 Fix bug in evaluation with both mentions extractors and linkers * 🎨 Fix style Signed-off-by: Marcos Martinez <Marcos.Martinez.Galindo@ibm.com>
1 parent 3ae7066 commit eebbbb1

File tree

2 files changed

+51
-84
lines changed

2 files changed

+51
-84
lines changed

zshot/evaluation/run_evaluation.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import argparse
22

33
import spacy
4-
54
from zshot import PipelineConfig
6-
from zshot.evaluation import load_medmentions, load_ontonotes
75
from zshot.evaluation.metrics.seqeval.seqeval import Seqeval
86
from zshot.evaluation.zshot_evaluate import evaluate
97
from zshot.linker import LinkerTARS, LinkerSMXM, LinkerRegen
@@ -26,7 +24,8 @@
2624

2725
if __name__ == "__main__":
2826
parser = argparse.ArgumentParser()
29-
parser.add_argument("--dataset", default="ontonotes", type=str, help="Name or path to the validation data. Comma separated")
27+
parser.add_argument("--dataset", default="ontonotes", type=str,
28+
help="Name or path to the validation data. Comma separated")
3029
parser.add_argument("--splits", required=False, default="train, test, validation", type=str,
3130
help="Splits to evaluate. Comma separated")
3231
parser.add_argument("--mode", required=False, default="full", type=str,
@@ -62,8 +61,6 @@
6261
)
6362
else:
6463
configs[linker] = PipelineConfig(linker=LINKERS[linker]())
65-
for mentions_extractor in mentions_extractors:
66-
configs[mentions_extractor] = PipelineConfig(mentions_extractor=MENTION_EXTRACTORS[mentions_extractor]())
6764
elif args.mode == "mentions_extractor":
6865
for mentions_extractor in mentions_extractors:
6966
configs[mentions_extractor] = PipelineConfig(mentions_extractor=MENTION_EXTRACTORS[mentions_extractor]())
@@ -81,9 +78,4 @@
8178
nlp = spacy.blank("en") if "spacy" not in key else spacy.load("en_core_web_sm")
8279
nlp.add_pipe("zshot", config=config, last=True)
8380

84-
if args.dataset.lower() == "medmentions":
85-
dataset = load_medmentions()
86-
else:
87-
dataset = load_ontonotes()
88-
89-
print(evaluate(nlp, dataset, splits=args.splits, metric=Seqeval()))
81+
print(evaluate(nlp, args.dataset, splits=args.splits, metric=Seqeval()))

zshot/evaluation/zshot_evaluate.py

+48-73
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,19 @@
11
from typing import Optional, List, Union
22

33
import spacy
4-
from datasets import Dataset, NamedSplit
54
from evaluate import EvaluationModule
65
from prettytable import PrettyTable
7-
8-
from zshot.evaluation.evaluator import (
9-
ZeroShotTokenClassificationEvaluator,
10-
MentionsExtractorEvaluator,
11-
)
6+
from zshot.evaluation import load_medmentions, load_ontonotes
7+
from zshot.evaluation.evaluator import ZeroShotTokenClassificationEvaluator, MentionsExtractorEvaluator
128
from zshot.evaluation.pipeline import LinkerPipeline, MentionsExtractorPipeline
139

1410

15-
def evaluate(
16-
nlp: spacy.language.Language,
17-
datasets: Union[Dataset, List[Dataset]],
18-
splits: Optional[Union[NamedSplit, List[NamedSplit]]] = None,
19-
metric: Optional[Union[str, EvaluationModule]] = None,
20-
batch_size: Optional[int] = 16,
21-
) -> str:
22-
"""Evaluate a spacy zshot model
11+
def evaluate(nlp: spacy.language.Language,
12+
datasets: Union[str, List[str]],
13+
splits: Optional[Union[str, List[str]]] = None,
14+
metric: Optional[Union[str, EvaluationModule]] = None,
15+
batch_size: Optional[int] = 16) -> str:
16+
""" Evaluate a spacy zshot model
2317
2418
:param nlp: Spacy Language pipeline with ZShot components
2519
:param datasets: Dataset or list of datasets to evaluate
@@ -32,42 +26,52 @@ def evaluate(
3226
linker_evaluator = ZeroShotTokenClassificationEvaluator("token-classification")
3327
mentions_extractor_evaluator = MentionsExtractorEvaluator("token-classification")
3428

35-
if not isinstance(splits, list):
29+
if type(splits) == str:
3630
splits = [splits]
3731

38-
if not isinstance(datasets, list):
32+
if type(datasets) == str:
3933
datasets = [datasets]
4034

4135
result = {}
4236
field_names = ["Metric"]
43-
for dataset in datasets:
37+
for dataset_name in datasets:
38+
if dataset_name.lower() == "medmentions":
39+
dataset = load_medmentions()
40+
else:
41+
dataset = load_ontonotes()
42+
4443
for split in splits:
45-
field_name = f"{dataset.description} {split}"
44+
field_name = f"{dataset_name} {split}"
4645
field_names.append(field_name)
4746
nlp.get_pipe("zshot").mentions = dataset[split].entities
4847
nlp.get_pipe("zshot").entities = dataset[split].entities
4948
if nlp.get_pipe("zshot").linker:
5049
pipe = LinkerPipeline(nlp, batch_size)
51-
result.update(
52-
{
53-
field_name: {
54-
"linker": linker_evaluator.compute(
55-
pipe, dataset[split], metric=metric
56-
)
50+
res_tmp = {
51+
'linker': linker_evaluator.compute(pipe, dataset[split], metric=metric)
52+
}
53+
if field_name not in result:
54+
result.update(
55+
{
56+
field_name: res_tmp
5757
}
58-
}
59-
)
58+
)
59+
else:
60+
result[field_name].update(res_tmp)
6061
if nlp.get_pipe("zshot").mentions_extractor:
6162
pipe = MentionsExtractorPipeline(nlp, batch_size)
62-
result.update(
63-
{
64-
field_name: {
65-
"mentions_extractor": mentions_extractor_evaluator.compute(
66-
pipe, dataset[split], metric=metric
67-
)
63+
res_tmp = {
64+
'mentions_extractor': mentions_extractor_evaluator.compute(pipe, dataset[split],
65+
metric=metric)
66+
}
67+
if field_name not in result:
68+
result.update(
69+
{
70+
field_name: res_tmp
6871
}
69-
}
70-
)
72+
)
73+
else:
74+
result[field_name].update(res_tmp)
7175

7276
table = PrettyTable()
7377
table.field_names = field_names
@@ -81,25 +85,11 @@ def evaluate(
8185
for field_name in field_names:
8286
if field_name == "Metric":
8387
continue
84-
linker_precisions.append(
85-
"{:.2f}%".format(
86-
result[field_name]["linker"]["overall_precision_macro"] * 100
87-
)
88-
)
89-
linker_recalls.append(
90-
"{:.2f}%".format(
91-
result[field_name]["linker"]["overall_recall_macro"] * 100
92-
)
93-
)
94-
linker_accuracies.append(
95-
"{:.2f}%".format(result[field_name]["linker"]["overall_accuracy"] * 100)
96-
)
97-
linker_micros.append(
98-
"{:.2f}%".format(result[field_name]["linker"]["overall_f1_micro"] * 100)
99-
)
100-
linker_macros.append(
101-
"{:.2f}%".format(result[field_name]["linker"]["overall_f1_macro"] * 100)
102-
)
88+
linker_precisions.append("{:.2f}%".format(result[field_name]['linker']['overall_precision_macro'] * 100))
89+
linker_recalls.append("{:.2f}%".format(result[field_name]['linker']['overall_recall_macro'] * 100))
90+
linker_accuracies.append("{:.2f}%".format(result[field_name]['linker']['overall_accuracy'] * 100))
91+
linker_micros.append("{:.2f}%".format(result[field_name]['linker']['overall_f1_micro'] * 100))
92+
linker_macros.append("{:.2f}%".format(result[field_name]['linker']['overall_f1_macro'] * 100))
10393

10494
rows.append(["Linker Precision"] + linker_precisions)
10595
rows.append(["Linker Recall"] + linker_recalls)
@@ -117,30 +107,15 @@ def evaluate(
117107
if field_name == "Metric":
118108
continue
119109
mentions_extractor_precisions.append(
120-
"{:.2f}%".format(
121-
result[field_name]["mentions_extractor"]["overall_precision_macro"] * 100
122-
)
123-
)
110+
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_precision_macro'] * 100))
124111
mentions_extractor_recalls.append(
125-
"{:.2f}%".format(
126-
result[field_name]["mentions_extractor"]["overall_recall_macro"] * 100
127-
)
128-
)
112+
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_recall_macro'] * 100))
129113
mentions_extractor_accuracies.append(
130-
"{:.2f}%".format(
131-
result[field_name]["mentions_extractor"]["overall_accuracy"] * 100
132-
)
133-
)
114+
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_accuracy'] * 100))
134115
mentions_extractor_micros.append(
135-
"{:.2f}%".format(
136-
result[field_name]["mentions_extractor"]["overall_f1_micro"] * 100
137-
)
138-
)
116+
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_f1_micro'] * 100))
139117
mentions_extractor_macros.append(
140-
"{:.2f}%".format(
141-
result[field_name]["mentions_extractor"]["overall_f1_macro"] * 100
142-
)
143-
)
118+
"{:.2f}%".format(result[field_name]['mentions_extractor']['overall_f1_macro'] * 100))
144119

145120
rows.append(["Mentions extractor Precision"] + mentions_extractor_precisions)
146121
rows.append(["Mentions extractor Recall"] + mentions_extractor_recalls)

0 commit comments

Comments
 (0)