Skip to content

Commit 14ad93d

Browse files
committed
cleanup
1 parent 8e0e392 commit 14ad93d

File tree

4 files changed

+80
-108
lines changed

4 files changed

+80
-108
lines changed

localizations/zh_CN.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,7 @@
683683
"Save Preview/Ckpt Every Epoch": "经过若干个 Epoch 保存预览/检查点",
684684
"Save Checkpoint Frequency": "保存检查点频率",
685685
"Save Preview(s) Frequency": "保存预览频率",
686-
"Batch": "批处理",
686+
"Batching": "批处理",
687687
"Batch Size": "批量大小",
688688
"Class Batch Size": "类批量大小",
689689
"Learning Rate": "学习率",

localizations/zh_TW.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@
677677
"Save Preview/Ckpt Every Epoch": "經過若干個 Epoch 保存預覽/檢查點",
678678
"Save Checkpoint Frequency": "保存檢查點頻率",
679679
"Save Preview(s) Frequency": "保存預覽頻率",
680-
"Batch": "批處理",
680+
"Batching": "批處理",
681681
"Batch Size": "批量大小",
682682
"Class Batch Size": "類批量大小",
683683
"Learning Rate": "學習率",

modules/shared.py

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
parser.add_argument('--hypernetwork-s3uri', default='', type=str, help='Hypernetwork S3Uri')
9999
parser.add_argument('--sd-models-s3uri', default='', type=str, help='SD Models S3Uri')
100100
parser.add_argument('--db-models-s3uri', default='', type=str, help='DB Models S3Uri')
101+
parser.add_argument('--lora-models-s3uri', default='', type=str, help='Lora Models S3Uri')
101102
parser.add_argument('--region-name', type=str, help='Region Name')
102103
parser.add_argument('--username', default='', type=str, help='Username')
103104
parser.add_argument('--api-endpoint', default='', type=str, help='API Endpoint')

webui.py

+77-106
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from fastapi.exceptions import RequestValidationError
1313
from fastapi.responses import JSONResponse
1414

15-
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
15+
from modules.call_queue import wrap_queued_call, queue_lock
1616
from modules.paths import script_path
1717

1818
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
@@ -43,11 +43,9 @@
4343
import json
4444
import uuid
4545
if not cmd_opts.api:
46-
from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig, sanitize_name
47-
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
48-
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import start_training_from_config
49-
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import performance_wizard, training_wizard
50-
from extensions.sd_dreambooth_extension.dreambooth.db_config import from_file
46+
from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig
47+
from extensions.sd_dreambooth_extension.scripts.dreambooth import start_training_from_config, create_model
48+
from extensions.sd_dreambooth_extension.scripts.dreambooth import performance_wizard, training_wizard
5149
from modules import paths
5250
import glob
5351

@@ -348,6 +346,7 @@ def train():
348346
hypernetwork_s3uri = cmd_opts.hypernetwork_s3uri
349347
sd_models_s3uri = cmd_opts.sd_models_s3uri
350348
db_models_s3uri = cmd_opts.db_models_s3uri
349+
lora_models_s3uri = cmd_opts.lora_models_s3uri
351350
api_endpoint = cmd_opts.api_endpoint
352351
username = cmd_opts.username
353352

@@ -564,16 +563,10 @@ def train():
564563
opts.data = default_options
565564
elif train_task == 'dreambooth':
566565
db_create_new_db_model = train_args['train_dreambooth_settings']['db_create_new_db_model']
567-
568-
db_lora_model_name = train_args['train_dreambooth_settings']['db_lora_model_name']
569-
db_lora_weight = train_args['train_dreambooth_settings']['db_lora_weight']
570-
db_lora_txt_weight = train_args['train_dreambooth_settings']['db_lora_txt_weight']
571-
db_train_imagic_only = train_args['train_dreambooth_settings']['db_train_imagic_only']
572-
db_use_subdir = train_args['train_dreambooth_settings']['db_use_subdir']
573-
db_custom_model_name = train_args['train_dreambooth_settings']['db_custom_model_name']
574566
db_train_wizard_person = train_args['train_dreambooth_settings']['db_train_wizard_person']
575567
db_train_wizard_object = train_args['train_dreambooth_settings']['db_train_wizard_object']
576568
db_performance_wizard = train_args['train_dreambooth_settings']['db_performance_wizard']
569+
db_use_txt2img = train_args['train_dreambooth_settings']['db_use_txt2img']
577570

