@@ -238,42 +238,29 @@ def webui():
238
238
modules .sd_models .list_models ()
239
239
print ('Restarting Gradio' )
240
240
241
- def upload_s3file (s3uri , file_path , file_name ):
242
- s3_client = boto3 .client ('s3' , region_name = cmd_opts .region_name )
243
-
241
+ def upload_s3files (s3uri , file_path_with_pattern ):
244
242
pos = s3uri .find ('/' , 5 )
245
243
bucket = s3uri [5 : pos ]
246
244
key = s3uri [pos + 1 : ]
247
245
248
- binary = io .BytesIO (open (file_path , 'rb' ).read ())
249
- key = key + file_name
246
+ s3_resource = boto3 .resource ('s3' )
247
+ s3_bucket = s3_resource .Bucket (bucket )
248
+
250
249
try :
251
- s3_client .upload_fileobj (binary , bucket , key )
250
+ for file_path in glob .glob (file_path_with_pattern ):
251
+ file_name = os .path .basename (file_path )
252
+ __s3file = f'{ key } /{ file_name } '
253
+ print (file_path , __s3file )
254
+ s3_bucket .upload_file (file_path , __s3file )
252
255
except ClientError as e :
253
256
print (e )
254
257
return False
255
258
return True
256
259
257
- def upload_s3files (s3uri , file_path_with_pattern ):
258
- s3_client = boto3 .client ('s3' , region_name = cmd_opts .region_name )
259
-
260
- pos = s3uri .find ('/' , 5 )
261
- bucket = s3uri [5 : pos ]
262
- key = s3uri [pos + 1 : ]
263
-
264
- for file_name in glob .glob (file_path_with_pattern ):
265
- binary = io .BytesIO (open (file_name , 'rb' ).read ())
266
- key = key + file_name
267
- try :
268
- s3_client .upload_fileobj (binary , bucket , key )
269
- except ClientError as e :
270
- print (e )
271
- return False
272
- return True
273
-
274
260
def upload_s3folder (s3uri , file_path ):
275
261
pos = s3uri .find ('/' , 5 )
276
262
bucket = s3uri [5 : pos ]
263
+ key = s3uri [pos + 1 : ]
277
264
278
265
s3_resource = boto3 .resource ('s3' )
279
266
s3_bucket = s3_resource .Bucket (bucket )
@@ -282,7 +269,7 @@ def upload_s3folder(s3uri, file_path):
282
269
for path , _ , files in os .walk (file_path ):
283
270
for file in files :
284
271
dest_path = path .replace (file_path ,"" )
285
- __s3file = os . path . normpath ( s3uri + dest_path + '/' + file )
272
+ __s3file = f' { key } { dest_path } / { file } '
286
273
__local_file = os .path .join (path , file )
287
274
print (__local_file , __s3file )
288
275
s3_bucket .upload_file (__local_file , __s3file )
@@ -400,7 +387,7 @@ def train():
400
387
* txt2img_preview_params
401
388
)
402
389
try :
403
- upload_s3file (embeddings_s3uri , os .path .join (cmd_opts .embeddings_dir , '{0}.pt' .format (train_embedding_name )), '{0}.pt' . format ( train_embedding_name ))
390
+ upload_s3files (embeddings_s3uri , os .path .join (cmd_opts .embeddings_dir , '{0}.pt' .format (train_embedding_name )))
404
391
except Exception as e :
405
392
traceback .print_exc ()
406
393
print (e )
@@ -666,13 +653,19 @@ def train():
666
653
db_model_dir = os .path .join (db_model_dir , "dreambooth" )
667
654
668
655
try :
656
+ print ('Uploading SD Models...' )
657
+ upload_s3files (
658
+ sd_models_s3uri ,
659
+ os .path .join (sd_models_path , f'{ sd_models_path } /{ db_model_name } _*.yaml' )
660
+ )
669
661
upload_s3files (
670
662
sd_models_s3uri ,
671
- os .path .join (sd_models_path , f'{ sd_models_path } /{ db_model_name } _*.pt ' )
663
+ os .path .join (sd_models_path , f'{ sd_models_path } /{ db_model_name } _*.ckpt ' )
672
664
)
665
+ print ('Uploading DB Models...' )
673
666
upload_s3folder (
674
- db_models_s3uri ,
675
- db_model_dir
667
+ f' { db_models_s3uri } / { db_model_name } ' ,
668
+ os . path . join ( db_model_dir , db_model_name )
676
669
)
677
670
except Exception as e :
678
671
traceback .print_exc ()
0 commit comments