From ed3b2e13d4975a6e61db891a4f21772c96026686 Mon Sep 17 00:00:00 2001 From: feffy380 <114889020+feffy380@users.noreply.github.com> Date: Wed, 16 Oct 2024 11:58:48 +0200 Subject: [PATCH] feat: save on ctrl-c --- train_network.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index d5330aef4..0eaa7ac7e 100644 --- a/train_network.py +++ b/train_network.py @@ -2,6 +2,7 @@ import argparse import math import os +import signal import sys import random import time @@ -51,6 +52,7 @@ class NetworkTrainer: def __init__(self): self.vae_scale_factor = 0.18215 self.is_sdxl = False + self.interrupted = False # TODO 他のスクリプトと共通化する def generate_step_logs( @@ -1067,6 +1069,11 @@ def remove_model(old_ckpt_name): clean_memory_on_device(accelerator.device) + # signal handler: save on ctrl-c + def signal_handler(sig, frame): + self.interrupted = True + signal.signal(signal.SIGINT, signal_handler) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -1209,9 +1216,10 @@ def remove_model(old_ckpt_name): keys_scaled, mean_norm, maximum_norm = None, None, None # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 + if accelerator.sync_gradients or self.interrupted: + if not self.interrupted: + progress_bar.update(1) + global_step += 1 optimizer_eval_fn() self.sample_images( @@ -1219,7 +1227,7 @@ def remove_model(old_ckpt_name): ) # 指定ステップごとにモデルを保存 - if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: + if (args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0) or self.interrupted: accelerator.wait_for_everyone() if accelerator.is_main_process: ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step) @@ -1232,6 +1240,10 @@ def remove_model(old_ckpt_name): if remove_step_no is not None: remove_ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, remove_step_no) remove_model(remove_ckpt_name) + + if self.interrupted: + logger.warning("Received Ctrl-C. Saving model and exiting.") + return optimizer_train_fn() current_loss = loss.detach().item()