-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathprocess.py
55 lines (43 loc) · 2.06 KB
/
process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import re
from typing import List, Union
from dragon_baseline import DragonBaseline
class DragonSubmission(DragonBaseline):
def __init__(self, **kwargs):
# Example of how to adapt the DRAGON baseline to use a different model
"""
Adapt the DRAGON baseline to use the joeranbosma/dragon-roberta-base-mixed-domain model.
Note: when changing the model, update the Dockerfile to pre-download that model.
"""
super().__init__(**kwargs)
self.model_name = "joeranbosma/dragon-roberta-base-mixed-domain"
self.per_device_train_batch_size = 4
self.gradient_accumulation_steps = 2
self.gradient_checkpointing = False
self.max_seq_length = 512
self.learning_rate = 1e-05
self.num_train_epochs = 5
def custom_text_cleaning(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
"""
Perform custom text cleaning on the input text.
Args:
text (Union[str, List[str]]): The input text to be cleaned. It can be a string or a list of strings.
Returns:
Union[str, List[str]]: The cleaned text. If the input is a string, the cleaned string is returned.
If the input is a list of strings, a list of cleaned strings is returned.
"""
if isinstance(text, str):
# Remove HTML tags and URLs:
text = re.sub(r"<.*?>", "", text)
text = re.sub(r"http\S+", "", text)
return text
else:
# If text is a list, apply the function to each element
return [self.custom_text_cleaning(t) for t in text]
def preprocess(self):
# Example of how to adapt the DRAGON baseline to use a different preprocessing function
super().preprocess()
# Uncomment the following lines to use the custom_text_cleaning function
# for df in [self.df_train, self.df_val, self.df_test]:
# df[self.task.input_name] = df[self.task.input_name].map(self.custom_text_cleaning)
if __name__ == "__main__":
DragonSubmission().process()