1
1
import argparse , os , sys , glob
2
- import torch
3
- import torch .nn as nn
4
- import numpy as np
5
2
import gradio as gr
6
- from omegaconf import OmegaConf
7
- from PIL import Image , ImageFont , ImageDraw
8
- from itertools import islice
9
- from einops import rearrange , repeat
10
- from torch import autocast
11
- from contextlib import contextmanager , nullcontext
12
- import mimetypes
13
- import random
3
+ import k_diffusion as K
14
4
import math
5
+ import mimetypes
6
+ import numpy as np
15
7
import pynvml
8
+ import random
16
9
import threading
17
10
import time
11
+ import torch
12
+ import torch .nn as nn
18
13
19
- import k_diffusion as K
20
- from ldm .util import instantiate_from_config
14
+ from contextlib import contextmanager , nullcontext
15
+ from einops import rearrange , repeat
16
+ from itertools import islice
17
+ from omegaconf import OmegaConf
18
+ from PIL import Image , ImageFont , ImageDraw
19
+ from torch import autocast
21
20
from ldm .models .diffusion .ddim import DDIMSampler
22
21
from ldm .models .diffusion .plms import PLMSSampler
22
+ from ldm .util import instantiate_from_config
23
23
24
24
try :
25
25
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -87,6 +87,50 @@ def load_model_from_config(config, ckpt, verbose=False):
87
87
model .eval ()
88
88
return model
89
89
90
+ def crash (e , s ):
91
+ global model
92
+ global device
93
+
94
+ print (s , '\n ' , e )
95
+
96
+ del model
97
+ del device
98
+
99
+ print ('exiting...calling os._exit(0)' )
100
+ t = threading .Timer (0.25 , os ._exit , args = [0 ])
101
+ t .start ()
102
+
103
+ class MemUsageMonitor (threading .Thread ):
104
+ stop_flag = False
105
+ max_usage = 0
106
+ total = 0
107
+
108
+ def __init__ (self , name ):
109
+ threading .Thread .__init__ (self )
110
+ self .name = name
111
+
112
+ def run (self ):
113
+ print (f"[{ self .name } ] Recording max memory usage...\n " )
114
+ pynvml .nvmlInit ()
115
+ handle = pynvml .nvmlDeviceGetHandleByIndex (0 )
116
+ self .total = pynvml .nvmlDeviceGetMemoryInfo (handle ).total
117
+ while not self .stop_flag :
118
+ m = pynvml .nvmlDeviceGetMemoryInfo (handle )
119
+ self .max_usage = max (self .max_usage , m .used )
120
+ # print(self.max_usage)
121
+ time .sleep (0.1 )
122
+ print (f"[{ self .name } ] Stopped recording.\n " )
123
+ pynvml .nvmlShutdown ()
124
+
125
+ def read (self ):
126
+ return self .max_usage , self .total
127
+
128
+ def stop (self ):
129
+ self .stop_flag = True
130
+
131
+ def read_and_stop (self ):
132
+ self .stop_flag = True
133
+ return self .max_usage , self .total
90
134
91
135
class CFGDenoiser (nn .Module ):
92
136
def __init__ (self , model ):
@@ -389,8 +433,10 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
389
433
390
434
precision_scope = autocast if opt .precision == "autocast" else nullcontext
391
435
output_images = []
436
+ stats = []
392
437
with torch .no_grad (), precision_scope ("cuda" ), model .ema_scope ():
393
438
init_data = func_init ()
439
+ tic = time .time ()
394
440
395
441
for n in range (n_iter ):
396
442
prompts = all_prompts [n * batch_size :(n + 1 ) * batch_size ]
@@ -432,7 +478,6 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
432
478
grid = image_grid (output_images , batch_size , round_down = prompt_matrix )
433
479
434
480
if prompt_matrix :
435
-
436
481
try :
437
482
grid = draw_prompt_matrix (grid , width , height , prompt_matrix_parts )
438
483
except Exception :
@@ -442,31 +487,38 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
442
487
443
488
output_images .insert (0 , grid )
444
489
490
+
445
491
grid_file = f"grid-{ grid_count :05} -{ seed } _{ prompts [i ].replace (' ' , '_' ).translate ({ord (x ): '' for x in invalid_filename_chars })[:128 ]} .jpg"
446
492
grid .save (os .path .join (outpath , grid_file ), 'jpeg' , quality = 80 , optimize = True )
447
493
grid_count += 1
494
+ toc = time .time ()
448
495
449
496
mem_max_used , mem_total = mem_mon .read_and_stop ()
450
497
time_diff = time .time ()- start_time
451
498
452
- notes = f'''
453
- Took { round (time_diff , 2 ) } s total ({ round (time_diff / (batch_size * n_iter ),2 ) } s per image)<br>
454
- Peak memory usage: { - (mem_max_used // - 1_048_576 ) } MiB / { - (mem_total // - 1_048_576 ) } MiB / { round (mem_max_used / mem_total * 100 , 3 ) } %<br>
455
- '''
499
+ mem_max_used , mem_total = mem_mon .read_and_stop ()
500
+ time_diff = time .time ()- start_time
456
501
457
502
info = f"""
458
503
{ prompt }
459
- Steps: { steps } , Sampler: { sampler_name } , CFG scale: { cfg_scale } , Seed: { seed } { ', GFPGAN' if use_GFPGAN and GFPGAN is not None else '' }
460
- """ .strip ()
461
-
504
+ Steps: { steps } , Sampler: { sampler_name } , CFG scale: { cfg_scale } , Seed: { seed } { ', GFPGAN' if use_GFPGAN and GFPGAN is not None else '' } { ', Prompt Matrix Mode.' if prompt_matrix else '' } """ .strip ()
505
+ stats = f'''
506
+ Took { round (time_diff , 2 ) } s total ({ round (time_diff / (len (all_prompts )),2 ) } s per image)
507
+ Peak memory usage: { - (mem_max_used // - 1_048_576 ) } MiB / { - (mem_total // - 1_048_576 ) } MiB / { round (mem_max_used / mem_total * 100 , 3 ) } %'''
508
+
462
509
for comment in comments :
463
510
info += "\n \n " + comment
511
+
512
+ #mem_mon.stop()
513
+ #del mem_mon
464
514
torch_gc ()
465
- return output_images , seed , info , notes
515
+
516
+ return output_images , seed , info , stats
466
517
467
518
468
519
def txt2img (prompt : str , ddim_steps : int , sampler_name : str , use_GFPGAN : bool , prompt_matrix : bool , ddim_eta : float , n_iter : int , batch_size : int , cfg_scale : float , seed : int , height : int , width : int ):
469
520
outpath = opt .outdir or "outputs/txt2img-samples"
521
+ err = False
470
522
471
523
if sampler_name == 'PLMS' :
472
524
sampler = PLMSSampler (model )
@@ -483,27 +535,35 @@ def init():
483
535
def sample (init_data , x , conditioning , unconditional_conditioning ):
484
536
samples_ddim , _ = sampler .sample (S = ddim_steps , conditioning = conditioning , batch_size = int (x .shape [0 ]), shape = x [0 ].shape , verbose = False , unconditional_guidance_scale = cfg_scale , unconditional_conditioning = unconditional_conditioning , eta = ddim_eta , x_T = x )
485
537
return samples_ddim
486
-
487
- output_images , seed , info , notes = process_images (
488
- outpath = outpath ,
489
- func_init = init ,
490
- func_sample = sample ,
491
- prompt = prompt ,
492
- seed = seed ,
493
- sampler_name = sampler_name ,
494
- batch_size = batch_size ,
495
- n_iter = n_iter ,
496
- steps = ddim_steps ,
497
- cfg_scale = cfg_scale ,
498
- width = width ,
499
- height = height ,
500
- prompt_matrix = prompt_matrix ,
501
- use_GFPGAN = use_GFPGAN
502
- )
503
-
504
- del sampler
505
-
506
- return output_images , seed , info , notes
538
+ try :
539
+ output_images , seed , info , stats = process_images (
540
+ outpath = outpath ,
541
+ func_init = init ,
542
+ func_sample = sample ,
543
+ prompt = prompt ,
544
+ seed = seed ,
545
+ sampler_name = sampler_name ,
546
+ batch_size = batch_size ,
547
+ n_iter = n_iter ,
548
+ steps = ddim_steps ,
549
+ cfg_scale = cfg_scale ,
550
+ width = width ,
551
+ height = height ,
552
+ prompt_matrix = prompt_matrix ,
553
+ use_GFPGAN = use_GFPGAN
554
+ )
555
+
556
+ del sampler
557
+
558
+ return output_images , seed , info , stats
559
+ except RuntimeError as e :
560
+ err = e
561
+ err_msg = f'CRASHED:<br><textarea rows="5" style="background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{ str (e )} </textarea><br><br>Please wait while the program restarts.'
562
+ stats = err_msg
563
+ return [], 1
564
+ finally :
565
+ if err :
566
+ crash (err , '!!Runtime error (txt2img)!!' )
507
567
508
568
509
569
class Flagging (gr .FlaggingCallback ):
@@ -567,16 +627,17 @@ def flag(self, flag_data, flag_option=None, flag_index=None, username=None):
567
627
gr .Gallery (label = "Images" ),
568
628
gr .Number (label = 'Seed' ),
569
629
gr .Textbox (label = "Copy-paste generation parameters" ),
570
- gr .HTML (label = 'Notes ' ),
630
+ gr .HTML (label = 'Stats ' ),
571
631
],
572
- title = "Stable Diffusion Text-to-Image K " ,
573
- description = "Generate images from text with Stable Diffusion (using K-LMS) " ,
632
+ title = "Stable Diffusion Text-to-Image Unified " ,
633
+ description = "Generate images from text with Stable Diffusion" ,
574
634
flagging_callback = Flagging ()
575
635
)
576
636
577
637
578
638
def img2img (prompt : str , init_img , ddim_steps : int , use_GFPGAN : bool , prompt_matrix , n_iter : int , batch_size : int , cfg_scale : float , denoising_strength : float , seed : int , height : int , width : int , resize_mode : int ):
579
639
outpath = opt .outdir or "outputs/img2img-samples"
640
+ err = False
580
641
581
642
sampler = KDiffusionSampler (model )
582
643
@@ -609,26 +670,77 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
609
670
samples_ddim = K .sampling .sample_lms (model_wrap_cfg , xi , sigma_sched , extra_args = {'cond' : conditioning , 'uncond' : unconditional_conditioning , 'cond_scale' : cfg_scale }, disable = False )
610
671
return samples_ddim
611
672
612
- output_images , seed , info , notes = process_images (
613
- outpath = outpath ,
614
- func_init = init ,
615
- func_sample = sample ,
616
- prompt = prompt ,
617
- seed = seed ,
618
- sampler_name = 'k-diffusion' ,
619
- batch_size = batch_size ,
620
- n_iter = n_iter ,
621
- steps = ddim_steps ,
622
- cfg_scale = cfg_scale ,
623
- width = width ,
624
- height = height ,
625
- prompt_matrix = prompt_matrix ,
626
- use_GFPGAN = use_GFPGAN
627
- )
628
673
629
- del sampler
674
+ try :
675
+ if loopback :
676
+ output_images , info = None , None
677
+ history = []
678
+ initial_seed = None
679
+
680
+ for i in range (n_iter ):
681
+ output_images , seed , info , stats = process_images (
682
+ outpath = outpath ,
683
+ func_init = init ,
684
+ func_sample = sample ,
685
+ prompt = prompt ,
686
+ seed = seed ,
687
+ sampler_name = 'k-diffusion' ,
688
+ batch_size = 1 ,
689
+ n_iter = 1 ,
690
+ steps = ddim_steps ,
691
+ cfg_scale = cfg_scale ,
692
+ width = width ,
693
+ height = height ,
694
+ prompt_matrix = prompt_matrix ,
695
+ use_GFPGAN = use_GFPGAN ,
696
+ do_not_save_grid = True
697
+ )
698
+
699
+ if initial_seed is None :
700
+ initial_seed = seed
701
+
702
+ init_img = output_images [0 ]
703
+ seed = seed + 1
704
+ denoising_strength = max (denoising_strength * 0.95 , 0.1 )
705
+ history .append (init_img )
706
+
707
+ grid_count = len (os .listdir (outpath )) - 1
708
+ grid = image_grid (history , batch_size , force_n_rows = 1 )
709
+ grid .save (os .path .join (outpath , f'grid-{ grid_count :04} .{ opt .grid_format } ' ))
710
+
711
+ output_images = history
712
+ seed = initial_seed
713
+
714
+ else :
715
+ output_images , seed , info , stats = process_images (
716
+ outpath = outpath ,
717
+ func_init = init ,
718
+ func_sample = sample ,
719
+ prompt = prompt ,
720
+ seed = seed ,
721
+ sampler_name = 'k-diffusion' ,
722
+ batch_size = batch_size ,
723
+ n_iter = n_iter ,
724
+ steps = ddim_steps ,
725
+ cfg_scale = cfg_scale ,
726
+ width = width ,
727
+ height = height ,
728
+ prompt_matrix = prompt_matrix ,
729
+ use_GFPGAN = use_GFPGAN
730
+ )
731
+
732
+ del sampler
733
+
734
+ return output_images , seed , info , stats
735
+ except RuntimeError as e :
736
+ err = e
737
+ err_msg = f'CRASHED:<br><textarea rows="5" style="background: black;width: -webkit-fill-available;font-family: monospace;font-size: small;font-weight: bold;">{ str (e )} </textarea><br><br>Please wait while the program restarts.'
738
+ stats = err_msg
739
+ return [], 1
740
+ finally :
741
+ if err :
742
+ crash (err , '!!Runtime error (img2img)!!' )
630
743
631
- return output_images , seed , info , notes
632
744
633
745
634
746
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
@@ -655,9 +767,9 @@ def sample(init_data, x, conditioning, unconditional_conditioning):
655
767
gr .Gallery (),
656
768
gr .Number (label = 'Seed' ),
657
769
gr .Textbox (label = "Copy-paste generation parameters" ),
658
- gr .HTML (label = 'Notes ' ),
770
+ gr .HTML (label = 'Stats ' ),
659
771
],
660
- title = "Stable Diffusion Image-to-Image" ,
772
+ title = "Stable Diffusion Image-to-Image Unified " ,
661
773
description = "Generate images from images with Stable Diffusion" ,
662
774
allow_flagging = "never" ,
663
775
)
@@ -700,4 +812,4 @@ def run_GFPGAN(image, strength):
700
812
css = ("" if opt .no_progressbar_hiding else css_hide_progressbar )
701
813
)
702
814
703
- demo .launch ()
815
+ demo .launch ()
0 commit comments