@@ -20,17 +20,11 @@ def main():
20
20
parser .add_argument ("--data_dir" , type = str , default = "train_corpus" )
21
21
parser .add_argument ("--mr" , type = float , default = 0.15 )
22
22
parser .add_argument ("--p" , type = float , default = 0.5 )
23
-
24
23
parser .add_argument ("--batch_size" , type = int , default = 16 )
25
- parser .add_argument ("--analysis" , action = "store_true" )
26
24
parser .add_argument ("--num_shards" , type = int , default = 10 )
27
25
28
26
args = parser .parse_args ()
29
27
30
- if args .analysis :
31
- analysis (args )
32
- return
33
-
34
28
ext = "_mr{}_p{}.jsonl" .format (args .mr , args .p )
35
29
36
30
def find_files (out_dir ):
@@ -60,163 +54,6 @@ def find_files(out_dir):
60
54
tot += 1
61
55
62
56
63
- def analysis (args ):
64
- import json
65
- from transformers import RobertaTokenizer
66
- tokenizer = RobertaTokenizer .from_pretrained ("roberta-large" )
67
- mask_id = tokenizer .mask_token_id
68
-
69
- def load (fn ):
70
- print ("Starting loading" , fn )
71
- data = []
72
- raw_text_to_position = {}
73
- with open (fn , "r" ) as f :
74
- for line in f :
75
- dp = json .loads (line )
76
- for i , raw_text in enumerate (dp ["contents" ]):
77
- raw_text_to_position [raw_text ] = (len (data ), i )
78
- data .append (dp )
79
- if len (data )== 3000 :
80
- break
81
- return data , raw_text_to_position
82
-
83
- def backgrounded (text , color ):
84
- return "<span style='background-color: {}'>{}</span>" .format (color , text )
85
-
86
- def decode (masked_input_ids_list , merged_labels ):
87
- decoded_list = []
88
- colors = ["#FAF884" , "#E2FAB5" ]
89
- for i , (labels , masked_input_ids ) in enumerate (zip (merged_labels , masked_input_ids_list )):
90
- while masked_input_ids [- 1 ]== 0 :
91
- masked_input_ids = masked_input_ids [:- 1 ]
92
- decoded = tokenizer .decode (masked_input_ids )
93
- color_idx = 0
94
- for label in labels :
95
- assert "<mask>" * len (label ) in decoded
96
- decoded = decoded .replace ("<mask>" * len (label ),
97
- backgrounded (tokenizer .decode (label ), colors [color_idx ]),
98
- 1 )
99
- color_idx = 1 - color_idx
100
- assert "<mask>" not in decoded
101
- decoded_list .append (decoded .replace ("<s>" , "" ).replace ("</s>" , "" ))
102
- return decoded_list
103
-
104
- if args .wiki :
105
- data_dir = "/private/home/sewonmin/data/enwiki/enwiki_roberta_tokenized"
106
- prefix = "enwiki0_grouped"
107
- else :
108
- data_dir = "/private/home/sewonmin/data/cc_news_en/cc_news_roberta_tokenized"
109
- prefix = "batch0" #_grouped_v4"
110
-
111
- output_file = os .path .join (data_dir , "{}_{}.jsonl" .format (prefix , "mr0.4_p0.2" ))
112
- output2_file = os .path .join (data_dir , "{}_{}_token_ids.jsonl" .format (prefix , "mr0.4_p0.2" ))
113
- output3_file = os .path .join (data_dir , "{}_{}.jsonl" .format (prefix , "mr0.15_p0.5" ))
114
- output4_file = os .path .join (data_dir , "{}_{}_token_ids.jsonl" .format (prefix , "mr0.15_p0.5" ))
115
- output5_file = os .path .join (data_dir , "{}_{}.jsonl" .format (prefix , "mr0.15_p0.2" ))
116
- output6_file = os .path .join (data_dir , "{}_{}_token_ids.jsonl" .format (prefix , "mr0.15_p0.2" ))
117
-
118
- if not os .path .exists (output_file ):
119
- output_file = output_file .replace ("batch" , "" )
120
- if not os .path .exists (output2_file ):
121
- output2_file = output2_file .replace ("batch" , "" )
122
- if not os .path .exists (output3_file ):
123
- output3_file = output3_file .replace ("batch" , "" )
124
- if not os .path .exists (output4_file ):
125
- output4_file = output4_file .replace ("batch" , "" )
126
- if not os .path .exists (output5_file ):
127
- output5_file = output5_file .replace ("batch" , "" )
128
- if not os .path .exists (output6_file ):
129
- output6_file = output6_file .replace ("batch" , "" )
130
-
131
- '''
132
- output_file = os.path.join(data_dir, "{}_{}.jsonl".format(0, "mr0.4_p0.2"))
133
- output2_file = os.path.join(data_dir, "{}_{}.jsonl".format(0, "mr0.4_p0.2"))
134
- output3_file = os.path.join(data_dir, "{}_{}_token_ids.jsonl".format(0, "mr0.4_p0.2"))
135
- output4_file = os.path.join(data_dir, "{}_{}_inv_token_ids.jsonl".format(0, "mr0.4_p0.2"))
136
- output5_file = os.path.join(data_dir, "{}_{}_token_ids_ent.jsonl".format(0, "mr0.15_p0.2"))
137
- output6_file = os.path.join(data_dir, "{}_{}_inv_token_ids_ent.jsonl".format(0, "mr0.4_p0.2"))
138
- '''
139
-
140
- start_time = time .time ()
141
- np .random .seed (2022 )
142
- data1 , raw_text_to_position1 = load (output_file )
143
- data2 , raw_text_to_position2 = load (output2_file )
144
- data3 , raw_text_to_position3 = load (output3_file )
145
- data4 , raw_text_to_position4 = load (output4_file )
146
- data5 , raw_text_to_position5 = load (output5_file )
147
- data6 , raw_text_to_position6 = load (output6_file )
148
-
149
- is_same = []
150
-
151
- with open ("{}samples.html" .format ("wiki_" if args .wiki else "" ), "w" ) as f :
152
-
153
- for dp_idx in range (50 ):
154
- dp = data3 [dp_idx ]
155
- masked_texts = decode (dp ["masked_input_ids" ], dp ["merged_labels" ])
156
- raw_texts = dp ["contents" ]
157
-
158
- if np .all ([raw_text not in raw_text_to_position1 for raw_text in raw_texts ]):
159
- continue
160
-
161
- for masked_text3 , raw_text in zip (masked_texts , raw_texts ):
162
- if raw_text not in raw_text_to_position1 :
163
- continue
164
- if raw_text not in raw_text_to_position2 :
165
- continue
166
- if raw_text not in raw_text_to_position4 :
167
- continue
168
- if raw_text not in raw_text_to_position5 :
169
- continue
170
- if raw_text not in raw_text_to_position6 :
171
- continue
172
-
173
- p = raw_text_to_position1 [raw_text ]
174
- other_input_ids = data1 [p [0 ]]["masked_input_ids" ][p [1 ]]
175
- other_labels = data1 [p [0 ]]["merged_labels" ][p [1 ]]
176
- masked_text1 = decode ([other_input_ids ], [other_labels ])[0 ]
177
-
178
- p = raw_text_to_position2 [raw_text ]
179
- other_input_ids = data2 [p [0 ]]["masked_input_ids" ][p [1 ]]
180
- other_labels = data2 [p [0 ]]["merged_labels" ][p [1 ]]
181
- masked_text2 = decode ([other_input_ids ], [other_labels ])[0 ]
182
-
183
- p = raw_text_to_position4 [raw_text ]
184
- other_input_ids = data4 [p [0 ]]["masked_input_ids" ][p [1 ]]
185
- other_labels = data4 [p [0 ]]["merged_labels" ][p [1 ]]
186
- masked_text4 = decode ([other_input_ids ], [other_labels ])[0 ]
187
-
188
- p = raw_text_to_position5 [raw_text ]
189
- other_input_ids = data5 [p [0 ]]["masked_input_ids" ][p [1 ]]
190
- other_labels = data5 [p [0 ]]["merged_labels" ][p [1 ]]
191
- masked_text5 = decode ([other_input_ids ], [other_labels ])[0 ]
192
-
193
- p = raw_text_to_position6 [raw_text ]
194
- other_input_ids = data6 [p [0 ]]["masked_input_ids" ][p [1 ]]
195
- other_labels = data6 [p [0 ]]["merged_labels" ][p [1 ]]
196
- masked_text6 = decode ([other_input_ids ], [other_labels ])[0 ]
197
-
198
- is_same .append (masked_text3 == masked_text4 )
199
-
200
- '''
201
- f.write("<strong>w/ BM25 (inv_token_ids, ent):</strong> {}<br /><br />".format(masked_text6))
202
- f.write("<strong>w/ BM25 (inv_token_ids):</strong> {}<br /><br />".format(masked_text5))
203
- f.write("<strong>w/ BM25 (token_ids, ent):</strong> {}<br /><br />".format(masked_text4))
204
- f.write("<strong>w/ BM25 (token_ids):</strong> {}<br /><br />".format(masked_text3))
205
- f.write("<strong>w/ BM25:</strong> {}<br /><br />".format(masked_text2))
206
- f.write("<strong>w/o BM25:</strong> {}<br /><br />".format(masked_text1))
207
- '''
208
- f .write ("<strong>w/o BM25 (token_ids, 0.15, 0.2):</strong> {}<br /><br />" .format (masked_text6 ))
209
- f .write ("<strong>w/o BM25 (0.15, 0.2):</strong> {}<br /><br />" .format (masked_text5 ))
210
- f .write ("<strong>w/o BM25 (token_ids, 0.15, 0.5):</strong> {}<br /><br />" .format (masked_text4 ))
211
- f .write ("<strong>w/o BM25 (0.15, 0.5):</strong> {}<br /><br />" .format (masked_text3 ))
212
- f .write ("<strong>w/o BM25 (token_ids, 0.4, 0.2):</strong> {}<br /><br />" .format (masked_text2 ))
213
- f .write ("<strong>w/o BM25 (0.4, 0.2):</strong> {}<br /><br />" .format (masked_text1 ))
214
-
215
-
216
- f .write ("<hr />" )
217
-
218
- print (np .mean (is_same ))
219
-
220
57
if __name__ == '__main__' :
221
58
main ()
222
59
0 commit comments