|
| 1 | +import glob |
| 2 | +import os |
| 3 | +import re |
| 4 | +import torch |
| 5 | +from typing import Union |
| 6 | + |
| 7 | +from modules import shared, devices, sd_models, errors |
| 8 | + |
| 9 | +metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} |
| 10 | + |
| 11 | +re_digits = re.compile(r"\d+") |
| 12 | +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") |
| 13 | +re_compiled = {} |
| 14 | + |
| 15 | +suffix_conversion = { |
| 16 | + "attentions": {}, |
| 17 | + "resnets": { |
| 18 | + "conv1": "in_layers_2", |
| 19 | + "conv2": "out_layers_3", |
| 20 | + "time_emb_proj": "emb_layers_1", |
| 21 | + "conv_shortcut": "skip_connection", |
| 22 | + } |
| 23 | +} |
| 24 | + |
| 25 | + |
| 26 | +def convert_diffusers_name_to_compvis(key, is_sd2): |
| 27 | + def match(match_list, regex_text): |
| 28 | + regex = re_compiled.get(regex_text) |
| 29 | + if regex is None: |
| 30 | + regex = re.compile(regex_text) |
| 31 | + re_compiled[regex_text] = regex |
| 32 | + |
| 33 | + r = re.match(regex, key) |
| 34 | + if not r: |
| 35 | + return False |
| 36 | + |
| 37 | + match_list.clear() |
| 38 | + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) |
| 39 | + return True |
| 40 | + |
| 41 | + m = [] |
| 42 | + |
| 43 | + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): |
| 44 | + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) |
| 45 | + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" |
| 46 | + |
| 47 | + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): |
| 48 | + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) |
| 49 | + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" |
| 50 | + |
| 51 | + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): |
| 52 | + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) |
| 53 | + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" |
| 54 | + |
| 55 | + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): |
| 56 | + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" |
| 57 | + |
| 58 | + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): |
| 59 | + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" |
| 60 | + |
| 61 | + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): |
| 62 | + if is_sd2: |
| 63 | + if 'mlp_fc1' in m[1]: |
| 64 | + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" |
| 65 | + elif 'mlp_fc2' in m[1]: |
| 66 | + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" |
| 67 | + else: |
| 68 | + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" |
| 69 | + |
| 70 | + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" |
| 71 | + |
| 72 | + return key |
| 73 | + |
| 74 | + |
| 75 | +class LoraOnDisk: |
| 76 | + def __init__(self, name, filename): |
| 77 | + self.name = name |
| 78 | + self.filename = filename |
| 79 | + self.metadata = {} |
| 80 | + |
| 81 | + _, ext = os.path.splitext(filename) |
| 82 | + if ext.lower() == ".safetensors": |
| 83 | + try: |
| 84 | + self.metadata = sd_models.read_metadata_from_safetensors(filename) |
| 85 | + except Exception as e: |
| 86 | + errors.display(e, f"reading lora {filename}") |
| 87 | + |
| 88 | + if self.metadata: |
| 89 | + m = {} |
| 90 | + for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): |
| 91 | + m[k] = v |
| 92 | + |
| 93 | + self.metadata = m |
| 94 | + |
| 95 | + self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text |
| 96 | + |
| 97 | + |
| 98 | +class LoraModule: |
| 99 | + def __init__(self, name): |
| 100 | + self.name = name |
| 101 | + self.multiplier = 1.0 |
| 102 | + self.modules = {} |
| 103 | + self.mtime = None |
| 104 | + |
| 105 | + |
| 106 | +class LoraUpDownModule: |
| 107 | + def __init__(self): |
| 108 | + self.up = None |
| 109 | + self.down = None |
| 110 | + self.alpha = None |
| 111 | + |
| 112 | + |
| 113 | +def assign_lora_names_to_compvis_modules(sd_model): |
| 114 | + lora_layer_mapping = {} |
| 115 | + |
| 116 | + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): |
| 117 | + lora_name = name.replace(".", "_") |
| 118 | + lora_layer_mapping[lora_name] = module |
| 119 | + module.lora_layer_name = lora_name |
| 120 | + |
| 121 | + for name, module in shared.sd_model.model.named_modules(): |
| 122 | + lora_name = name.replace(".", "_") |
| 123 | + lora_layer_mapping[lora_name] = module |
| 124 | + module.lora_layer_name = lora_name |
| 125 | + |
| 126 | + sd_model.lora_layer_mapping = lora_layer_mapping |
| 127 | + |
| 128 | + |
| 129 | +def load_lora(name, filename): |
| 130 | + lora = LoraModule(name) |
| 131 | + lora.mtime = os.path.getmtime(filename) |
| 132 | + |
| 133 | + sd = sd_models.read_state_dict(filename) |
| 134 | + |
| 135 | + keys_failed_to_match = {} |
| 136 | + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping |
| 137 | + |
| 138 | + for key_diffusers, weight in sd.items(): |
| 139 | + key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) |
| 140 | + key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2) |
| 141 | + |
| 142 | + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) |
| 143 | + |
| 144 | + if sd_module is None: |
| 145 | + m = re_x_proj.match(key) |
| 146 | + if m: |
| 147 | + sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) |
| 148 | + |
| 149 | + if sd_module is None: |
| 150 | + keys_failed_to_match[key_diffusers] = key |
| 151 | + continue |
| 152 | + |
| 153 | + lora_module = lora.modules.get(key, None) |
| 154 | + if lora_module is None: |
| 155 | + lora_module = LoraUpDownModule() |
| 156 | + lora.modules[key] = lora_module |
| 157 | + |
| 158 | + if lora_key == "alpha": |
| 159 | + lora_module.alpha = weight.item() |
| 160 | + continue |
| 161 | + |
| 162 | + if type(sd_module) == torch.nn.Linear: |
| 163 | + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) |
| 164 | + elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: |
| 165 | + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) |
| 166 | + elif type(sd_module) == torch.nn.MultiheadAttention: |
| 167 | + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) |
| 168 | + elif type(sd_module) == torch.nn.Conv2d: |
| 169 | + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) |
| 170 | + else: |
| 171 | + print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') |
| 172 | + continue |
| 173 | + assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' |
| 174 | + |
| 175 | + with torch.no_grad(): |
| 176 | + module.weight.copy_(weight) |
| 177 | + |
| 178 | + module.to(device=devices.cpu, dtype=devices.dtype) |
| 179 | + |
| 180 | + if lora_key == "lora_up.weight": |
| 181 | + lora_module.up = module |
| 182 | + elif lora_key == "lora_down.weight": |
| 183 | + lora_module.down = module |
| 184 | + else: |
| 185 | + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' |
| 186 | + |
| 187 | + if len(keys_failed_to_match) > 0: |
| 188 | + print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") |
| 189 | + |
| 190 | + return lora |
| 191 | + |
| 192 | + |
| 193 | +def load_loras(names, multipliers=None): |
| 194 | + already_loaded = {} |
| 195 | + |
| 196 | + for lora in loaded_loras: |
| 197 | + if lora.name in names: |
| 198 | + already_loaded[lora.name] = lora |
| 199 | + |
| 200 | + loaded_loras.clear() |
| 201 | + |
| 202 | + loras_on_disk = [available_loras.get(name, None) for name in names] |
| 203 | + if any([x is None for x in loras_on_disk]): |
| 204 | + list_available_loras() |
| 205 | + |
| 206 | + loras_on_disk = [available_loras.get(name, None) for name in names] |
| 207 | + |
| 208 | + for i, name in enumerate(names): |
| 209 | + lora = already_loaded.get(name, None) |
| 210 | + |
| 211 | + lora_on_disk = loras_on_disk[i] |
| 212 | + if lora_on_disk is not None: |
| 213 | + if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: |
| 214 | + lora = load_lora(name, lora_on_disk.filename) |
| 215 | + |
| 216 | + if lora is None: |
| 217 | + print(f"Couldn't find Lora with name {name}") |
| 218 | + continue |
| 219 | + |
| 220 | + lora.multiplier = multipliers[i] if multipliers else 1.0 |
| 221 | + loaded_loras.append(lora) |
| 222 | + |
| 223 | + |
| 224 | +def lora_calc_updown(lora, module, target): |
| 225 | + with torch.no_grad(): |
| 226 | + up = module.up.weight.to(target.device, dtype=target.dtype) |
| 227 | + down = module.down.weight.to(target.device, dtype=target.dtype) |
| 228 | + |
| 229 | + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): |
| 230 | + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) |
| 231 | + else: |
| 232 | + updown = up @ down |
| 233 | + |
| 234 | + updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) |
| 235 | + |
| 236 | + return updown |
| 237 | + |
| 238 | + |
| 239 | +def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): |
| 240 | + """ |
| 241 | + Applies the currently selected set of Loras to the weights of torch layer self. |
| 242 | + If weights already have this particular set of loras applied, does nothing. |
| 243 | + If not, restores orginal weights from backup and alters weights according to loras. |
| 244 | + """ |
| 245 | + |
| 246 | + lora_layer_name = getattr(self, 'lora_layer_name', None) |
| 247 | + if lora_layer_name is None: |
| 248 | + return |
| 249 | + |
| 250 | + current_names = getattr(self, "lora_current_names", ()) |
| 251 | + wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) |
| 252 | + |
| 253 | + weights_backup = getattr(self, "lora_weights_backup", None) |
| 254 | + if weights_backup is None: |
| 255 | + if isinstance(self, torch.nn.MultiheadAttention): |
| 256 | + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) |
| 257 | + else: |
| 258 | + weights_backup = self.weight.to(devices.cpu, copy=True) |
| 259 | + |
| 260 | + self.lora_weights_backup = weights_backup |
| 261 | + |
| 262 | + if current_names != wanted_names: |
| 263 | + if weights_backup is not None: |
| 264 | + if isinstance(self, torch.nn.MultiheadAttention): |
| 265 | + self.in_proj_weight.copy_(weights_backup[0]) |
| 266 | + self.out_proj.weight.copy_(weights_backup[1]) |
| 267 | + else: |
| 268 | + self.weight.copy_(weights_backup) |
| 269 | + |
| 270 | + for lora in loaded_loras: |
| 271 | + module = lora.modules.get(lora_layer_name, None) |
| 272 | + if module is not None and hasattr(self, 'weight'): |
| 273 | + self.weight += lora_calc_updown(lora, module, self.weight) |
| 274 | + continue |
| 275 | + |
| 276 | + module_q = lora.modules.get(lora_layer_name + "_q_proj", None) |
| 277 | + module_k = lora.modules.get(lora_layer_name + "_k_proj", None) |
| 278 | + module_v = lora.modules.get(lora_layer_name + "_v_proj", None) |
| 279 | + module_out = lora.modules.get(lora_layer_name + "_out_proj", None) |
| 280 | + |
| 281 | + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: |
| 282 | + updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight) |
| 283 | + updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight) |
| 284 | + updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight) |
| 285 | + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) |
| 286 | + |
| 287 | + self.in_proj_weight += updown_qkv |
| 288 | + self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight) |
| 289 | + continue |
| 290 | + |
| 291 | + if module is None: |
| 292 | + continue |
| 293 | + |
| 294 | + print(f'failed to calculate lora weights for layer {lora_layer_name}') |
| 295 | + |
| 296 | + setattr(self, "lora_current_names", wanted_names) |
| 297 | + |
| 298 | + |
| 299 | +def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): |
| 300 | + setattr(self, "lora_current_names", ()) |
| 301 | + setattr(self, "lora_weights_backup", None) |
| 302 | + |
| 303 | + |
| 304 | +def lora_Linear_forward(self, input): |
| 305 | + lora_apply_weights(self) |
| 306 | + |
| 307 | + return torch.nn.Linear_forward_before_lora(self, input) |
| 308 | + |
| 309 | + |
| 310 | +def lora_Linear_load_state_dict(self, *args, **kwargs): |
| 311 | + lora_reset_cached_weight(self) |
| 312 | + |
| 313 | + return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) |
| 314 | + |
| 315 | + |
| 316 | +def lora_Conv2d_forward(self, input): |
| 317 | + lora_apply_weights(self) |
| 318 | + |
| 319 | + return torch.nn.Conv2d_forward_before_lora(self, input) |
| 320 | + |
| 321 | + |
| 322 | +def lora_Conv2d_load_state_dict(self, *args, **kwargs): |
| 323 | + lora_reset_cached_weight(self) |
| 324 | + |
| 325 | + return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) |
| 326 | + |
| 327 | + |
| 328 | +def lora_MultiheadAttention_forward(self, *args, **kwargs): |
| 329 | + lora_apply_weights(self) |
| 330 | + |
| 331 | + return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs) |
| 332 | + |
| 333 | + |
| 334 | +def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): |
| 335 | + lora_reset_cached_weight(self) |
| 336 | + |
| 337 | + return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs) |
| 338 | + |
| 339 | + |
| 340 | +def list_available_loras(): |
| 341 | + available_loras.clear() |
| 342 | + |
| 343 | + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) |
| 344 | + |
| 345 | + candidates = \ |
| 346 | + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \ |
| 347 | + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ |
| 348 | + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) |
| 349 | + |
| 350 | + for filename in sorted(candidates, key=str.lower): |
| 351 | + if os.path.isdir(filename): |
| 352 | + continue |
| 353 | + |
| 354 | + name = os.path.splitext(os.path.basename(filename))[0] |
| 355 | + |
| 356 | + available_loras[name] = LoraOnDisk(name, filename) |
| 357 | + |
| 358 | + |
| 359 | +available_loras = {} |
| 360 | +loaded_loras = [] |
| 361 | + |
| 362 | +list_available_loras() |
0 commit comments