From fe874aa7bdb28b1f662697806fd47fdd84c4fe0b Mon Sep 17 00:00:00 2001 From: bmaltais Date: Sun, 7 May 2023 16:14:19 -0400 Subject: [PATCH] Update run_cmd_training syntax --- library/common_gui.py | 104 +++++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 43 deletions(-) diff --git a/library/common_gui.py b/library/common_gui.py index bcb6be46a..e9ba737a7 100644 --- a/library/common_gui.py +++ b/library/common_gui.py @@ -821,48 +821,66 @@ def gradio_training( def run_cmd_training(**kwargs): - options = [ - f' --learning_rate="{kwargs.get("learning_rate", "")}"' - if kwargs.get('learning_rate') - else '', - f' --lr_scheduler="{kwargs.get("lr_scheduler", "")}"' - if kwargs.get('lr_scheduler') - else '', - f' --lr_warmup_steps="{kwargs.get("lr_warmup_steps", "")}"' - if kwargs.get('lr_warmup_steps') - else '', - f' --train_batch_size="{kwargs.get("train_batch_size", "")}"' - if kwargs.get('train_batch_size') - else '', - f' --max_train_steps="{kwargs.get("max_train_steps", "")}"' - if kwargs.get('max_train_steps') - else '', - f' --save_every_n_epochs="{int(kwargs.get("save_every_n_epochs", 1))}"' - if int(kwargs.get('save_every_n_epochs')) - else '', - f' --mixed_precision="{kwargs.get("mixed_precision", "")}"' - if kwargs.get('mixed_precision') - else '', - f' --save_precision="{kwargs.get("save_precision", "")}"' - if kwargs.get('save_precision') - else '', - f' --seed="{kwargs.get("seed", "")}"' - if kwargs.get('seed') != '' - else '', - f' --caption_extension="{kwargs.get("caption_extension", "")}"' - if kwargs.get('caption_extension') - else '', - ' --cache_latents' if kwargs.get('cache_latents') else '', - ' --cache_latents_to_disk' - if kwargs.get('cache_latents_to_disk') - else '', - # ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '', - f' --optimizer_type="{kwargs.get("optimizer", "AdamW")}"', - f' --optimizer_args {kwargs.get("optimizer_args", "")}' - if not kwargs.get('optimizer_args') == '' - else '', - ] - run_cmd = ''.join(options) + run_cmd = '' + + learning_rate = kwargs.get("learning_rate", "") + if learning_rate: + run_cmd += f' --learning_rate="{learning_rate}"' + + lr_scheduler = kwargs.get("lr_scheduler", "") + if lr_scheduler: + run_cmd += f' --lr_scheduler="{lr_scheduler}"' + + lr_warmup_steps = kwargs.get("lr_warmup_steps", "") + if lr_warmup_steps: + if lr_scheduler == 'constant': + print('Can\'t use LR warmup with LR Scheduler constant... ignoring...') + else: + run_cmd += f' --lr_warmup_steps="{lr_warmup_steps}"' + + train_batch_size = kwargs.get("train_batch_size", "") + if train_batch_size: + run_cmd += f' --train_batch_size="{train_batch_size}"' + + max_train_steps = kwargs.get("max_train_steps", "") + if max_train_steps: + run_cmd += f' --max_train_steps="{max_train_steps}"' + + save_every_n_epochs = kwargs.get("save_every_n_epochs") + if save_every_n_epochs: + run_cmd += f' --save_every_n_epochs="{int(save_every_n_epochs)}"' + + mixed_precision = kwargs.get("mixed_precision", "") + if mixed_precision: + run_cmd += f' --mixed_precision="{mixed_precision}"' + + save_precision = kwargs.get("save_precision", "") + if save_precision: + run_cmd += f' --save_precision="{save_precision}"' + + seed = kwargs.get("seed", "") + if seed != '': + run_cmd += f' --seed="{seed}"' + + caption_extension = kwargs.get("caption_extension", "") + if caption_extension: + run_cmd += f' --caption_extension="{caption_extension}"' + + cache_latents = kwargs.get('cache_latents') + if cache_latents: + run_cmd += ' --cache_latents' + + cache_latents_to_disk = kwargs.get('cache_latents_to_disk') + if cache_latents_to_disk: + run_cmd += ' --cache_latents_to_disk' + + optimizer_type = kwargs.get("optimizer", "AdamW") + run_cmd += f' --optimizer_type="{optimizer_type}"' + + optimizer_args = kwargs.get("optimizer_args", "") + if optimizer_args != '': + run_cmd += f' --optimizer_args {optimizer_args}' + return run_cmd @@ -1084,7 +1102,7 @@ def run_cmd_advanced_training(**kwargs): max_train_epochs = kwargs.get("max_train_epochs", "") if max_train_epochs: - run_cmd += ' --max_train_epochs={max_train_epochs}' + run_cmd += f' --max_train_epochs={max_train_epochs}' max_data_loader_n_workers = kwargs.get("max_data_loader_n_workers", "") if max_data_loader_n_workers: