Skip to content

Commit 99112d2

Browse files
committed
add dreambooth support
1 parent 665db1c commit 99112d2

File tree

5 files changed

+247
-39
lines changed

5 files changed

+247
-39
lines changed

modules/sd_models.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ def modeltitle(path, shorthash):
9999
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
100100

101101
if shared.cmd_opts.pureui:
102-
response = requests.get(url=f'{api_endpoint}/sd/models')
102+
params = {
103+
'endpoint_name': shared.opts.sagemaker_endpoint
104+
}
105+
response = requests.get(url=f'{api_endpoint}/sd/models', params=params)
103106
if response.status_code == 200:
104107
model_list = json.loads(response.text)
105108

modules/shared.py

+3
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,12 @@
9696
parser.add_argument("--train-args", type=str, help='Train args', default='')
9797
parser.add_argument('--embeddings-s3uri', default='', type=str, help='Embedding S3Uri')
9898
parser.add_argument('--hypernetwork-s3uri', default='', type=str, help='Hypernetwork S3Uri')
99+
parser.add_argument('--sd-models-s3uri', default='', type=str, help='SD Models S3Uri')
100+
parser.add_argument('--db-models-s3uri', default='', type=str, help='DB Models S3Uri')
99101
parser.add_argument('--region-name', type=str, help='Region Name')
100102
parser.add_argument('--username', default='', type=str, help='Username')
101103
parser.add_argument('--api-endpoint', default='', type=str, help='API Endpoint')
104+
parser.add_argument('--dreambooth-config-id', default='', type=str, help='Dreambooth config ID')
102105

103106
script_loading.preload_extensions(extensions.extensions_dir, parser)
104107
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)

modules/ui.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@
4545
training_instance_types = [
4646
'ml.p2.xlarge',
4747
'ml.p2.8xlarge',
48-
'ml.p2.16xlarge',
48+
'ml.p2.16xlarge',
4949
'ml.p3.2xlarge',
5050
'ml.p3.8xlarge',
5151
'ml.p3.16xlarge',
5252
'ml.g4dn.xlarge',
5353
'ml.g4dn.2xlarge',
54-
'ml.g4dn.4xlarge',
55-
'ml.g4dn.8xlarge',
56-
'ml.g4dn.12xlarge',
57-
'ml.g4dn.16xlarge'
54+
'ml.g4dn.4xlarge',
55+
'ml.g4dn.8xlarge',
56+
'ml.g4dn.12xlarge',
57+
'ml.g4dn.16xlarge'
5858
]
5959

6060
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI

requirements.txt

-33
This file was deleted.

webui.py

+235
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@
4242
import io
4343
import json
4444
import uuid
45+
from extensions.sd_dreambooth_extension.dreambooth.db_config import DreamboothConfig, sanitize_name
46+
from extensions.sd_dreambooth_extension.dreambooth.sd_to_diff import extract_checkpoint
47+
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import start_training_from_config
48+
from extensions.sd_dreambooth_extension.dreambooth.dreambooth import performance_wizard, training_wizard
49+
from extensions.sd_dreambooth_extension.dreambooth.db_config import from_file
50+
from modules import paths
51+
import glob
4552

4653
if cmd_opts.server_name:
4754
server_name = cmd_opts.server_name
@@ -134,6 +141,31 @@ def api_only():
134141
app.add_middleware(GZipMiddleware, minimum_size=1000)
135142
api = create_api(app)
136143

144+
ckpt_dir = cmd_opts.ckpt_dir
145+
sd_models_path = os.path.join(shared.models_path, "Stable-diffusion")
146+
if ckpt_dir is not None:
147+
sd_models_path = ckpt_dir
148+
149+
if 'endpoint_name' in os.environ:
150+
items = []
151+
api_endpoint = os.environ['api_endpoint']
152+
endpoint_name = os.environ['endpoint_name']
153+
for file in os.listdir(sd_models_path):
154+
if os.path.isfile(os.path.join(sd_models_path, file)) and file.endswith('.ckpt'):
155+
hash = modules.sd_models.model_hash(os.path.join(sd_models_path, file))
156+
item = {}
157+
item['model_name'] = file
158+
item['config'] = '/opt/ml/code/stable-diffusion-webui/repositories/stable-diffusion/configs/stable-diffusion/v1-inference.yaml'
159+
item['filename'] = '/opt/ml/code/stable-diffusion-webui/models/Stable-diffusion/{0}'.format(file)
160+
item['hash'] = hash
161+
item['title'] = '{0} [{1}]'.format(file, hash)
162+
item['endpoint_name'] = endpoint_name
163+
items.append(item)
164+
inputs = {
165+
'items': items
166+
}
167+
response = requests.post(url=f'{api_endpoint}/sd/models', json=inputs)
168+
137169
modules.script_callbacks.app_started_callback(None, app)
138170

