20
20
21
21
from modules import shared , devices , sd_models , errors , scripts , sd_hijack
22
22
import modules .textual_inversion .textual_inversion as textual_inversion
23
+ import modules .models .sd3 .mmdit
23
24
24
25
from lora_logger import logger
25
26
@@ -166,12 +167,26 @@ def load_network(name, network_on_disk):
166
167
167
168
keys_failed_to_match = {}
168
169
is_sd2 = 'model_transformer_resblocks' in shared .sd_model .network_layer_mapping
170
+ if hasattr (shared .sd_model , 'diffusers_weight_map' ):
171
+ diffusers_weight_map = shared .sd_model .diffusers_weight_map
172
+ elif hasattr (shared .sd_model , 'diffusers_weight_mapping' ):
173
+ diffusers_weight_map = {}
174
+ for k , v in shared .sd_model .diffusers_weight_mapping ():
175
+ diffusers_weight_map [k ] = v
176
+ shared .sd_model .diffusers_weight_map = diffusers_weight_map
177
+ else :
178
+ diffusers_weight_map = None
169
179
170
180
matched_networks = {}
171
181
bundle_embeddings = {}
172
182
173
183
for key_network , weight in sd .items ():
174
- key_network_without_network_parts , _ , network_part = key_network .partition ("." )
184
+
185
+ if diffusers_weight_map :
186
+ key_network_without_network_parts , network_name , network_weight = key_network .rsplit ("." , 2 )
187
+ network_part = network_name + '.' + network_weight
188
+ else :
189
+ key_network_without_network_parts , _ , network_part = key_network .partition ("." )
175
190
176
191
if key_network_without_network_parts == "bundle_emb" :
177
192
emb_name , vec_name = network_part .split ("." , 1 )
@@ -183,7 +198,11 @@ def load_network(name, network_on_disk):
183
198
emb_dict [vec_name ] = weight
184
199
bundle_embeddings [emb_name ] = emb_dict
185
200
186
- key = convert_diffusers_name_to_compvis (key_network_without_network_parts , is_sd2 )
201
+ if diffusers_weight_map :
202
+ key = diffusers_weight_map .get (key_network_without_network_parts , key_network_without_network_parts )
203
+ else :
204
+ key = convert_diffusers_name_to_compvis (key_network_without_network_parts , is_sd2 )
205
+
187
206
sd_module = shared .sd_model .network_layer_mapping .get (key , None )
188
207
189
208
if sd_module is None :
@@ -347,6 +366,28 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
347
366
purge_networks_from_memory ()
348
367
349
368
369
+ def allowed_layer_without_weight (layer ):
370
+ if isinstance (layer , torch .nn .LayerNorm ) and not layer .elementwise_affine :
371
+ return True
372
+
373
+ return False
374
+
375
+
376
+ def store_weights_backup (weight ):
377
+ if weight is None :
378
+ return None
379
+
380
+ return weight .to (devices .cpu , copy = True )
381
+
382
+
383
+ def restore_weights_backup (obj , field , weight ):
384
+ if weight is None :
385
+ setattr (obj , field , None )
386
+ return
387
+
388
+ getattr (obj , field ).copy_ (weight )
389
+
390
+
350
391
def network_restore_weights_from_backup (self : Union [torch .nn .Conv2d , torch .nn .Linear , torch .nn .GroupNorm , torch .nn .LayerNorm , torch .nn .MultiheadAttention ]):
351
392
weights_backup = getattr (self , "network_weights_backup" , None )
352
393
bias_backup = getattr (self , "network_bias_backup" , None )
@@ -356,21 +397,15 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
356
397
357
398
if weights_backup is not None :
358
399
if isinstance (self , torch .nn .MultiheadAttention ):
359
- self . in_proj_weight . copy_ ( weights_backup [0 ])
360
- self .out_proj . weight . copy_ ( weights_backup [1 ])
400
+ restore_weights_backup ( self , ' in_proj_weight' , weights_backup [0 ])
401
+ restore_weights_backup ( self .out_proj , ' weight' , weights_backup [1 ])
361
402
else :
362
- self . weight . copy_ ( weights_backup )
403
+ restore_weights_backup ( self , ' weight' , weights_backup )
363
404
364
- if bias_backup is not None :
365
- if isinstance (self , torch .nn .MultiheadAttention ):
366
- self .out_proj .bias .copy_ (bias_backup )
367
- else :
368
- self .bias .copy_ (bias_backup )
405
+ if isinstance (self , torch .nn .MultiheadAttention ):
406
+ restore_weights_backup (self .out_proj , 'bias' , bias_backup )
369
407
else :
370
- if isinstance (self , torch .nn .MultiheadAttention ):
371
- self .out_proj .bias = None
372
- else :
373
- self .bias = None
408
+ restore_weights_backup (self , 'bias' , bias_backup )
374
409
375
410
376
411
def network_apply_weights (self : Union [torch .nn .Conv2d , torch .nn .Linear , torch .nn .GroupNorm , torch .nn .LayerNorm , torch .nn .MultiheadAttention ]):
@@ -389,37 +424,38 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
389
424
390
425
weights_backup = getattr (self , "network_weights_backup" , None )
391
426
if weights_backup is None and wanted_names != ():
392
- if current_names != ():
393
- raise RuntimeError (" no backup weights found and current weights are not unchanged" )
427
+ if current_names != () and not allowed_layer_without_weight ( self ) :
428
+ raise RuntimeError (f" { network_layer_name } - no backup weights found and current weights are not unchanged" )
394
429
395
430
if isinstance (self , torch .nn .MultiheadAttention ):
396
- weights_backup = (self .in_proj_weight . to ( devices . cpu , copy = True ), self .out_proj .weight . to ( devices . cpu , copy = True ))
431
+ weights_backup = (store_weights_backup ( self .in_proj_weight ), store_weights_backup ( self .out_proj .weight ))
397
432
else :
398
- weights_backup = self .weight . to ( devices . cpu , copy = True )
433
+ weights_backup = store_weights_backup ( self .weight )
399
434
400
435
self .network_weights_backup = weights_backup
401
436
402
437
bias_backup = getattr (self , "network_bias_backup" , None )
403
438
if bias_backup is None and wanted_names != ():
404
439
if isinstance (self , torch .nn .MultiheadAttention ) and self .out_proj .bias is not None :
405
- bias_backup = self .out_proj .bias . to ( devices . cpu , copy = True )
440
+ bias_backup = store_weights_backup ( self .out_proj .bias )
406
441
elif getattr (self , 'bias' , None ) is not None :
407
- bias_backup = self .bias . to ( devices . cpu , copy = True )
442
+ bias_backup = store_weights_backup ( self .bias )
408
443
else :
409
444
bias_backup = None
410
445
411
446
# Unlike weight which always has value, some modules don't have bias.
412
447
# Only report if bias is not None and current bias are not unchanged.
413
448
if bias_backup is not None and current_names != ():
414
449
raise RuntimeError ("no backup bias found and current bias are not unchanged" )
450
+
415
451
self .network_bias_backup = bias_backup
416
452
417
453
if current_names != wanted_names :
418
454
network_restore_weights_from_backup (self )
419
455
420
456
for net in loaded_networks :
421
457
module = net .modules .get (network_layer_name , None )
422
- if module is not None and hasattr (self , 'weight' ):
458
+ if module is not None and hasattr (self , 'weight' ) and not isinstance ( module , modules . models . sd3 . mmdit . QkvLinear ) :
423
459
try :
424
460
with torch .no_grad ():
425
461
if getattr (self , 'fp16_weight' , None ) is None :
@@ -479,6 +515,24 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
479
515
480
516
continue
481
517
518
+ if isinstance (self , modules .models .sd3 .mmdit .QkvLinear ) and module_q and module_k and module_v :
519
+ try :
520
+ with torch .no_grad ():
521
+ # Send "real" orig_weight into MHA's lora module
522
+ qw , kw , vw = self .weight .chunk (3 , 0 )
523
+ updown_q , _ = module_q .calc_updown (qw )
524
+ updown_k , _ = module_k .calc_updown (kw )
525
+ updown_v , _ = module_v .calc_updown (vw )
526
+ del qw , kw , vw
527
+ updown_qkv = torch .vstack ([updown_q , updown_k , updown_v ])
528
+ self .weight += updown_qkv
529
+
530
+ except RuntimeError as e :
531
+ logging .debug (f"Network { net .name } layer { network_layer_name } : { e } " )
532
+ extra_network_lora .errors [net .name ] = extra_network_lora .errors .get (net .name , 0 ) + 1
533
+
534
+ continue
535
+
482
536
if module is None :
483
537
continue
484
538
0 commit comments