42
42
import io
43
43
import json
44
44
import uuid
45
+ from extensions .sd_dreambooth_extension .dreambooth .db_config import DreamboothConfig , sanitize_name
46
+ from extensions .sd_dreambooth_extension .dreambooth .sd_to_diff import extract_checkpoint
47
+ from extensions .sd_dreambooth_extension .dreambooth .dreambooth import start_training_from_config
48
+ from extensions .sd_dreambooth_extension .dreambooth .dreambooth import performance_wizard , training_wizard
49
+ from extensions .sd_dreambooth_extension .dreambooth .db_config import from_file
50
+ from modules import paths
51
+ import glob
45
52
46
53
if cmd_opts .server_name :
47
54
server_name = cmd_opts .server_name
@@ -134,6 +141,31 @@ def api_only():
134
141
app .add_middleware (GZipMiddleware , minimum_size = 1000 )
135
142
api = create_api (app )
136
143
144
+ ckpt_dir = cmd_opts .ckpt_dir
145
+ sd_models_path = os .path .join (shared .models_path , "Stable-diffusion" )
146
+ if ckpt_dir is not None :
147
+ sd_models_path = ckpt_dir
148
+
149
+ if 'endpoint_name' in os .environ :
150
+ items = []
151
+ api_endpoint = os .environ ['api_endpoint' ]
152
+ endpoint_name = os .environ ['endpoint_name' ]
153
+ for file in os .listdir (sd_models_path ):
154
+ if os .path .isfile (os .path .join (sd_models_path , file )) and file .endswith ('.ckpt' ):
155
+ hash = modules .sd_models .model_hash (os .path .join (sd_models_path , file ))
156
+ item = {}
157
+ item ['model_name' ] = file
158
+ item ['config' ] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml'
159
+ item ['filename' ] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}' .format (file )
160
+ item ['hash' ] = hash
161
+ item ['title' ] = '{0} [{1}]' .format (file , hash )
162
+ item ['endpoint_name' ] = endpoint_name
163
+ items .append (item )
164
+ inputs = {
165
+ 'items' : items
166
+ }
167
+ response = requests .post (url = f'{ api_endpoint } /sd/models' , json = inputs )
168
+
137
169
modules .script_callbacks .app_started_callback (None , app )
138
170
139
171
@app .exception_handler (RequestValidationError )
@@ -221,6 +253,41 @@ def upload_s3file(s3uri, file_path, file_name):
221
253
return False
222
254
return True
223
255
256
+ def upload_s3files (s3uri , file_path_with_pattern ):
257
+ s3_client = boto3 .client ('s3' , region_name = cmd_opts .region_name )
258
+
259
+ pos = s3uri .find ('/' , 5 )
260
+ bucket = s3uri [5 : pos ]
261
+ key = s3uri [pos + 1 : ]
262
+
263
+ for file_name in glob .glob (file_path_with_pattern ):
264
+ binary = io .BytesIO (open (file_name , 'rb' ).read ())
265
+ key = key + file_name
266
+ try :
267
+ s3_client .upload_fileobj (binary , bucket , key )
268
+ except ClientError as e :
269
+ print (e )
270
+ return False
271
+ return True
272
+
273
+ def upload_s3folder (s3uri , file_path ):
274
+ pos = s3uri .find ('/' , 5 )
275
+ bucket = s3uri [5 : pos ]
276
+
277
+ s3_resource = boto3 .resource ('s3' )
278
+ s3_bucket = s3_resource .Bucket (bucket )
279
+
280
+ try :
281
+ for path , _ , files in os .walk (file_path ):
282
+ for file in files :
283
+ dest_path = path .replace (file_path ,"" )
284
+ __s3file = os .path .normpath (s3uri + dest_path + '/' + file )
285
+ __local_file = os .path .join (path , file )
286
+ print (__local_file , __s3file )
287
+ s3_bucket .upload_file (__local_file , __s3file )
288
+ except Exception as e :
289
+ print (e )
290
+
224
291
def train ():
225
292
initialize ()
226
293
@@ -229,6 +296,8 @@ def train():
229
296
230
297
embeddings_s3uri = cmd_opts .embeddings_s3uri
231
298
hypernetwork_s3uri = cmd_opts .hypernetwork_s3uri
299
+ sd_models_s3uri = cmd_opts .sd_models_s3uri
300
+ db_models_s3uri = cmd_opts .db_models_s3uri
232
301
api_endpoint = cmd_opts .api_endpoint
233
302
username = cmd_opts .username
234
303
@@ -441,6 +510,172 @@ def train():
441
510
traceback .print_exc ()
442
511
print (e )
443
512
opts .data = default_options
513
+ elif train_task == 'dreambooth' :
514
+ db_create_new_db_model = train_args ['train_dreambooth_settings' ]['db_create_new_db_model' ]
515
+
516
+ db_lora_model_name = train_args ['train_dreambooth_settings' ]['db_lora_model_name' ]
517
+ db_lora_weight = train_args ['train_dreambooth_settings' ]['db_lora_weight' ]
518
+ db_lora_txt_weight = train_args ['train_dreambooth_settings' ]['db_lora_txt_weight' ]
519
+ db_train_imagic_only = train_args ['train_dreambooth_settings' ]['db_train_imagic_only' ]
520
+ db_use_subdir = train_args ['train_dreambooth_settings' ]['db_use_subdir' ]
521
+ db_custom_model_name = train_args ['train_dreambooth_settings' ]['db_custom_model_name' ]
522
+ db_train_wizard_person = train_args ['train_dreambooth_settings' ]['db_train_wizard_person' ]
523
+ db_train_wizard_object = train_args ['train_dreambooth_settings' ]['db_train_wizard_object' ]
524
+ db_performance_wizard = train_args ['train_dreambooth_settings' ]['db_performance_wizard' ]
525
+
526
+ if db_create_new_db_model :
527
+ db_new_model_name = train_args ['train_dreambooth_settings' ]['db_new_model_name' ]
528
+ db_new_model_src = train_args ['train_dreambooth_settings' ]['db_new_model_src' ]
529
+ db_new_model_scheduler = train_args ['train_dreambooth_settings' ]['db_new_model_scheduler' ]
530
+ db_create_from_hub = train_args ['train_dreambooth_settings' ]['db_create_from_hub' ]
531
+ db_new_model_url = train_args ['train_dreambooth_settings' ]['db_new_model_url' ]
532
+ db_new_model_token = train_args ['train_dreambooth_settings' ]['db_new_model_token' ]
533
+ db_new_model_extract_ema = train_args ['train_dreambooth_settings' ]['db_new_model_extract_ema' ]
534
+ db_model_name , _ , db_revision , db_scheduler , db_src , db_has_ema , db_v2 , db_resolution = extract_checkpoint (
535
+ db_new_model_name ,
536
+ db_new_model_src ,
537
+ db_new_model_scheduler ,
538
+ db_create_from_hub ,
539
+ db_new_model_url ,
540
+ db_new_model_token ,
541
+ db_new_model_extract_ema
542
+ )
543
+ dreambooth_config_id = cmd_opts .dreambooth_config_id
544
+ try :
545
+ with open (f'/opt/ml/input/data/config/{ dreambooth_config_id } .json' , 'r' ) as f :
546
+ content = f .read ()
547
+ except Exception :
548
+ params = {'module' : 'dreambooth_config' , 'dreambooth_config_id' : dreambooth_config_id }
549
+ response = requests .get (url = f'{ api_endpoint } /sd/models' , params = params )
550
+ if response .status_code == 200 :
551
+ content = response .text
552
+ else :
553
+ content = None
554
+
555
+ if content :
556
+ config_dict = json .loads (content )
557
+ print (db_model_name , db_revision , db_scheduler , db_src , db_has_ema , db_v2 , db_resolution )
558
+
559
+ config_dict [0 ] = db_model_name
560
+ config_dict [31 ] = db_revision
561
+ config_dict [39 ] = db_scheduler
562
+ config_dict [40 ] = db_src
563
+ config_dict [14 ] = db_has_ema
564
+ config_dict [49 ] = db_v2
565
+ config_dict [30 ] = db_resolution
566
+
567
+ db_config = DreamboothConfig (* config_dict )
568
+
569
+ if db_train_wizard_person :
570
+ _ , \
571
+ max_train_steps , \
572
+ num_train_epochs , \
573
+ c1_max_steps , \
574
+ c1_num_class_images , \
575
+ c2_max_steps , \
576
+ c2_num_class_images , \
577
+ c3_max_steps , \
578
+ c3_num_class_images = training_wizard (db_config , True )
579
+
580
+ config_dict [22 ] = int (max_train_steps )
581
+ config_dict [26 ] = int (num_train_epochs )
582
+ config_dict [59 ] = c1_max_steps
583
+ config_dict [61 ] = c1_num_class_images
584
+ config_dict [77 ] = c2_max_steps
585
+ config_dict [79 ] = c2_num_class_images
586
+ config_dict [95 ] = c3_max_steps
587
+ config_dict [97 ] = c3_num_class_images
588
+ if db_train_wizard_object :
589
+ _ , \
590
+ max_train_steps , \
591
+ num_train_epochs , \
592
+ c1_max_steps , \
593
+ c1_num_class_images , \
594
+ c2_max_steps , \
595
+ c2_num_class_images , \
596
+ c3_max_steps , \
597
+ c3_num_class_images = training_wizard (db_config , False )
598
+
599
+ config_dict [22 ] = int (max_train_steps )
600
+ config_dict [26 ] = int (num_train_epochs )
601
+ config_dict [59 ] = c1_max_steps
602
+ config_dict [61 ] = c1_num_class_images
603
+ config_dict [77 ] = c2_max_steps
604
+ config_dict [79 ] = c2_num_class_images
605
+ config_dict [95 ] = c3_max_steps
606
+ config_dict [97 ] = c3_num_class_images
607
+ if db_performance_wizard :
608
+ _ , \
609
+ attention , \
610
+ gradient_checkpointing , \
611
+ mixed_precision , \
612
+ not_cache_latents , \
613
+ sample_batch_size , \
614
+ train_batch_size , \
615
+ train_text_encoder , \
616
+ use_8bit_adam , \
617
+ use_cpu , \
618
+ use_ema = performance_wizard ()
619
+
620
+ config_dict [5 ] = attention
621
+ config_dict [12 ] = gradient_checkpointing
622
+ config_dict [23 ] = mixed_precision
623
+ config_dict [25 ] = not_cache_latents
624
+ config_dict [32 ] = sample_batch_size
625
+ config_dict [42 ] = train_batch_size
626
+ config_dict [43 ] = train_text_encoder
627
+ config_dict [44 ] = use_8bit_adam
628
+ config_dict [46 ] = use_cpu
629
+ config_dict [47 ] = use_ema
630
+ else :
631
+ db_model_name = train_args ['train_dreambooth_settings' ]['db_model_name' ]
632
+ db_model_name = sanitize_name (db_model_name )
633
+ db_models_path = cmd_opts .dreambooth_models_path
634
+ if db_models_path == "" or db_models_path is None :
635
+ db_models_path = os .path .join (shared .models_path , "dreambooth" )
636
+ working_dir = os .path .join (db_models_path , db_model_name , "working" )
637
+ config_dict = from_file (os .path .join (db_models_path , db_model_name ))
638
+ config_dict ["pretrained_model_name_or_path" ] = working_dir
639
+
640
+ db_config = DreamboothConfig (* config_dict )
641
+
642
+ ckpt_dir = cmd_opts .ckpt_dir
643
+ sd_models_path = os .path .join (shared .models_path , "Stable-diffusion" )
644
+ if ckpt_dir is not None :
645
+ sd_models_path = ckpt_dir
646
+
647
+ print (vars (db_config ))
648
+ start_training_from_config (
649
+ db_config ,
650
+ db_lora_model_name if db_lora_model_name != '' else None ,
651
+ db_lora_weight ,
652
+ db_lora_txt_weight ,
653
+ db_train_imagic_only ,
654
+ db_use_subdir ,
655
+ db_custom_model_name
656
+ )
657
+
658
+ try :
659
+ cmd_dreambooth_models_path = cmd_opts .dreambooth_models_path
660
+ except :
661
+ cmd_dreambooth_models_path = None
662
+
663
+ db_model_dir = os .path .dirname (cmd_dreambooth_models_path ) if cmd_dreambooth_models_path else paths .models_path
664
+ db_model_dir = os .path .join (db_model_dir , "dreambooth" )
665
+
666
+ try :
667
+ upload_s3files (
668
+ sd_models_s3uri ,
669
+ os .path .join (sd_models_path , f'{ sd_models_path } /{ db_model_name } _*.pt' )
670
+ )
671
+ upload_s3folder (
672
+ db_models_s3uri ,
673
+ db_model_dir
674
+ )
675
+ except Exception as e :
676
+ traceback .print_exc ()
677
+ print (e )
678
+ opts .data = default_options
444
679
else :
445
680
print ('Incorrect training task' )
446
681
exit (- 1 )
0 commit comments