139171
@app.exception_handler(RequestValidationError)
@@ -221,6 +253,41 @@ def upload_s3file(s3uri, file_path, file_name):
221253
return False
222254
return True
223255

256+
def upload_s3files(s3uri, file_path_with_pattern):
257+
s3_client = boto3.client('s3', region_name = cmd_opts.region_name)
258+
259+
pos = s3uri.find('/', 5)
260+
bucket = s3uri[5 : pos]
261+
key = s3uri[pos + 1 : ]
262+
263+
for file_name in glob.glob(file_path_with_pattern):
264+
binary = io.BytesIO(open(file_name, 'rb').read())
265+
key = key + file_name
266+
try:
267+
s3_client.upload_fileobj(binary, bucket, key)
268+
except ClientError as e:
269+
print(e)
270+
return False
271+
return True
272+
273+
def upload_s3folder(s3uri, file_path):
274+
pos = s3uri.find('/', 5)
275+
bucket = s3uri[5 : pos]
276+
277+
s3_resource = boto3.resource('s3')
278+
s3_bucket = s3_resource.Bucket(bucket)
279+
280+
try:
281+
for path, _, files in os.walk(file_path):
282+
for file in files:
283+
dest_path = path.replace(file_path,"")
284+
__s3file = os.path.normpath(s3uri + dest_path + '/' + file)
285+
__local_file = os.path.join(path, file)
286+
print(__local_file, __s3file)
287+
s3_bucket.upload_file(__local_file, __s3file)
288+
except Exception as e:
289+
print(e)
290+
224291
def train():
225292
initialize()
226293

@@ -229,6 +296,8 @@ def train():
229296

230297
embeddings_s3uri = cmd_opts.embeddings_s3uri
231298
hypernetwork_s3uri = cmd_opts.hypernetwork_s3uri
299+
sd_models_s3uri = cmd_opts.sd_models_s3uri
300+
db_models_s3uri = cmd_opts.db_models_s3uri
232301
api_endpoint = cmd_opts.api_endpoint
233302
username = cmd_opts.username
234303

