1
1
from typing import Optional , List , Union
2
2
3
3
import spacy
4
- from datasets import Dataset , NamedSplit
5
4
from evaluate import EvaluationModule
6
5
from prettytable import PrettyTable
7
-
8
- from zshot .evaluation .evaluator import (
9
- ZeroShotTokenClassificationEvaluator ,
10
- MentionsExtractorEvaluator ,
11
- )
6
+ from zshot .evaluation import load_medmentions , load_ontonotes
7
+ from zshot .evaluation .evaluator import ZeroShotTokenClassificationEvaluator , MentionsExtractorEvaluator
12
8
from zshot .evaluation .pipeline import LinkerPipeline , MentionsExtractorPipeline
13
9
14
10
15
- def evaluate (
16
- nlp : spacy .language .Language ,
17
- datasets : Union [Dataset , List [Dataset ]],
18
- splits : Optional [Union [NamedSplit , List [NamedSplit ]]] = None ,
19
- metric : Optional [Union [str , EvaluationModule ]] = None ,
20
- batch_size : Optional [int ] = 16 ,
21
- ) -> str :
22
- """Evaluate a spacy zshot model
11
+ def evaluate (nlp : spacy .language .Language ,
12
+ datasets : Union [str , List [str ]],
13
+ splits : Optional [Union [str , List [str ]]] = None ,
14
+ metric : Optional [Union [str , EvaluationModule ]] = None ,
15
+ batch_size : Optional [int ] = 16 ) -> str :
16
+ """ Evaluate a spacy zshot model
23
17
24
18
:param nlp: Spacy Language pipeline with ZShot components
25
19
:param datasets: Dataset or list of datasets to evaluate
@@ -32,42 +26,52 @@ def evaluate(
32
26
linker_evaluator = ZeroShotTokenClassificationEvaluator ("token-classification" )
33
27
mentions_extractor_evaluator = MentionsExtractorEvaluator ("token-classification" )
34
28
35
- if not isinstance (splits , list ) :
29
+ if type (splits ) == str :
36
30
splits = [splits ]
37
31
38
- if not isinstance (datasets , list ) :
32
+ if type (datasets ) == str :
39
33
datasets = [datasets ]
40
34
41
35
result = {}
42
36
field_names = ["Metric" ]
43
- for dataset in datasets :
37
+ for dataset_name in datasets :
38
+ if dataset_name .lower () == "medmentions" :
39
+ dataset = load_medmentions ()
40
+ else :
41
+ dataset = load_ontonotes ()
42
+
44
43
for split in splits :
45
- field_name = f"{ dataset . description } { split } "
44
+ field_name = f"{ dataset_name } { split } "
46
45
field_names .append (field_name )
47
46
nlp .get_pipe ("zshot" ).mentions = dataset [split ].entities
48
47
nlp .get_pipe ("zshot" ).entities = dataset [split ].entities
49
48
if nlp .get_pipe ("zshot" ).linker :
50
49
pipe = LinkerPipeline (nlp , batch_size )
51
- result .update (
52
- {
53
- field_name : {
54
- "linker" : linker_evaluator .compute (
55
- pipe , dataset [split ], metric = metric
56
- )
50
+ res_tmp = {
51
+ 'linker' : linker_evaluator .compute (pipe , dataset [split ], metric = metric )
52
+ }
53
+ if field_name not in result :
54
+ result .update (
55
+ {
56
+ field_name : res_tmp
57
57
}
58
- }
59
- )
58
+ )
59
+ else :
60
+ result [field_name ].update (res_tmp )
60
61
if nlp .get_pipe ("zshot" ).mentions_extractor :
61
62
pipe = MentionsExtractorPipeline (nlp , batch_size )
62
- result .update (
63
- {
64
- field_name : {
65
- "mentions_extractor" : mentions_extractor_evaluator .compute (
66
- pipe , dataset [split ], metric = metric
67
- )
63
+ res_tmp = {
64
+ 'mentions_extractor' : mentions_extractor_evaluator .compute (pipe , dataset [split ],
65
+ metric = metric )
66
+ }
67
+ if field_name not in result :
68
+ result .update (
69
+ {
70
+ field_name : res_tmp
68
71
}
69
- }
70
- )
72
+ )
73
+ else :
74
+ result [field_name ].update (res_tmp )
71
75
72
76
table = PrettyTable ()
73
77
table .field_names = field_names
@@ -81,25 +85,11 @@ def evaluate(
81
85
for field_name in field_names :
82
86
if field_name == "Metric" :
83
87
continue
84
- linker_precisions .append (
85
- "{:.2f}%" .format (
86
- result [field_name ]["linker" ]["overall_precision_macro" ] * 100
87
- )
88
- )
89
- linker_recalls .append (
90
- "{:.2f}%" .format (
91
- result [field_name ]["linker" ]["overall_recall_macro" ] * 100
92
- )
93
- )
94
- linker_accuracies .append (
95
- "{:.2f}%" .format (result [field_name ]["linker" ]["overall_accuracy" ] * 100 )
96
- )
97
- linker_micros .append (
98
- "{:.2f}%" .format (result [field_name ]["linker" ]["overall_f1_micro" ] * 100 )
99
- )
100
- linker_macros .append (
101
- "{:.2f}%" .format (result [field_name ]["linker" ]["overall_f1_macro" ] * 100 )
102
- )
88
+ linker_precisions .append ("{:.2f}%" .format (result [field_name ]['linker' ]['overall_precision_macro' ] * 100 ))
89
+ linker_recalls .append ("{:.2f}%" .format (result [field_name ]['linker' ]['overall_recall_macro' ] * 100 ))
90
+ linker_accuracies .append ("{:.2f}%" .format (result [field_name ]['linker' ]['overall_accuracy' ] * 100 ))
91
+ linker_micros .append ("{:.2f}%" .format (result [field_name ]['linker' ]['overall_f1_micro' ] * 100 ))
92
+ linker_macros .append ("{:.2f}%" .format (result [field_name ]['linker' ]['overall_f1_macro' ] * 100 ))
103
93
104
94
rows .append (["Linker Precision" ] + linker_precisions )
105
95
rows .append (["Linker Recall" ] + linker_recalls )
@@ -117,30 +107,15 @@ def evaluate(
117
107
if field_name == "Metric" :
118
108
continue
119
109
mentions_extractor_precisions .append (
120
- "{:.2f}%" .format (
121
- result [field_name ]["mentions_extractor" ]["overall_precision_macro" ] * 100
122
- )
123
- )
110
+ "{:.2f}%" .format (result [field_name ]['mentions_extractor' ]['overall_precision_macro' ] * 100 ))
124
111
mentions_extractor_recalls .append (
125
- "{:.2f}%" .format (
126
- result [field_name ]["mentions_extractor" ]["overall_recall_macro" ] * 100
127
- )
128
- )
112
+ "{:.2f}%" .format (result [field_name ]['mentions_extractor' ]['overall_recall_macro' ] * 100 ))
129
113
mentions_extractor_accuracies .append (
130
- "{:.2f}%" .format (
131
- result [field_name ]["mentions_extractor" ]["overall_accuracy" ] * 100
132
- )
133
- )
114
+ "{:.2f}%" .format (result [field_name ]['mentions_extractor' ]['overall_accuracy' ] * 100 ))
134
115
mentions_extractor_micros .append (
135
- "{:.2f}%" .format (
136
- result [field_name ]["mentions_extractor" ]["overall_f1_micro" ] * 100
137
- )
138
- )
116
+ "{:.2f}%" .format (result [field_name ]['mentions_extractor' ]['overall_f1_micro' ] * 100 ))
139
117
mentions_extractor_macros .append (
140
- "{:.2f}%" .format (
141
- result [field_name ]["mentions_extractor" ]["overall_f1_macro" ] * 100
142
- )
143
- )
118
+ "{:.2f}%" .format (result [field_name ]['mentions_extractor' ]['overall_f1_macro' ] * 100 ))
144
119
145
120
rows .append (["Mentions extractor Precision" ] + mentions_extractor_precisions )
146
121
rows .append (["Mentions extractor Recall" ] + mentions_extractor_recalls )
0 commit comments