Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Commit e4d35c5

Browse files
author
Sewon Min
committed
add missing files
1 parent 527c02c commit e4d35c5

17 files changed

+1367
-53
lines changed

.gitignore

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,18 @@
55
*vscode*
66
Makefile
77
*tmp*
8-
*.txt
98
*.html
109
*.out
1110
*.err
1211
*.log
1312
*.json
1413
*.npy
14+
my*
1515
task_data
1616
core
1717
data
1818
save
1919
corpus
20+
train_corpus
2021
deleted_files
21-
preprocess
22-
2322

README.md

+19-13
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ This repo contains the original implementation of the paper "[Nonparametric Mask
1616

1717
Models are available from Huggingface Hub:hugs:! Check out [**npm**](https://huggingface.co/facebook/npm) (for phrase retrieval) and [**npm-single**](https://huggingface.co/facebook/npm-single) (for token retrieval).
1818

19+
**We are working on a simple demo where you can simply download all the resources and deploy on your machine. Stay tuned!**
20+
1921
### Updates
2022
* **01/02/2023**: The code for training is released. See [train.md](train.md) for instructions.
2123
* **12/22/2022**: The code for inference is released. Stay tuned for the code for training.
@@ -85,8 +87,8 @@ python -m scripts.prompt \
8587

8688
```bash
8789
# To run on AGN, Yahoo and RTE:
88-
bash scripts/save_embeddings.sh npm enwiki-0 false 384
89-
bash scripts/save_embeddings.sh npm cc_news false 384
90+
bash scripts/save_embeddings.sh npm enwiki-0 false 320
91+
bash scripts/save_embeddings.sh npm cc_news false 320
9092
python -m scripts.prompt \
9193
--corpus_data enwiki-0+cc_news \
9294
--checkpoint_path npm \
@@ -95,7 +97,7 @@ python -m scripts.prompt \
9597
--save_dir save/npm
9698

9799
# To run on Subj:
98-
bash scripts/save_embeddings.sh npm subj false 384
100+
bash scripts/save_embeddings.sh npm subj false 320
99101
python -m scripts.prompt \
100102
--corpus_data subj \
101103
--checkpoint_path npm \
@@ -104,8 +106,8 @@ python -m scripts.prompt \
104106
--save_dir save/npm
105107

106108
# To run on SST-2, MR, RT, CR and Amazon:
107-
bash scripts/save_embeddings.sh npm imdb false 384
108-
bash scripts/save_embeddings.sh npm amazon false 384
109+
bash scripts/save_embeddings.sh npm imdb false 320
110+
bash scripts/save_embeddings.sh npm amazon false 320
109111
python -m scripts.prompt \
110112
--corpus_data imdb+amazon \
111113
--checkpoint_path npm \
@@ -114,14 +116,19 @@ python -m scripts.prompt \
114116
--save_dir save/npm
115117
```
116118

117-
Note that `scripts/save_embeddings.sh` takes 'model name', 'corpus name', 'whether it is an open-set task' and `batch size` (`384` is good for a 32gb GPU) as arguments. Embeddings are saved under `save/{model_name}/dstore`.
119+
Note that `scripts/save_embeddings.sh` takes
120+
- model name (npm or npm-single)
121+
- corpus name
122+
- whether it is an open-set task (true or false)
123+
- batch size (`320` is good for a 32gb GPU; if `trainer.precision=16` is used, `400` is good for a 32gb GPU)
124+
as arguments. Embeddings are saved under `save/{model_name}/dstore`.
118125

119126
#### NPM Single on closed-set tasks
120127

121128
```bash
122129
# To run on AGN, Yahoo and RTE:
123-
bash scripts/save_embeddings.sh npm-single enwiki-0 false 384
124-
bash scripts/save_embeddings.sh npm-single cc_news false 384
130+
bash scripts/save_embeddings.sh npm-single enwiki-0 false 320
131+
bash scripts/save_embeddings.sh npm-single cc_news false 320
125132
python -m scripts.prompt \
126133
--corpus_data enwiki-0+cc_news \
127134
--checkpoint_path npm-single \
@@ -130,9 +137,8 @@ python -m scripts.prompt \
130137
--single \
131138
--save_dir save/npm-single
132139

133-
134140
# To run on Subj:
135-
bash scripts/save_embeddings.sh npm-single subj false 384
141+
bash scripts/save_embeddings.sh npm-single subj false 320
136142
python -m scripts.prompt \
137143
--corpus_data subj \
138144
--checkpoint_path npm-single \
@@ -142,8 +148,8 @@ python -m scripts.prompt \
142148
--save_dir save/npm-single
143149

144150
# To run on SST-2, MR, RT, CR and Amazon:
145-
bash scripts/save_embeddings.sh npm-single imdb false 384
146-
bash scripts/save_embeddings.sh npm-single amazon false 384
151+
bash scripts/save_embeddings.sh npm-single imdb false 320
152+
bash scripts/save_embeddings.sh npm-single amazon false 320
147153
python -m scripts.prompt \
148154
--corpus_data imdb+amazon \
149155
--checkpoint_path npm-single \
@@ -175,7 +181,7 @@ Please note that running open-set tasks requires around 70GB of RAM and 1.4TB of
175181
```bash
176182
# Note that this can be executed in parallel with up to 20 GPUs. In total, it takes about 10 GPU hours and 1.4TB of disk memory.
177183
for i in {0..19} ; do
178-
bash scripts/save_embeddings.sh npm enwiki-${i} true 384
184+
bash scripts/save_embeddings.sh npm enwiki-${i} true 320
179185
done
180186

181187
# Loading the model takes about 40min, and 70GB of RAM (specify `--keep_uint8` to reduce RAM usage to 40GB which increases the model loading time to 60-80min).

config/roberta_stopwords.txt

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
4
2+
5
3+
6
4+
7
5+
8
6+
9
7+
10
8+
11
9+
12
10+
13
11+
14
12+
15
13+
16
14+
19
15+
21
16+
22
17+
23
18+
24
19+
25
20+
28
21+
30
22+
31
23+
32
24+
33
25+
34
26+
35
27+
36
28+
37
29+
39
30+
40
31+
41
32+
42
33+
43
34+
45
35+
47
36+
49
37+
50
38+
51
39+
52
40+
53
41+
54
42+
55
43+
56
44+
57
45+
58
46+
59
47+
61
48+
62
49+
63
50+
64
51+
66
52+
68
53+
69
54+
70
55+
71
56+
73
57+
77
58+
79
59+
81
60+
84
61+
87
62+
88
63+
89
64+
95
65+
97
66+
98
67+
99
68+
103
69+
106
70+
108
71+
109
72+
110
73+
111
74+
113
75+
114
76+
116
77+
122
78+
123
79+
127
80+
128
81+
129
82+
131
83+
136
84+
137
85+
141
86+
142
87+
143
88+
144
89+
145
90+
147
91+
148
92+
149
93+
150
94+
159
95+
160
96+
162
97+
167
98+
172
99+
182
100+
197
101+
207
102+
209
103+
215
104+
218
105+
222
106+
223
107+
227
108+
258
109+
259
110+
276
111+
308
112+
328
113+
349
114+
350
115+
351
116+
359
117+
367
118+
385
119+
399
120+
454
121+
456
122+
473
123+
475
124+
479
125+
519
126+
524
127+
579
128+
596
129+
608
130+
617
131+
630
132+
646
133+
683
134+
742
135+
769
136+
787
137+
849
138+
874
139+
938
140+
939
141+
947
142+
965
143+
1003
144+
1009
145+
1021
146+
1039
147+
1065
148+
1215
149+
1235
150+
1423
151+
1495
152+
1589
153+
1629
154+
1640
155+
1705
156+
1721
157+
1979
158+
2025
159+
2055
160+
2156
161+
2185
162+
2220
163+
2282
164+
2512
165+
2661
166+
2744
167+
2864
168+
3226
169+
3486
170+
3559
171+
4288
172+
4395
173+
4832
174+
4839
175+
5030
176+
5214
177+
5457
178+
5844
179+
7606
180+
8061
181+
9131
182+
10431
183+
10975
184+
12905
185+
14314
186+
14434
187+
15157
188+
15483
189+
15698
190+
17487
191+
18134
192+
18212
193+
19385
194+
20343
195+
22209
196+
23367
197+
24303
198+
25522
199+
25606
200+
27779
201+
27785
202+
28696
203+
31954
204+
34437
205+
35227
206+
35524
207+
37249
208+
37457
209+
41552
210+
44128
211+
45152

dpr_scale/task/mlm_task.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def setup(self, stage: str):
105105
state_dict = checkpoint_dict["state_dict"] if "state_dict" in checkpoint_dict else checkpoint_dict
106106
self.starting_global_step = checkpoint_dict["global_step"] if "global_step" in checkpoint_dict else 0
107107
self.load_state_dict(state_dict)
108-
109108
print(f"Loaded state dict from {self.pretrained_checkpoint_path}")
110109
else:
111110
self.starting_global_step = 0
@@ -532,17 +531,20 @@ def _get_contrastive_loss(scores, labels, score_mask=score_mask, label_mask=labe
532531

533532
class MaskedLanguageModelingEncodingTask(MaskedLanguageModelingTask):
534533

535-
def __init__(self, ctx_embeddings_dir, checkpoint_path=None, remove_stopwords=False, **kwargs):
534+
def __init__(self, ctx_embeddings_dir, checkpoint_path=None, use_half_precision=True,
535+
remove_stopwords=False, stopwords_dir=None, **kwargs):
536536
super().__init__(**kwargs)
537537
self.ctx_embeddings_dir = ctx_embeddings_dir
538538
self.checkpoint_path = checkpoint_path
539+
self.use_half_precision = use_half_precision
539540
pathlib.Path(ctx_embeddings_dir).mkdir(parents=True, exist_ok=True)
540541

541542
self.remove_stopwords = remove_stopwords
542543

543544
if self.remove_stopwords:
544545
stopwords = set()
545-
stopwords_dir = "/".join(self.checkpoint_path.split("/")[:-3]) + "/config"
546+
#assert stopwords_dir is not None
547+
stopwords_dir = "/private/home/sewonmin/clean-token-retrieval/config"
546548
with open(os.path.join(stopwords_dir, "roberta_stopwords.txt")) as f:
547549
for line in f:
548550
stopwords.add(int(line.strip()))
@@ -572,7 +574,7 @@ def test_step(self, batch, batch_idx):
572574

573575
def test_epoch_end(self, outputs):
574576
assert self.global_rank==0
575-
use_half_precision = True
577+
use_half_precision = self.use_half_precision
576578

577579
if not self.ctx_embeddings_dir:
578580
self.ctx_embeddings_dir = self.trainer.weights_save_path
@@ -597,7 +599,6 @@ def _filter(curr_input_ids, curr_attention_mask, curr_is_valid, curr_hidden_stat
597599
for i, hidden_states in enumerate(curr_hidden_states):
598600
if not curr_is_valid[i]:
599601
continue
600-
# if the current word is stopword, we don't have to save it
601602
if self.remove_stopwords and curr_input_ids[i] in self.stopwords:
602603
continue
603604
vec.append(hidden_states)

0 commit comments

Comments
 (0)