Commit 06ae2ff 1 parent 216a63c commit 06ae2ff Copy full SHA for 06ae2ff
File tree 1 file changed +4
-3
lines changed
1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -746,15 +746,16 @@ def train():
746
746
#automatic tar latest checkpoint and upload to s3 by zheng on 2023.03.22
747
747
os .makedirs (os .path .dirname ("/opt/ml/model/" ), exist_ok = True )
748
748
train_steps = int (db_config .revision )
749
- f1 = os .path .join (sd_models_path , db_model_name , f'{ db_model_name } _{ train_steps } .yaml' )
749
+ model_file_basename = f'{ db_model_name } _{ train_steps } _lora' if db_config .use_lora else f'{ db_model_name } _{ train_steps } '
750
+ f1 = os .path .join (sd_models_path , db_model_name , f'{ model_file_basename } .yaml' )
750
751
if os .path .exists (f1 ):
751
752
shutil .copy (f1 ,"/opt/ml/model/" )
752
753
if db_save_safetensors :
753
- f2 = os .path .join (sd_models_path , db_model_name , f'{ db_model_name } _ { train_steps } .safetensors' )
754
+ f2 = os .path .join (sd_models_path , db_model_name , f'{ model_file_basename } .safetensors' )
754
755
if os .path .exists (f2 ):
755
756
shutil .copy (f2 ,"/opt/ml/model/" )
756
757
else :
757
- f2 = os .path .join (sd_models_path , db_model_name , f'{ db_model_name } _ { train_steps } .ckpt' )
758
+ f2 = os .path .join (sd_models_path , db_model_name , f'{ model_file_basename } .ckpt' )
758
759
if os .path .exists (f2 ):
759
760
shutil .copy (f2 ,"/opt/ml/model/" )
760
761
except Exception as e :
You can’t perform that action at this time.
0 commit comments