12
12
from fastapi .exceptions import RequestValidationError
13
13
from fastapi .responses import JSONResponse
14
14
15
- from modules .call_queue import wrap_queued_call , queue_lock , wrap_gradio_gpu_call
15
+ from modules .call_queue import wrap_queued_call , queue_lock
16
16
from modules .paths import script_path
17
17
18
18
from modules import shared , devices , sd_samplers , upscaler , extensions , localization , ui_tempdir
43
43
import json
44
44
import uuid
45
45
if not cmd_opts .api :
46
- from extensions .sd_dreambooth_extension .dreambooth .db_config import DreamboothConfig , sanitize_name
47
- from extensions .sd_dreambooth_extension .dreambooth .sd_to_diff import extract_checkpoint
48
- from extensions .sd_dreambooth_extension .dreambooth .dreambooth import start_training_from_config
49
- from extensions .sd_dreambooth_extension .dreambooth .dreambooth import performance_wizard , training_wizard
50
- from extensions .sd_dreambooth_extension .dreambooth .db_config import from_file
46
+ from extensions .sd_dreambooth_extension .dreambooth .db_config import DreamboothConfig
47
+ from extensions .sd_dreambooth_extension .scripts .dreambooth import start_training_from_config , create_model
48
+ from extensions .sd_dreambooth_extension .scripts .dreambooth import performance_wizard , training_wizard
51
49
from modules import paths
52
50
import glob
53
51
@@ -348,6 +346,7 @@ def train():
348
346
hypernetwork_s3uri = cmd_opts .hypernetwork_s3uri
349
347
sd_models_s3uri = cmd_opts .sd_models_s3uri
350
348
db_models_s3uri = cmd_opts .db_models_s3uri
349
+ lora_models_s3uri = cmd_opts .lora_models_s3uri
351
350
api_endpoint = cmd_opts .api_endpoint
352
351
username = cmd_opts .username
353
352
@@ -564,16 +563,10 @@ def train():
564
563
opts .data = default_options
565
564
elif train_task == 'dreambooth' :
566
565
db_create_new_db_model = train_args ['train_dreambooth_settings' ]['db_create_new_db_model' ]
567
-
568
- db_lora_model_name = train_args ['train_dreambooth_settings' ]['db_lora_model_name' ]
569
- db_lora_weight = train_args ['train_dreambooth_settings' ]['db_lora_weight' ]
570
- db_lora_txt_weight = train_args ['train_dreambooth_settings' ]['db_lora_txt_weight' ]
571
- db_train_imagic_only = train_args ['train_dreambooth_settings' ]['db_train_imagic_only' ]
572
- db_use_subdir = train_args ['train_dreambooth_settings' ]['db_use_subdir' ]
573
- db_custom_model_name = train_args ['train_dreambooth_settings' ]['db_custom_model_name' ]
574
566
db_train_wizard_person = train_args ['train_dreambooth_settings' ]['db_train_wizard_person' ]
575
567
db_train_wizard_object = train_args ['train_dreambooth_settings' ]['db_train_wizard_object' ]
576
568
db_performance_wizard = train_args ['train_dreambooth_settings' ]['db_performance_wizard' ]
569
+ db_use_txt2img = train_args ['train_dreambooth_settings' ]['db_use_txt2img' ]
577
570
578
571
if db_create_new_db_model :
579
572
db_new_model_name = train_args ['train_dreambooth_settings' ]['db_new_model_name' ]
@@ -583,15 +576,20 @@ def train():
583
576
db_new_model_url = train_args ['train_dreambooth_settings' ]['db_new_model_url' ]
584
577
db_new_model_token = train_args ['train_dreambooth_settings' ]['db_new_model_token' ]
585
578
db_new_model_extract_ema = train_args ['train_dreambooth_settings' ]['db_new_model_extract_ema' ]
586
- db_model_name , _ , db_revision , db_scheduler , db_src , db_has_ema , db_v2 , db_resolution = extract_checkpoint (
587
- db_new_model_name ,
588
- db_new_model_src ,
589
- db_new_model_scheduler ,
590
- db_create_from_hub ,
591
- db_new_model_url ,
592
- db_new_model_token ,
593
- db_new_model_extract_ema
594
- )
579
+ db_train_unfrozen = train_args ['train_dreambooth_settings' ]['db_train_unfrozen' ]
580
+ db_512_model = train_args ['train_dreambooth_settings' ]['db_512_model' ]
581
+
582
+ db_model_name , db_model_path , db_revision , db_epochs , db_scheduler , db_src , _ , _ , _ = create_model (
583
+ db_new_model_name ,
584
+ db_new_model_src ,
585
+ db_new_model_scheduler ,
586
+ db_create_from_hub ,
587
+ db_new_model_url ,
588
+ db_new_model_token ,
589
+ db_new_model_extract_ema ,
590
+ db_train_unfrozen ,
591
+ db_512_model ,
592
+ )
595
593
dreambooth_config_id = cmd_opts .dreambooth_config_id
596
594
try :
597
595
with open (f'/opt/ml/input/data/config/{ dreambooth_config_id } .json' , 'r' ) as f :
@@ -605,91 +603,61 @@ def train():
605
603
content = None
606
604
607
605
if content :
608
- config_dict = json .loads (content )
609
- print (db_model_name , db_revision , db_scheduler , db_src , db_has_ema , db_v2 , db_resolution )
610
-
611
- config_dict [0 ] = db_model_name
612
- config_dict [31 ] = db_revision
613
- config_dict [39 ] = db_scheduler
614
- config_dict [40 ] = db_src
615
- config_dict [14 ] = db_has_ema
616
- config_dict [49 ] = db_v2
617
- config_dict [30 ] = db_resolution
618
-
619
- db_config = DreamboothConfig (* config_dict )
620
-
621
- if db_train_wizard_person :
622
- _ , \
623
- max_train_steps , \
624
- num_train_epochs , \
625
- c1_max_steps , \
626
- c1_num_class_images , \
627
- c2_max_steps , \
628
- c2_num_class_images , \
629
- c3_max_steps , \
630
- c3_num_class_images = training_wizard (db_config , True )
631
-
632
- config_dict [22 ] = int (max_train_steps )
633
- config_dict [26 ] = int (num_train_epochs )
634
- config_dict [59 ] = c1_max_steps
635
- config_dict [61 ] = c1_num_class_images
636
- config_dict [77 ] = c2_max_steps
637
- config_dict [79 ] = c2_num_class_images
638
- config_dict [95 ] = c3_max_steps
639
- config_dict [97 ] = c3_num_class_images
640
- if db_train_wizard_object :
641
- _ , \
642
- max_train_steps , \
643
- num_train_epochs , \
644
- c1_max_steps , \
645
- c1_num_class_images , \
646
- c2_max_steps , \
647
- c2_num_class_images , \
648
- c3_max_steps , \
649
- c3_num_class_images = training_wizard (db_config , False )
650
-
651
- config_dict [22 ] = int (max_train_steps )
652
- config_dict [26 ] = int (num_train_epochs )
653
- config_dict [59 ] = c1_max_steps
654
- config_dict [61 ] = c1_num_class_images
655
- config_dict [77 ] = c2_max_steps
656
- config_dict [79 ] = c2_num_class_images
657
- config_dict [95 ] = c3_max_steps
658
- config_dict [97 ] = c3_num_class_images
606
+ params_dict = json .loads (content )
607
+
608
+ params_dict ['db_model_name' ] = db_model_name
609
+ params_dict ['db_model_path' ] = db_model_path
610
+ params_dict ['db_revision' ] = db_revision
611
+ params_dict ['db_epochs' ] = db_epochs
612
+ params_dict ['db_scheduler' ] = db_scheduler
613
+ params_dict ['db_src' ] = db_src
614
+
615
+ if db_train_wizard_person or db_train_wizard_object :
616
+ db_num_train_epochs , \
617
+ c1_num_class_images_per , \
618
+ c2_num_class_images_per , \
619
+ c3_num_class_images_per , \
620
+ c4_num_class_images_per = training_wizard (db_config , db_train_wizard_person if db_train_wizard_person else db_train_wizard_object )
621
+
622
+ params_dict ['db_num_train_epochs' ] = db_num_train_epochs
623
+ params_dict [59 ] = c1_num_class_images_per
624
+ params_dict [61 ] = c2_num_class_images_per
625
+ params_dict [77 ] = c3_num_class_images_per
626
+ params_dict [79 ] = c4_num_class_images_per
659
627
if db_performance_wizard :
660
- _ , \
661
628
attention , \
662
629
gradient_checkpointing , \
630
+ gradient_accumulation_steps , \
663
631
mixed_precision , \
664
- not_cache_latents , \
632
+ cache_latents , \
665
633
sample_batch_size , \
666
634
train_batch_size , \
667
- train_text_encoder , \
635
+ stop_text_encoder , \
668
636
use_8bit_adam , \
669
- use_cpu , \
670
- use_ema = performance_wizard ()
671
-
672
- config_dict [5 ] = attention
673
- config_dict [12 ] = gradient_checkpointing
674
- config_dict [23 ] = mixed_precision
675
- config_dict [25 ] = not_cache_latents
676
- config_dict [32 ] = sample_batch_size
677
- config_dict [42 ] = train_batch_size
678
- config_dict [43 ] = train_text_encoder
679
- config_dict [44 ] = use_8bit_adam
680
- config_dict [46 ] = use_cpu
681
- config_dict [47 ] = use_ema
637
+ use_lora , \
638
+ use_ema , \
639
+ save_samples_every , \
640
+ save_weights_every = performance_wizard ()
641
+
642
+ params_dict ['attention' ] = attention
643
+ params_dict ['gradient_checkpointing' ] = gradient_checkpointing
644
+ params_dict ['gradient_accumulation_steps' ] = gradient_accumulation_steps
645
+ params_dict ['mixed_precision' ] = mixed_precision
646
+ params_dict ['cache_latents' ] = cache_latents
647
+ params_dict ['sample_batch_size' ] = sample_batch_size
648
+ params_dict ['train_batch_size' ] = train_batch_size
649
+ params_dict ['stop_text_encoder' ] = stop_text_encoder
650
+ params_dict ['use_8bit_adam' ] = use_8bit_adam
651
+ params_dict ['use_lora' ] = use_lora
652
+ params_dict ['use_ema' ] = use_ema
653
+ params_dict ['save_samples_every' ] = save_samples_every
654
+ params_dict ['params_dict' ] = save_weights_every
655
+
656
+ db_config = DreamboothConfig (db_model_name )
657
+ db_config .load_params (params_dict )
682
658
else :
683
659
db_model_name = train_args ['train_dreambooth_settings' ]['db_model_name' ]
684
- db_model_name = sanitize_name (db_model_name )
685
- db_models_path = cmd_opts .dreambooth_models_path
686
- if db_models_path == "" or db_models_path is None :
687
- db_models_path = os .path .join (shared .models_path , "dreambooth" )
688
- working_dir = os .path .join (db_models_path , db_model_name , "working" )
689
- config_dict = from_file (os .path .join (db_models_path , db_model_name ))
690
- config_dict ["pretrained_model_name_or_path" ] = working_dir
691
-
692
- db_config = DreamboothConfig (* config_dict )
660
+ db_config = DreamboothConfig (db_model_name )
693
661
694
662
ckpt_dir = cmd_opts .ckpt_dir
695
663
sd_models_path = os .path .join (shared .models_path , "Stable-diffusion" )
@@ -699,12 +667,7 @@ def train():
699
667
print (vars (db_config ))
700
668
start_training_from_config (
701
669
db_config ,
702
- db_lora_model_name if db_lora_model_name != '' else None ,
703
- db_lora_weight ,
704
- db_lora_txt_weight ,
705
- db_train_imagic_only ,
706
- db_use_subdir ,
707
- db_custom_model_name
670
+ db_use_txt2img ,
708
671
)
709
672
710
673
try :
@@ -715,12 +678,14 @@ def train():
715
678
db_model_dir = os .path .dirname (cmd_dreambooth_models_path ) if cmd_dreambooth_models_path else paths .models_path
716
679
db_model_dir = os .path .join (db_model_dir , "dreambooth" )
717
680
681
+ lora_models_path = os .path .join (shared .models_path , "Lora" )
682
+
718
683
try :
719
- print ('Uploading SD Models...' )
684
+ print ('Uploading SD Models...' )
720
685
upload_s3files (
721
686
sd_models_s3uri ,
722
687
os .path .join (sd_models_path , f'{ sd_models_path } /{ db_model_name } _*.yaml' )
723
- )
688
+ )
724
689
upload_s3files (
725
690
sd_models_s3uri ,
726
691
os .path .join (sd_models_path , f'{ sd_models_path } /{ db_model_name } _*.ckpt' )
@@ -730,6 +695,12 @@ def train():
730
695
f'{ db_models_s3uri } { db_model_name } ' ,
731
696
os .path .join (db_model_dir , db_model_name )
732
697
)
698
+ if db_config .use_lora :
699
+ print ('Uploading Lora Models...' )
700
+ upload_s3files (
701
+ lora_models_s3uri ,
702
+ os .path .join (lora_models_path , f'{ lora_models_path } /{ db_model_name } _*.pt' )
703
+ )
733
704
except Exception as e :
734
705
traceback .print_exc ()
735
706
print (e )
0 commit comments