Skip to content

Commit

Permalink
allow setting token strings manually for PTI
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed Apr 13, 2024
1 parent 3794486 commit 53274a7
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,15 +226,19 @@ def train(self, args):
embeddings_map = {}
embedding_to_token_ids = {}
if len(args.embeddings) > 0:
for embeds_file in args.embeddings:
if not args.token_strings:
args.token_strings = [Path(embeds_file).stem for embeds_file in args.embeddings]
if len(args.token_strings) != len(args.embeddings):
raise ValueError(f"token_strings must have a name for each embedding. Got {len(args.token_strings)}/{len(args.embeddings)}")

for embeds_file, token_string in zip(args.embeddings, args.token_strings):
if model_util.is_safetensors(embeds_file):
from safetensors.torch import load_file

data = load_file(embeds_file)
else:
data = torch.load(embeds_file, map_location="cpu")

token_string = Path(embeds_file).stem
embeds_list, _shape, num_vectors_per_token = self.create_embedding_from_data(data, token_string)
if isinstance(embeds_list, dict) and "clip_l" in embeds_list and "clip_g" in embeds_list:
embeds_list = [embeds_list["clip_l"], embeds_list["clip_g"]]
Expand Down Expand Up @@ -1401,6 +1405,12 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*",
help="Embeddings files of Textual Inversion / Textual Inversionのembeddings",
)
parser.add_argument(
"--token_strings",
type=str,
nargs="*",
help="Names to use for each embedding instead of filename",
)
parser.add_argument("--continue_inversion", action="store_true", help="Continue the textual inversion when training the LoRA")
parser.add_argument("--embedding_lr", type=float, default=None, help="Learning rate used when continuing the textual inversion")
parser.add_argument(
Expand Down

0 comments on commit 53274a7

Please sign in to comment.