Skip to content

Commit 3e4f422

Browse files
committed
cleanup
1 parent 4642385 commit 3e4f422

File tree

2 files changed

+35
-38
lines changed

2 files changed

+35
-38
lines changed

modules/shared.py

-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@
9999
parser.add_argument('--sd-models-s3uri', default='', type=str, help='SD Models S3Uri')
100100
parser.add_argument('--db-models-s3uri', default='', type=str, help='DB Models S3Uri')
101101
parser.add_argument('--lora-models-s3uri', default='', type=str, help='Lora Models S3Uri')
102-
parser.add_argument('--region-name', type=str, help='Region Name')
103102
parser.add_argument('--username', default='', type=str, help='Username')
104103
parser.add_argument('--api-endpoint', default='', type=str, help='API Endpoint')
105104
parser.add_argument('--dreambooth-config-id', default='', type=str, help='Dreambooth config ID')

webui.py

+35-37
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import shutil
3-
import threading
43
import time
54
import importlib
65
import signal
@@ -16,7 +15,7 @@
1615
from modules.call_queue import wrap_queued_call, queue_lock
1716
from modules.paths import script_path
1817

19-
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
18+
from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir
2019
import modules.codeformer_model as codeformer
2120
import modules.extras
2221
import modules.face_restoration
@@ -220,24 +219,23 @@ def webui():
220219
if launch_api:
221220
create_api(app)
222221

223-
ckpt_dir = cmd_opts.ckpt_dir
224-
sd_models_path = os.path.join(shared.models_path, "Stable-diffusion")
225-
if ckpt_dir is not None:
226-
sd_models_path = ckpt_dir
222+
cmd_sd_models_path = cmd_opts.ckpt_dir
223+
sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion")
224+
if cmd_sd_models_path is not None:
225+
sd_models_dir = cmd_sd_models_path
227226

228-
controlnet_dir = cmd_opts.controlnet_dir
229-
cn_models_path = os.path.join(shared.models_path, "ControlNet")
230-
os.makedirs(controlnet_dir, exist_ok=True)
231-
if controlnet_dir is not None:
232-
cn_models_path = controlnet_dir
227+
cmd_controlnet_models_path = cmd_opts.controlnet_dir
228+
cn_models_dir = os.path.join(shared.models_path, "ControlNet")
229+
if cmd_controlnet_models_path is not None:
230+
cn_models_dir = cmd_controlnet_models_path
233231

234232
if 'endpoint_name' in os.environ:
235233
items = []
236234
api_endpoint = os.environ['api_endpoint']
237235
endpoint_name = os.environ['endpoint_name']
238-
for file in os.listdir(sd_models_path):
239-
if os.path.isfile(os.path.join(sd_models_path, file)) and (file.endswith('.ckpt') or file.endswith('.safesentors')):
240-
hash = modules.sd_models.model_hash(os.path.join(sd_models_path, file))
236+
for file in os.listdir(sd_models_dir):
237+
if os.path.isfile(os.path.join(sd_models_dir, file)) and (file.endswith('.ckpt') or file.endswith('.safesentors')):
238+
hash = modules.sd_models.model_hash(os.path.join(sd_models_dir, file))
241239
item = {}
242240
item['model_name'] = file
243241
item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml'
@@ -263,10 +261,10 @@ def webui():
263261
params = {
264262
'module': 'ControlNet'
265263
}
266-
for file in os.listdir(cn_models_path):
267-
if os.path.isfile(os.path.join(cn_models_path, file)) and \
264+
for file in os.listdir(cn_models_dir):
265+
if os.path.isfile(os.path.join(cn_models_dir, file)) and \
268266
(file.endswith('pt') or file.endswith('.pth') or file.endswith('.ckpt') or file.endswith('.safetensors')):
269-
hash = modules.sd_models.model_hash(os.path.join(cn_models_path, file))
267+
hash = modules.sd_models.model_hash(os.path.join(cn_models_dir, file))
270268
item = {}
271269
item['model_name'] = file
272270
item['title'] = '{0} [{1}]'.format(os.path.splitext(file)[0], hash)
@@ -565,10 +563,10 @@ def train():
565563
opts.data = default_options
566564
elif train_task == 'dreambooth':
567565
db_create_new_db_model = train_args['train_dreambooth_settings']['db_create_new_db_model']
566+
db_use_txt2img = train_args['train_dreambooth_settings']['db_use_txt2img']
568567
db_train_wizard_person = train_args['train_dreambooth_settings']['db_train_wizard_person']
569568
db_train_wizard_object = train_args['train_dreambooth_settings']['db_train_wizard_object']
570569
db_performance_wizard = train_args['train_dreambooth_settings']['db_performance_wizard']
571-
db_use_txt2img = train_args['train_dreambooth_settings']['db_use_txt2img']
572570

