Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Commit 91dec78

Browse files
author
Sewon Min
committed
delete hard-coded paths
1 parent e4d35c5 commit 91dec78

File tree

4 files changed

+2
-324
lines changed

4 files changed

+2
-324
lines changed

dpr_scale/task/mlm_task.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,7 @@ def __init__(self, ctx_embeddings_dir, checkpoint_path=None, use_half_precision=
543543

544544
if self.remove_stopwords:
545545
stopwords = set()
546-
#assert stopwords_dir is not None
547-
stopwords_dir = "/private/home/sewonmin/clean-token-retrieval/config"
546+
assert stopwords_dir is not None
548547
with open(os.path.join(stopwords_dir, "roberta_stopwords.txt")) as f:
549548
for line in f:
550549
stopwords.add(int(line.strip()))

preprocess/mask_spans.py

-163
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,11 @@ def main():
2020
parser.add_argument("--data_dir", type=str, default="train_corpus")
2121
parser.add_argument("--mr", type=float, default=0.15)
2222
parser.add_argument("--p", type=float, default=0.5)
23-
2423
parser.add_argument("--batch_size", type=int, default=16)
25-
parser.add_argument("--analysis", action="store_true")
2624
parser.add_argument("--num_shards", type=int, default=10)
2725

2826
args = parser.parse_args()
2927

30-
if args.analysis:
31-
analysis(args)
32-
return
33-
3428
ext = "_mr{}_p{}.jsonl".format(args.mr, args.p)
3529

3630
def find_files(out_dir):
@@ -60,163 +54,6 @@ def find_files(out_dir):
6054
tot += 1
6155

6256

63-
def analysis(args):
64-
import json
65-
from transformers import RobertaTokenizer
66-
tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
67-
mask_id = tokenizer.mask_token_id
68-
69-
def load(fn):
70-
print ("Starting loading", fn)
71-
data = []
72-
raw_text_to_position = {}
73-
with open(fn, "r") as f:
74-
for line in f:
75-
dp = json.loads(line)
76-
for i, raw_text in enumerate(dp["contents"]):
77-
raw_text_to_position[raw_text] = (len(data), i)
78-
data.append(dp)
79-
if len(data)==3000:
80-
break
81-
return data, raw_text_to_position
82-
83-
def backgrounded(text, color):
84-
return "<span style='background-color: {}'>{}</span>".format(color, text)
85-
86-
def decode(masked_input_ids_list, merged_labels):
87-
decoded_list = []
88-
colors = ["#FAF884", "#E2FAB5"]
89-
for i, (labels, masked_input_ids) in enumerate(zip(merged_labels, masked_input_ids_list)):
90-
while masked_input_ids[-1]==0:
91-
masked_input_ids = masked_input_ids[:-1]
92-
decoded = tokenizer.decode(masked_input_ids)
93-
color_idx = 0
94-
for label in labels:
95-
assert "<mask>"*len(label) in decoded
96-
decoded = decoded.replace("<mask>"*len(label),
97-
backgrounded(tokenizer.decode(label), colors[color_idx]),
98-
1)
99-
color_idx = 1-color_idx
100-
assert "<mask>" not in decoded
101-
decoded_list.append(decoded.replace("<s>", "").replace("</s>", ""))
102-
return decoded_list
103-
104-
if args.wiki:
105-
data_dir = "/private/home/sewonmin/data/enwiki/enwiki_roberta_tokenized"
106-
prefix = "enwiki0_grouped"
107-
else:
108-
data_dir = "/private/home/sewonmin/data/cc_news_en/cc_news_roberta_tokenized"
109-
prefix = "batch0" #_grouped_v4"
110-
111-
output_file = os.path.join(data_dir, "{}_{}.jsonl".format(prefix, "mr0.4_p0.2"))
112-
output2_file = os.path.join(data_dir, "{}_{}_token_ids.jsonl".format(prefix, "mr0.4_p0.2"))
113-
output3_file = os.path.join(data_dir, "{}_{}.jsonl".format(prefix, "mr0.15_p0.5"))
114-
output4_file = os.path.join(data_dir, "{}_{}_token_ids.jsonl".format(prefix, "mr0.15_p0.5"))
115-
output5_file = os.path.join(data_dir, "{}_{}.jsonl".format(prefix, "mr0.15_p0.2"))
116-
output6_file = os.path.join(data_dir, "{}_{}_token_ids.jsonl".format(prefix, "mr0.15_p0.2"))
117-
118-
if not os.path.exists(output_file):
119-
output_file = output_file.replace("batch", "")
120-
if not os.path.exists(output2_file):
121-
output2_file = output2_file.replace("batch", "")
122-
if not os.path.exists(output3_file):
123-
output3_file = output3_file.replace("batch", "")
124-
if not os.path.exists(output4_file):
125-
output4_file = output4_file.replace("batch", "")
126-
if not os.path.exists(output5_file):
127-
output5_file = output5_file.replace("batch", "")
128-
if not os.path.exists(output6_file):
129-
output6_file = output6_file.replace("batch", "")
130-
131-
'''
132-
output_file = os.path.join(data_dir, "{}_{}.jsonl".format(0, "mr0.4_p0.2"))
133-
output2_file = os.path.join(data_dir, "{}_{}.jsonl".format(0, "mr0.4_p0.2"))
134-
output3_file = os.path.join(data_dir, "{}_{}_token_ids.jsonl".format(0, "mr0.4_p0.2"))
135-
output4_file = os.path.join(data_dir, "{}_{}_inv_token_ids.jsonl".format(0, "mr0.4_p0.2"))
136-
output5_file = os.path.join(data_dir, "{}_{}_token_ids_ent.jsonl".format(0, "mr0.15_p0.2"))
137-
output6_file = os.path.join(data_dir, "{}_{}_inv_token_ids_ent.jsonl".format(0, "mr0.4_p0.2"))
138-
'''
139-
140-
start_time = time.time()
141-
np.random.seed(2022)
142-
data1, raw_text_to_position1 = load(output_file)
143-
data2, raw_text_to_position2 = load(output2_file)
144-
data3, raw_text_to_position3 = load(output3_file)
145-
data4, raw_text_to_position4 = load(output4_file)
146-
data5, raw_text_to_position5 = load(output5_file)
147-
data6, raw_text_to_position6 = load(output6_file)
148-
149-
is_same = []
150-
151-
with open("{}samples.html".format("wiki_" if args.wiki else ""), "w") as f:
152-
153-
for dp_idx in range(50):
154-
dp = data3[dp_idx]
155-
masked_texts = decode(dp["masked_input_ids"], dp["merged_labels"])
156-
raw_texts = dp["contents"]
157-
158-
if np.all([raw_text not in raw_text_to_position1 for raw_text in raw_texts]):
159-
continue
160-
161-
for masked_text3, raw_text in zip(masked_texts, raw_texts):
162-
if raw_text not in raw_text_to_position1:
163-
continue
164-
if raw_text not in raw_text_to_position2:
165-
continue
166-
if raw_text not in raw_text_to_position4:
167-
continue
168-
if raw_text not in raw_text_to_position5:
169-
continue
170-
if raw_text not in raw_text_to_position6:
171-
continue
172-
173-
p = raw_text_to_position1[raw_text]
174-
other_input_ids = data1[p[0]]["masked_input_ids"][p[1]]
175-
other_labels = data1[p[0]]["merged_labels"][p[1]]
176-
masked_text1 = decode([other_input_ids], [other_labels])[0]
177-
178-
p = raw_text_to_position2[raw_text]
179-
other_input_ids = data2[p[0]]["masked_input_ids"][p[1]]
180-
other_labels = data2[p[0]]["merged_labels"][p[1]]
181-
masked_text2 = decode([other_input_ids], [other_labels])[0]
182-
183-
p = raw_text_to_position4[raw_text]
184-
other_input_ids = data4[p[0]]["masked_input_ids"][p[1]]
185-
other_labels = data4[p[0]]["merged_labels"][p[1]]
186-
masked_text4 = decode([other_input_ids], [other_labels])[0]
187-
188-
p = raw_text_to_position5[raw_text]
189-
other_input_ids = data5[p[0]]["masked_input_ids"][p[1]]
190-
other_labels = data5[p[0]]["merged_labels"][p[1]]
191-
masked_text5 = decode([other_input_ids], [other_labels])[0]
192-
193-
p = raw_text_to_position6[raw_text]
194-
other_input_ids = data6[p[0]]["masked_input_ids"][p[1]]
195-
other_labels = data6[p[0]]["merged_labels"][p[1]]
196-
masked_text6 = decode([other_input_ids], [other_labels])[0]
197-
198-
is_same.append(masked_text3==masked_text4)
199-
200-
'''
201-
f.write("<strong>w/ BM25 (inv_token_ids, ent):</strong> {}<br /><br />".format(masked_text6))
202-
f.write("<strong>w/ BM25 (inv_token_ids):</strong> {}<br /><br />".format(masked_text5))
203-
f.write("<strong>w/ BM25 (token_ids, ent):</strong> {}<br /><br />".format(masked_text4))
204-
f.write("<strong>w/ BM25 (token_ids):</strong> {}<br /><br />".format(masked_text3))
205-
f.write("<strong>w/ BM25:</strong> {}<br /><br />".format(masked_text2))
206-
f.write("<strong>w/o BM25:</strong> {}<br /><br />".format(masked_text1))
207-
'''
208-
f.write("<strong>w/o BM25 (token_ids, 0.15, 0.2):</strong> {}<br /><br />".format(masked_text6))
209-
f.write("<strong>w/o BM25 (0.15, 0.2):</strong> {}<br /><br />".format(masked_text5))
210-
f.write("<strong>w/o BM25 (token_ids, 0.15, 0.5):</strong> {}<br /><br />".format(masked_text4))
211-
f.write("<strong>w/o BM25 (0.15, 0.5):</strong> {}<br /><br />".format(masked_text3))
212-
f.write("<strong>w/o BM25 (token_ids, 0.4, 0.2):</strong> {}<br /><br />".format(masked_text2))
213-
f.write("<strong>w/o BM25 (0.4, 0.2):</strong> {}<br /><br />".format(masked_text1))
214-
215-
216-
f.write("<hr />")
217-
218-
print (np.mean(is_same))
219-
22057
if __name__=='__main__':
22158
main()
22259

scripts/clm_prompt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def load_inputs(args):
101101

102102
if args.ret:
103103
from npm.searcher import BM25Searcher
104-
base_dir = "/checkpoint/sewonmin/npm_checkpoints/data"
104+
base_dir = "corpus"
105105
name = "new-enwiki" if ret=="bm25_2022" else "enwiki"
106106
data_dir = os.path.join(base_dir, name)
107107
index_dir = os.path.join(base_dir, name + "-index")

task/utils_entity_translation.py

-158
This file was deleted.

0 commit comments

Comments
 (0)