forked from benlevyx/florabert
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathpretrain.py
78 lines (62 loc) · 2.1 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
Pretraining on masked language model task.
"""
import sys
sys.path.append('/kaggle/working/florabert')
import torch
import os
from module.florabert import config, utils, training, dataio
from module.florabert import transformers as tr
DATA_DIR = config.data_final / "transformer" / "seq"
TOKENIZER_DIR = config.models / "byte-level-bpe-tokenizer"
OUTPUT_DIR = config.models / "transformer" / "language-model"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
def main():
args = utils.get_args(
data_dir=DATA_DIR,
train_data="all_seqs_train.txt",
test_data="all_seqs_test.txt",
output_dir=OUTPUT_DIR,
model_name="roberta-lm",
pretrained_model = OUTPUT_DIR
)
# args.warmstart = True
print(args)
settings = utils.get_model_settings(config.settings, args, args.model_name)
config_obj, tokenizer, model = tr.load_model(
args.model_name,
TOKENIZER_DIR,
pretrained_model=args.pretrained_model,
**settings,
)
num_params = utils.count_model_parameters(model, trainable_only=True)
print(f"Loaded {args.model_name} model with {num_params:,} trainable parameters")
datasets = dataio.load_datasets(
tokenizer,
args.train_data,
test_data=args.test_data,
file_type="text",
seq_key="text",
)
dataset_train = datasets["train"]
dataset_test = datasets["test"]
print(f"Loaded training data with {len(dataset_train):,} examples")
data_collator = dataio.load_data_collator(
"language-model",
tokenizer=tokenizer,
)
training_settings = config.settings["training"]["pretrain"]
trainer = training.make_trainer(
model,
data_collator,
dataset_train,
dataset_test,
args.output_dir,
**training_settings,
)
print(f"Starting training on {torch.cuda.device_count()} GPUs" if "COLAB_TPU_ADDR" not in os.environ else "Starting TPU training")
training.do_training(trainer, args, args.output_dir)
print("Saving model")
trainer.save_model(str(args.output_dir))
if __name__ == "__main__":
main()