578571
if db_create_new_db_model:
579572
db_new_model_name = train_args['train_dreambooth_settings']['db_new_model_name']
@@ -583,15 +576,20 @@ def train():
583576
db_new_model_url = train_args['train_dreambooth_settings']['db_new_model_url']
584577
db_new_model_token = train_args['train_dreambooth_settings']['db_new_model_token']
585578
db_new_model_extract_ema = train_args['train_dreambooth_settings']['db_new_model_extract_ema']
586-
db_model_name, _, db_revision, db_scheduler, db_src, db_has_ema, db_v2, db_resolution = extract_checkpoint(
587-
db_new_model_name,
588-
db_new_model_src,
589-
db_new_model_scheduler,
590-
db_create_from_hub,
591-
db_new_model_url,
592-
db_new_model_token,
593-
db_new_model_extract_ema
594-
)
579+
db_train_unfrozen = train_args['train_dreambooth_settings']['db_train_unfrozen']
580+
db_512_model = train_args['train_dreambooth_settings']['db_512_model']
581+
582+
db_model_name, db_model_path, db_revision, db_epochs, db_scheduler, db_src, _, _, _ = create_model(
583+
db_new_model_name,
584+
db_new_model_src,
585+
db_new_model_scheduler,
586+
db_create_from_hub,
587+
db_new_model_url,
588+
db_new_model_token,
589+
db_new_model_extract_ema,
590+
db_train_unfrozen,
591+
db_512_model,
592+
)
595593
dreambooth_config_id = cmd_opts.dreambooth_config_id
596594
try:
597595
with open(f'/opt/ml/input/data/config/{dreambooth_config_id}.json', 'r') as f:
@@ -605,91 +603,61 @@ def train():
605603
content = None
606604

607605
if content:
608-
config_dict = json.loads(content)
609-
print(db_model_name, db_revision, db_scheduler, db_src, db_has_ema, db_v2, db_resolution)
610-
611-
config_dict[0] = db_model_name
612-
config_dict[31] = db_revision
613-
config_dict[39] = db_scheduler
614-
config_dict[40] = db_src
615-
config_dict[14] = db_has_ema
616-
config_dict[49] = db_v2
617-
config_dict[30] = db_resolution
618-
619-
db_config = DreamboothConfig(*config_dict)
620-
621-
if db_train_wizard_person:
622-
_, \
623-
max_train_steps, \
624-
num_train_epochs, \
625-
c1_max_steps, \
626-
c1_num_class_images, \
627-
c2_max_steps, \
628-
c2_num_class_images, \
629-
c3_max_steps, \
630-
c3_num_class_images = training_wizard(db_config, True)
631-
632-
config_dict[22] = int(max_train_steps)
633-
config_dict[26] = int(num_train_epochs)
634-
config_dict[59] = c1_max_steps
635-
config_dict[61] = c1_num_class_images
636-
config_dict[77] = c2_max_steps
637-
config_dict[79] = c2_num_class_images
638-
config_dict[95] = c3_max_steps
639-
config_dict[97] = c3_num_class_images
640-
if db_train_wizard_object:
641-
_, \
642-
max_train_steps, \
643-
num_train_epochs, \
644-
c1_max_steps, \
645-
c1_num_class_images, \
646-
c2_max_steps, \
647-
c2_num_class_images, \
648-
c3_max_steps, \
649-
c3_num_class_images = training_wizard(db_config, False)
650-
651-
config_dict[22] = int(max_train_steps)
652-
config_dict[26] = int(num_train_epochs)
653-
config_dict[59] = c1_max_steps
654-
config_dict[61] = c1_num_class_images
655-
config_dict[77] = c2_max_steps
656-
config_dict[79] = c2_num_class_images
657-
config_dict[95] = c3_max_steps
658-
config_dict[97] = c3_num_class_images
606+
params_dict = json.loads(content)
607+
608+
params_dict['db_model_name'] = db_model_name
609+
params_dict['db_model_path'] = db_model_path
610+
params_dict['db_revision'] = db_revision
611+
params_dict['db_epochs'] = db_epochs
612+
params_dict['db_scheduler'] = db_scheduler
613+
params_dict['db_src'] = db_src
614+
615+
if db_train_wizard_person or db_train_wizard_object:
616+
db_num_train_epochs, \
617+
c1_num_class_images_per, \
618+
c2_num_class_images_per, \
619+
c3_num_class_images_per, \
620+
c4_num_class_images_per = training_wizard(db_config, db_train_wizard_person if db_train_wizard_person else db_train_wizard_object)
621+
622+
params_dict['db_num_train_epochs'] = db_num_train_epochs
623+
params_dict[59] = c1_num_class_images_per
624+
params_dict[61] = c2_num_class_images_per
625+
params_dict[77] = c3_num_class_images_per
626+
params_dict[79] = c4_num_class_images_per
659627
if db_performance_wizard:
660-
_, \
661628
attention, \
662629
gradient_checkpointing, \
630+
gradient_accumulation_steps, \
663631
mixed_precision, \
664-
not_cache_latents, \
632+
cache_latents, \
665633
sample_batch_size, \
666634
train_batch_size, \
667-
train_text_encoder, \
635+
stop_text_encoder, \
668636
use_8bit_adam, \
669-
use_cpu, \
670-
use_ema = performance_wizard()
671-
672-
config_dict[5] = attention
673-
config_dict[12] = gradient_checkpointing
674-
config_dict[23] = mixed_precision
675-
config_dict[25] = not_cache_latents
676-
config_dict[32] = sample_batch_size
677-
config_dict[42] = train_batch_size
678-
config_dict[43] = train_text_encoder
679-
config_dict[44] = use_8bit_adam
680-
config_dict[46] = use_cpu
681-
config_dict[47] = use_ema
637+
use_lora, \
638+
use_ema, \
639+
save_samples_every, \
640+
save_weights_every = performance_wizard()
641+
642+
params_dict['attention'] = attention
643+
params_dict['gradient_checkpointing'] = gradient_checkpointing
644+
params_dict['gradient_accumulation_steps'] = gradient_accumulation_steps
645+
params_dict['mixed_precision'] = mixed_precision
646+
params_dict['cache_latents'] = cache_latents
647+
params_dict['sample_batch_size'] = sample_batch_size
648+
params_dict['train_batch_size'] = train_batch_size
649+
params_dict['stop_text_encoder'] = stop_text_encoder
650+
params_dict['use_8bit_adam'] = use_8bit_adam
651+
params_dict['use_lora'] = use_lora
652+
params_dict['use_ema'] = use_ema
653+
params_dict['save_samples_every'] = save_samples_every
654+
params_dict['params_dict'] = save_weights_every
655+
656+
db_config = DreamboothConfig(db_model_name)
657+
db_config.load_params(params_dict)
682658
else:
683659
db_model_name = train_args['train_dreambooth_settings']['db_model_name']
684-
db_model_name = sanitize_name(db_model_name)
685-
db_models_path = cmd_opts.dreambooth_models_path
686-
if db_models_path == "" or db_models_path is None:
687-
db_models_path = os.path.join(shared.models_path, "dreambooth")
688-
working_dir = os.path.join(db_models_path, db_model_name, "working")
689-
config_dict = from_file(os.path.join(db_models_path, db_model_name))
690-
config_dict["pretrained_model_name_or_path"] = working_dir
691-
692-
db_config = DreamboothConfig(*config_dict)
660+
db_config = DreamboothConfig(db_model_name)
693661

