Skip to content

Commit b1d9758

Browse files
committed
revise multi-user support
1 parent 18815ab commit b1d9758

File tree

9 files changed

+203
-175
lines changed

9 files changed

+203
-175
lines changed

launch.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def prepare_enviroment():
149149
sys.argv += shlex.split(commandline_args)
150150
test_argv = [x for x in sys.argv if x != '--tests']
151151

152+
sys.argv, skip_torch_cuda = extract_arg(sys.argv, '--skip-torch-cuda')
152153
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
153154
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
154155
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
@@ -164,11 +165,11 @@ def prepare_enviroment():
164165

165166
print(f"Python {sys.version}")
166167
print(f"Commit hash: {commit}")
167-
168-
if not is_installed("torch") or not is_installed("torchvision"):
168+
169+
if not skip_torch_cuda and (not is_installed("torch") or not is_installed("torchvision")):
169170
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
170171

171-
if not skip_torch_cuda_test:
172+
if not skip_torch_cuda and not skip_torch_cuda_test:
172173
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
173174

174175
if not is_installed("gfpgan"):
@@ -206,7 +207,7 @@ def prepare_enviroment():
206207
if not is_installed("lpips"):
207208
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
208209

209-
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
210+
#run_pip(f"install -r {requirements_file}", "requirements for Web UI")
210211

211212
run_extensions_installers()
212213

localizations/zh_CN.json

+29-1
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,35 @@
618618
"favorites": "收藏夹(已保存)",
619619
"others": "其他",
620620
"Collect": "收藏(保存)",
621-
621+
"Create & Train Embedding": "创建并训练 Embedding",
622+
"Train an embedding; you must specify a directory with a set of 1:1 ratio images": "训练 embedding; 必须指定一组具有 1:1 比例图像的目录",
623+
"Embedding settings": "Embedding 设置",
624+
"Image preprocess settings": "图像预处理设置",
625+
"Train settings": "训练设置",
626+
"Create & Train Hypernetwork": "创建并训练 Hypernetwork",
627+
"Train an hypernetwork; you must specify a directory with a set of 1:1 ratio images": "训练 hypernetwork; 必须指定一组具有 1:1 比例图像的目录",
628+
"Hypernetwork settings": "Hypernetwork 设置",
629+
"Sign Options": "登陆选项",
630+
"Sign In": "登入",
631+
"Sign Up": "注册",
632+
"Sign Out": "登出",
633+
"Username": "用户名",
634+
"Password": "密码",
635+
"Email": "电子邮箱",
636+
"Update": "更新",
637+
"Delete": "删除",
638+
"Mismatched username/password or not existed username": "用户名/密码不匹配或用户不存在",
639+
"Signup failed, please check and retry again": "注册失败,请检查后并重试",
640+
"Update failed, please check and retry again": "更新失败,请检查后并重试",
641+
"Output": "输出",
642+
"Images S3 URI": "图像 S3 位置",
643+
"Models S3 URI": "模型 S3 位置",
644+
"Instance type": "实例类型",
645+
"Instance count": "实例数量",
646+
"Submit training job sucessful": "训练任务提交成功",
647+
"Settings saved failed": "设置保存错误",
648+
"SageMaker endpoint": "SageMaker 端点",
649+
"User": "用户",
622650

623651
"--------": "--------"
624652
}

localizations/zh_TW.json

+117-86
Large diffs are not rendered by default.

modules/api/api.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import json
1919
import os
2020
import boto3
21-
from modules import sd_hijack, hypernetworks
21+
from modules import sd_hijack, hypernetworks, sd_models
2222
from typing import Union
2323
import traceback
2424
import requests
@@ -359,6 +359,7 @@ def invocations(self, req: InvocationsRequest):
359359
response = requests.post(url=f'{api_endpoint}/sd/user', json=inputs)
360360
if response.status_code == 200 and response.text != '':
361361
shared.opts.data = json.loads(response.text)
362+
sd_models.reload_model_weights()
362363

363364
self.download_s3files(hypernetwork_s3uri, os.path.join(script_path, shared.cmd_opts.hypernetwork_dir))
364365
hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)

modules/sd_models.py

