|
| 1 | +import os |
| 2 | +os.environ['TRANSFORMERS_CACHE'] = '/data/private/peilin/cache' |
| 3 | +os.environ['HF_HOME'] = '/data/private/peilin/cache' |
| 4 | +import sys |
| 5 | +from typing import List |
| 6 | + |
| 7 | +import fire |
| 8 | +import torch |
| 9 | +import transformers |
| 10 | +from datasets import load_dataset |
| 11 | + |
| 12 | +""" |
| 13 | +Unused imports: |
| 14 | +import torch.nn as nn |
| 15 | +import bitsandbytes as bnb |
| 16 | +""" |
| 17 | + |
| 18 | +from peft import ( |
| 19 | + LoraConfig, |
| 20 | + get_peft_model, |
| 21 | + get_peft_model_state_dict, |
| 22 | + prepare_model_for_int8_training, |
| 23 | + set_peft_model_state_dict, |
| 24 | +) |
| 25 | +from transformers import LlamaForCausalLM, LlamaTokenizer |
| 26 | + |
| 27 | +from utils.prompter import Prompter |
| 28 | +from utils.data_format_transform import filter_and_convert_rec |
| 29 | + |
| 30 | + |
| 31 | +def train( |
| 32 | + # model/data params |
| 33 | + base_model: str = "decapoda-research/llama-7b-hf", # the only required argument |
| 34 | + data_path: str = "data/beauty_sequential_single_prompt_train_sample.json", |
| 35 | + valid_data_path: str = "data/beauty_sequential_single_prompt_test.json", |
| 36 | + output_dir: str = "./lora-alpaca", |
| 37 | + # training hyperparams |
| 38 | + batch_size: int = 32, |
| 39 | + micro_batch_size: int = 2, |
| 40 | + num_epochs: int = 10, |
| 41 | + learning_rate: float = 3e-4, |
| 42 | + cutoff_len: int = 256, |
| 43 | + val_set_size: int = 500, |
| 44 | + save_eval_steps: int = 10, |
| 45 | + # lora hyperparams |
| 46 | + lora_r: int = 8, |
| 47 | + lora_alpha: int = 16, |
| 48 | + lora_dropout: float = 0.05, |
| 49 | + lora_target_modules: List[str] = [ |
| 50 | + "q_proj", |
| 51 | + "v_proj", |
| 52 | + ], |
| 53 | + # llm hyperparams |
| 54 | + train_on_inputs: bool = False, # if False, masks out inputs in loss |
| 55 | + group_by_length: bool = False, # faster, but produces an odd training loss curve |
| 56 | + # wandb params |
| 57 | + wandb_project: str = "llama_med", |
| 58 | + wandb_run_name: str = "", |
| 59 | + wandb_watch: str = "", # options: false | gradients | all |
| 60 | + wandb_log_model: str = "", # options: false | true |
| 61 | + resume_from_checkpoint: str = None, # either training checkpoint or final adapter |
| 62 | + prompt_template_name: str = "alpaca", # The prompt template to use, will default to alpaca. |
| 63 | +): |
| 64 | + if int(os.environ.get("LOCAL_RANK", 0)) == 0: |
| 65 | + print( |
| 66 | + f"Training Alpaca-LoRA model with params:\n" |
| 67 | + f"base_model: {base_model}\n" |
| 68 | + f"data_path: {data_path}\n" |
| 69 | + f"output_dir: {output_dir}\n" |
| 70 | + f"batch_size: {batch_size}\n" |
| 71 | + f"micro_batch_size: {micro_batch_size}\n" |
| 72 | + f"num_epochs: {num_epochs}\n" |
| 73 | + f"learning_rate: {learning_rate}\n" |
| 74 | + f"cutoff_len: {cutoff_len}\n" |
| 75 | + f"val_set_size: {val_set_size}\n" |
| 76 | + f"lora_r: {lora_r}\n" |
| 77 | + f"lora_alpha: {lora_alpha}\n" |
| 78 | + f"lora_dropout: {lora_dropout}\n" |
| 79 | + f"lora_target_modules: {lora_target_modules}\n" |
| 80 | + f"train_on_inputs: {train_on_inputs}\n" |
| 81 | + f"group_by_length: {group_by_length}\n" |
| 82 | + f"wandb_project: {wandb_project}\n" |
| 83 | + f"wandb_run_name: {wandb_run_name}\n" |
| 84 | + f"wandb_watch: {wandb_watch}\n" |
| 85 | + f"wandb_log_model: {wandb_log_model}\n" |
| 86 | + f"resume_from_checkpoint: {resume_from_checkpoint or False}\n" |
| 87 | + f"prompt template: {prompt_template_name}\n" |
| 88 | + ) |
| 89 | + assert ( |
| 90 | + base_model |
| 91 | + ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'" |
| 92 | + gradient_accumulation_steps = batch_size // micro_batch_size |
| 93 | + |
| 94 | + prompter = Prompter(prompt_template_name) |
| 95 | + |
| 96 | + device_map = "auto" |
| 97 | + world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| 98 | + ddp = world_size != 1 |
| 99 | + if ddp: |
| 100 | + device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} |
| 101 | + gradient_accumulation_steps = gradient_accumulation_steps // world_size |
| 102 | + |
| 103 | + # Check if parameter passed or if set within environ |
| 104 | + use_wandb = len(wandb_project) > 0 or ( |
| 105 | + "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0 |
| 106 | + ) |
| 107 | + # Only overwrite environ if wandb param passed |
| 108 | + if len(wandb_project) > 0: |
| 109 | + os.environ["WANDB_PROJECT"] = wandb_project |
| 110 | + if len(wandb_watch) > 0: |
| 111 | + os.environ["WANDB_WATCH"] = wandb_watch |
| 112 | + if len(wandb_log_model) > 0: |
| 113 | + os.environ["WANDB_LOG_MODEL"] = wandb_log_model |
| 114 | + |
| 115 | + model = LlamaForCausalLM.from_pretrained( |
| 116 | + base_model, |
| 117 | + load_in_8bit=True, |
| 118 | + torch_dtype=torch.float16, |
| 119 | + device_map=device_map, |
| 120 | + ) |
| 121 | + |
| 122 | + tokenizer = LlamaTokenizer.from_pretrained(base_model) |
| 123 | + |
| 124 | + tokenizer.pad_token_id = ( |
| 125 | + 0 # unk. we want this to be different from the eos token |
| 126 | + ) |
| 127 | + tokenizer.padding_side = "left" # Allow batched inference |
| 128 | + |
| 129 | + def tokenize(prompt, add_eos_token=True): |
| 130 | + # there's probably a way to do this with the tokenizer settings |
| 131 | + # but again, gotta move fast |
| 132 | + result = tokenizer( |
| 133 | + prompt, |
| 134 | + truncation=True, |
| 135 | + max_length=cutoff_len, |
| 136 | + padding=False, |
| 137 | + return_tensors=None, |
| 138 | + ) |
| 139 | + if ( |
| 140 | + result["input_ids"][-1] != tokenizer.eos_token_id |
| 141 | + and len(result["input_ids"]) < cutoff_len |
| 142 | + and add_eos_token |
| 143 | + ): |
| 144 | + result["input_ids"].append(tokenizer.eos_token_id) |
| 145 | + result["attention_mask"].append(1) |
| 146 | + |
| 147 | + result["labels"] = result["input_ids"].copy() |
| 148 | + |
| 149 | + return result |
| 150 | + |
| 151 | + def generate_and_tokenize_prompt(data_point): |
| 152 | + full_prompt = prompter.generate_prompt( |
| 153 | + data_point["source"], |
| 154 | + None, |
| 155 | + data_point["target"], |
| 156 | + ) |
| 157 | + tokenized_full_prompt = tokenize(full_prompt) |
| 158 | + if not train_on_inputs: |
| 159 | + user_prompt = prompter.generate_prompt( |
| 160 | + data_point["source"], None |
| 161 | + ) |
| 162 | + tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False) |
| 163 | + user_prompt_len = len(tokenized_user_prompt["input_ids"]) |
| 164 | + |
| 165 | + tokenized_full_prompt["labels"] = [ |
| 166 | + -100 |
| 167 | + ] * user_prompt_len + tokenized_full_prompt["labels"][ |
| 168 | + user_prompt_len: |
| 169 | + ] # could be sped up, probably |
| 170 | + return tokenized_full_prompt |
| 171 | + |
| 172 | + model = prepare_model_for_int8_training(model) |
| 173 | + |
| 174 | + config = LoraConfig( |
| 175 | + r=lora_r, |
| 176 | + lora_alpha=lora_alpha, |
| 177 | + target_modules=lora_target_modules, |
| 178 | + lora_dropout=lora_dropout, |
| 179 | + bias="none", |
| 180 | + task_type="CAUSAL_LM", |
| 181 | + ) |
| 182 | + model = get_peft_model(model, config) |
| 183 | + |
| 184 | + if data_path.endswith(".json") or data_path.endswith(".jsonl"): |
| 185 | + data = load_dataset("json", data_files=data_path) |
| 186 | + else: |
| 187 | + data = load_dataset(data_path) |
| 188 | + |
| 189 | + if os.path.exists(valid_data_path): |
| 190 | + if valid_data_path.endswith(".json") or valid_data_path.endswith(".jsonl"): |
| 191 | + valid_data = load_dataset("json", data_files=valid_data_path) |
| 192 | + else: |
| 193 | + valid_data = load_dataset(data_path) |
| 194 | + else: |
| 195 | + valid_data = None |
| 196 | + |
| 197 | + if resume_from_checkpoint: |
| 198 | + # Check the available weights and load them |
| 199 | + checkpoint_name = os.path.join( |
| 200 | + resume_from_checkpoint, "pytorch_model.bin" |
| 201 | + ) # Full checkpoint |
| 202 | + if not os.path.exists(checkpoint_name): |
| 203 | + checkpoint_name = os.path.join( |
| 204 | + resume_from_checkpoint, "adapter_model.bin" |
| 205 | + ) # only LoRA model - LoRA config above has to fit |
| 206 | + resume_from_checkpoint = ( |
| 207 | + False # So the trainer won't try loading its state |
| 208 | + ) |
| 209 | + # The two files above have a different name depending on how they were saved, but are actually the same. |
| 210 | + if os.path.exists(checkpoint_name): |
| 211 | + print(f"Restarting from {checkpoint_name}") |
| 212 | + adapters_weights = torch.load(checkpoint_name) |
| 213 | + model = set_peft_model_state_dict(model, adapters_weights) |
| 214 | + else: |
| 215 | + print(f"Checkpoint {checkpoint_name} not found") |
| 216 | + |
| 217 | + model.print_trainable_parameters() # Be more transparent about the % of trainable params. |
| 218 | + |
| 219 | + if val_set_size > 0 and not valid_data: |
| 220 | + train_val = data["train"].train_test_split( |
| 221 | + test_size=val_set_size, shuffle=True, seed=2023 |
| 222 | + ) |
| 223 | + train_data = ( |
| 224 | + train_val["train"].shuffle().map(generate_and_tokenize_prompt) |
| 225 | + ) |
| 226 | + val_data = ( |
| 227 | + train_val["test"].shuffle().map(generate_and_tokenize_prompt) |
| 228 | + ) |
| 229 | + elif val_set_size > 0 and valid_data: |
| 230 | + train_data = ( |
| 231 | + data["train"].shuffle(seed=2023).map(generate_and_tokenize_prompt) |
| 232 | + ) |
| 233 | + val_sample = valid_data["train"].train_test_split( |
| 234 | + test_size=val_set_size, shuffle=True, seed=2023 |
| 235 | + ) |
| 236 | + val_data = ( |
| 237 | + val_sample["test"].shuffle().map(generate_and_tokenize_prompt) |
| 238 | + ) |
| 239 | + else: |
| 240 | + train_data = data["train"].shuffle().map(generate_and_tokenize_prompt) |
| 241 | + val_data = None |
| 242 | + |
| 243 | + if not ddp and torch.cuda.device_count() > 1: |
| 244 | + # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available |
| 245 | + model.is_parallelizable = True |
| 246 | + model.model_parallel = True |
| 247 | + |
| 248 | + trainer = transformers.Trainer( |
| 249 | + model=model, |
| 250 | + train_dataset=train_data, |
| 251 | + eval_dataset=val_data, |
| 252 | + args=transformers.TrainingArguments( |
| 253 | + per_device_train_batch_size=micro_batch_size, |
| 254 | + gradient_accumulation_steps=gradient_accumulation_steps, |
| 255 | + warmup_ratio=0.1, |
| 256 | + num_train_epochs=num_epochs, |
| 257 | + learning_rate=learning_rate, |
| 258 | + fp16=True, |
| 259 | + logging_steps=8, |
| 260 | + optim="adamw_torch", |
| 261 | + evaluation_strategy="steps" if val_set_size > 0 else "no", |
| 262 | + save_strategy="steps", |
| 263 | + eval_steps=save_eval_steps if val_set_size > 0 else None, |
| 264 | + save_steps=save_eval_steps, |
| 265 | + output_dir=output_dir, |
| 266 | + save_total_limit=1000, |
| 267 | + load_best_model_at_end=True if val_set_size > 0 else False, |
| 268 | + ddp_find_unused_parameters=False if ddp else None, |
| 269 | + group_by_length=group_by_length, |
| 270 | + report_to="wandb" if use_wandb else None, |
| 271 | + run_name=wandb_run_name if use_wandb else None, |
| 272 | + ), |
| 273 | + data_collator=transformers.DataCollatorForSeq2Seq( |
| 274 | + tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True |
| 275 | + ), |
| 276 | + ) |
| 277 | + model.config.use_cache = False |
| 278 | + |
| 279 | + old_state_dict = model.state_dict |
| 280 | + model.state_dict = ( |
| 281 | + lambda self, *_, **__: get_peft_model_state_dict( |
| 282 | + self, old_state_dict() |
| 283 | + ) |
| 284 | + ).__get__(model, type(model)) |
| 285 | + |
| 286 | + if torch.__version__ >= "2" and sys.platform != "win32": |
| 287 | + model = torch.compile(model) |
| 288 | + |
| 289 | + trainer.train(resume_from_checkpoint=resume_from_checkpoint) |
| 290 | + |
| 291 | + model.save_pretrained(output_dir) |
| 292 | + |
| 293 | + print( |
| 294 | + "\n If there's a warning about missing keys above, please disregard :)" |
| 295 | + ) |
| 296 | + |
| 297 | + |
| 298 | +if __name__ == "__main__": |
| 299 | + fire.Fire(train) |
0 commit comments