1
1
import os
2
2
import shutil
3
- import threading
4
3
import time
5
4
import importlib
6
5
import signal
16
15
from modules .call_queue import wrap_queued_call , queue_lock
17
16
from modules .paths import script_path
18
17
19
- from modules import shared , devices , sd_samplers , upscaler , extensions , localization , ui_tempdir
18
+ from modules import shared , sd_samplers , upscaler , extensions , localization , ui_tempdir
20
19
import modules .codeformer_model as codeformer
21
20
import modules .extras
22
21
import modules .face_restoration
@@ -220,24 +219,23 @@ def webui():
220
219
if launch_api :
221
220
create_api (app )
222
221
223
- ckpt_dir = cmd_opts .ckpt_dir
224
- sd_models_path = os .path .join (shared .models_path , "Stable-diffusion" )
225
- if ckpt_dir is not None :
226
- sd_models_path = ckpt_dir
222
+ cmd_sd_models_path = cmd_opts .ckpt_dir
223
+ sd_models_dir = os .path .join (shared .models_path , "Stable-diffusion" )
224
+ if cmd_sd_models_path is not None :
225
+ sd_models_dir = cmd_sd_models_path
227
226
228
- controlnet_dir = cmd_opts .controlnet_dir
229
- cn_models_path = os .path .join (shared .models_path , "ControlNet" )
230
- os .makedirs (controlnet_dir , exist_ok = True )
231
- if controlnet_dir is not None :
232
- cn_models_path = controlnet_dir
227
+ cmd_controlnet_models_path = cmd_opts .controlnet_dir
228
+ cn_models_dir = os .path .join (shared .models_path , "ControlNet" )
229
+ if cmd_controlnet_models_path is not None :
230
+ cn_models_dir = cmd_controlnet_models_path
233
231
234
232
if 'endpoint_name' in os .environ :
235
233
items = []
236
234
api_endpoint = os .environ ['api_endpoint' ]
237
235
endpoint_name = os .environ ['endpoint_name' ]
238
- for file in os .listdir (sd_models_path ):
239
- if os .path .isfile (os .path .join (sd_models_path , file )) and (file .endswith ('.ckpt' ) or file .endswith ('.safesentors' )):
240
- hash = modules .sd_models .model_hash (os .path .join (sd_models_path , file ))
236
+ for file in os .listdir (sd_models_dir ):
237
+ if os .path .isfile (os .path .join (sd_models_dir , file )) and (file .endswith ('.ckpt' ) or file .endswith ('.safesentors' )):
238
+ hash = modules .sd_models .model_hash (os .path .join (sd_models_dir , file ))
241
239
item = {}
242
240
item ['model_name' ] = file
243
241
item ['config' ] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml'
@@ -263,10 +261,10 @@ def webui():
263
261
params = {
264
262
'module' : 'ControlNet'
265
263
}
266
- for file in os .listdir (cn_models_path ):
267
- if os .path .isfile (os .path .join (cn_models_path , file )) and \
264
+ for file in os .listdir (cn_models_dir ):
265
+ if os .path .isfile (os .path .join (cn_models_dir , file )) and \
268
266
(file .endswith ('pt' ) or file .endswith ('.pth' ) or file .endswith ('.ckpt' ) or file .endswith ('.safetensors' )):
269
- hash = modules .sd_models .model_hash (os .path .join (cn_models_path , file ))
267
+ hash = modules .sd_models .model_hash (os .path .join (cn_models_dir , file ))
270
268
item = {}
271
269
item ['model_name' ] = file
272
270
item ['title' ] = '{0} [{1}]' .format (os .path .splitext (file )[0 ], hash )
@@ -565,10 +563,10 @@ def train():
565
563
opts .data = default_options
566
564
elif train_task == 'dreambooth' :
567
565
db_create_new_db_model = train_args ['train_dreambooth_settings' ]['db_create_new_db_model' ]
566
+ db_use_txt2img = train_args ['train_dreambooth_settings' ]['db_use_txt2img' ]
568
567
db_train_wizard_person = train_args ['train_dreambooth_settings' ]['db_train_wizard_person' ]
569
568
db_train_wizard_object = train_args ['train_dreambooth_settings' ]['db_train_wizard_object' ]
570
569
db_performance_wizard = train_args ['train_dreambooth_settings' ]['db_performance_wizard' ]
571
- db_use_txt2img = train_args ['train_dreambooth_settings' ]['db_use_txt2img' ]
572
570
573
571
if db_create_new_db_model :
574
572
db_new_model_name = train_args ['train_dreambooth_settings' ]['db_new_model_name' ]
@@ -623,13 +621,13 @@ def train():
623
621
c1_num_class_images_per , \
624
622
c2_num_class_images_per , \
625
623
c3_num_class_images_per , \
626
- c4_num_class_images_per = training_wizard (db_config , db_train_wizard_person if db_train_wizard_person else db_train_wizard_object )
624
+ c4_num_class_images_per = training_wizard (db_train_wizard_person if db_train_wizard_person else db_train_wizard_object )
627
625
628
626
params_dict ['db_num_train_epochs' ] = db_num_train_epochs
629
- params_dict [59 ] = c1_num_class_images_per
630
- params_dict [61 ] = c2_num_class_images_per
631
- params_dict [77 ] = c3_num_class_images_per
632
- params_dict [79 ] = c4_num_class_images_per
627
+ params_dict ['c1_num_class_images_per' ] = c1_num_class_images_per
628
+ params_dict ['c1_num_class_images_per' ] = c2_num_class_images_per
629
+ params_dict ['c1_num_class_images_per' ] = c3_num_class_images_per
630
+ params_dict ['c1_num_class_images_per' ] = c4_num_class_images_per
633
631
if db_performance_wizard :
634
632
attention , \
635
633
gradient_checkpointing , \
@@ -684,17 +682,17 @@ def train():
684
682
db_model_name = train_args ['train_dreambooth_settings' ]['db_model_name' ]
685
683
db_config = DreamboothConfig (db_model_name )
686
684
687
- ckpt_dir = cmd_opts .ckpt_dir
688
- sd_models_path = os .path .join (shared .models_path , "Stable-diffusion" )
689
- if ckpt_dir is not None :
690
- sd_models_path = ckpt_dir
691
-
692
685
print (vars (db_config ))
693
686
start_training_from_config (
694
687
db_config ,
695
688
db_use_txt2img ,
696
689
)
697
690
691
+ cmd_sd_models_path = cmd_opts .ckpt_dir
692
+ sd_models_dir = os .path .join (shared .models_path , "Stable-diffusion" )
693
+ if cmd_sd_models_path is not None :
694
+ sd_models_dir = cmd_sd_models_path
695
+
698
696
try :
699
697
cmd_dreambooth_models_path = cmd_opts .dreambooth_models_path
700
698
except :
@@ -711,26 +709,26 @@ def train():
711
709
lora_model_dir = os .path .dirname (cmd_lora_models_path ) if cmd_lora_models_path else paths .models_path
712
710
lora_model_dir = os .path .join (lora_model_dir , "lora" )
713
711
714
- print ('---models path---' , sd_models_path , lora_model_dir )
715
- os .system (f'ls -l { sd_models_path } ' )
716
- os .system ('ls -l {0}' .format (os .path .join (sd_models_path , db_model_name )))
712
+ print ('---models path---' , sd_models_dir , lora_model_dir )
713
+ os .system (f'ls -l { sd_models_dir } ' )
714
+ os .system ('ls -l {0}' .format (os .path .join (sd_models_dir , db_model_name )))
717
715
os .system (f'ls -l { lora_model_dir } ' )
718
716
719
717
try :
720
718
print ('Uploading SD Models...' )
721
719
upload_s3files (
722
720
sd_models_s3uri ,
723
- os .path .join (sd_models_path , db_model_name , f'{ db_model_name } _*.yaml' )
721
+ os .path .join (sd_models_dir , db_model_name , f'{ db_model_name } _*.yaml' )
724
722
)
725
723
if db_save_safetensors :
726
724
upload_s3files (
727
725
sd_models_s3uri ,
728
- os .path .join (sd_models_path , db_model_name , f'{ db_model_name } _*.safetensors' )
726
+ os .path .join (sd_models_dir , db_model_name , f'{ db_model_name } _*.safetensors' )
729
727
)
730
728
else :
731
729
upload_s3files (
732
730
sd_models_s3uri ,
733
- os .path .join (sd_models_path , db_model_name , f'{ db_model_name } _*.ckpt' )
731
+ os .path .join (sd_models_dir , db_model_name , f'{ db_model_name } _*.ckpt' )
734
732
)
735
733
print ('Uploading DB Models...' )
736
734
upload_s3folder (
@@ -747,15 +745,15 @@ def train():
747
745
os .makedirs (os .path .dirname ("/opt/ml/model/" ), exist_ok = True )
748
746
train_steps = int (db_config .revision )
749
747
model_file_basename = f'{ db_model_name } _{ train_steps } _lora' if db_config .use_lora else f'{ db_model_name } _{ train_steps } '
750
- f1 = os .path .join (sd_models_path , db_model_name , f'{ model_file_basename } .yaml' )
748
+ f1 = os .path .join (sd_models_dir , db_model_name , f'{ model_file_basename } .yaml' )
751
749
if os .path .exists (f1 ):
752
750
shutil .copy (f1 ,"/opt/ml/model/" )
753
751
if db_save_safetensors :
754
- f2 = os .path .join (sd_models_path , db_model_name , f'{ model_file_basename } .safetensors' )
752
+ f2 = os .path .join (sd_models_dir , db_model_name , f'{ model_file_basename } .safetensors' )
755
753
if os .path .exists (f2 ):
756
754
shutil .copy (f2 ,"/opt/ml/model/" )
757
755
else :
758
- f2 = os .path .join (sd_models_path , db_model_name , f'{ model_file_basename } .ckpt' )
756
+ f2 = os .path .join (sd_models_dir , db_model_name , f'{ model_file_basename } .ckpt' )
759
757
if os .path .exists (f2 ):
760
758
shutil .copy (f2 ,"/opt/ml/model/" )
761
759
except Exception as e :
0 commit comments