Skip to content

Commit 0ed0dd7

Browse files
authored
Merge pull request #1 from ShinkoNet/master
Memory Patch
2 parents 95423e2 + 60c8177 commit 0ed0dd7

File tree

2 files changed

+190
-67
lines changed

2 files changed

+190
-67
lines changed

relauncher.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import os, time
2+
3+
n = 0
4+
while True:
5+
print('Relauncher: Launching...')
6+
if n > 0:
7+
print(f'\tRelaunch count: {n}')
8+
os.system("python scripts/webui.py")
9+
print('Relauncher: Process ending. Relaunching in 0.5s...')
10+
n += 1
11+
time.sleep(0.5)

webui.py

+179-67
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
import argparse, os, sys, glob
2-
import torch
3-
import torch.nn as nn
4-
import numpy as np
52
import gradio as gr
6-
from omegaconf import OmegaConf
7-
from PIL import Image, ImageFont, ImageDraw
8-
from itertools import islice
9-
from einops import rearrange, repeat
10-
from torch import autocast
11-
from contextlib import contextmanager, nullcontext
12-
import mimetypes
13-
import random
3+
import k_diffusion as K
144
import math
5+
import mimetypes
6+
import numpy as np
157
import pynvml
8+
import random
169
import threading
1710
import time
11+
import torch
12+
import torch.nn as nn
1813

19-
import k_diffusion as K
20-
from ldm.util import instantiate_from_config
14+
from contextlib import contextmanager, nullcontext
15+
from einops import rearrange, repeat
16+
from itertools import islice
17+
from omegaconf import OmegaConf
18+
from PIL import Image, ImageFont, ImageDraw
19+
from torch import autocast
2120
from ldm.models.diffusion.ddim import DDIMSampler
2221
from ldm.models.diffusion.plms import PLMSSampler
22+
from ldm.util import instantiate_from_config
2323

2424
try:
2525
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -87,6 +87,50 @@ def load_model_from_config(config, ckpt, verbose=False):
8787
model.eval()
8888
return model
8989

90+
def crash(e, s):
91+
global model
92+
global device
93+
94+
print(s, '\n', e)
95+
96+
del model
97+
del device
98+
99+
print('exiting...calling os._exit(0)')
100+
t = threading.Timer(0.25, os._exit, args=[0])
101+
t.start()
102+
103+
class MemUsageMonitor(threading.Thread):
104+
stop_flag = False
105+
max_usage = 0
106+
total = 0
107+
108+
def __init__(self, name):
109+
threading.Thread.__init__(self)
110+
self.name = name
111+
112+
def run(self):
113+
print(f"[{self.name}] Recording max memory usage...\n")
114+
pynvml.nvmlInit()
115+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
116+
self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
117+
while not self.stop_flag:
118+
m = pynvml.nvmlDeviceGetMemoryInfo(handle)
119+
self.max_usage = max(self.max_usage, m.used)
120+
# print(self.max_usage)
121+
time.sleep(0.1)
122+
print(f"[{self.name}] Stopped recording.\n")
123+
pynvml.nvmlShutdown()
124+
125+
def read(self):
126+
return self.max_usage, self.total
127+
128+
def stop(self):
129+
self.stop_flag = True
130+
131+
def read_and_stop(self):
132+
self.stop_flag = True
133+
return self.max_usage, self.total
90134

91135
class CFGDenoiser(nn.Module):
92136
def __init__(self, model):
@@ -389,8 +433,10 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
389433

390434
precision_scope = autocast if opt.precision == "autocast" else nullcontext
391435
output_images = []
436+
stats = []
392437
with torch.no_grad(), precision_scope("cuda"), model.ema_scope():
393438
init_data = func_init()
439+
tic = time.time()
394440

395441
for n in range(n_iter):
396442
prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
@@ -432,7 +478,6 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
432478
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
433479

434480
if prompt_matrix:
435-
436481
try:
437482
grid = draw_prompt_matrix(grid, width, height, prompt_matrix_parts)
438483
except Exception:
@@ -442,31 +487,38 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
442487

