Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.

Commit 4680235

Browse files
author
Sewon Min
committed
minor fixes
1 parent 1be566a commit 4680235

File tree

9 files changed

+56
-101
lines changed

9 files changed

+56
-101
lines changed

dpr_scale/models/hf_encoder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dpr_scale.utils.utils import PathManager
1111

1212
# @manual=//python/wheel/transformers3:transformers3
13-
from transformers import AutoModelForMaskedLM, AutoConfig
13+
from transformers import RobertaForMaskedLM, AutoModelForMaskedLM, AutoConfig
1414

1515
class Encoder(nn.Module):
1616
def __init__(
@@ -40,7 +40,7 @@ def __init__(
4040
self.transformer = AutoModelForMaskedLM.from_pretrained(local_model_path, config=cfg)
4141
print ("Initializing from", local_model_path)
4242
else:
43-
self.transformer = AutoModelForMaskedLM.from_pretrained(config=cfg)
43+
self.transformer = RobertaForMaskedLM(config=cfg)
4444

4545
def forward(self, tokens):
4646
return self.transformer(**tokens, return_dict=True)

npm/dstore.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self,
5151
probe=8,
5252
num_keys_to_add_at_a_time=1000000,
5353
remove_stopwords=False,
54+
remove_stopwords_except_k=None,
5455
restricted=None,
5556
consider_string_boundary=True,
5657
cuda=True,
@@ -64,7 +65,7 @@ def __init__(self,
6465
if model_dir is not None:
6566
model_dir = os.path.join(model_dir, setting)
6667
elif setting in ["enwiki", "enwiki-2022"]:
67-
assert remove_stopwords, remove_stopwords
68+
assert remove_stopwords or remove_stopwords_except_k
6869
data_path=[os.path.join(base_dir, setting, "{}.npy".format(idx)) for idx in range(20)]
6970
if model_dir is not None:
7071
model_dir=[os.path.join(model_dir, "{}-{}".format(setting, idx)) for idx in range(20)]
@@ -76,13 +77,16 @@ def __init__(self,
7677
else:
7778
raise NotImplementedError(setting)
7879

80+
assert not (remove_stopwords and remove_stopwords_except_k)
81+
7982
self.setting = setting
8083
self.dimension = dimension
8184
self.ncentroids = ncentroids
8285
self.code_size = code_size
8386
self.probe = probe
8487
self.num_keys_to_add_at_a_time = num_keys_to_add_at_a_time
8588
self.remove_stopwords = remove_stopwords
89+
self.remove_stopwords_except_k = remove_stopwords_except_k
8690
self.restricted = restricted
8791
self.consider_string_boundary = consider_string_boundary
8892
self.cuda = cuda
@@ -127,7 +131,7 @@ def __init__(self,
127131
self.load_index(model_dir)
128132

129133
def load_stopwords(self):
130-
if self.remove_stopwords:
134+
if self.remove_stopwords or self.remove_stopwords_except_k:
131135
stopwords = set()
132136
stopwords_dir = "/private/home/sewonmin/token-retrieval/task_data"
133137
with open(os.path.join(stopwords_dir, "roberta_stopwords.txt")) as f:
@@ -178,6 +182,9 @@ def load_data(self, data_path):
178182
true_dstore_size = 0
179183
offset_block = 0 if self.input_ids is None else len(self.input_ids)
180184

185+
remove_stopwords = self.remove_stopwords or (
186+
self.remove_stopwords_except_k is not None and data_path_idx >= self.remove_stopwords_except_k)
187+
181188
for block_idx, (valid_start, valid_end) in enumerate(tqdm(valid_candidates)):
182189
start = start_end_pairs[block_idx]
183190
end = start_end_pairs[block_idx+1] if block_idx<len(start_end_pairs)-1 else len(input_ids)
@@ -189,11 +196,11 @@ def load_data(self, data_path):
189196
curr_dstore_size = 0
190197

191198
for i, curr_token in enumerate(curr_input_ids):
192-
if self.remove_stopwords and curr_token in stopwords:
199+
if remove_stopwords and curr_token in stopwords:
193200
continue
194201
if self.embs_consider_boundary and i not in valid_idxs:
195202
continue
196-
elif curr_token in [0, 2]:
203+
if curr_token in [0, 2]:
197204
continue
198205
if is_valid:
199206
self.token_idx_to_block_idx.append(len(self.input_ids))
@@ -234,16 +241,19 @@ def load_data(self, data_path):
234241
self.true_dstore_size = np.sum(true_dstore_size_list)
235242

236243
def load_embeds(self, model_dir):
237-
postfix = "_wo_stopwords" if self.remove_stopwords else ""
238244
if type(model_dir)==list:
239245
self.embs = []
240-
for _model_dir, dstore_size in zip(model_dir, self.dstore_size_list):
246+
for shard_idx, (_model_dir, dstore_size) in enumerate(zip(model_dir, self.dstore_size_list)):
247+
remove_stopwords = self.remove_stopwords or (
248+
self.remove_stopwords_except_k is not None and shard_idx >= self.remove_stopwords_except_k)
249+
postfix = "_wo_stopwords" if remove_stopwords else ""
241250
embed_path = os.path.join(_model_dir,
242251
"embeddings{}.float16.npy".format(postfix))
243252
print ("Start loading the embed from %s with (%d, %d)..." % (embed_path.split("/")[-2], dstore_size, self.dimension))
244253
curr_emb = load_embs(embed_path, dstore_size, self.dimension)
245254
self.embs.append(curr_emb)
246255
else:
256+
postfix = "_wo_stopwords" if self.remove_stopwords else ""
247257
embed_path = os.path.join(model_dir, "embeddings{}.float16.npy".format(postfix))
248258
print ("Start loading the embed with (%d, %d)..." % (self.dstore_size, self.dimension))
249259
self.embs = load_embs(embed_path, self.dstore_size, self.dimension)

npm/npm.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,8 @@ def get_scores(start_indices, end_indices):
209209

210210
# now, assign scores to possible ngrams
211211
for (start, end) in all_start_and_end:
212-
try:
213-
assert start in idx2start_score
214-
assert end in idx2end_score
215-
except Exception:
216-
from IPython import embed; embed(); exit()
212+
assert start in idx2start_score
213+
assert end in idx2end_score
217214
score = idx2start_score[start] + idx2end_score[end]
218215

219216
pos2score[(start, end)] = score

npm/utils.py

-71
This file was deleted.

scripts/demo.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,19 @@
1818

1919
class NPMDemo(object):
2020

21-
def __init__(self, save_dir, checkpoint_path, k, temperature, remove_stopwords, single,
21+
def __init__(self, save_dir, checkpoint_path, k, temperature,
22+
remove_stopwords, remove_stopwords_except_k, single, restricted,
2223
embs_consider_boundary, keep_uint8):
2324
start_time = time.time()
2425
dstore = DataStore(setting="enwiki",
25-
model_dir=os.path.join(save_dir, "dstore"),
26-
do_load_index=False,
27-
remove_stopwords=remove_stopwords,
28-
restricted=True,
26+
model_dir=os.path.join(save_dir, "dstore"),
27+
do_load_index=False,
28+
remove_stopwords=remove_stopwords,
29+
remove_stopwords_except_k=remove_stopwords_except_k,
30+
restricted=restricted,
2931
embs_consider_boundary=embs_consider_boundary,
3032
keep_uint8=keep_uint8
31-
)
33+
)
3234
model_class = SingleModel if single else Model
3335
model = model_class(checkpoint_path=checkpoint_path)
3436
print ("Finish loading the model (%dsec)" % (time.time()-start_time))
@@ -49,7 +51,7 @@ def predict(self, text):
4951
predicted = self.npm.predict_span(text,
5052
ngram_max=10,
5153
valid_func=self.valid_func,
52-
alphas=[0.0])[0.0]
54+
alphas=[0.0])["a=0.0"]
5355
return self.npm.decode(predicted)
5456

5557
def generate(self, text, num_tokens=20, num_masked_tokens=20):
@@ -59,7 +61,7 @@ def generate(self, text, num_tokens=20, num_masked_tokens=20):
5961
predicted = self.npm.predict_span(input_text,
6062
ngram_max=10,
6163
valid_func=self.valid_func,
62-
alphas=[0.0])[0.0]
64+
alphas=[0.0])["a=0.0"]
6365
predicted = self.npm.decode(predicted)
6466
text += predicted
6567
return text
@@ -71,7 +73,9 @@ def main():
7173
parser.add_argument('--k', type=int, default=4096)
7274
parser.add_argument('--temperature', type=float, default=1.0)
7375
parser.add_argument("--remove_stopwords", action="store_true")
76+
parser.add_argument("--remove_stopwords_except_k", type=int, default=None)
7477
parser.add_argument("--single", action="store_true")
78+
parser.add_argument("--restricted", action="store_true")
7579

7680
parser.add_argument("--embs_consider_boundary", action="store_true", default=True)
7781
parser.add_argument("--keep_uint8", action="store_true")
@@ -82,7 +86,9 @@ def main():
8286
k=args.k,
8387
temperature=args.temperature,
8488
remove_stopwords=args.remove_stopwords,
89+
remove_stopwords_except_k=args.remove_stopwords_except_k,
8590
single=args.single,
91+
restricted=args.restricted,
8692
embs_consider_boundary=args.embs_consider_boundary,
8793
keep_uint8=args.keep_uint8)
8894

@@ -92,9 +98,10 @@ def predict(text):
9298
print ("(Took %.2fs)" % (time.time()-start_time))
9399
return predicted
94100

95-
input_text = "Hagios Demetrios is located in <mask>."
101+
input_text = "Hagios Demetrios is located in"
96102
print (predict(input_text))
97-
from IPython import embed; embed()
103+
104+
print (npm.generate(input_text))
98105

99106

100107
if __name__=='__main__':

scripts/prompt.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def main():
2626
parser.add_argument('--temperature', type=float, default=1.0)
2727
parser.add_argument('--n_samples', type=int, default=3000)
2828
parser.add_argument("--remove_stopwords", action="store_true")
29+
parser.add_argument("--remove_stopwords_except_k", type=int, default=None)
2930

3031
parser.add_argument("--single", action="store_true")
3132
parser.add_argument("--open", action="store_true")
@@ -56,6 +57,7 @@ def main():
5657
model_dir=os.path.join(args.save_dir, "dstore"),
5758
do_load_index=not args.restricted,
5859
remove_stopwords=args.remove_stopwords,
60+
remove_stopwords_except_k=args.remove_stopwords_except_k,
5961
restricted=(True if args.load_all_embs else tasks) if args.restricted else None,
6062
embs_consider_boundary=args.embs_consider_boundary,
6163
keep_uint8=args.keep_uint8

scripts/save_embeddings.sh

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ corpus=$2
1010
open=$3
1111
bs=$4
1212

13-
checkpoint_path=$(pwd)/save/${model_name}/model.ckpt
14-
ctx_embeddings_dir=$(pwd)/save/${model_name}/dstore/${corpus}
13+
14+
out=$(pwd)/save/${model_name}
15+
ctx_embeddings_dir=${out}/dstore/${corpus}
1516

1617
if [[ $open == "true" ]] ; then
1718
if [[ $corpus == "enwiki-"* ]] ; then

scripts/train.sh

+1-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ HYDRA_FULL_ERROR=1 PYTHONPATH=. python dpr_scale/main.py -m \
7171
trainer.gradient_clip_val=${clip} \
7272
trainer=slurm \
7373
hydra.launcher.name=${SAVE_DIR} \
74-
hydra.sweep.dir=${SAVE_DIR} \
75-
hydra.launcher.partition=devlab
74+
hydra.sweep.dir=${SAVE_DIR}
7675

7776

7877

train.md

+13-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ This is a guideline for training the NPM model. The training code is largely bas
99
* [Span Masking](#span-masking)
1010
* [Uniform Masking](#uniform-masking)
1111
2. [Training](#training)
12-
3. [Debugging locally](#debugging-locally): see this if you want to do a test run before running the entire pipeline.
13-
12+
* [Debugging locally](#debugging-locally): see this if you want to do a test run before running the entire pipeline.
13+
3. [Evaluation](#evaluation)
1414

1515
## Prepare Training Data
1616

@@ -94,11 +94,21 @@ To train NPM-single with uniform masking, run
9494
bash scripts/train.sh {save_dir} false 3e-05 16 0.15 uniform
9595
```
9696

97-
## Debugging Locally
97+
### Debugging Locally
9898
If you want a training run on a subset of datas with one local GPU (instead of using slurm and hydra), simply run `scripts/train_debug.sh` instead of `scripts/train.sh` with the same arguments as in the [Training section](#training).
9999

100100
This use RoBERTA-base instead of RoBERTa-large, and can work with >=9GB GPU memory.
101101

102102
Note: This only uses the first shard of English Wikipedia (no CC-News), so if you have not started preprocessing and want to do a test run first, you can preprocess English Wikipedia only and keep CC-News later.
103103

104+
## Evaluation
105+
Evaluation can be done by following the guidelines for inference in the main [README](README.md).
106+
107+
* Checkpoints are saved every 10,000 training steps. You can find them under `{save_dir}/{hyperparam_settings}/0/lightning_logs/version_{slurm_id}/checkpoints`.
108+
* When saving embeddings, specify `+task.checkpoint_path=${checkpoint_path}`
109+
* When running `python -m scripts.prompt`, specify `--checkpoint_path ${checkpoint_path}`
110+
111+
112+
113+
104114

0 commit comments

Comments
 (0)