Skip to content

Commit 57e6d05

Browse files
committed
added tool for profiling code
1 parent aafbb5b commit 57e6d05

File tree

5 files changed

+78
-5
lines changed

5 files changed

+78
-5
lines changed

modules/call_queue.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import os.path
12
from functools import wraps
23
import html
34
import time
45

5-
from modules import shared, progress, errors, devices, fifo_lock
6+
from modules import shared, progress, errors, devices, fifo_lock, profiling
67

78
queue_lock = fifo_lock.FIFOLock()
89

@@ -111,8 +112,13 @@ def f(*args, extra_outputs_array=extra_outputs, **kwargs):
111112
else:
112113
vram_html = ''
113114

115+
if shared.opts.profiling_enable and os.path.exists(shared.opts.profiling_filename):
116+
profiling_html = f"<p class='profile'> [ <a href='{profiling.webpath()}' download>Profile</a> ] </p>"
117+
else:
118+
profiling_html = ''
119+
114120
# last item is always HTML
115-
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}</div>"
121+
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr><span class='measurement'>{elapsed_text}</span></p>{vram_html}{profiling_html}</div>"
116122

117123
return tuple(res)
118124

modules/processing.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Any
1717

1818
import modules.sd_hijack
19-
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
19+
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
2020
from modules.rng import slerp # noqa: F401
2121
from modules.sd_hijack import model_hijack
2222
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
@@ -843,7 +843,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
843843
# backwards compatibility, fix sampler and scheduler if invalid
844844
sd_samplers.fix_p_invalid_sampler_and_scheduler(p)
845845

846-
res = process_images_inner(p)
846+
with profiling.Profiler():
847+
res = process_images_inner(p)
847848

848849
finally:
849850
sd_models.apply_token_merging(p.sd_model, 0)

modules/profiling.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
3+
from modules import shared, ui_gradio_extensions
4+
5+
6+
class Profiler:
7+
def __init__(self):
8+
if not shared.opts.profiling_enable:
9+
self.profiler = None
10+
return
11+
12+
activities = []
13+
if "CPU" in shared.opts.profiling_activities:
14+
activities.append(torch.profiler.ProfilerActivity.CPU)
15+
if "CUDA" in shared.opts.profiling_activities:
16+
activities.append(torch.profiler.ProfilerActivity.CUDA)
17+
18+
if not activities:
19+
self.profiler = None
20+
return
21+
22+
self.profiler = torch.profiler.profile(
23+
activities=activities,
24+
record_shapes=shared.opts.profiling_record_shapes,
25+
profile_memory=shared.opts.profiling_profile_memory,
26+
with_stack=shared.opts.profiling_with_stack
27+
)
28+
29+
def __enter__(self):
30+
if self.profiler:
31+
self.profiler.__enter__()
32+
33+
return self
34+
35+
def __exit__(self, exc_type, exc, exc_tb):
36+
if self.profiler:
37+
shared.state.textinfo = "Finishing profile..."
38+
39+
self.profiler.__exit__(exc_type, exc, exc_tb)
40+
41+
self.profiler.export_chrome_trace(shared.opts.profiling_filename)
42+
43+
44+
def webpath():
45+
return ui_gradio_extensions.webpath(shared.opts.profiling_filename)
46+

modules/shared_options.py

+16
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,22 @@
129129
"dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."),
130130
}))
131131

132+
options_templates.update(options_section(('profiler', "Profiler", "system"), {
133+
"profiling_explanation": OptionHTML("""
134+
Those settings allow you to enable torch profiler when generating pictures.
135+
Profiling allows you to see which code uses how much of computer's resources during generation.
136+
Each generation writes its own profile to one file, overwriting previous.
137+
The file can be viewed in <a href="chrome:tracing">Chrome</a>, or on a <a href="https://ui.perfetto.dev/">Perfetto</a> web site.
138+
Warning: writing profile can take a lot of time, up to 30 seconds, and the file itelf can be around 500MB in size.
139+
"""),
140+
"profiling_enable": OptionInfo(False, "Enable profiling"),
141+
"profiling_activities": OptionInfo(["CPU"], "Activities", gr.CheckboxGroup, {"choices": ["CPU", "CUDA"]}),
142+
"profiling_record_shapes": OptionInfo(True, "Record shapes"),
143+
"profiling_profile_memory": OptionInfo(True, "Profile memory"),
144+
"profiling_with_stack": OptionInfo(True, "Include python stack"),
145+
"profiling_filename": OptionInfo("trace.json", "Profile filename"),
146+
}))
147+
132148
options_templates.update(options_section(('API', "API", "system"), {
133149
"api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True),
134150
"api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True),

style.css

+5-1
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ input[type="checkbox"].input-accordion-checkbox{
279279
display: inline-block;
280280
}
281281

282-
.html-log .performance p.time, .performance p.vram, .performance p.time abbr, .performance p.vram abbr {
282+
.html-log .performance p.time, .performance p.vram, .performance p.profile, .performance p.time abbr, .performance p.vram abbr {
283283
margin-bottom: 0;
284284
color: var(--block-title-text-color);
285285
}
@@ -291,6 +291,10 @@ input[type="checkbox"].input-accordion-checkbox{
291291
margin-left: auto;
292292
}
293293

294+
.html-log .performance p.profile {
295+
margin-left: 0.5em;
296+
}
297+
294298
.html-log .performance .measurement{
295299
color: var(--body-text-color);
296300
font-weight: bold;

0 commit comments

Comments
 (0)