Skip to content

Commit e3d2139

Browse files
author
xie river
committed
dynamic loading models
2 parents 972a04f + f352ab2 commit e3d2139

21 files changed

+1726
-291
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from modules import extra_networks, shared
2+
import lora
3+
4+
class ExtraNetworkLora(extra_networks.ExtraNetwork):
5+
def __init__(self):
6+
super().__init__('lora')
7+
8+
def activate(self, p, params_list):
9+
additional = shared.opts.sd_lora
10+
11+
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
12+
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
13+
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
14+
15+
names = []
16+
multipliers = []
17+
for params in params_list:
18+
assert len(params.items) > 0
19+
20+
names.append(params.items[0])
21+
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
22+
23+
lora.load_loras(names, multipliers)
24+
25+
def deactivate(self, p):
26+
pass

extensions-builtin/Lora/lora.py

+362
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,362 @@
1+
import glob
2+
import os
3+
import re
4+
import torch
5+
from typing import Union
6+
7+
from modules import shared, devices, sd_models, errors
8+
9+
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
10+
11+
re_digits = re.compile(r"\d+")
12+
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
13+
re_compiled = {}
14+
15+
suffix_conversion = {
16+
"attentions": {},
17+
"resnets": {
18+
"conv1": "in_layers_2",
19+
"conv2": "out_layers_3",
20+
"time_emb_proj": "emb_layers_1",
21+
"conv_shortcut": "skip_connection",
22+
}
23+
}
24+
25+
26+
def convert_diffusers_name_to_compvis(key, is_sd2):
27+
def match(match_list, regex_text):
28+
regex = re_compiled.get(regex_text)
29+
if regex is None:
30+
regex = re.compile(regex_text)
31+
re_compiled[regex_text] = regex
32+
33+
r = re.match(regex, key)
34+
if not r:
35+
return False
36+
37+
match_list.clear()
38+
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
39+
return True
40+
41+
m = []
42+
43+
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
44+
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
45+
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
46+
47+
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
48+
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
49+
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
50+
51+
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
52+
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
53+
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
54+
55+
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
56+
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
57+
58+
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
59+
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
60+
61+
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
62+
if is_sd2:
63+
if 'mlp_fc1' in m[1]:
64+
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
65+
elif 'mlp_fc2' in m[1]:
66+
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
67+
else:
68+
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
69+
70+
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
71+
72+
return key
73+
74+
75+
class LoraOnDisk:
76+
def __init__(self, name, filename):
77+
self.name = name
78+
self.filename = filename
79+
self.metadata = {}
80+
81+
_, ext = os.path.splitext(filename)
82+
if ext.lower() == ".safetensors":
83+
try:
84+
self.metadata = sd_models.read_metadata_from_safetensors(filename)
85+
except Exception as e:
86+
errors.display(e, f"reading lora {filename}")
87+
88+
if self.metadata:
89+
m = {}
90+
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
91+
m[k] = v
92+
93+
self.metadata = m
94+
95+
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
96+
97+
98+
class LoraModule:
99+
def __init__(self, name):
100+
self.name = name
101+
self.multiplier = 1.0
102+
self.modules = {}
103+
self.mtime = None
104+
105+
106+
class LoraUpDownModule:
107+
def __init__(self):
108+
self.up = None
109+
self.down = None
110+
self.alpha = None
111+
112+
113+
def assign_lora_names_to_compvis_modules(sd_model):
114+
lora_layer_mapping = {}
115+
116+
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
117+
lora_name = name.replace(".", "_")
118+
lora_layer_mapping[lora_name] = module
119+
module.lora_layer_name = lora_name
120+
121+
for name, module in shared.sd_model.model.named_modules():
122+
lora_name = name.replace(".", "_")
123+
lora_layer_mapping[lora_name] = module
124+
module.lora_layer_name = lora_name
125+
126+
sd_model.lora_layer_mapping = lora_layer_mapping
127+
128+
129+
def load_lora(name, filename):
130+
lora = LoraModule(name)
131+
lora.mtime = os.path.getmtime(filename)
132+
133+
sd = sd_models.read_state_dict(filename)
134+
135+
keys_failed_to_match = {}
136+
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
137+
138+
for key_diffusers, weight in sd.items():
139+
key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
140+
key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
141+
142+
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
143+
144+
if sd_module is None:
145+
m = re_x_proj.match(key)
146+
if m:
147+
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
148+
149+
if sd_module is None:
150+
keys_failed_to_match[key_diffusers] = key
151+
continue
152+
153+
lora_module = lora.modules.get(key, None)
154+
if lora_module is None:
155+
lora_module = LoraUpDownModule()
156+
lora.modules[key] = lora_module
157+
158+
if lora_key == "alpha":
159+
lora_module.alpha = weight.item()
160+
continue
161+
162+
if type(sd_module) == torch.nn.Linear:
163+
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
164+
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
165+
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
166+
elif type(sd_module) == torch.nn.MultiheadAttention:
167+
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
168+
elif type(sd_module) == torch.nn.Conv2d:
169+
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
170+
else:
171+
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
172+
continue
173+
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
174+
175+
with torch.no_grad():
176+
module.weight.copy_(weight)
177+
178+
module.to(device=devices.cpu, dtype=devices.dtype)
179+
180+
if lora_key == "lora_up.weight":
181+
lora_module.up = module
182+
elif lora_key == "lora_down.weight":
183+
lora_module.down = module
184+
else:
185+
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
186+
187+
if len(keys_failed_to_match) > 0:
188+
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
189+
190+
return lora
191+
192+
193+
def load_loras(names, multipliers=None):
194+
already_loaded = {}
195+
196+
for lora in loaded_loras:
197+
if lora.name in names:
198+
already_loaded[lora.name] = lora
199+
200+
loaded_loras.clear()
201+
202+
loras_on_disk = [available_loras.get(name, None) for name in names]
203+
if any([x is None for x in loras_on_disk]):
204+
list_available_loras()
205+
206+
loras_on_disk = [available_loras.get(name, None) for name in names]
207+
208+
for i, name in enumerate(names):
209+
lora = already_loaded.get(name, None)
210+
211+
lora_on_disk = loras_on_disk[i]
212+
if lora_on_disk is not None:
213+
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
214+
lora = load_lora(name, lora_on_disk.filename)
215+
216+
if lora is None:
217+
print(f"Couldn't find Lora with name {name}")
218+
continue
219+
220+
lora.multiplier = multipliers[i] if multipliers else 1.0
221+
loaded_loras.append(lora)
222+
223+
224+
def lora_calc_updown(lora, module, target):
225+
with torch.no_grad():
226+
up = module.up.weight.to(target.device, dtype=target.dtype)
227+
down = module.down.weight.to(target.device, dtype=target.dtype)
228+
229+
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
230+
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
231+
else:
232+
updown = up @ down
233+
234+
updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
235+
236+
return updown
237+
238+
239+
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
240+
"""
241+
Applies the currently selected set of Loras to the weights of torch layer self.
242+
If weights already have this particular set of loras applied, does nothing.
243+
If not, restores orginal weights from backup and alters weights according to loras.
244+
"""
245+
246+
lora_layer_name = getattr(self, 'lora_layer_name', None)
247+
if lora_layer_name is None:
248+
return
249+
250+
current_names = getattr(self, "lora_current_names", ())
251+
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
252+
253+
weights_backup = getattr(self, "lora_weights_backup", None)
254+
if weights_backup is None:
255+
if isinstance(self, torch.nn.MultiheadAttention):
256+
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
257+
else:
258+
weights_backup = self.weight.to(devices.cpu, copy=True)
259+
260+
self.lora_weights_backup = weights_backup
261+
262+
if current_names != wanted_names:
263+
if weights_backup is not None:
264+
if isinstance(self, torch.nn.MultiheadAttention):
265+
self.in_proj_weight.copy_(weights_backup[0])
266+
self.out_proj.weight.copy_(weights_backup[1])
267+
else:
268+
self.weight.copy_(weights_backup)
269+
270+
for lora in loaded_loras:
271+
module = lora.modules.get(lora_layer_name, None)
272+
if module is not None and hasattr(self, 'weight'):
273+
self.weight += lora_calc_updown(lora, module, self.weight)
274+
continue
275+
276+
module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
277+
module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
278+
module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
279+
module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
280+
281+
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
282+
updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
283+
updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
284+
updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
285+
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
286+
287+
self.in_proj_weight += updown_qkv
288+
self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
289+
continue
290+
291+
if module is None:
292+
continue
293+
294+
print(f'failed to calculate lora weights for layer {lora_layer_name}')
295+
296+
setattr(self, "lora_current_names", wanted_names)
297+
298+
299+
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
300+
setattr(self, "lora_current_names", ())
301+
setattr(self, "lora_weights_backup", None)
302+
303+
304+
def lora_Linear_forward(self, input):
305+
lora_apply_weights(self)
306+
307+
return torch.nn.Linear_forward_before_lora(self, input)
308+
309+
310+
def lora_Linear_load_state_dict(self, *args, **kwargs):
311+
lora_reset_cached_weight(self)
312+
313+
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
314+
315+
316+
def lora_Conv2d_forward(self, input):
317+
lora_apply_weights(self)
318+
319+
return torch.nn.Conv2d_forward_before_lora(self, input)
320+
321+
322+
def lora_Conv2d_load_state_dict(self, *args, **kwargs):
323+
lora_reset_cached_weight(self)
324+
325+
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
326+
327+
328+
def lora_MultiheadAttention_forward(self, *args, **kwargs):
329+
lora_apply_weights(self)
330+
331+
return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
332+
333+
334+
def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
335+
lora_reset_cached_weight(self)
336+
337+
return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
338+
339+
340+
def list_available_loras():
341+
available_loras.clear()
342+
343+
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
344+
345+
candidates = \
346+
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
347+
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
348+
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
349+
350+
for filename in sorted(candidates, key=str.lower):
351+
if os.path.isdir(filename):
352+
continue
353+
354+
name = os.path.splitext(os.path.basename(filename))[0]
355+
356+
available_loras[name] = LoraOnDisk(name, filename)
357+
358+
359+
available_loras = {}
360+
loaded_loras = []
361+
362+
list_available_loras()

extensions-builtin/Lora/preload.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import os
2+
from modules import paths
3+
4+
5+
def preload(parser):
6+
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))

0 commit comments

Comments
 (0)