Skip to content

Commit 8bec26e

Browse files
authoredMar 28, 2023
Add files via upload
1 parent 9043086 commit 8bec26e

4 files changed

+288
-22
lines changed
 

‎export_hf_checkpoint.py

-9
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,6 @@
33
import torch
44
import transformers
55
from peft import PeftModel
6-
7-
# Unused imports
8-
# import json
9-
# from peft import LoraConfig
10-
11-
assert (
12-
"LlamaTokenizer" in transformers._import_structure["models.llama"]
13-
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
14-
156
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
167

178
BASE_MODEL = os.environ.get("BASE_MODEL", None)

‎export_state_dict_checkpoint.py

-8
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,7 @@
33

44
import torch
55
import transformers
6-
7-
# Unused imports
8-
# from peft import LoraConfig
96
from peft import PeftModel
10-
11-
assert (
12-
"LlamaTokenizer" in transformers._import_structure["models.llama"]
13-
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
14-
157
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: E402
168

179
BASE_MODEL = os.environ.get("BASE_MODEL", None)

‎finetune_alpaca.py

+286
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
import os
2+
import sys
3+
from typing import List
4+
5+
import fire
6+
import torch
7+
import transformers
8+
from datasets import load_dataset
9+
10+
"""
11+
Unused imports:
12+
import torch.nn as nn
13+
import bitsandbytes as bnb
14+
"""
15+
16+
from peft import ( # noqa: E402
17+
LoraConfig,
18+
get_peft_model,
19+
get_peft_model_state_dict,
20+
prepare_model_for_int8_training,
21+
set_peft_model_state_dict,
22+
)
23+
from transformers import LlamaForCausalLM, LlamaTokenizer # noqa: F402
24+
25+
26+
def train(
27+
# model/data params
28+
base_model: str = "", # the only required argument
29+
data_path: str = "./alpaca_data_cleaned.json",
30+
output_dir: str = "./lora-alpaca",
31+
# training hyperparams
32+
batch_size: int = 128,
33+
micro_batch_size: int = 4,
34+
num_epochs: int = 3,
35+
learning_rate: float = 3e-4,
36+
cutoff_len: int = 256,
37+
val_set_size: int = 2000,
38+
# lora hyperparams
39+
lora_r: int = 8,
40+
lora_alpha: int = 16,
41+
lora_dropout: float = 0.05,
42+
lora_target_modules: List[str] = [
43+
"q_proj",
44+
"v_proj",
45+
],
46+
# llm hyperparams
47+
train_on_inputs: bool = True, # if False, masks out inputs in loss
48+
group_by_length: bool = False, # faster, but produces an odd training loss curve
49+
# wandb params
50+
wandb_project: str = "",
51+
wandb_run_name: str = "",
52+
wandb_watch: str = "", # options: false | gradients | all
53+
wandb_log_model: str = "", # options: false | true
54+
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
55+
):
56+
print(
57+
f"Training Alpaca-LoRA model with params:\n"
58+
f"base_model: {base_model}\n"
59+
f"data_path: {data_path}\n"
60+
f"output_dir: {output_dir}\n"
61+
f"batch_size: {batch_size}\n"
62+
f"micro_batch_size: {micro_batch_size}\n"
63+
f"num_epochs: {num_epochs}\n"
64+
f"learning_rate: {learning_rate}\n"
65+
f"cutoff_len: {cutoff_len}\n"
66+
f"val_set_size: {val_set_size}\n"
67+
f"lora_r: {lora_r}\n"
68+
f"lora_alpha: {lora_alpha}\n"
69+
f"lora_dropout: {lora_dropout}\n"
70+
f"lora_target_modules: {lora_target_modules}\n"
71+
f"train_on_inputs: {train_on_inputs}\n"
72+
f"group_by_length: {group_by_length}\n"
73+
f"wandb_project: {wandb_project}\n"
74+
f"wandb_run_name: {wandb_run_name}\n"
75+
f"wandb_watch: {wandb_watch}\n"
76+
f"wandb_log_model: {wandb_log_model}\n"
77+
f"resume_from_checkpoint: {resume_from_checkpoint}\n"
78+
)
79+
assert (
80+
base_model
81+
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
82+
gradient_accumulation_steps = batch_size // micro_batch_size
83+
84+
device_map = "auto"
85+
world_size = int(os.environ.get("WORLD_SIZE", 1))
86+
ddp = world_size != 1
87+
if ddp:
88+
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
89+
gradient_accumulation_steps = gradient_accumulation_steps // world_size
90+
91+
# Check if parameter passed or if set within environ
92+
use_wandb = len(wandb_project) > 0 or (
93+
"WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
94+
)
95+
# Only overwrite environ if wandb param passed
96+
if len(wandb_project) > 0:
97+
os.environ["WANDB_PROJECT"] = wandb_project
98+
if len(wandb_watch) > 0:
99+
os.environ["WANDB_WATCH"] = wandb_watch
100+
if len(wandb_log_model) > 0:
101+
os.environ["WANDB_LOG_MODEL"] = wandb_log_model
102+
103+
model = LlamaForCausalLM.from_pretrained(
104+
base_model,
105+
load_in_8bit=True,
106+
torch_dtype=torch.float16,
107+
device_map=device_map,
108+
)
109+
110+
tokenizer = LlamaTokenizer.from_pretrained(base_model)
111+
112+
tokenizer.pad_token_id = (
113+
0 # unk. we want this to be different from the eos token
114+
)
115+
tokenizer.padding_side = "left" # Allow batched inference
116+
117+
def tokenize(prompt, add_eos_token=True):
118+
# there's probably a way to do this with the tokenizer settings
119+
# but again, gotta move fast
120+
result = tokenizer(
121+
prompt,
122+
truncation=True,
123+
max_length=cutoff_len,
124+
padding=False,
125+
return_tensors=None,
126+
)
127+
if (
128+
result["input_ids"][-1] != tokenizer.eos_token_id
129+
and len(result["input_ids"]) < cutoff_len
130+
and add_eos_token
131+
):
132+
result["input_ids"].append(tokenizer.eos_token_id)
133+
result["attention_mask"].append(1)
134+
135+
result["labels"] = result["input_ids"].copy()
136+
137+
return result
138+
139+
def generate_and_tokenize_prompt(data_point):
140+
full_prompt = generate_prompt(data_point)
141+
tokenized_full_prompt = tokenize(full_prompt)
142+
if not train_on_inputs:
143+
user_prompt = generate_prompt({**data_point, "output": ""})
144+
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=False)
145+
user_prompt_len = len(tokenized_user_prompt["input_ids"])
146+
147+
tokenized_full_prompt["labels"] = [
148+
-100
149+
] * user_prompt_len + tokenized_full_prompt["labels"][
150+
user_prompt_len:
151+
] # could be sped up, probably
152+
return tokenized_full_prompt
153+
154+
model = prepare_model_for_int8_training(model)
155+
156+
config = LoraConfig(
157+
r=lora_r,
158+
lora_alpha=lora_alpha,
159+
target_modules=lora_target_modules,
160+
lora_dropout=lora_dropout,
161+
bias="none",
162+
task_type="CAUSAL_LM",
163+
)
164+
model = get_peft_model(model, config)
165+
166+
if data_path.endswith(".json"): # todo: support jsonl
167+
data = load_dataset("json", data_files=data_path)
168+
else:
169+
data = load_dataset(data_path)
170+
171+
if resume_from_checkpoint:
172+
# Check the available weights and load them
173+
checkpoint_name = os.path.join(
174+
resume_from_checkpoint, "pytorch_model.bin"
175+
) # Full checkpoint
176+
if not os.path.exists(checkpoint_name):
177+
checkpoint_name = os.path.join(
178+
resume_from_checkpoint, "adapter_model.bin"
179+
) # only LoRA model - LoRA config above has to fit
180+
resume_from_checkpoint = (
181+
False # So the trainer won't try loading its state
182+
)
183+
# The two files above have a different name depending on how they were saved, but are actually the same.
184+
if os.path.exists(checkpoint_name):
185+
print(f"Restarting from {checkpoint_name}")
186+
adapters_weights = torch.load(checkpoint_name)
187+
model = set_peft_model_state_dict(model, adapters_weights)
188+
else:
189+
print(f"Checkpoint {checkpoint_name} not found")
190+
191+
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
192+
193+
if val_set_size > 0:
194+
train_val = data["train"].train_test_split(
195+
test_size=val_set_size, shuffle=True, seed=42
196+
)
197+
train_data = (
198+
train_val["train"].shuffle().map(generate_and_tokenize_prompt)
199+
)
200+
val_data = (
201+
train_val["test"].shuffle().map(generate_and_tokenize_prompt)
202+
)
203+
else:
204+
train_data = data["train"].shuffle().map(generate_and_tokenize_prompt)
205+
val_data = None
206+
207+
if not ddp and torch.cuda.device_count() > 1:
208+
# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
209+
model.is_parallelizable = True
210+
model.model_parallel = True
211+
212+
trainer = transformers.Trainer(
213+
model=model,
214+
train_dataset=train_data,
215+
eval_dataset=val_data,
216+
args=transformers.TrainingArguments(
217+
per_device_train_batch_size=micro_batch_size,
218+
gradient_accumulation_steps=gradient_accumulation_steps,
219+
warmup_steps=100,
220+
num_train_epochs=num_epochs,
221+
learning_rate=learning_rate,
222+
fp16=True,
223+
logging_steps=10,
224+
optim="adamw_torch",
225+
evaluation_strategy="steps" if val_set_size > 0 else "no",
226+
save_strategy="steps",
227+
eval_steps=200 if val_set_size > 0 else None,
228+
save_steps=200,
229+
output_dir=output_dir,
230+
save_total_limit=3,
231+
load_best_model_at_end=True if val_set_size > 0 else False,
232+
ddp_find_unused_parameters=False if ddp else None,
233+
group_by_length=group_by_length,
234+
report_to="wandb" if use_wandb else None,
235+
run_name=wandb_run_name if use_wandb else None,
236+
),
237+
data_collator=transformers.DataCollatorForSeq2Seq(
238+
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
239+
),
240+
)
241+
model.config.use_cache = False
242+
243+
old_state_dict = model.state_dict
244+
model.state_dict = (
245+
lambda self, *_, **__: get_peft_model_state_dict(
246+
self, old_state_dict()
247+
)
248+
).__get__(model, type(model))
249+
250+
if torch.__version__ >= "2" and sys.platform != "win32":
251+
model = torch.compile(model)
252+
253+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
254+
255+
model.save_pretrained(output_dir)
256+
257+
print(
258+
"\n If there's a warning about missing keys above, please disregard :)"
259+
)
260+
261+
262+
def generate_prompt(data_point):
263+
# sorry about the formatting disaster gotta move fast
264+
if data_point["input"]:
265+
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
266+
267+
### Instruction:
268+
{data_point["instruction"]}
269+
270+
### Input:
271+
{data_point["input"]}
272+
273+
### Response:
274+
{data_point["output"]}"""
275+
else:
276+
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
277+
278+
### Instruction:
279+
{data_point["instruction"]}
280+
281+
### Response:
282+
{data_point["output"]}"""
283+
284+
285+
if __name__ == "__main__":
286+
fire.Fire(train)

‎generate.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,6 @@
55
import torch
66
import transformers
77
from peft import PeftModel
8-
9-
assert (
10-
"LlamaTokenizer" in transformers._import_structure["models.llama"]
11-
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git" # noqa: E501
128
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
139

1410
if torch.cuda.is_available():
@@ -27,6 +23,7 @@ def main(
2723
load_8bit: bool = False,
2824
base_model: str = "",
2925
lora_weights: str = "tloen/alpaca-lora-7b",
26+
share_gradio: bool = False,
3027
):
3128
assert (
3229
base_model
@@ -144,7 +141,7 @@ def evaluate(
144141
],
145142
title="🦙🌲 Alpaca-LoRA",
146143
description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).", # noqa: E501
147-
).launch()
144+
).launch(share=share_gradio)
148145
# Old testing code follows.
149146

150147
"""

0 commit comments

Comments
 (0)
Please sign in to comment.