Skip to content

Commit f325b2e

Browse files
author
RefalMachine
committed
short ver
1 parent 8d74c4d commit f325b2e

File tree

3 files changed

+97
-4
lines changed

3 files changed

+97
-4
lines changed

run_evaluate_multinode_multigpu.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,59 @@
2525
{'name': 'copy_tasks_en', 'params': {'dataset_names': 'darumeru/cp_sent_en darumeru/cp_para_en', 'allow_vllm': False}}
2626
]
2727

28+
'''
29+
task_groups_zero_shot = [
30+
{'name': 'darumeru_rest', 'params': {'dataset_names': 'darumeru/parus darumeru/rcb darumeru/ruopenbookqa darumeru/ruworldtree darumeru/rwsd russiannlp/rucola_custom', 'allow_vllm': False}},
31+
{'name': 'darumeru_mmlu_ru_extractive', 'params': {'dataset_names': 'darumeru/rummlu daru/treewayextractive', 'allow_vllm': False}},
32+
{'name': 'nlpcoreteam_mmlu', 'params': {'dataset_names': 'nlpcoreteam/rummlu nlpcoreteam/enmmlu', 'allow_vllm': False}},
33+
{'name': 'treewayabstractive', 'params': {'dataset_names': 'daru/treewayabstractive', 'allow_vllm': False, 'max_sample_per_dataset': 500}},
34+
{'name': 'darumeru_multiq_use', 'params': {'dataset_names': 'darumeru/multiq darumeru/use', 'allow_vllm': False}},
35+
{'name': 'copy_tasks', 'params': {'dataset_names': 'darumeru/cp_sent_ru darumeru/cp_para_ru darumeru/cp_sent_en darumeru/cp_para_en', 'allow_vllm': False}}
36+
]
37+
'''
38+
39+
task_groups_zero_shot = [
40+
{'name': 'darumeru_rest', 'params': {'dataset_names': 'darumeru/multiq darumeru/use darumeru/parus darumeru/rcb darumeru/ruopenbookqa darumeru/ruworldtree darumeru/rwsd russiannlp/rucola_custom', 'allow_vllm': False}},
41+
{'name': 'darumeru_mmlu_ru_extractive', 'params': {'dataset_names': 'darumeru/rummlu daru/treewayextractive', 'allow_vllm': False}},
42+
{'name': 'nlpcoreteam_mmlu', 'params': {'dataset_names': 'nlpcoreteam/rummlu nlpcoreteam/enmmlu', 'allow_vllm': False}},
43+
{'name': 'treewayabstractive', 'params': {'dataset_names': 'daru/treewayabstractive', 'allow_vllm': False, 'max_sample_per_dataset': 500}},
44+
{'name': 'copy_tasks', 'params': {'dataset_names': 'darumeru/cp_sent_ru darumeru/cp_para_ru darumeru/cp_sent_en darumeru/cp_para_en', 'allow_vllm': False}}
45+
]
46+
47+
task_groups_zero_shot = [
48+
{'name': 'multiq', 'params': {'dataset_names': 'darumeru/multiq', 'allow_vllm': False}},
49+
{'name': 'use', 'params': {'dataset_names': 'darumeru/use', 'allow_vllm': False}},
50+
{'name': 'parus', 'params': {'dataset_names': 'darumeru/parus', 'allow_vllm': False}},
51+
{'name': 'rcb', 'params': {'dataset_names': 'darumeru/rcb', 'allow_vllm': False}},
52+
{'name': 'ruopenbookqa', 'params': {'dataset_names': 'darumeru/ruopenbookqa', 'allow_vllm': False}},
53+
{'name': 'ruworldtree', 'params': {'dataset_names': 'darumeru/ruworldtree', 'allow_vllm': False}},
54+
{'name': 'rwsd', 'params': {'dataset_names': 'darumeru/rwsd', 'allow_vllm': False}},
55+
{'name': 'rucola_custom', 'params': {'dataset_names': 'russiannlp/rucola_custom', 'allow_vllm': False}},
56+
{'name': 'darumeru_rummlu', 'params': {'dataset_names': 'darumeru/rummlu', 'allow_vllm': False}},
57+
{'name': 'daru_extractive', 'params': {'dataset_names': 'daru/treewayextractive', 'allow_vllm': False}},
58+
{'name': 'nlpcoreteam_rummlu', 'params': {'dataset_names': 'nlpcoreteam/rummlu', 'allow_vllm': False}},
59+
{'name': 'nlpcoreteam_enmmlu', 'params': {'dataset_names': 'nlpcoreteam/enmmlu', 'allow_vllm': False}},
60+
{'name': 'treewayabstractive', 'params': {'dataset_names': 'daru/treewayabstractive', 'allow_vllm': False, 'max_sample_per_dataset': 500}},
61+
{'name': 'cp_sent_ru', 'params': {'dataset_names': 'darumeru/cp_sent_ru', 'allow_vllm': False}},
62+
{'name': 'cp_para_ru', 'params': {'dataset_names': 'darumeru/cp_para_ru', 'allow_vllm': False}},
63+
{'name': 'cp_sent_en', 'params': {'dataset_names': 'darumeru/cp_sent_en', 'allow_vllm': False}},
64+
{'name': 'cp_para_en', 'params': {'dataset_names': 'darumeru/cp_para_en', 'allow_vllm': False}}
65+
]
66+
67+
task_groups_zero_shot_short_ver = [
68+
{'name': 'multiq', 'params': {'dataset_names': 'darumeru/multiq', 'allow_vllm': False}},
69+
{'name': 'parus', 'params': {'dataset_names': 'darumeru/parus', 'allow_vllm': False}},
70+
{'name': 'rcb', 'params': {'dataset_names': 'darumeru/rcb', 'allow_vllm': False}},
71+
{'name': 'ruopenbookqa', 'params': {'dataset_names': 'darumeru/ruopenbookqa', 'allow_vllm': False}},
72+
{'name': 'ruworldtree', 'params': {'dataset_names': 'darumeru/ruworldtree', 'allow_vllm': False}},
73+
{'name': 'rwsd', 'params': {'dataset_names': 'darumeru/rwsd', 'allow_vllm': False}},
74+
{'name': 'daru_extractive', 'params': {'dataset_names': 'daru/treewayextractive', 'allow_vllm': False, 'max_sample_per_dataset': 1000}},
75+
{'name': 'nlpcoreteam_rummlu', 'params': {'dataset_names': 'nlpcoreteam/rummlu', 'allow_vllm': False}},
76+
{'name': 'nlpcoreteam_enmmlu', 'params': {'dataset_names': 'nlpcoreteam/enmmlu', 'allow_vllm': False}},
77+
{'name': 'treewayabstractive', 'params': {'dataset_names': 'daru/treewayabstractive', 'allow_vllm': False, 'max_sample_per_dataset': 200}},
78+
{'name': 'cp_para_ru', 'params': {'dataset_names': 'darumeru/cp_para_ru', 'allow_vllm': False}},
79+
]
80+
2881
task_groups = None
2982
def get_current_groups(rank, total_workers):
3083
current_idx = [i for i in range(rank, len(task_groups), total_workers)]
@@ -45,6 +98,10 @@ def run_eval(args, group, local_rank):
4598
command += ['--device_map', f'cuda:{0}', '--output_dir', output_dir]
4699
if args.force_recalc:
47100
command += ['--force_recalc']
101+
102+
command += ['--alpha_scale', str(args.alpha_scale)]
103+
if args.not_scale_lm_head:
104+
command += ['--not_scale_lm_head']
48105

