Skip to content

Commit 2b50233

Browse files
committed
fix bugs in lora support
1 parent 7e5cdaa commit 2b50233

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

extensions-builtin/Lora/networks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
398398
if weights_backup is not None:
399399
if isinstance(self, torch.nn.MultiheadAttention):
400400
restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
401-
restore_weights_backup(self.out_proj, 'weight', weights_backup[0])
401+
restore_weights_backup(self.out_proj, 'weight', weights_backup[1])
402402
else:
403403
restore_weights_backup(self, 'weight', weights_backup)
404404

@@ -437,7 +437,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
437437
bias_backup = getattr(self, "network_bias_backup", None)
438438
if bias_backup is None and wanted_names != ():
439439
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
440-
bias_backup = store_weights_backup(self.out_proj)
440+
bias_backup = store_weights_backup(self.out_proj.bias)
441441
elif getattr(self, 'bias', None) is not None:
442442
bias_backup = store_weights_backup(self.bias)
443443
else:

0 commit comments

Comments
 (0)