+2
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,8 @@ def reload_model_weights(sd_model=None, info=None):
311311
if not sd_model:
312312
sd_model = shared.sd_model
313313

314+
print('Origin checkpoint: ', sd_model.sd_model_checkpoint)
315+
print('Current checkpoint: ', checkpoint_info.filename)
314316
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
315317
return
316318

modules/shared.py

+43-46
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,6 @@
136136
username = ''
137137
api_endpoint = os.environ['api_endpoint']
138138
industrial_model = ''
139-
endpoint_name = ''
140-
endpoint_names = []
141139
default_options = {}
142140

143141
def reload_hypernetworks():
@@ -268,7 +266,49 @@ def options_section(section_identifier, options_dict):
268266

269267
options_templates = {}
270268

269+
def refresh_sagemaker_endpoints():
270+
global industrial_model, api_endpoint, default_options
271+
272+
if industrial_model == '':
273+
response = requests.get(url=f'{api_endpoint}/sd/industrialmodel')
274+
if response.status_code == 200:
275+
industrial_model = response.text
276+
else:
277+
model_name = 'stable-diffusion-webui'
278+
model_description = model_name
279+
inputs = {
280+
'model_algorithm': 'stable-diffusion-webui',
281+
'model_name': model_name,
282+
'model_description': model_description,
283+
'model_extra': '{"visible": "false"}',
284+
'model_samples': '',
285+
'file_content': {
286+
'data': [(lambda x: int(x))(x) for x in open(os.path.join(script_path, 'logo.ico'), 'rb').read()]
287+
}
288+
}
289+
290+
response = requests.post(url=f'{api_endpoint}/industrialmodel', json = inputs)
291+
if response.status_code == 200:
292+
body = json.loads(response.text)
293+
industrial_model = body['id']
294+
295+
default_options = self.data
296+
297+
sagemaker_endpoints = []
298+
299+
if industrial_model != '':
300+
params = {
301+
'industrial_model': industrial_model
302+
}
303+
response = requests.get(url=f'{api_endpoint}/endpoint', params=params)
304+
if response.status_code == 200:
305+
for endpoint_item in json.loads(response.text):
306+
sagemaker_endpoints.append(endpoint_item['EndpointName'])
307+
308+
return sagemaker_endpoints
309+
271310
options_templates.update(options_section(('sd', "Stable Diffusion"), {
311+
"sagemaker_endpoint": OptionInfo(None, "SaegMaker endpoint", gr.Dropdown, lambda: {"choices": refresh_sagemaker_endpoints()}, refresh=refresh_sagemaker_endpoints),
272312
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
273313
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
274314
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list),
@@ -393,7 +433,7 @@ def options_section(section_identifier, options_dict):
393433
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
394434
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
395435
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
396-
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
436+
'quicksettings': OptionInfo("sagemaker_endpoint", "Quicksettings list"),
397437
'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
398438
}))
399439

@@ -480,34 +520,6 @@ def load(self, filename):
480520
if bad_settings > 0:
481521
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
482522

483-
if cmd_opts.pureui:
484-
global api_endpoint, industrial_model, default_options
485-
486-
#opts.show_progressbar = False
487-
response = requests.get(url=f'{api_endpoint}/sd/industrialmodel')
488-
if response.status_code == 200:
489-
industrial_model = response.text
490-
else:
491-
model_name = 'stable-diffusion-webui'
492-
model_description = model_name
493-
inputs = {
494-
'model_algorithm': 'stable-diffusion-webui',
495-
'model_name': model_name,
496-
'model_description': model_description,
497-
'model_extra': '{"visible": "false"}',
498-
'model_samples': '',
499-
'file_content': {
500-
'data': [(lambda x: int(x))(x) for x in open(os.path.join(script_path, 'logo.ico'), 'rb').read()]
501-
}
502-
}
503-
504-
response = requests.post(url=f'{api_endpoint}/industrialmodel', json = inputs)
505-
if response.status_code == 200:
506-
body = json.loads(response.text)
507-
industrial_model = body['id']
508-
509-
default_options = self.data
510-
511523
def onchange(self, key, func, call=True):
512524
item = self.data_labels.get(key)
513525
item.onchange = func
@@ -587,18 +599,3 @@ def clear(self):
587599
def listfiles(dirname):
588600
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
589601
return [file for file in filenames if os.path.isfile(file)]
590-
591-
if cmd_opts.pureui:
592-
def init_endpoints():
593-
global endpoint_name, endpoint_names, industrial_model, api_endpoint
594-
595-
endpoints = []
596-
params = {
597-
'industrial_model': industrial_model
598-
}
599-
response = requests.get(url=f'{api_endpoint}/endpoint', params=params)
600-
if response.status_code == 200:
601-
for endpoint_item in json.loads(response.text):
602-
endpoints.append(endpoint_item['EndpointName'])
603-
endpoint_name = endpoints[0] if len(endpoints) > 0 else ''
604-
endpoint_names = endpoints