573571
if db_create_new_db_model:
574572
db_new_model_name = train_args['train_dreambooth_settings']['db_new_model_name']
@@ -623,13 +621,13 @@ def train():
623621
c1_num_class_images_per, \
624622
c2_num_class_images_per, \
625623
c3_num_class_images_per, \
626-
c4_num_class_images_per = training_wizard(db_config, db_train_wizard_person if db_train_wizard_person else db_train_wizard_object)
624+
c4_num_class_images_per = training_wizard(db_train_wizard_person if db_train_wizard_person else db_train_wizard_object)
627625

628626
params_dict['db_num_train_epochs'] = db_num_train_epochs
629-
params_dict[59] = c1_num_class_images_per
630-
params_dict[61] = c2_num_class_images_per
631-
params_dict[77] = c3_num_class_images_per
632-
params_dict[79] = c4_num_class_images_per
627+
params_dict['c1_num_class_images_per'] = c1_num_class_images_per
628+
params_dict['c1_num_class_images_per'] = c2_num_class_images_per
629+
params_dict['c1_num_class_images_per'] = c3_num_class_images_per
630+
params_dict['c1_num_class_images_per'] = c4_num_class_images_per
633631
if db_performance_wizard:
634632
attention, \
635633
gradient_checkpointing, \
@@ -684,17 +682,17 @@ def train():
684682
db_model_name = train_args['train_dreambooth_settings']['db_model_name']
685683
db_config = DreamboothConfig(db_model_name)
686684

687-
ckpt_dir = cmd_opts.ckpt_dir
688-
sd_models_path = os.path.join(shared.models_path, "Stable-diffusion")
689-
if ckpt_dir is not None:
690-
sd_models_path = ckpt_dir
691-
692685
print(vars(db_config))
693686
start_training_from_config(
694687
db_config,
695688
db_use_txt2img,
696689
)
697690

691+
cmd_sd_models_path = cmd_opts.ckpt_dir
692+
sd_models_dir = os.path.join(shared.models_path, "Stable-diffusion")
693+
if cmd_sd_models_path is not None:
694+
sd_models_dir = cmd_sd_models_path
695+
698696
try:
699697
cmd_dreambooth_models_path = cmd_opts.dreambooth_models_path
700698
except:
@@ -711,26 +709,26 @@ def train():
711709
lora_model_dir = os.path.dirname(cmd_lora_models_path) if cmd_lora_models_path else paths.models_path
712710
lora_model_dir = os.path.join(lora_model_dir, "lora")
713711

714-
print('---models path---', sd_models_path, lora_model_dir)
715-
os.system(f'ls -l {sd_models_path}')
716-
os.system('ls -l {0}'.format(os.path.join(sd_models_path, db_model_name)))
712+
print('---models path---', sd_models_dir, lora_model_dir)
713+
os.system(f'ls -l {sd_models_dir}')
714+
os.system('ls -l {0}'.format(os.path.join(sd_models_dir, db_model_name)))
717715
os.system(f'ls -l {lora_model_dir}')
718716

719717
try:
720718
print('Uploading SD Models...')
721719
upload_s3files(
722720
sd_models_s3uri,
723-
os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.yaml')
721+
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.yaml')
724722
)
725723
if db_save_safetensors:
726724
upload_s3files(
727725
sd_models_s3uri,
728-
os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.safetensors')
726+
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.safetensors')
729727
)
730728
else:
731729
upload_s3files(
732730
sd_models_s3uri,
733-
os.path.join(sd_models_path, db_model_name, f'{db_model_name}_*.ckpt')
731+
os.path.join(sd_models_dir, db_model_name, f'{db_model_name}_*.ckpt')
734732
)
735733
print('Uploading DB Models...')
736734
upload_s3folder(
@@ -747,15 +745,15 @@ def train():
747745
os.makedirs(os.path.dirname("/opt/ml/model/"), exist_ok=True)
748746
train_steps=int(db_config.revision)
749747
model_file_basename = f'{db_model_name}_{train_steps}_lora' if db_config.use_lora else f'{db_model_name}_{train_steps}'
750-
f1=os.path.join(sd_models_path, db_model_name, f'{model_file_basename}.yaml')
748+
f1=os.path.join(sd_models_dir, db_model_name, f'{model_file_basename}.yaml')
751749
if os.path.exists(f1):
752750
shutil.copy(f1,"/opt/ml/model/")
753751
if db_save_safetensors:
754-
f2=os.path.join(sd_models_path, db_model_name, f'{model_file_basename}.safetensors')
752+
f2=os.path.join(sd_models_dir, db_model_name, f'{model_file_basename}.safetensors')
755753
if os.path.exists(f2):
756754
shutil.copy(f2,"/opt/ml/model/")
757755
else:
758-
f2=os.path.join(sd_models_path, db_model_name, f'{model_file_basename}.ckpt')
756+
f2=os.path.join(sd_models_dir, db_model_name, f'{model_file_basename}.ckpt')
759757
if os.path.exists(f2):
760758
shutil.copy(f2,"/opt/ml/model/")
761759
except Exception as e:

0 commit comments

Comments
 (0)