694662
ckpt_dir = cmd_opts.ckpt_dir
695663
sd_models_path = os.path.join(shared.models_path, "Stable-diffusion")
@@ -699,12 +667,7 @@ def train():
699667
print(vars(db_config))
700668
start_training_from_config(
701669
db_config,
702-
db_lora_model_name if db_lora_model_name != '' else None,
703-
db_lora_weight,
704-
db_lora_txt_weight,
705-
db_train_imagic_only,
706-
db_use_subdir,
707-
db_custom_model_name
670+
db_use_txt2img,
708671
)
709672

710673
try:
@@ -715,12 +678,14 @@ def train():
715678
db_model_dir = os.path.dirname(cmd_dreambooth_models_path) if cmd_dreambooth_models_path else paths.models_path
716679
db_model_dir = os.path.join(db_model_dir, "dreambooth")
717680

681+
lora_models_path = os.path.join(shared.models_path, "Lora")
682+
718683
try:
719-
print('Uploading SD Models...')
684+
print('Uploading SD Models...')
720685
upload_s3files(
721686
sd_models_s3uri,
722687
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.yaml')
723-
)
688+
)
724689
upload_s3files(
725690
sd_models_s3uri,
726691
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.ckpt')
@@ -730,6 +695,12 @@ def train():
730695
f'{db_models_s3uri}{db_model_name}',
731696
os.path.join(db_model_dir, db_model_name)
732697
)
698+
if db_config.use_lora:
699+
print('Uploading Lora Models...')
700+
upload_s3files(
701+
lora_models_s3uri,
702+
os.path.join(lora_models_path, f'{lora_models_path}/{db_model_name}_*.pt')
703+
)
733704
except Exception as e:
734705
traceback.print_exc()
735706
print(e)

0 commit comments

Comments
 (0)