443488
output_images.insert(0, grid)
444489

490+
445491
grid_file = f"grid-{grid_count:05}-{seed}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.jpg"
446492
grid.save(os.path.join(outpath, grid_file), 'jpeg', quality=80, optimize=True)
447493
grid_count += 1
494+
toc = time.time()
448495

449496
mem_max_used, mem_total = mem_mon.read_and_stop()
450497
time_diff = time.time()-start_time
451498

452-
notes = f'''
453-
Took { round(time_diff, 2) }s total ({ round(time_diff/(batch_size*n_iter),2) }s per image)<br>
454-
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%<br>
455-
'''
499+
mem_max_used, mem_total = mem_mon.read_and_stop()
500+
time_diff = time.time()-start_time
456501

457502
info = f"""
458503
{prompt}
459-
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
460-
""".strip()
461-
504+
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
505+
stats = f'''
506+
Took { round(time_diff, 2) }s total ({ round(time_diff/(len(all_prompts)),2) }s per image)
507+
Peak memory usage: { -(mem_max_used // -1_048_576) } MiB / { -(mem_total // -1_048_576) } MiB / { round(mem_max_used/mem_total*100, 3) }%'''
508+
462509
for comment in comments:
463510
info += "\n\n" + comment
511+
512+
#mem_mon.stop()
513+
#del mem_mon
464514
torch_gc()
465-
return output_images, seed, info, notes
515+
516+
return output_images, seed, info, stats
466517

467518

468519
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
469520
outpath = opt.outdir or "outputs/txt2img-samples"
521+
err = False
470522

471523
if sampler_name == 'PLMS':
472524
sampler = PLMSSampler(model)
@@ -483,27 +535,35 @@ def init():
483535
def sample(init_data, x, conditioning, unconditional_conditioning):
484536
samples_ddim, _ = sampler.sample(S=ddim_steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=unconditional_conditioning, eta=ddim_eta, x_T=x)
485537
return samples_ddim
486-
487-
output_images, seed, info, notes = process_images(
488-
outpath=outpath,
489-
func_init=init,
490-
func_sample=sample,
491-
prompt=prompt,
492-
seed=seed,
493-
sampler_name=sampler_name,
494-
batch_size=batch_size,
495-
n_iter=n_iter,
496-
steps=ddim_steps,
497-
cfg_scale=cfg_scale,
498-
width=width,
499-
height=height,
500-
prompt_matrix=prompt_matrix,
501-
use_GFPGAN=use_GFPGAN
502-
)
503-
504-
del sampler
505-
506-
return output_images, seed, info, notes
538+
try:
539+
output_images, seed, info, stats = process_images(
540+
outpath=outpath,
541+
func_init=init,
542+
func_sample=sample,
543+
prompt=prompt,
544+
seed=seed,
545+
sampler_name=sampler_name,
546+
batch_size=batch_size,
547+
n_iter=n_iter,
548+
steps=ddim_steps,
549+
cfg_scale=cfg_scale,
550+
width=width,
551+
height=height,
552+
prompt_matrix=prompt_matrix,
553+
use_GFPGAN=use_GFPGAN
554+
)
555+
556+
del sampler
557+
558+
return output_images, seed, info, stats
559+
except RuntimeError as e:
560+
err = e
561+
err_msg = f'CRASHED:<br><textarea rows="5" style="background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
562+
stats = err_msg
563+
return [], 1
564+
finally:
565+
if err:
566+
crash(err, '!!Runtime error (txt2img)!!')
507567

508568

509569
class Flagging(gr.FlaggingCallback):
@@ -567,16 +627,17 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
567627
gr.Gallery(label="Images"),
568628
gr.Number(label='Seed'),
569629
gr.Textbox(label="Copy-paste generation parameters"),
570-
gr.HTML(label='Notes'),
630+
gr.HTML(label='Stats'),
571631
],
572-
title="Stable Diffusion Text-to-Image K",
573-
description="Generate images from text with Stable Diffusion (using K-LMS)",
632+
title="Stable Diffusion Text-to-Image Unified",
633+
description="Generate images from text with Stable Diffusion",
574634
flagging_callback=Flagging()
575635
)
576636

