Skip to content

Commit 730f47b

Browse files
committed
add support to run it on sagemaker
1 parent 804d9fb commit 730f47b

File tree

7 files changed

+148
-26
lines changed

7 files changed

+148
-26
lines changed

modules/api/api.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
7777
self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
7878
self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
7979
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
80+
self.app.add_api_route("/invocations", self.invocations, methods=["POST"], response_model=InvocationsResponse)
81+
self.app.add_api_route("/ping", self.ping, methods=["GET"], response_model=PingResponse)
8082

8183
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
8284
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
@@ -92,13 +94,17 @@ def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
9294
}
9395
)
9496
p = StableDiffusionProcessingTxt2Img(**vars(populate))
97+
9598
# Override object param
9699

97100
shared.state.begin()
98101

99102
with self.queue_lock:
100-
processed = process_images(p)
101-
103+
if p.script_args is not None:
104+
processed = p.scripts.run(p, *p.script_args)
105+
if processed is None:
106+
processed = process_images(p)
107+
102108
shared.state.end()
103109

104110
b64images = list(map(encode_pil_to_base64, processed.images))
@@ -141,7 +147,10 @@ def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
141147
shared.state.begin()
142148

143149
with self.queue_lock:
144-
processed = process_images(p)
150+
if p.script_args is not None:
151+
processed = p.scripts.run(p, *p.script_args)
152+
if processed is None:
153+
processed = process_images(p)
145154

146155
shared.state.end()
147156

@@ -297,6 +306,17 @@ def get_artists_categories(self):
297306
def get_artists(self):
298307
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
299308

309+
def invocations(self, req: InvocationsRequest):
310+
if req.task == 'text-to-image':
311+
return self.text2imgapi(req.payload)
312+
elif req.task == 'image-to-image':
313+
return self.img2imgapi(req.payload)
314+
else:
315+
raise NotImplementedError
316+
317+
def ping(self):
318+
return {'status': 'Healthy'}
319+
300320
def launch(self, server_name, port):
301321
self.app.include_router(self.router)
302-
uvicorn.run(self.app, host=server_name, port=port)
322+
uvicorn.run(self.app, host=server_name, port=port)

modules/api/models.py

+12
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
77
from modules.shared import sd_upscalers, opts, parser
88
from typing import Dict, List
9+
from typing import Union
910

1011
API_NOT_ALLOWED = [
1112
"self",
@@ -239,3 +240,14 @@ class ArtistItem(BaseModel):
239240
score: float = Field(title="Score")
240241
category: str = Field(title="Category")
241242

243+
class InvocationsRequest(BaseModel):
244+
task: str
245+
payload: Union[StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI]
246+
247+
class InvocationsResponse(BaseModel):
248+
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
249+
parameters: dict
250+
info: str
251+
252+
class PingResponse(BaseModel):
253+
status: str

modules/processing.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ class StableDiffusionProcessing():
7878
"""
7979
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
8080
"""
81-
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None):
81+
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_index: int = 0, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, script_args: str = None, override_settings: Dict[str, Any] = None):
8282
self.sd_model = sd_model
83-
self.outpath_samples: str = outpath_samples
84-
self.outpath_grids: str = outpath_grids
83+
self.outpath_samples: str = outpath_samples or opts.outdir_samples or opts.outdir_txt2img_samples
84+
self.outpath_grids: str = outpath_grids or opts.outdir_grids or opts.outdir_txt2img_grids
8585
self.prompt: str = prompt
8686
self.prompt_for_display: str = None
8787
self.negative_prompt: str = (negative_prompt or "")
@@ -116,15 +116,14 @@ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prom
116116
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
117117
self.s_noise = s_noise or opts.s_noise
118118
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
119-
119+
self.script_args = json.loads(script_args) if script_args != None else None
120+
120121
if not seed_enable_extras:
121122
self.subseed = -1
122123
self.subseed_strength = 0
123124
self.seed_resize_from_h = 0
124125
self.seed_resize_from_w = 0
125126

126-
self.scripts = None
127-
self.script_args = None
128127
self.all_prompts = None
129128
self.all_seeds = None
130129
self.all_subseeds = None

modules/scripts.py

+16
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,22 @@ def __init__(self):
201201
self.titles = []
202202
self.infotext_fields = []
203203

204+
def setup_scripts(self, is_img2img):
205+
for script_class, path, basedir in scripts_data:
206+
script = script_class()
207+
script.filename = path
208+
209+
visibility = script.show(is_img2img)
210+
211+
if visibility == AlwaysVisible:
212+
self.scripts.append(script)
213+
self.alwayson_scripts.append(script)
214+
script.alwayson = True
215+
216+
elif visibility:
217+
self.scripts.append(script)
218+
self.selectable_scripts.append(script)
219+
204220
def setup_ui(self, is_img2img):
205221
for script_class, path, basedir in scripts_data:
206222
script = script_class()

modules/shared.py

+3
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
9191
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
9292
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
93+
parser.add_argument("--pureui", action='store_true', help="Pure UI without local inference and progress bar", default=False)
9394

