Skip to content

Commit 60c8177

Browse files
authored
Merge branch 'master' into master
2 parents 2ca56e3 + 95423e2 commit 60c8177

File tree

3 files changed

+54
-34
lines changed

3 files changed

+54
-34
lines changed

README.md

-13
Original file line numberDiff line numberDiff line change
@@ -133,16 +133,3 @@ the same effect. Use the --no-progressbar-hiding commandline option to revert th
133133
### Prompt validation
134134
Stable Diffusion has a limit for input text length. If your prompt is too long, you will get a
135135
warning in the text output field, showing which parts of your text were truncated and ignored by the model.
136-
137-
### Loopback
138-
A checkbox for img2img allowing to automatically feed output image as input for the next batch. Equivalent to
139-
saving output image, and replacing input image with it. Batch count setting controls how many iterations of
140-
this you get.
141-
142-
Usually, when doing this, you would choose one of many images for the next iteration yourself, so the usefulness
143-
of this feature may be questionable, but I've managed to get some very nice outputs with it that I wasn't abble
144-
to get otherwise.
145-
146-
Example: (cherrypicked result; original picture by anon)
147-
148-
![](images/loopback.jpg)

images/loopback.jpg

-465 KB
Binary file not shown.

webui.py

+54-21
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from omegaconf import OmegaConf
1818
from PIL import Image, ImageFont, ImageDraw
1919
from torch import autocast
20-
2120
from ldm.models.diffusion.ddim import DDIMSampler
2221
from ldm.models.diffusion.plms import PLMSSampler
2322
from ldm.util import instantiate_from_config
@@ -53,8 +52,6 @@
5352
parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long")
5453
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
5554
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
56-
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
57-
parser.add_argument("--grid-format", type=str, default='png', help="file format for saved grids; can be png or jpg")
5855
opt = parser.parse_args()
5956

6057
GFPGAN_dir = opt.gfpgan_dir
@@ -162,6 +159,37 @@ def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guid
162159

163160
return samples_ddim, None
164161

162+
class MemUsageMonitor(threading.Thread):
163+
stop_flag = False
164+
max_usage = 0
165+
total = 0
166+
167+
def __init__(self, name):
168+
threading.Thread.__init__(self)
169+
self.name = name
170+
171+
def run(self):
172+
print(f"[{self.name}] Recording max memory usage...\n")
173+
pynvml.nvmlInit()
174+
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
175+
self.total = pynvml.nvmlDeviceGetMemoryInfo(handle).total
176+
while not self.stop_flag:
177+
m = pynvml.nvmlDeviceGetMemoryInfo(handle)
178+
self.max_usage = max(self.max_usage, m.used)
179+
# print(self.max_usage)
180+
time.sleep(0.1)
181+
print(f"[{self.name}] Stopped recording.\n")
182+
pynvml.nvmlShutdown()
183+
184+
def read(self):
185+
return self.max_usage, self.total
186+
187+
def stop(self):
188+
self.stop_flag = True
189+
190+
def read_and_stop(self):
191+
self.stop_flag = True
192+
return self.max_usage, self.total
165193

166194
def create_random_tensors(shape, seeds):
167195
xs = []
@@ -209,10 +237,8 @@ def load_GFPGAN():
209237
model = (model if opt.no_half else model.half()).to(device)
210238

211239

212-
def image_grid(imgs, batch_size, round_down=False, force_n_rows=None):
213-
if force_n_rows is not None:
214-
rows = force_n_rows
215-
elif opt.n_rows > 0:
240+
def image_grid(imgs, batch_size, round_down=False):
241+
if opt.n_rows > 0:
216242
rows = opt.n_rows
217243
elif opt.n_rows == 0:
218244
rows = batch_size
@@ -351,14 +377,15 @@ def check_prompt_length(prompt, comments):
351377
comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
352378

353379

354-
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
380+
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN):
355381
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
356-
mem_mon = MemUsageMonitor('MemMon')
357-
mem_mon.start()
382+
assert prompt is not None
383+
torch_gc()
384+
# start time after garbage collection (or before?)
358385
start_time = time.time()
359386

360-
assert prompt is not None
361-
torch.cuda.empty_cache()
387+
mem_mon = MemUsageMonitor('MemMon')
388+
mem_mon.start()
362389

363390
if seed == -1:
364391
seed = random.randrange(4294967294)
@@ -447,7 +474,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
447474
output_images.append(image)
448475
base_count += 1
449476

450-
if (prompt_matrix or not opt.skip_grid) and not do_not_save_grid:
477+
if prompt_matrix or not opt.skip_grid:
451478
grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
452479

453480
if prompt_matrix:
@@ -460,13 +487,18 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
460487

461488
output_images.insert(0, grid)
462489

463-
grid.save(os.path.join(outpath, f'grid-{grid_count:04}.{opt.grid_format}'))
490+
491+
grid_file = f"grid-{grid_count:05}-{seed}_{prompts[i].replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]}.jpg"
492+
grid.save(os.path.join(outpath, grid_file), 'jpeg', quality=80, optimize=True)
464493
grid_count += 1
465494
toc = time.time()
466495

467496
mem_max_used, mem_total = mem_mon.read_and_stop()
468497
time_diff = time.time()-start_time
469498

499+
mem_max_used, mem_total = mem_mon.read_and_stop()
500+
time_diff = time.time()-start_time
501+
470502
info = f"""
471503
{prompt}
472504
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}{', Prompt Matrix Mode.' if prompt_matrix else ''}""".strip()
@@ -480,8 +512,10 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
480512
#mem_mon.stop()
481513
#del mem_mon
482514
torch_gc()
515+
483516
return output_images, seed, info, stats
484517

518+
485519
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
486520
outpath = opt.outdir or "outputs/txt2img-samples"
487521
err = False
@@ -532,7 +566,6 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
532566
crash(err, '!!Runtime error (txt2img)!!')
533567

534568

535-
536569
class Flagging(gr.FlaggingCallback):
537570

538571
def setup(self, components, flagging_dir: str):
@@ -583,7 +616,7 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
583616
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
584617
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
585618
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
586-
gr.Slider(minimum=1, maximum=opt.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
619+
gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
587620
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
588621
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
589622
gr.Number(label='Seed', value=-1),
@@ -602,13 +635,14 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
602635
)
603636

604637

605-
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
638+
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
606639
outpath = opt.outdir or "outputs/img2img-samples"
607640
err = False
608641

609642
sampler = KDiffusionSampler(model)
610643

611644
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
645+
t_enc = int(denoising_strength * ddim_steps)
612646

613647
def init():
614648
image = init_img.convert("RGB")
@@ -625,8 +659,6 @@ def init():
625659
return init_latent,
626660

627661
def sample(init_data, x, conditioning, unconditional_conditioning):
628-
t_enc = int(denoising_strength * ddim_steps)
629-
630662
x0, = init_data
631663

632664
sigmas = sampler.model_wrap.get_sigmas(ddim_steps)
@@ -638,6 +670,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
638670
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': cfg_scale}, disable=False)
639671
return samples_ddim
640672

673+
641674
try:
642675
if loopback:
643676
output_images, info = None, None
@@ -709,6 +742,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
709742
crash(err, '!!Runtime error (img2img)!!')
710743

711744

745+
712746
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
713747
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
714748

@@ -720,8 +754,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
720754
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
721755
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
722756
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
723-
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
724-
gr.Slider(minimum=1, maximum=opt.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
757+
gr.Slider(minimum=1, maximum=16, step=1, label='Batch count (how many batches of images to generate)', value=1),
725758
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
726759
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
727760
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),

0 commit comments

Comments
 (0)