49106
env = os.environ.copy()
50107
env['CUDA_VISIBLE_DEVICES'] = str(local_rank)
@@ -74,6 +131,10 @@ def func(command, env):
74131
parser.add_argument('--few_shot_count', default=0, type=int)
75132
parser.add_argument('--vllm', action='store_true')
76133
parser.add_argument('--force_recalc', action='store_true')
134+
parser.add_argument('--alpha_scale', type=float, default=1.0)
135+
parser.add_argument('--not_scale_lm_head', action='store_true')
136+
parser.add_argument('--short', action='store_true')
137+
77138
args = parser.parse_args()
78139

79140
local_rank = int(os.environ['LOCAL_RANK'])
@@ -85,7 +146,10 @@ def func(command, env):
85146
if int(args.few_shot_count) > 0:
86147
task_groups = task_groups_few_shot
87148
else:
88-
task_groups = task_groups_zero_shot
149+
if args.short:
150+
task_groups = task_groups_zero_shot_short_ver
151+
else:
152+
task_groups = task_groups_zero_shot
89153

90154
for group in get_current_groups(rank, workers):
91155
print(f'RANK {rank} starting {group}')

run_evaluate_singlenode_multigpu.sh

+28-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,34 @@ pip install pytest==8.0.0
77

88
echo $GPUS_PER_NODE
99

