Skip to content

Commit

Permalink
feat: save on ctrl-c
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed Oct 16, 2024
1 parent c6b63b7 commit ed3b2e1
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
import math
import os
import signal
import sys
import random
import time
Expand Down Expand Up @@ -51,6 +52,7 @@ class NetworkTrainer:
def __init__(self):
self.vae_scale_factor = 0.18215
self.is_sdxl = False
self.interrupted = False

# TODO 他のスクリプトと共通化する
def generate_step_logs(
Expand Down Expand Up @@ -1067,6 +1069,11 @@ def remove_model(old_ckpt_name):

clean_memory_on_device(accelerator.device)

# signal handler: save on ctrl-c
def signal_handler(sig, frame):
self.interrupted = True
signal.signal(signal.SIGINT, signal_handler)

for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
Expand Down Expand Up @@ -1209,17 +1216,18 @@ def remove_model(old_ckpt_name):
keys_scaled, mean_norm, maximum_norm = None, None, None

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.sync_gradients or self.interrupted:
if not self.interrupted:
progress_bar.update(1)
global_step += 1

optimizer_eval_fn()
self.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
)

# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
if (args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0) or self.interrupted:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
Expand All @@ -1232,6 +1240,10 @@ def remove_model(old_ckpt_name):
if remove_step_no is not None:
remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no)
remove_model(remove_ckpt_name)

if self.interrupted:
logger.warning("Received Ctrl-C. Saving model and exiting.")
return
optimizer_train_fn()

current_loss = loss.detach().item()
Expand Down

0 comments on commit ed3b2e1

Please sign in to comment.