@@ -29,7 +29,6 @@ def __init__(self, name, filename):
29
29
30
30
def read_metadata ():
31
31
metadata = sd_models .read_metadata_from_safetensors (filename )
32
- metadata .pop ('ssmd_cover_images' , None ) # those are cover images, and they are too big to display in UI as text
33
32
34
33
return metadata
35
34
@@ -117,6 +116,12 @@ def __init__(self, net: Network, weights: NetworkWeights):
117
116
118
117
if hasattr (self .sd_module , 'weight' ):
119
118
self .shape = self .sd_module .weight .shape
119
+ elif isinstance (self .sd_module , nn .MultiheadAttention ):
120
+ # For now, only self-attn use Pytorch's MHA
121
+ # So assume all qkvo proj have same shape
122
+ self .shape = self .sd_module .out_proj .weight .shape
123
+ else :
124
+ self .shape = None
120
125
121
126
self .ops = None
122
127
self .extra_kwargs = {}
@@ -146,6 +151,9 @@ def __init__(self, net: Network, weights: NetworkWeights):
146
151
self .alpha = weights .w ["alpha" ].item () if "alpha" in weights .w else None
147
152
self .scale = weights .w ["scale" ].item () if "scale" in weights .w else None
148
153
154
+ self .dora_scale = weights .w .get ("dora_scale" , None )
155
+ self .dora_norm_dims = len (self .shape ) - 1
156
+
149
157
def multiplier (self ):
150
158
if 'transformer' in self .sd_key [:20 ]:
151
159
return self .network .te_multiplier
@@ -160,6 +168,27 @@ def calc_scale(self):
160
168
161
169
return 1.0
162
170
171
+ def apply_weight_decompose (self , updown , orig_weight ):
172
+ # Match the device/dtype
173
+ orig_weight = orig_weight .to (updown .dtype )
174
+ dora_scale = self .dora_scale .to (device = orig_weight .device , dtype = updown .dtype )
175
+ updown = updown .to (orig_weight .device )
176
+
177
+ merged_scale1 = updown + orig_weight
178
+ merged_scale1_norm = (
179
+ merged_scale1 .transpose (0 , 1 )
180
+ .reshape (merged_scale1 .shape [1 ], - 1 )
181
+ .norm (dim = 1 , keepdim = True )
182
+ .reshape (merged_scale1 .shape [1 ], * [1 ] * self .dora_norm_dims )
183
+ .transpose (0 , 1 )
184
+ )
185
+
186
+ dora_merged = (
187
+ merged_scale1 * (dora_scale / merged_scale1_norm )
188
+ )
189
+ final_updown = dora_merged - orig_weight
190
+ return final_updown
191
+
163
192
def finalize_updown (self , updown , orig_weight , output_shape , ex_bias = None ):
164
193
if self .bias is not None :
165
194
updown = updown .reshape (self .bias .shape )
@@ -175,6 +204,9 @@ def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
175
204
if ex_bias is not None :
176
205
ex_bias = ex_bias * self .multiplier ()
177
206
207
+ if self .dora_scale is not None :
208
+ updown = self .apply_weight_decompose (updown , orig_weight )
209
+
178
210
return updown * self .calc_scale () * self .multiplier (), ex_bias
179
211
180
212
def calc_updown (self , target ):
0 commit comments