577637

578638
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
579639
outpath = opt.outdir or "outputs/img2img-samples"
640+
err = False
580641

581642
sampler = KDiffusionSampler(model)
582643

@@ -609,26 +670,77 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
609670
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
610671
return samples_ddim
611672

612-
output_images, seed, info, notes = process_images(
613-
outpath=outpath,
614-
func_init=init,
615-
func_sample=sample,
616-
prompt=prompt,
617-
seed=seed,
618-
sampler_name='k-diffusion',
619-
batch_size=batch_size,
620-
n_iter=n_iter,
621-
steps=ddim_steps,
622-
cfg_scale=cfg_scale,
623-
width=width,
624-
height=height,
625-
prompt_matrix=prompt_matrix,
626-
use_GFPGAN=use_GFPGAN
627-
)
628673

629-
del sampler
674+
try:
675+
if loopback:
676+
output_images, info = None, None
677+
history = []
678+
initial_seed = None
679+
680+
for i in range(n_iter):
681+
output_images, seed, info, stats = process_images(
682+
outpath=outpath,
683+
func_init=init,
684+
func_sample=sample,
685+
prompt=prompt,
686+
seed=seed,
687+
sampler_name='k-diffusion',
688+
batch_size=1,
689+
n_iter=1,
690+
steps=ddim_steps,
691+
cfg_scale=cfg_scale,
692+
width=width,
693+
height=height,
694+
prompt_matrix=prompt_matrix,
695+
use_GFPGAN=use_GFPGAN,
696+
do_not_save_grid=True
697+
)
698+
699+
if initial_seed is None:
700+
initial_seed = seed
701+
702+
init_img = output_images[0]
703+
seed = seed + 1
704+
denoising_strength = max(denoising_strength * 0.95, 0.1)
705+
history.append(init_img)
706+
707+
grid_count = len(os.listdir(outpath)) - 1
708+
grid = image_grid(history, batch_size, force_n_rows=1)
709+
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.{opt.grid_format}'))
710+
711+
output_images = history
712+
seed = initial_seed
713+
714+
else:
715+
output_images, seed, info, stats = process_images(
716+
outpath=outpath,
717+
func_init=init,
718+
func_sample=sample,
719+
prompt=prompt,
720+
seed=seed,
721+
sampler_name='k-diffusion',
722+
batch_size=batch_size,
723+
n_iter=n_iter,
724+
steps=ddim_steps,
725+
cfg_scale=cfg_scale,
726+
width=width,
727+
height=height,
728+
prompt_matrix=prompt_matrix,
729+
use_GFPGAN=use_GFPGAN
730+
)
731+
732+
del sampler
733+
734+
return output_images, seed, info, stats
735+
except RuntimeError as e:
736+
err = e
737+
err_msg = f'CRASHED:<br><textarea rows="5" style="background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{str(e)}</textarea><br><br>Please wait while the program restarts.'
738+
stats = err_msg
739+
return [], 1
740+
finally:
741+
if err:
742+
crash(err, '!!Runtime error (img2img)!!')
630743

631-
return output_images, seed, info, notes
632744

633745

634746
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
@@ -655,9 +767,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
655767
gr.Gallery(),
656768
gr.Number(label='Seed'),
657769
gr.Textbox(label="Copy-paste generation parameters"),
658-
gr.HTML(label='Notes'),
770+
gr.HTML(label='Stats'),
659771
],
660-
title="Stable Diffusion Image-to-Image",
772+
title="Stable Diffusion Image-to-Image Unified",
661773
description="Generate images from images with Stable Diffusion",
662774
allow_flagging="never",
663775
)
@@ -700,4 +812,4 @@ def run_GFPGAN(image, strength):
700812
css=("" if opt.no_progressbar_hiding else css_hide_progressbar)
701813
)
702814

703-
demo.launch()
815+
demo.launch()

0 commit comments

Comments
 (0)