10-
torchrun --nnodes=1 --nproc-per-node=$GPUS_PER_NODE run_evaluate_multinode_multigpu.py \
11-
--model_dir ../../data/models/llama3_cluster/ruadapt_llama3_bpe_extended_part1-2_vo_1e4_bs256 \
12-
--conv_path conversation_configs/non_instruct_simple.json \
10+
torchrun --nnodes=1 --nproc-per-node=6 run_evaluate_multinode_multigpu.py \
11+
--model_dir ../../data/models/saiga_scored_d7_mistral \
12+
--conv_path conversation_configs/openchat_3.5_1210.json \
13+
--output_dir ../../data/models/saiga_scored_d7_mistral/llmtf_eval_k5 \
14+
--batch_size 1 \
15+
--max_len 4000 \
16+
--few_shot_count 5
17+
18+
torchrun --nnodes=1 --nproc-per-node=6 run_evaluate_multinode_multigpu.py \
19+
--model_dir ../../data/models/saiga_scored_d7_mistral_extended_darulm_20_05_24_part1-2_32000_bpe_full_lr1e4_bs256 \
20+
--conv_path conversation_configs/openchat_3.5_1210.json \
21+
--output_dir ../../data/models/saiga_scored_d7_mistral_extended_darulm_20_05_24_part1-2_32000_bpe_full_lr1e4_bs256/llmtf_eval_k5 \
22+
--batch_size 1 \
23+
--max_len 4000 \
24+
--few_shot_count 5
25+
26+
torchrun --nnodes=1 --nproc-per-node=6 run_evaluate_multinode_multigpu.py \
27+
--model_dir ../../data/models/saiga_scored_d7_mistral_darulm_20_05_24_part1-2_32000_unigram_full_lr1e4_bs256 \
28+
--conv_path conversation_configs/openchat_3.5_1210.json \
29+
--output_dir ../../data/models/saiga_scored_d7_mistral_darulm_20_05_24_part1-2_32000_unigram_full_lr1e4_bs256/llmtf_eval_k5 \
30+
--batch_size 1 \
31+
--max_len 4000 \
32+
--few_shot_count 5
33+
34+
torchrun --nnodes=1 --nproc-per-node=6 run_evaluate_multinode_multigpu.py \
35+
--model_dir ../../data/models/saiga_scored_d7_mistral_darulm_20_05_24_part1-2_32000_bpe_full_lr1e4_bs256 \
36+
--conv_path conversation_configs/openchat_3.5_1210.json \
37+
--output_dir ../../data/models/saiga_scored_d7_mistral_darulm_20_05_24_part1-2_32000_bpe_full_lr1e4_bs256/llmtf_eval_k5 \
1338
--batch_size 1 \
1439
--max_len 4000 \
1540
--few_shot_count 5

todo.txt

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#TODO: обработка few-shot не с 0 до K, а сразу с K, а затем уменьшать, если не влезает. (должно существенно ускорить подготовку данных)
22

3+
#TODO: stop strings tructation.
4+
5+
#TODO: \n at atart of generation fix
6+
37
#TODO: ленивая инициализация модели.
48

59
#TODO: task gen config context

0 commit comments

Comments
 (0)