Skip to content

Commit f41e8be

Browse files
committed
revise webui.py
1 parent 52c313e commit f41e8be

File tree

1 file changed

+21
-28
lines changed

1 file changed

+21
-28
lines changed

webui.py

+21-28
Original file line numberDiff line numberDiff line change
@@ -238,42 +238,29 @@ def webui():
238238
modules.sd_models.list_models()
239239
print('Restarting Gradio')
240240

241-
def upload_s3file(s3uri, file_path, file_name):
242-
s3_client = boto3.client('s3', region_name = cmd_opts.region_name)
243-
241+
def upload_s3files(s3uri, file_path_with_pattern):
244242
pos = s3uri.find('/', 5)
245243
bucket = s3uri[5 : pos]
246244
key = s3uri[pos + 1 : ]
247245

248-
binary = io.BytesIO(open(file_path, 'rb').read())
249-
key = key + file_name
246+
s3_resource = boto3.resource('s3')
247+
s3_bucket = s3_resource.Bucket(bucket)
248+
250249
try:
251-
s3_client.upload_fileobj(binary, bucket, key)
250+
for file_path in glob.glob(file_path_with_pattern):
251+
file_name = os.path.basename(file_path)
252+
__s3file = f'{key}/{file_name}'
253+
print(file_path, __s3file)
254+
s3_bucket.upload_file(file_path, __s3file)
252255
except ClientError as e:
253256
print(e)
254257
return False
255258
return True
256259

257-
def upload_s3files(s3uri, file_path_with_pattern):
258-
s3_client = boto3.client('s3', region_name = cmd_opts.region_name)
259-
260-
pos = s3uri.find('/', 5)
261-
bucket = s3uri[5 : pos]
262-
key = s3uri[pos + 1 : ]
263-
264-
for file_name in glob.glob(file_path_with_pattern):
265-
binary = io.BytesIO(open(file_name, 'rb').read())
266-
key = key + file_name
267-
try:
268-
s3_client.upload_fileobj(binary, bucket, key)
269-
except ClientError as e:
270-
print(e)
271-
return False
272-
return True
273-
274260
def upload_s3folder(s3uri, file_path):
275261
pos = s3uri.find('/', 5)
276262
bucket = s3uri[5 : pos]
263+
key = s3uri[pos + 1 : ]
277264

278265
s3_resource = boto3.resource('s3')
279266
s3_bucket = s3_resource.Bucket(bucket)
@@ -282,7 +269,7 @@ def upload_s3folder(s3uri, file_path):
282269
for path, _, files in os.walk(file_path):
283270
for file in files:
284271
dest_path = path.replace(file_path,"")
285-
__s3file = os.path.normpath(s3uri + dest_path + '/' + file)
272+
__s3file = f'{key}{dest_path}/{file}'
286273
__local_file = os.path.join(path, file)
287274
print(__local_file, __s3file)
288275
s3_bucket.upload_file(__local_file, __s3file)
@@ -400,7 +387,7 @@ def train():
400387
*txt2img_preview_params
401388
)
402389
try:
403-
upload_s3file(embeddings_s3uri, os.path.join(cmd_opts.embeddings_dir, '{0}.pt'.format(train_embedding_name)), '{0}.pt'.format(train_embedding_name))
390+
upload_s3files(embeddings_s3uri, os.path.join(cmd_opts.embeddings_dir, '{0}.pt'.format(train_embedding_name)))
404391
except Exception as e:
405392
traceback.print_exc()
406393
print(e)
@@ -666,13 +653,19 @@ def train():
666653
db_model_dir = os.path.join(db_model_dir, "dreambooth")
667654

668655
try:
656+
print('Uploading SD Models...')
657+
upload_s3files(
658+
sd_models_s3uri,
659+
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.yaml')
660+
)
669661
upload_s3files(
670662
sd_models_s3uri,
671-
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.pt')
663+
os.path.join(sd_models_path, f'{sd_models_path}/{db_model_name}_*.ckpt')
672664
)
665+
print('Uploading DB Models...')
673666
upload_s3folder(
674-
db_models_s3uri,
675-
db_model_dir
667+
f'{db_models_s3uri}/{db_model_name}',
668+
os.path.join(db_model_dir, db_model_name)
676669
)
677670
except Exception as e:
678671
traceback.print_exc()

0 commit comments

Comments
 (0)