Skip to content

Commit ac15d03

Browse files
committed
update stable-diffusion-webui
1 parent 4331462 commit ac15d03

File tree

1 file changed

+45
-29
lines changed

1 file changed

+45
-29
lines changed

webui.py

+45-29
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,27 @@ def user_auth(username, password):
215215

216216
return response.status_code == 200
217217

218+
def get_bucket_and_key(s3uri):
219+
pos = s3uri.find('/', 5)
220+
bucket = s3uri[5 : pos]
221+
key = s3uri[pos + 1 : ]
222+
return bucket, key
223+
224+
def get_models(path, extensions):
225+
candidates = []
226+
models = []
227+
228+
for extension in extensions:
229+
candidates = candidates + glob.glob(os.path.join(path, f'**/{extension}'), recursive=True)
230+
231+
for filename in sorted(candidates, key=str.lower):
232+
if os.path.isdir(filename):
233+
continue
234+
235+
models.append(filename)
236+
237+
return models
238+
218239
def webui():
219240
launch_api = cmd_opts.api
220241

@@ -301,6 +322,8 @@ def webui():
301322
if launch_api:
302323
create_api(app)
303324

325+
os.path.splitext(os.path.basename(filename))[0]
326+
304327
cmd_sd_models_path = cmd_opts.ckpt_dir
305328
sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion")
306329
if cmd_sd_models_path is not None:
@@ -324,17 +347,14 @@ def webui():
324347
params = {
325348
'module': 'Stable-diffusion'
326349
}
327-
for file in os.listdir(sd_models_dir):
328-
if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safetensors')):
329-
hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file))
330-
item = {}
331-
item['model_name'] = file
332-
item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml'
333-
item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file)
334-
item['hash'] = hash
335-
item['title'] = '{0} [{1}]'.format(file, hash)
336-
item['endpoint_name'] = endpoint_name
337-
items.append(item)
350+
for file in get_models(sd_models_dir, ['*.ckpt', '*.safetensors']):
351+
hash = modules.sd_models.model_hash(file)
352+
item = {}
353+
item['model_name'] = os.path.basename(file)
354+
item['hash'] = hash
355+
item['title'] = '{0} [{1}]'.format(os.path.basename(file), hash)
356+
item['endpoint_name'] = endpoint_name
357+
items.append(item)
338358
inputs = {
339359
'items': items
340360
}
@@ -346,15 +366,13 @@ def webui():
346366
params = {
347367
'module': 'ControlNet'
348368
}
349-
for file in os.listdir(cn_models_dir):
350-
if os.path.isfile(os.path.join(cn_models_dir, file)) and \
351-
(file.endswith('.pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')):
352-
hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file))
353-
item = {}
354-
item['model_name'] = file
355-
item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash)
356-
item['endpoint_name'] = endpoint_name
357-
items.append(item)
369+
for file in get_models(cn_models_dir, ['*.pt', '*.pth', '*.ckpt', '*.safetensors']):
370+
hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file))
371+
item = {}
372+
item['model_name'] = os.path.basename(file)
373+
item['title'] = '{0} [{1}]'.format(os.path.splitext(os.path.basename(file))[0], hash)
374+
item['endpoint_name'] = endpoint_name
375+
items.append(item)
358376
inputs = {
359377
'items': items
360378
}
@@ -366,15 +384,13 @@ def webui():
366384
params = {
367385
'module': 'Lora'
368386
}
369-
for file in os.listdir(lora_models_dir):
370-
if os.path.isfile(os.path.join(lora_models_dir, file)) and \
371-
(file.endswith('.pt') or file.endswith('.ckpt') or file.endswith('.safetensors')):
372-
hash = modules.sd_models.model_hash(os.path.join(lora_models_dir, file))
373-
item = {}
374-
item['model_name'] = file
375-
item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash)
376-
item['endpoint_name'] = endpoint_name
377-
items.append(item)
387+
for file in get_models(lora_models_dir, ['*.pt', '*.ckpt', '*.safetensors']):
388+
hash = modules.sd_models.model_hash(os.path.join(lora_models_dir, file))
389+
item = {}
390+
item['model_name'] = os.path.basename(file)
391+
item['title'] = '{0} [{1}]'.format(os.path.splitext(os.path.basename(file))[0], hash)
392+
item['endpoint_name'] = endpoint_name
393+
items.append(item)
378394
inputs = {
379395
'items': items
380396
}

0 commit comments

Comments
 (0)