9495
cmd_opts = parser.parse_args()
9596
restricted_opts = {
@@ -490,6 +491,8 @@ def reorder(self):
490491
opts = Options()
491492
if os.path.exists(config_filename):
492493
opts.load(config_filename)
494+
if cmd_opts.pureui:
495+
opts.show_progressbar = False
493496

494497
sd_upscalers = []
495498

webui.py

+87-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
from fastapi.middleware.cors import CORSMiddleware
99
from fastapi.middleware.gzip import GZipMiddleware
1010

11+
import requests
12+
import json
13+
import time
14+
from PIL import Image
15+
import base64
16+
import io
17+
1118
from modules.paths import script_path
1219

1320
from modules import devices, sd_samplers, upscaler, extensions, localization
@@ -29,16 +36,93 @@
2936

3037
import modules.ui
3138
from modules import modelloader
32-
from modules.shared import cmd_opts
39+
from modules.shared import cmd_opts, opts
3340
import modules.hypernetworks.hypernetwork
3441

3542
queue_lock = threading.Lock()
3643
server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name
3744

45+
api_endpoint = os.environ['api_endpoint']
46+
endpoint_name = os.environ['endpoint_name']
47+
3848
def wrap_queued_call(func):
49+
def sagemaker_inference(task, *args, **kwargs):
50+
script_args = []
51+
for i in range(23, len(args)):
52+
script_args.append(args[i])
53+
54+
payload = {
55+
"prompt": args[0],
56+
"negative_prompt": args[1],
57+
"styles": [args[2], args[3]],
58+
"steps": args[4],
59+
"sampler_index": sd_samplers.samplers[args[5]].name,
60+
"restore_faces": args[6],
61+
"tiling": args[7],
62+
"batch_count": args[8],
63+
"batch_size": args[9],
64+
"cfg_scale": args[10],
65+
"seed": args[11],
66+
"subseed": args[12],
67+
"subseed_strength": args[13],
68+
"seed_resize_from_h": args[14],
69+
"seed_resize_from_w": args[15],
70+
"seed_checkbox": args[16],
71+
"width": args[17],
72+
"height": args[18],
73+
"enable_hr": args[19],
74+
"denoising_strength": args[20],
75+
"firstphase_width": args[21],
76+
"firstphase_height": args[22],
77+
"script_args": json.dumps(script_args),
78+
"eta": opts.eta_ddim if sd_samplers.samplers[args[5]].name == 'DDIM' or sd_samplers.samplers[args[5]].name == 'PLMS' else opts.eta_ancestral,
79+
"s_churn": opts.s_churn,
80+
"s_tmax": None,
81+
"s_tmin": opts.s_tmin,
82+
"s_noise": opts.s_noise,
83+
}
84+
inputs = {
85+
'task': task,
86+
'payload': payload
87+
}
88+
params = {
89+
'endpoint_name': endpoint_name,
90+
'infer_type': 'async'
91+
}
92+
93+
response = requests.post(url=f'{api_endpoint}/inference', params = params, json = inputs)
94+
s3uri = response.text
95+
params = {'s3uri': s3uri}
96+
start = time.time()
97+
while True:
98+
response = requests.get(url=f'{api_endpoint}/s3', params = params)
99+
text = json.loads(response.text)
100+
101+
if text['count'] > 0:
102+
break
103+
else:
104+
time.sleep(1)
105+
106+
httpuri = text['payload'][0]['httpuri']
107+
response = requests.get(url=httpuri)
108+
processed = json.loads(response.text)
109+
images = []
110+
for image in processed['images']:
111+
images.append(Image.open(io.BytesIO(base64.b64decode(image.split(',')[1]))))
112+
parameters = processed['parameters']
113+
info = processed['info']
114+
print(f"Time taken: {time.time() - start}s")
115+
116+
return images, json.dumps(payload), modules.ui.plaintext_to_html(info)
117+
39118
def f(*args, **kwargs):
40-
with queue_lock:
41-
res = func(*args, **kwargs)
119+
if cmd_opts.pureui and func == modules.txt2img.txt2img:
120+
res = sagemaker_inference('text-to-image', *args, **kwargs)
121+
elif(cmd_opts.pureui and func == modules.img2img.img2img):
122+
res = sagemaker_inference('image-to-image', *args, **kwargs)
123+
else:
124+
with queue_lock:
125+
res = func(*args, **kwargs)
42126

43127
return res
44128

webui.sh

+1-13
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,6 @@ printf "\e[1m\e[32mInstall script for stable-diffusion + Web UI\n"
6060
printf "\e[1m\e[34mTested on Debian 11 (Bullseye)\e[0m"
6161
printf "\n%s\n" "${delimiter}"
6262

63-
# Do not run as root
64-
if [[ $(id -u) -eq 0 ]]
65-
then
66-
printf "\n%s\n" "${delimiter}"
67-
printf "\e[1m\e[31mERROR: This script must not be launched as root, aborting...\e[0m"
68-
printf "\n%s\n" "${delimiter}"
69-
exit 1
70-
else
71-
printf "\n%s\n" "${delimiter}"
72-
printf "Running on \e[1m\e[32m%s\e[0m user" "$(whoami)"
73-
printf "\n%s\n" "${delimiter}"
74-
fi
75-
7663
if [[ -d .git ]]
7764
then
7865
printf "\n%s\n" "${delimiter}"
@@ -106,6 +93,7 @@ cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...
10693
if [[ -d "${clone_dir}" ]]
10794
then
10895
cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
96+
"${GIT}" pull
10997
else
11098
printf "\n%s\n" "${delimiter}"
11199
printf "Clone stable-diffusion-webui"

0 commit comments

Comments
 (0)