@@ -249,7 +249,7 @@ def run_pnginfo(image):
249
249
return '' , geninfo , info
250
250
251
251
252
- def run_modelmerger (primary_model_name , secondary_model_name , teritary_model_name , interp_method , multiplier , save_as_half , custom_name ):
252
+ def run_modelmerger (primary_model_name , secondary_model_name , tertiary_model_name , interp_method , multiplier , save_as_half , custom_name ):
253
253
def weighted_sum (theta0 , theta1 , alpha ):
254
254
return ((1 - alpha ) * theta0 ) + (alpha * theta1 )
255
255
@@ -259,49 +259,69 @@ def get_difference(theta1, theta2):
259
259
def add_difference (theta0 , theta1_2_diff , alpha ):
260
260
return theta0 + (alpha * theta1_2_diff )
261
261
262
+ theta_funcs = {
263
+ "Weighted sum" : (None , weighted_sum ),
264
+ "Add difference" : (get_difference , add_difference ),
265
+ }
266
+
267
+ theta_func1 , theta_func2 = theta_funcs [interp_method ]
268
+
269
+ # Load info for A and B as they're always required.
262
270
primary_model_info = sd_models .checkpoints_list [primary_model_name ]
263
271
secondary_model_info = sd_models .checkpoints_list [secondary_model_name ]
264
- teritary_model_info = sd_models . checkpoints_list . get ( teritary_model_name , None )
272
+ b_loaded = False
265
273
266
- print (f"Loading { primary_model_info .filename } ..." )
267
- primary_model = torch .load (primary_model_info .filename , map_location = 'cpu' )
268
- theta_0 = sd_models .get_state_dict_from_checkpoint (primary_model )
274
+ print (f"Interpolation method: { interp_method } " )
275
+ print (f"Merging (Step 1/2)..." )
269
276
270
- print (f"Loading { secondary_model_info .filename } ..." )
271
- secondary_model = torch .load (secondary_model_info .filename , map_location = 'cpu' )
272
- theta_1 = sd_models .get_state_dict_from_checkpoint (secondary_model )
277
+ if interp_method == "Add difference" :
273
278
274
- if teritary_model_info is not None :
275
- print (f"Loading { teritary_model_info .filename } ..." )
276
- teritary_model = torch .load (teritary_model_info .filename , map_location = 'cpu' )
277
- theta_2 = sd_models .get_state_dict_from_checkpoint (teritary_model )
278
- else :
279
- teritary_model = None
280
- theta_2 = None
279
+ if tertiary_model_name != "" :
281
280
282
- theta_funcs = {
283
- "Weighted sum" : ( None , weighted_sum ),
284
- "Add difference" : ( get_difference , add_difference ),
285
- }
286
- theta_func1 , theta_func2 = theta_funcs [ interp_method ]
281
+ # Load models B and C.
282
+ print ( f"Loading secondary model (B): { secondary_model_info . filename } ..." )
283
+ secondary_model = torch . load ( secondary_model_info . filename , map_location = 'cpu' )
284
+ theta_1 = sd_models . get_state_dict_from_checkpoint ( secondary_model )
285
+ b_loaded = True
287
286
288
- print (f"Merging..." )
287
+ tertiary_model_info = sd_models .checkpoints_list .get (tertiary_model_name , None )
288
+ if tertiary_model_info is not None :
289
+ print (f"Loading tertiary model (C): { tertiary_model_info .filename } ..." )
290
+ tertiary_model = torch .load (tertiary_model_info .filename , map_location = 'cpu' )
291
+ theta_2 = sd_models .get_state_dict_from_checkpoint (tertiary_model )
292
+ else :
293
+ tertiary_model = None
294
+ theta_2 = None
295
+
296
+ if theta_func1 :
297
+ for key in tqdm .tqdm (theta_1 .keys ()):
298
+ if 'model' in key :
299
+ if key in theta_2 :
300
+ t2 = theta_2 .get (key , torch .zeros_like (theta_1 [key ]))
301
+ theta_1 [key ] = theta_func1 (theta_1 [key ], t2 )
302
+ else :
303
+ theta_1 [key ] = torch .zeros_like (theta_1 [key ])
304
+ del theta_2 , tertiary_model
305
+ else :
306
+ print (f"No model selected for C." )
307
+ return ["Select a tertiary model (C) or consider using 'Weighted sum'" ] + [gr .Dropdown .update (choices = sd_models .checkpoint_tiles ()) for _ in range (4 )]
289
308
290
- if theta_func1 :
291
- for key in tqdm .tqdm (theta_1 .keys ()):
292
- if 'model' in key :
293
- if key in theta_2 :
294
- t2 = theta_2 .get (key , torch .zeros_like (theta_1 [key ]))
295
- theta_1 [key ] = theta_func1 (theta_1 [key ], t2 )
296
- else :
297
- theta_1 [key ] = torch .zeros_like (theta_1 [key ])
298
- del theta_2 , teritary_model
309
+ # Load model A.
310
+ print (f"Loading primary model (A): { primary_model_info .filename } ..." )
311
+ primary_model = torch .load (primary_model_info .filename , map_location = 'cpu' )
312
+ theta_0 = sd_models .get_state_dict_from_checkpoint (primary_model )
313
+
314
+ # Load model B if we haven't loaded it yet to operate with C.
315
+ if b_loaded == False :
316
+ print (f"Loading secondary model (B): { secondary_model_info .filename } ..." )
317
+ secondary_model = torch .load (secondary_model_info .filename , map_location = 'cpu' )
318
+ theta_1 = sd_models .get_state_dict_from_checkpoint (secondary_model )
319
+
320
+ print (f"Merging (Step 2/2)..." )
299
321
300
322
for key in tqdm .tqdm (theta_0 .keys ()):
301
323
if 'model' in key and key in theta_1 :
302
-
303
324
theta_0 [key ] = theta_func2 (theta_0 [key ], theta_1 [key ], multiplier )
304
-
305
325
if save_as_half :
306
326
theta_0 [key ] = theta_0 [key ].half ()
307
327
@@ -324,4 +344,4 @@ def add_difference(theta0, theta1_2_diff, alpha):
324
344
sd_models .list_models ()
325
345
326
346
print (f"Checkpoint saved." )
327
- return ["Checkpoint saved to " + output_modelname ] + [gr .Dropdown .update (choices = sd_models .checkpoint_tiles ()) for _ in range (4 )]
347
+ return ["Checkpoint saved to " + output_modelname ] + [gr .Dropdown .update (choices = sd_models .checkpoint_tiles ()) for _ in range (4 )]
0 commit comments