-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathconvert_lora.py
129 lines (100 loc) · 4.92 KB
/
convert_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from utils import model_utils
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def convert_from_diffusers(prefix, weights_sd):
# convert from diffusers(?) to default LoRA
# Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
# default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
# note: Diffusers has no alpha, so alpha is set to rank
new_weights_sd = {}
lora_dims = {}
for key, weight in weights_sd.items():
diffusers_prefix, key_body = key.split(".", 1)
if diffusers_prefix != "diffusion_model" and diffusers_prefix !="transformer":
logger.warning(f"unexpected key: {key} in diffusers format")
continue
new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
new_weights_sd[new_key] = weight
lora_name = new_key.split(".")[0] # before first dot
if lora_name not in lora_dims and "lora_down" in new_key:
lora_dims[lora_name] = weight.shape[0]
# add alpha with rank
for lora_name, dim in lora_dims.items():
new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
return new_weights_sd
def convert_to_diffusers(prefix, weights_sd):
# convert from default LoRA to diffusers
# get alphas
lora_alphas = {}
for key, weight in weights_sd.items():
if key.startswith(prefix):
lora_name = key.split(".", 1)[0] # before first dot
if lora_name not in lora_alphas and "alpha" in key:
lora_alphas[lora_name] = weight
new_weights_sd = {}
for key, weight in weights_sd.items():
if key.startswith(prefix):
if "alpha" in key:
continue
lora_name = key.split(".", 1)[0] # before first dot
# HunyuanVideo lora name to module name: ugly but works
module_name = lora_name[len(prefix) :] # remove "lora_unet_"
module_name = module_name.replace("_", ".") # replace "_" with "."
module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
module_name = module_name.replace("img.", "img_") # fix img
module_name = module_name.replace("txt.", "txt_") # fix txt
module_name = module_name.replace("attn.", "attn_") # fix attn
diffusers_prefix = "diffusion_model"
if "lora_down" in key:
new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
dim = weight.shape[0]
elif "lora_up" in key:
new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
dim = weight.shape[1]
else:
logger.warning(f"unexpected key: {key} in default LoRA format")
continue
# scale weight by alpha
if lora_name in lora_alphas:
# we scale both down and up, so scale is sqrt
scale = lora_alphas[lora_name] / dim
scale = scale.sqrt()
weight = weight * scale
else:
logger.warning(f"missing alpha for {lora_name}")
new_weights_sd[new_key] = weight
return new_weights_sd
def convert(input_file, output_file, target_format):
logger.info(f"loading {input_file}")
weights_sd = load_file(input_file)
with safe_open(input_file, framework="pt") as f:
metadata = f.metadata()
logger.info(f"converting to {target_format}")
prefix = "lora_unet_"
if target_format == "default":
new_weights_sd = convert_from_diffusers(prefix, weights_sd)
metadata = metadata or {}
model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
elif target_format == "other":
new_weights_sd = convert_to_diffusers(prefix, weights_sd)
else:
raise ValueError(f"unknown target format: {target_format}")
logger.info(f"saving to {output_file}")
save_file(new_weights_sd, output_file, metadata=metadata)
logger.info("done")
def parse_args():
parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
parser.add_argument("--input", type=str, required=True, help="input model file")
parser.add_argument("--output", type=str, required=True, help="output model file")
parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
convert(args.input, args.output, args.target)