Skip to content

Commit

Permalink
call network.save_weights before bundling embedding (PiSSA compatibil…
Browse files Browse the repository at this point in the history
…ity)
  • Loading branch information
feffy380 committed Apr 13, 2024
1 parent 1991ed9 commit 3794486
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,9 +951,18 @@ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, embeddings_map, force_s
sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
metadata_to_save.update(sai_metadata)

unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)

if len(embeddings_map.keys()) > 0:
# load saved state dict
if model_util.is_safetensors(ckpt_file):
from safetensors.torch import load_file

state_dict = load_file(ckpt_file)
else:
state_dict = torch.load(ckpt_file, map_location="cpu")

# Bundle embeddings in LoRA state dict
state_dict = unwrapped_nw.state_dict()
is_sdxl = len(next(iter(embeddings_map.values()))) == 2
for emb_name in embeddings_map.keys():
accelerator.print(f"Bundling embedding: {emb_name}")
Expand All @@ -977,7 +986,7 @@ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, embeddings_map, force_s
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v

if os.path.splitext(ckpt_file)[1] == ".safetensors":
if model_util.is_safetensors(ckpt_file):
from safetensors.torch import save_file

# Precalculate model hashes to save time on indexing
Expand All @@ -990,8 +999,6 @@ def save_model(ckpt_name, unwrapped_nw, steps, epoch_no, embeddings_map, force_s
save_file(state_dict, ckpt_file, metadata_to_save)
else:
torch.save(state_dict, ckpt_file)
else:
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)

if args.huggingface_repo_id is not None:
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=force_sync_upload)
Expand Down

0 comments on commit 3794486

Please sign in to comment.