modules/ui.py

-30
Original file line numberDiff line numberDiff line change
@@ -2331,36 +2331,6 @@ def user_delete(login_username, login_password, login_email):
23312331
component = create_setting_component(k, is_quicksettings=True)
23322332
component_dict[k] = component
23332333

2334-
if cmd_opts.pureui:
2335-
shared.init_endpoints()
2336-
2337-
with gr.Row():
2338-
with gr.Column(scale=9):
2339-
endpoint_names = gr.Dropdown(label='SageMaker endpoint', value=shared.endpoint_name, choices=shared.endpoint_names)
2340-
with gr.Column(scale=1):
2341-
endpoint_refresh = gr.Button(refresh_symbol)
2342-
2343-
def refresh_endpoint():
2344-
shared.init_endpoints()
2345-
return {
2346-
endpoint_names: gr.update(value=shared.endpoint_name, choices=shared.endpoint_names)
2347-
}
2348-
2349-
def change_endpoint(endpoint_names):
2350-
shared.endpoint_name = endpoint_names
2351-
2352-
endpoint_names.change(
2353-
fn=change_endpoint,
2354-
inputs=[endpoint_names],
2355-
outputs=[]
2356-
)
2357-
2358-
endpoint_refresh.click(
2359-
fn=refresh_endpoint,
2360-
inputs=[],
2361-
outputs=[endpoint_names]
2362-
)
2363-
23642334
parameters_copypaste.integrate_settings_paste_fields(component_dict)
23652335
parameters_copypaste.run_bind()
23662336

style.css

+1-1
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ input[type="range"]{
501501
padding: 0;
502502
}
503503

504-
#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
504+
#refresh_sagemaker_endpoint, #refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
505505
max-width: 2.5em;
506506
min-width: 2.5em;
507507
height: 2.4em;

webui.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def sagemaker_inference(task, infer, *args, **kwargs):
286286
}
287287

288288
params = {
289-
'endpoint_name': shared.endpoint_name
289+
'endpoint_name': shared.opts.sagemaker_endpoint
290290
}
291291

292292
response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs)
@@ -384,7 +384,7 @@ def sagemaker_inference(task, infer, *args, **kwargs):
384384
}
385385

386386
params = {
387-
'endpoint_name': shared.endpoint_name
387+
'endpoint_name': shared.opts.sagemaker_endpoint
388388
}
389389
response = requests.post(url=f'{shared.api_endpoint}/inference', params=params, json=inputs)
390390
if infer == 'async':
@@ -440,6 +440,7 @@ def initialize():
440440
modules.scripts.load_scripts()
441441

442442
modules.sd_vae.refresh_vae_list()
443+
443444
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
444445

445446
if not cmd_opts.pureui:
@@ -593,10 +594,7 @@ def train():
593594
response = requests.post(url=f'{api_endpoint}/sd/user', json=inputs)
594595
if response.status_code == 200 and response.text != '':
595596
opts.data = json.loads(response.text)
596-
for key in modules.sd_models.checkpoints_list:
597-
if modules.sd_models.checkpoints_list[key].title == opts.data['sd_model_checkpoint']:
598-
shared.sd_model.sd_model_name = modules.sd_models.checkpoints_list[key].model_name
599-
break
597+
modules.sd_models.load_model()
600598

601599
if train_task == 'embedding':
602600
name = train_args['embedding_settings']['name']

0 commit comments

Comments
 (0)