@@ -441,6 +510,172 @@ def train():
441510
traceback.print_exc()
442511
print(e)
443512
opts.data = default_options
513+
elif train_task == 'dreambooth':
514+
db_create_new_db_model = train_args['train_dreambooth_settings']['db_create_new_db_model']
515+
516+
db_lora_model_name = train_args['train_dreambooth_settings']['db_lora_model_name']
517+
db_lora_weight = train_args['train_dreambooth_settings']['db_lora_weight']
518+
db_lora_txt_weight = train_args['train_dreambooth_settings']['db_lora_txt_weight']
519+
db_train_imagic_only = train_args['train_dreambooth_settings']['db_train_imagic_only']
520+
db_use_subdir = train_args['train_dreambooth_settings']['db_use_subdir']
521+
db_custom_model_name = train_args['train_dreambooth_settings']['db_custom_model_name']
522+
db_train_wizard_person = train_args['train_dreambooth_settings']['db_train_wizard_person']
523+
db_train_wizard_object = train_args['train_dreambooth_settings']['db_train_wizard_object']
524+
db_performance_wizard = train_args['train_dreambooth_settings']['db_performance_wizard']
525+
526+
if db_create_new_db_model:
527+
db_new_model_name = train_args['train_dreambooth_settings']['db_new_model_name']
528+
db_new_model_src = train_args['train_dreambooth_settings']['db_new_model_src']
529+
db_new_model_scheduler = train_args['train_dreambooth_settings']['db_new_model_scheduler']
530+
db_create_from_hub = train_args['train_dreambooth_settings']['db_create_from_hub']
531+
db_new_model_url = train_args['train_dreambooth_settings']['db_new_model_url']
532+
db_new_model_token = train_args['train_dreambooth_settings']['db_new_model_token']
533+
db_new_model_extract_ema = train_args['train_dreambooth_settings']['db_new_model_extract_ema']
534+
db_model_name, _, db_revision, db_scheduler, db_src, db_has_ema, db_v2, db_resolution = extract_checkpoint(
535+
db_new_model_name,
536+
db_new_model_src,
537+
db_new_model_scheduler,
538+
db_create_from_hub,
539+
db_new_model_url,
540+
db_new_model_token,
541+
db_new_model_extract_ema
542+
)
543+
dreambooth_config_id = cmd_opts.dreambooth_config_id
544+
try:
545+
with open(f'/opt/ml/input/data/config/{dreambooth_config_id}.json', 'r') as f:
546+
content = f.read()
547+
except Exception:
548+
params = {'module': 'dreambooth_config', 'dreambooth_config_id': dreambooth_config_id}
549+
response = requests.get(url=f'{api_endpoint}/sd/models', params=params)
550+
if response.status_code == 200:
551+
content = response.text
552+
else:
553+
content = None
554+
555+
if content:
556+
config_dict = json.loads(content)
557+
print(db_model_name, db_revision, db_scheduler, db_src, db_has_ema, db_v2, db_resolution)
558+
559+
config_dict[0] = db_model_name
560+
config_dict[31] = db_revision
561+
config_dict[39] = db_scheduler
562+
config_dict[40] = db_src
563+
config_dict[14] = db_has_ema
564+
config_dict[49] = db_v2
565+
config_dict[30] = db_resolution
566+
567+
db_config = DreamboothConfig(*config_dict)
568+
569+
if db_train_wizard_person:
570+
_, \
571+
max_train_steps, \
572+
num_train_epochs, \
573+
c1_max_steps, \
574+
c1_num_class_images, \
575+
c2_max_steps, \
576+
c2_num_class_images, \
577+
c3_max_steps, \
578+
c3_num_class_images = training_wizard(db_config, True)
579+
580+
config_dict[22] = int(max_train_steps)
581+
config_dict[26] = int(num_train_epochs)
582+
config_dict[59] = c1_max_steps
583+
config_dict[61] = c1_num_class_images
584+
config_dict[77] = c2_max_steps
585+
config_dict[79] = c2_num_class_images
586+
config_dict[95] = c3_max_steps
587+
config_dict[97] = c3_num_class_images
588+
if db_train_wizard_object:
589+
_, \
590+
max_train_steps, \
591+
num_train_epochs, \
592+
c1_max_steps, \
593+
c1_num_class_images, \
594+
c2_max_steps, \
595+
c2_num_class_images, \
596+
c3_max_steps, \
597+
c3_num_class_images = training_wizard(db_config, False)
598+
599+
config_dict[22] = int(max_train_steps)
600+
config_dict[26] = int(num_train_epochs)
601+
config_dict[59] = c1_max_steps
602+
config_dict[61] = c1_num_class_images
603+
config_dict[77] = c2_max_steps
604+
config_dict[79] = c2_num_class_images
605+
config_dict[95] = c3_max_steps
606+
config_dict[97] = c3_num_class_images
607+
if db_performance_wizard:
608+
_, \
609+
attention, \
610+
gradient_checkpointing, \
611+
mixed_precision, \
612+
not_cache_latents, \
613+
sample_batch_size, \
614+
train_batch_size, \
615+
train_text_encoder, \
616+
use_8bit_adam, \
617+
use_cpu, \
618+
use_ema = performance_wizard()
619+
620+
config_dict[5] = attention
621+
config_dict[12] = gradient_checkpointing
622+
config_dict[23] = mixed_precision
623+
config_dict[25] = not_cache_latents
624+
config_dict[32] = sample_batch_size
625+
config_dict[42] = train_batch_size
626+
config_dict[43] = train_text_encoder
627+
config_dict[44] = use_8bit_adam
628+
config_dict[46] = use_cpu
629+
config_dict[47] = use_ema
630+
else:
631+
db_model_name = train_args['train_dreambooth_settings']['db_model_name']
632+
db_model_name = sanitize_name(db_model_name)
633+
db_models_path = cmd_opts.dreambooth_models_path
634+
if db_models_path == "" or db_models_path is None:
635+
db_models_path = os.path.join(shared.models_path, "dreambooth")
636+
working_dir = os.path.join(db_models_path, db_model_name, "working")
637+
config_dict = from_file(os.path.join(db_models_path, db_model_name))
638+
config_dict["pretrained_model_name_or_path"] = working_dir
639+
640+
db_config = DreamboothConfig(*config_dict)
641+
642+
ckpt_dir = cmd_opts.ckpt_dir
643+
sd_models_path = os.path.join(shared.models_path, "Stable-diffusion")
644+
if ckpt_dir is not None:
645+
sd_models_path = ckpt_dir
646+
647+
print(vars(db_config))
648+
start_training_from_config(
649+
db_config,
650+
db_lora_model_name if db_lora_model_name != '' else None,
651+
db_lora_weight,
652+
db_lora_txt_weight,
653+
db_train_imagic_only,
654+
db_use_subdir,
655+
db_custom_model_name
656+
)
657+
658+
try:
659+
cmd_dreambooth_models_path = cmd_opts.dreambooth_models_path
660+
except:
661+
cmd_dreambooth_models_path = None
662+
663+
db_model_dir = os.path.dirname(cmd_dreambooth_models_path) if cmd_dreambooth_models_path else paths.models_path
664+
db_model_dir = os.path.join(db_model_dir, "dreambooth")
665+
666+
try:
667+
upload_s3files(
668+
sd_models_s3uri,
669+
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.pt')
670+
)
671+
upload_s3folder(
672+
db_models_s3uri,
673+
db_model_dir
674+
)
675+
except Exception as e:
676+
traceback.print_exc()
677+
print(e)
678+
opts.data = default_options
444679
else:
445680
print('Incorrect training task')
446681
exit(-1)

0 commit comments

Comments
 (0)