Skip to content

Commit 2f70d94

Browse files
authored
Merge pull request AUTOMATIC1111#7 from DarioFT/master
Reduce memory usage when merging and UX improvements.
2 parents 056ae19 + 854c686 commit 2f70d94

File tree

2 files changed

+55
-35
lines changed

2 files changed

+55
-35
lines changed

modules/extras.py

+53-33
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def run_pnginfo(image):
249249
return '', geninfo, info
250250

251251

252-
def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
252+
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name):
253253
def weighted_sum(theta0, theta1, alpha):
254254
return ((1 - alpha) * theta0) + (alpha * theta1)
255255

@@ -259,49 +259,69 @@ def get_difference(theta1, theta2):
259259
def add_difference(theta0, theta1_2_diff, alpha):
260260
return theta0 + (alpha * theta1_2_diff)
261261

262+
theta_funcs = {
263+
"Weighted sum": (None, weighted_sum),
264+
"Add difference": (get_difference, add_difference),
265+
}
266+
267+
theta_func1, theta_func2 = theta_funcs[interp_method]
268+
269+
# Load info for A and B as they're always required.
262270
primary_model_info = sd_models.checkpoints_list[primary_model_name]
263271
secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
264-
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
272+
b_loaded = False
265273

266-
print(f"Loading {primary_model_info.filename}...")
267-
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
268-
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
274+
print(f"Interpolation method: {interp_method}")
275+
print(f"Merging (Step 1/2)...")
269276

270-
print(f"Loading {secondary_model_info.filename}...")
271-
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
272-
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
277+
if interp_method == "Add difference":
273278

274-
if teritary_model_info is not None:
275-
print(f"Loading {teritary_model_info.filename}...")
276-
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
277-
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
278-
else:
279-
teritary_model = None
280-
theta_2 = None
279+
if tertiary_model_name != "":
281280

282-
theta_funcs = {
283-
"Weighted sum": (None, weighted_sum),
284-
"Add difference": (get_difference, add_difference),
285-
}
286-
theta_func1, theta_func2 = theta_funcs[interp_method]
281+
# Load models B and C.
282+
print(f"Loading secondary model (B): {secondary_model_info.filename}...")
283+
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
284+
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
285+
b_loaded = True
287286

288-
print(f"Merging...")
287+
tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None)
288+
if tertiary_model_info is not None:
289+
print(f"Loading tertiary model (C): {tertiary_model_info.filename}...")
290+
tertiary_model = torch.load(tertiary_model_info.filename, map_location='cpu')
291+
theta_2 = sd_models.get_state_dict_from_checkpoint(tertiary_model)
292+
else:
293+
tertiary_model = None
294+
theta_2 = None
295+
296+
if theta_func1:
297+
for key in tqdm.tqdm(theta_1.keys()):
298+
if 'model' in key:
299+
if key in theta_2:
300+
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
301+
theta_1[key] = theta_func1(theta_1[key], t2)
302+
else:
303+
theta_1[key] = torch.zeros_like(theta_1[key])
304+
del theta_2, tertiary_model
305+
else:
306+
print(f"No model selected for C.")
307+
return ["Select a tertiary model (C) or consider using 'Weighted sum'"] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
289308

290-
if theta_func1:
291-
for key in tqdm.tqdm(theta_1.keys()):
292-
if 'model' in key:
293-
if key in theta_2:
294-
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
295-
theta_1[key] = theta_func1(theta_1[key], t2)
296-
else:
297-
theta_1[key] = torch.zeros_like(theta_1[key])
298-
del theta_2, teritary_model
309+
# Load model A.
310+
print(f"Loading primary model (A): {primary_model_info.filename}...")
311+
primary_model = torch.load(primary_model_info.filename, map_location='cpu')
312+
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
313+
314+
# Load model B if we haven't loaded it yet to operate with C.
315+
if b_loaded == False:
316+
print(f"Loading secondary model (B): {secondary_model_info.filename}...")
317+
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
318+
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
319+
320+
print(f"Merging (Step 2/2)...")
299321

300322
for key in tqdm.tqdm(theta_0.keys()):
301323
if 'model' in key and key in theta_1:
302-
303324
theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
304-
305325
if save_as_half:
306326
theta_0[key] = theta_0[key].half()
307327

@@ -324,4 +344,4 @@ def add_difference(theta0, theta1_2_diff, alpha):
324344
sd_models.list_models()
325345

326346
print(f"Checkpoint saved.")
327-
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
347+
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]

modules/ui.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1197,8 +1197,8 @@ def create_ui(wrap_gradio_gpu_call):
11971197
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
11981198

11991199
with gr.Row():
1200-
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
1201-
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
1200+
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), value=random.choice(modules.sd_models.checkpoint_tiles()), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
1201+
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), value=random.choice(modules.sd_models.checkpoint_tiles()), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
12021202
tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
12031203
custom_name = gr.Textbox(label="Custom Name (Optional)")
12041204
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3)

0 commit comments

Comments
 (0)