@@ -215,6 +215,27 @@ def user_auth(username, password):
215
215
216
216
return response .status_code == 200
217
217
218
+ def get_bucket_and_key (s3uri ):
219
+ pos = s3uri .find ('/' , 5 )
220
+ bucket = s3uri [5 : pos ]
221
+ key = s3uri [pos + 1 : ]
222
+ return bucket , key
223
+
224
+ def get_models (path , extensions ):
225
+ candidates = []
226
+ models = []
227
+
228
+ for extension in extensions :
229
+ candidates = candidates + glob .glob (os .path .join (path , f'**/{ extension } ' ), recursive = True )
230
+
231
+ for filename in sorted (candidates , key = str .lower ):
232
+ if os .path .isdir (filename ):
233
+ continue
234
+
235
+ models .append (filename )
236
+
237
+ return models
238
+
218
239
def webui ():
219
240
launch_api = cmd_opts .api
220
241
@@ -301,6 +322,8 @@ def webui():
301
322
if launch_api :
302
323
create_api (app )
303
324
325
+ os .path .splitext (os .path .basename (filename ))[0 ]
326
+
304
327
cmd_sd_models_path = cmd_opts .ckpt_dir
305
328
sd_models_dir = os .path .join (shared .models_path , "Stable-diffusion" )
306
329
if cmd_sd_models_path is not None :
@@ -324,17 +347,14 @@ def webui():
324
347
params = {
325
348
'module' : 'Stable-diffusion'
326
349
}
327
- for file in os .listdir (sd_models_dir ):
328
- if os .path .isfile (os .path .join (sd_models_dir , file )) and (file .endswith ('.ckpt' ) or file .endswith ('.safetensors' )):
329
- hash = modules .sd_models .model_hash (os .path .join (sd_models_dir , file ))
330
- item = {}
331
- item ['model_name' ] = file
332
- item ['config' ] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml'
333
- item ['filename' ] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}' .format (file )
334
- item ['hash' ] = hash
335
- item ['title' ] = '{0} [{1}]' .format (file , hash )
336
- item ['endpoint_name' ] = endpoint_name
337
- items .append (item )
350
+ for file in get_models (sd_models_dir , ['*.ckpt' , '*.safetensors' ]):
351
+ hash = modules .sd_models .model_hash (file )
352
+ item = {}
353
+ item ['model_name' ] = os .path .basename (file )
354
+ item ['hash' ] = hash
355
+ item ['title' ] = '{0} [{1}]' .format (os .path .basename (file ), hash )
356
+ item ['endpoint_name' ] = endpoint_name
357
+ items .append (item )
338
358
inputs = {
339
359
'items' : items
340
360
}
@@ -346,15 +366,13 @@ def webui():
346
366
params = {
347
367
'module' : 'ControlNet'
348
368
}
349
- for file in os .listdir (cn_models_dir ):
350
- if os .path .isfile (os .path .join (cn_models_dir , file )) and \
351
- (file .endswith ('.pt' ) or file .endswith ('.pth' ) or file .endswith ('.ckpt' ) or file .endswith ('.safetensors' )):
352
- hash = modules .sd_models .model_hash (os .path .join (cn_models_dir , file ))
353
- item = {}
354
- item ['model_name' ] = file
355
- item ['title' ] = '{0} [{1}]' .format (os .path .splitext (file )[0 ], hash )
356
- item ['endpoint_name' ] = endpoint_name
357
- items .append (item )
369
+ for file in get_models (cn_models_dir , ['*.pt' , '*.pth' , '*.ckpt' , '*.safetensors' ]):
370
+ hash = modules .sd_models .model_hash (os .path .join (cn_models_dir , file ))
371
+ item = {}
372
+ item ['model_name' ] = os .path .basename (file )
373
+ item ['title' ] = '{0} [{1}]' .format (os .path .splitext (os .path .basename (file ))[0 ], hash )
374
+ item ['endpoint_name' ] = endpoint_name
375
+ items .append (item )
358
376
inputs = {
359
377
'items' : items
360
378
}
@@ -366,15 +384,13 @@ def webui():
366
384
params = {
367
385
'module' : 'Lora'
368
386
}
369
- for file in os .listdir (lora_models_dir ):
370
- if os .path .isfile (os .path .join (lora_models_dir , file )) and \
371
- (file .endswith ('.pt' ) or file .endswith ('.ckpt' ) or file .endswith ('.safetensors' )):
372
- hash = modules .sd_models .model_hash (os .path .join (lora_models_dir , file ))
373
- item = {}
374
- item ['model_name' ] = file
375
- item ['title' ] = '{0} [{1}]' .format (os .path .splitext (file )[0 ], hash )
376
- item ['endpoint_name' ] = endpoint_name
377
- items .append (item )
387
+ for file in get_models (lora_models_dir , ['*.pt' , '*.ckpt' , '*.safetensors' ]):
388
+ hash = modules .sd_models .model_hash (os .path .join (lora_models_dir , file ))
389
+ item = {}
390
+ item ['model_name' ] = os .path .basename (file )
391
+ item ['title' ] = '{0} [{1}]' .format (os .path .splitext (os .path .basename (file ))[0 ], hash )
392
+ item ['endpoint_name' ] = endpoint_name
393
+ items .append (item )
378
394
inputs = {
379
395
'items' : items
380
396
}
0 commit comments