25
25
{'name' : 'copy_tasks_en' , 'params' : {'dataset_names' : 'darumeru/cp_sent_en darumeru/cp_para_en' , 'allow_vllm' : False }}
26
26
]
27
27
28
+ '''
29
+ task_groups_zero_shot = [
30
+ {'name': 'darumeru_rest', 'params': {'dataset_names': 'darumeru/parus darumeru/rcb darumeru/ruopenbookqa darumeru/ruworldtree darumeru/rwsd russiannlp/rucola_custom', 'allow_vllm': False}},
31
+ {'name': 'darumeru_mmlu_ru_extractive', 'params': {'dataset_names': 'darumeru/rummlu daru/treewayextractive', 'allow_vllm': False}},
32
+ {'name': 'nlpcoreteam_mmlu', 'params': {'dataset_names': 'nlpcoreteam/rummlu nlpcoreteam/enmmlu', 'allow_vllm': False}},
33
+ {'name': 'treewayabstractive', 'params': {'dataset_names': 'daru/treewayabstractive', 'allow_vllm': False, 'max_sample_per_dataset': 500}},
34
+ {'name': 'darumeru_multiq_use', 'params': {'dataset_names': 'darumeru/multiq darumeru/use', 'allow_vllm': False}},
35
+ {'name': 'copy_tasks', 'params': {'dataset_names': 'darumeru/cp_sent_ru darumeru/cp_para_ru darumeru/cp_sent_en darumeru/cp_para_en', 'allow_vllm': False}}
36
+ ]
37
+ '''
38
+
39
+ task_groups_zero_shot = [
40
+ {'name' : 'darumeru_rest' , 'params' : {'dataset_names' : 'darumeru/multiq darumeru/use darumeru/parus darumeru/rcb darumeru/ruopenbookqa darumeru/ruworldtree darumeru/rwsd russiannlp/rucola_custom' , 'allow_vllm' : False }},
41
+ {'name' : 'darumeru_mmlu_ru_extractive' , 'params' : {'dataset_names' : 'darumeru/rummlu daru/treewayextractive' , 'allow_vllm' : False }},
42
+ {'name' : 'nlpcoreteam_mmlu' , 'params' : {'dataset_names' : 'nlpcoreteam/rummlu nlpcoreteam/enmmlu' , 'allow_vllm' : False }},
43
+ {'name' : 'treewayabstractive' , 'params' : {'dataset_names' : 'daru/treewayabstractive' , 'allow_vllm' : False , 'max_sample_per_dataset' : 500 }},
44
+ {'name' : 'copy_tasks' , 'params' : {'dataset_names' : 'darumeru/cp_sent_ru darumeru/cp_para_ru darumeru/cp_sent_en darumeru/cp_para_en' , 'allow_vllm' : False }}
45
+ ]
46
+
47
+ task_groups_zero_shot = [
48
+ {'name' : 'multiq' , 'params' : {'dataset_names' : 'darumeru/multiq' , 'allow_vllm' : False }},
49
+ {'name' : 'use' , 'params' : {'dataset_names' : 'darumeru/use' , 'allow_vllm' : False }},
50
+ {'name' : 'parus' , 'params' : {'dataset_names' : 'darumeru/parus' , 'allow_vllm' : False }},
51
+ {'name' : 'rcb' , 'params' : {'dataset_names' : 'darumeru/rcb' , 'allow_vllm' : False }},
52
+ {'name' : 'ruopenbookqa' , 'params' : {'dataset_names' : 'darumeru/ruopenbookqa' , 'allow_vllm' : False }},
53
+ {'name' : 'ruworldtree' , 'params' : {'dataset_names' : 'darumeru/ruworldtree' , 'allow_vllm' : False }},
54
+ {'name' : 'rwsd' , 'params' : {'dataset_names' : 'darumeru/rwsd' , 'allow_vllm' : False }},
55
+ {'name' : 'rucola_custom' , 'params' : {'dataset_names' : 'russiannlp/rucola_custom' , 'allow_vllm' : False }},
56
+ {'name' : 'darumeru_rummlu' , 'params' : {'dataset_names' : 'darumeru/rummlu' , 'allow_vllm' : False }},
57
+ {'name' : 'daru_extractive' , 'params' : {'dataset_names' : 'daru/treewayextractive' , 'allow_vllm' : False }},
58
+ {'name' : 'nlpcoreteam_rummlu' , 'params' : {'dataset_names' : 'nlpcoreteam/rummlu' , 'allow_vllm' : False }},
59
+ {'name' : 'nlpcoreteam_enmmlu' , 'params' : {'dataset_names' : 'nlpcoreteam/enmmlu' , 'allow_vllm' : False }},
60
+ {'name' : 'treewayabstractive' , 'params' : {'dataset_names' : 'daru/treewayabstractive' , 'allow_vllm' : False , 'max_sample_per_dataset' : 500 }},
61
+ {'name' : 'cp_sent_ru' , 'params' : {'dataset_names' : 'darumeru/cp_sent_ru' , 'allow_vllm' : False }},
62
+ {'name' : 'cp_para_ru' , 'params' : {'dataset_names' : 'darumeru/cp_para_ru' , 'allow_vllm' : False }},
63
+ {'name' : 'cp_sent_en' , 'params' : {'dataset_names' : 'darumeru/cp_sent_en' , 'allow_vllm' : False }},
64
+ {'name' : 'cp_para_en' , 'params' : {'dataset_names' : 'darumeru/cp_para_en' , 'allow_vllm' : False }}
65
+ ]
66
+
67
+ task_groups_zero_shot_short_ver = [
68
+ {'name' : 'multiq' , 'params' : {'dataset_names' : 'darumeru/multiq' , 'allow_vllm' : False }},
69
+ {'name' : 'parus' , 'params' : {'dataset_names' : 'darumeru/parus' , 'allow_vllm' : False }},
70
+ {'name' : 'rcb' , 'params' : {'dataset_names' : 'darumeru/rcb' , 'allow_vllm' : False }},
71
+ {'name' : 'ruopenbookqa' , 'params' : {'dataset_names' : 'darumeru/ruopenbookqa' , 'allow_vllm' : False }},
72
+ {'name' : 'ruworldtree' , 'params' : {'dataset_names' : 'darumeru/ruworldtree' , 'allow_vllm' : False }},
73
+ {'name' : 'rwsd' , 'params' : {'dataset_names' : 'darumeru/rwsd' , 'allow_vllm' : False }},
74
+ {'name' : 'daru_extractive' , 'params' : {'dataset_names' : 'daru/treewayextractive' , 'allow_vllm' : False , 'max_sample_per_dataset' : 1000 }},
75
+ {'name' : 'nlpcoreteam_rummlu' , 'params' : {'dataset_names' : 'nlpcoreteam/rummlu' , 'allow_vllm' : False }},
76
+ {'name' : 'nlpcoreteam_enmmlu' , 'params' : {'dataset_names' : 'nlpcoreteam/enmmlu' , 'allow_vllm' : False }},
77
+ {'name' : 'treewayabstractive' , 'params' : {'dataset_names' : 'daru/treewayabstractive' , 'allow_vllm' : False , 'max_sample_per_dataset' : 200 }},
78
+ {'name' : 'cp_para_ru' , 'params' : {'dataset_names' : 'darumeru/cp_para_ru' , 'allow_vllm' : False }},
79
+ ]
80
+
28
81
task_groups = None
29
82
def get_current_groups (rank , total_workers ):
30
83
current_idx = [i for i in range (rank , len (task_groups ), total_workers )]
@@ -45,6 +98,10 @@ def run_eval(args, group, local_rank):
45
98
command += ['--device_map' , f'cuda:{ 0 } ' , '--output_dir' , output_dir ]
46
99
if args .force_recalc :
47
100
command += ['--force_recalc' ]
101
+
102
+ command += ['--alpha_scale' , str (args .alpha_scale )]
103
+ if args .not_scale_lm_head :
104
+ command += ['--not_scale_lm_head' ]
48
105
49
106
env = os .environ .copy ()
50
107
env ['CUDA_VISIBLE_DEVICES' ] = str (local_rank )
@@ -74,6 +131,10 @@ def func(command, env):
74
131
parser .add_argument ('--few_shot_count' , default = 0 , type = int )
75
132
parser .add_argument ('--vllm' , action = 'store_true' )
76
133
parser .add_argument ('--force_recalc' , action = 'store_true' )
134
+ parser .add_argument ('--alpha_scale' , type = float , default = 1.0 )
135
+ parser .add_argument ('--not_scale_lm_head' , action = 'store_true' )
136
+ parser .add_argument ('--short' , action = 'store_true' )
137
+
77
138
args = parser .parse_args ()
78
139
79
140
local_rank = int (os .environ ['LOCAL_RANK' ])
@@ -85,7 +146,10 @@ def func(command, env):
85
146
if int (args .few_shot_count ) > 0 :
86
147
task_groups = task_groups_few_shot
87
148
else :
88
- task_groups = task_groups_zero_shot
149
+ if args .short :
150
+ task_groups = task_groups_zero_shot_short_ver
151
+ else :
152
+ task_groups = task_groups_zero_shot
89
153
90
154
for group in get_current_groups (rank , workers ):
91
155
print (f'RANK { rank } starting { group } ' )
0 commit comments