Skip to content

Commit

Permalink
fix: on resume, preserve progress bar and current step/epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed Oct 16, 2024
1 parent ed3b2e1 commit 3cf8bc3
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def load_model_hook(models, input_dir):
), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}"

progress_bar = tqdm(
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
initial=initial_step, total=args.max_train_steps, smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
)

epoch_to_start = 0
Expand All @@ -976,7 +976,6 @@ def load_model_hook(models, input_dir):
f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります"
)
logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします")
initial_step *= args.gradient_accumulation_steps

# set epoch to start to make initial_step less than len(train_dataloader)
epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -1045,18 +1044,19 @@ def remove_model(old_ckpt_name):

# For --sample_at_first
optimizer_eval_fn()
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
if args.sample_at_first and initial_step == 0:
self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

# training loop
if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}")
initial_step -= len(train_dataloader)
global_step = initial_step
for skip_epoch in range(epoch_to_start): # skip epochs
logger.info(f"skipping epoch {skip_epoch+1} because initial_step is {initial_step}")
initial_step -= math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

# log device and dtype for each model
logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")
Expand Down

0 comments on commit 3cf8bc3

Please sign in to comment.