17
17
from omegaconf import OmegaConf
18
18
from PIL import Image , ImageFont , ImageDraw
19
19
from torch import autocast
20
-
21
20
from ldm .models .diffusion .ddim import DDIMSampler
22
21
from ldm .models .diffusion .plms import PLMSSampler
23
22
from ldm .util import instantiate_from_config
53
52
parser .add_argument ("--no-verify-input" , action = 'store_true' , help = "do not verify input to check if it's too long" )
54
53
parser .add_argument ("--no-half" , action = 'store_true' , help = "do not switch the model to 16-bit floats" )
55
54
parser .add_argument ("--no-progressbar-hiding" , action = 'store_true' , help = "do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)" )
56
- parser .add_argument ("--max-batch-count" , type = int , default = 16 , help = "maximum batch count value for the UI" )
57
- parser .add_argument ("--grid-format" , type = str , default = 'png' , help = "file format for saved grids; can be png or jpg" )
58
55
opt = parser .parse_args ()
59
56
60
57
GFPGAN_dir = opt .gfpgan_dir
@@ -162,6 +159,37 @@ def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guid
162
159
163
160
return samples_ddim , None
164
161
162
+ class MemUsageMonitor (threading .Thread ):
163
+ stop_flag = False
164
+ max_usage = 0
165
+ total = 0
166
+
167
+ def __init__ (self , name ):
168
+ threading .Thread .__init__ (self )
169
+ self .name = name
170
+
171
+ def run (self ):
172
+ print (f"[{ self .name } ] Recording max memory usage...\n " )
173
+ pynvml .nvmlInit ()
174
+ handle = pynvml .nvmlDeviceGetHandleByIndex (0 )
175
+ self .total = pynvml .nvmlDeviceGetMemoryInfo (handle ).total
176
+ while not self .stop_flag :
177
+ m = pynvml .nvmlDeviceGetMemoryInfo (handle )
178
+ self .max_usage = max (self .max_usage , m .used )
179
+ # print(self.max_usage)
180
+ time .sleep (0.1 )
181
+ print (f"[{ self .name } ] Stopped recording.\n " )
182
+ pynvml .nvmlShutdown ()
183
+
184
+ def read (self ):
185
+ return self .max_usage , self .total
186
+
187
+ def stop (self ):
188
+ self .stop_flag = True
189
+
190
+ def read_and_stop (self ):
191
+ self .stop_flag = True
192
+ return self .max_usage , self .total
165
193
166
194
def create_random_tensors (shape , seeds ):
167
195
xs = []
@@ -209,10 +237,8 @@ def load_GFPGAN():
209
237
model = (model if opt .no_half else model .half ()).to (device )
210
238
211
239
212
- def image_grid (imgs , batch_size , round_down = False , force_n_rows = None ):
213
- if force_n_rows is not None :
214
- rows = force_n_rows
215
- elif opt .n_rows > 0 :
240
+ def image_grid (imgs , batch_size , round_down = False ):
241
+ if opt .n_rows > 0 :
216
242
rows = opt .n_rows
217
243
elif opt .n_rows == 0 :
218
244
rows = batch_size
@@ -351,14 +377,15 @@ def check_prompt_length(prompt, comments):
351
377
comments .append (f"Warning: too many input tokens; some ({ len (overflowing_words )} ) have been truncated:\n { overflowing_text } \n " )
352
378
353
379
354
- def process_images (outpath , func_init , func_sample , prompt , seed , sampler_name , batch_size , n_iter , steps , cfg_scale , width , height , prompt_matrix , use_GFPGAN , do_not_save_grid = False ):
380
+ def process_images (outpath , func_init , func_sample , prompt , seed , sampler_name , batch_size , n_iter , steps , cfg_scale , width , height , prompt_matrix , use_GFPGAN ):
355
381
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
356
- mem_mon = MemUsageMonitor ('MemMon' )
357
- mem_mon .start ()
382
+ assert prompt is not None
383
+ torch_gc ()
384
+ # start time after garbage collection (or before?)
358
385
start_time = time .time ()
359
386
360
- assert prompt is not None
361
- torch . cuda . empty_cache ()
387
+ mem_mon = MemUsageMonitor ( 'MemMon' )
388
+ mem_mon . start ()
362
389
363
390
if seed == - 1 :
364
391
seed = random .randrange (4294967294 )
@@ -447,7 +474,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
447
474
output_images .append (image )
448
475
base_count += 1
449
476
450
- if ( prompt_matrix or not opt .skip_grid ) and not do_not_save_grid :
477
+ if prompt_matrix or not opt .skip_grid :
451
478
grid = image_grid (output_images , batch_size , round_down = prompt_matrix )
452
479
453
480
if prompt_matrix :
@@ -460,13 +487,18 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
460
487
461
488
output_images .insert (0 , grid )
462
489
463
- grid .save (os .path .join (outpath , f'grid-{ grid_count :04} .{ opt .grid_format } ' ))
490
+
491
+ grid_file = f"grid-{ grid_count :05} -{ seed } _{ prompts [i ].replace (' ' , '_' ).translate ({ord (x ): '' for x in invalid_filename_chars })[:128 ]} .jpg"
492
+ grid .save (os .path .join (outpath , grid_file ), 'jpeg' , quality = 80 , optimize = True )
464
493
grid_count += 1
465
494
toc = time .time ()
466
495
467
496
mem_max_used , mem_total = mem_mon .read_and_stop ()
468
497
time_diff = time .time ()- start_time
469
498
499
+ mem_max_used , mem_total = mem_mon .read_and_stop ()
500
+ time_diff = time .time ()- start_time
501
+
470
502
info = f"""
471
503
{ prompt }
472
504
Steps: { steps } , Sampler: { sampler_name } , CFG scale: { cfg_scale } , Seed: { seed } { ', GFPGAN' if use_GFPGAN and GFPGAN is not None else '' } { ', Prompt Matrix Mode.' if prompt_matrix else '' } """ .strip ()
@@ -480,8 +512,10 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
480
512
#mem_mon.stop()
481
513
#del mem_mon
482
514
torch_gc ()
515
+
483
516
return output_images , seed , info , stats
484
517
518
+
485
519
def txt2img (prompt : str , ddim_steps : int , sampler_name : str , use_GFPGAN : bool , prompt_matrix : bool , ddim_eta : float , n_iter : int , batch_size : int , cfg_scale : float , seed : int , height : int , width : int ):
486
520
outpath = opt .outdir or "outputs/txt2img-samples"
487
521
err = False
@@ -532,7 +566,6 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
532
566
crash (err , '!!Runtime error (txt2img)!!' )
533
567
534
568
535
-
536
569
class Flagging (gr .FlaggingCallback ):
537
570
538
571
def setup (self , components , flagging_dir : str ):
@@ -583,7 +616,7 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
583
616
gr .Checkbox (label = 'Fix faces using GFPGAN' , value = False , visible = GFPGAN is not None ),
584
617
gr .Checkbox (label = 'Create prompt matrix (separate multiple prompts using |, and get all combinations of them)' , value = False ),
585
618
gr .Slider (minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = "DDIM ETA" , value = 0.0 , visible = False ),
586
- gr .Slider (minimum = 1 , maximum = opt . max_batch_count , step = 1 , label = 'Batch count (how many batches of images to generate)' , value = 1 ),
619
+ gr .Slider (minimum = 1 , maximum = 16 , step = 1 , label = 'Batch count (how many batches of images to generate)' , value = 1 ),
587
620
gr .Slider (minimum = 1 , maximum = 8 , step = 1 , label = 'Batch size (how many images are in a batch; memory-hungry)' , value = 1 ),
588
621
gr .Slider (minimum = 1.0 , maximum = 15.0 , step = 0.5 , label = 'Classifier Free Guidance Scale (how strongly the image should follow the prompt)' , value = 7.0 ),
589
622
gr .Number (label = 'Seed' , value = - 1 ),
@@ -602,13 +635,14 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
602
635
)
603
636
604
637
605
- def img2img (prompt : str , init_img , ddim_steps : int , use_GFPGAN : bool , prompt_matrix , loopback : bool , n_iter : int , batch_size : int , cfg_scale : float , denoising_strength : float , seed : int , height : int , width : int , resize_mode : int ):
638
+ def img2img (prompt : str , init_img , ddim_steps : int , use_GFPGAN : bool , prompt_matrix , n_iter : int , batch_size : int , cfg_scale : float , denoising_strength : float , seed : int , height : int , width : int , resize_mode : int ):
606
639
outpath = opt .outdir or "outputs/img2img-samples"
607
640
err = False
608
641
609
642
sampler = KDiffusionSampler (model )
610
643
611
644
assert 0. <= denoising_strength <= 1. , 'can only work with strength in [0.0, 1.0]'
645
+ t_enc = int (denoising_strength * ddim_steps )
612
646
613
647
def init ():
614
648
image = init_img .convert ("RGB" )
@@ -625,8 +659,6 @@ def init():
625
659
return init_latent ,
626
660
627
661
def sample (init_data , x , conditioning , unconditional_conditioning ):
628
- t_enc = int (denoising_strength * ddim_steps )
629
-
630
662
x0 , = init_data
631
663
632
664
sigmas = sampler .model_wrap .get_sigmas (ddim_steps )
@@ -638,6 +670,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
638
670
samples_ddim = K .sampling .sample_lms (model_wrap_cfg , xi , sigma_sched , extra_args = {'cond' : conditioning , 'uncond' : unconditional_conditioning , 'cond_scale' : cfg_scale }, disable = False )
639
671
return samples_ddim
640
672
673
+
641
674
try :
642
675
if loopback :
643
676
output_images , info = None , None
@@ -709,6 +742,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
709
742
crash (err , '!!Runtime error (img2img)!!' )
710
743
711
744
745
+
712
746
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
713
747
sample_img2img = sample_img2img if os .path .exists (sample_img2img ) else None
714
748
@@ -720,8 +754,7 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
720
754
gr .Slider (minimum = 1 , maximum = 150 , step = 1 , label = "Sampling Steps" , value = 50 ),
721
755
gr .Checkbox (label = 'Fix faces using GFPGAN' , value = False , visible = GFPGAN is not None ),
722
756
gr .Checkbox (label = 'Create prompt matrix (separate multiple prompts using |, and get all combinations of them)' , value = False ),
723
- gr .Checkbox (label = 'Loopback (use images from previous batch when creating next batch)' , value = False ),
724
- gr .Slider (minimum = 1 , maximum = opt .max_batch_count , step = 1 , label = 'Batch count (how many batches of images to generate)' , value = 1 ),
757
+ gr .Slider (minimum = 1 , maximum = 16 , step = 1 , label = 'Batch count (how many batches of images to generate)' , value = 1 ),
725
758
gr .Slider (minimum = 1 , maximum = 8 , step = 1 , label = 'Batch size (how many images are in a batch; memory-hungry)' , value = 1 ),
726
759
gr .Slider (minimum = 1.0 , maximum = 15.0 , step = 0.5 , label = 'Classifier Free Guidance Scale (how strongly the image should follow the prompt)' , value = 7.0 ),
727
760
gr .Slider (minimum = 0.0 , maximum = 1.0 , step = 0.01 , label = 'Denoising Strength' , value = 0.75 ),
0 commit comments