Skip to content

Commit c9b2470

Browse files
authored
Fix error with path initializing StylesDatabase (#8)
fix path NameError in styles.py
1 parent 8199e7f commit c9b2470

File tree

1 file changed

+45
-60
lines changed

1 file changed

+45
-60
lines changed

modules/styles.py

+45-60
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22
from pathlib import Path
3-
from modules import errors
43
import csv
54
import os
65
import typing
@@ -45,13 +44,13 @@ def extract_style_text_from_prompt(style_text, prompt):
4544
if "{prompt}" in stripped_style_text:
4645
left, _, right = stripped_style_text.partition("{prompt}")
4746
if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
48-
prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
47+
prompt = stripped_prompt[len(left) : len(stripped_prompt) - len(right)]
4948
return True, prompt
5049
else:
5150
if stripped_prompt.endswith(stripped_style_text):
52-
prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]
51+
prompt = stripped_prompt[: len(stripped_prompt) - len(stripped_style_text)]
5352

54-
if prompt.endswith(', '):
53+
if prompt.endswith(", "):
5554
prompt = prompt[:-2]
5655

5756
return True, prompt
@@ -68,11 +67,15 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
6867
if not style.prompt and not style.negative_prompt:
6968
return False, prompt, negative_prompt
7069

71-
match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
70+
match_positive, extracted_positive = extract_style_text_from_prompt(
71+
style.prompt, prompt
72+
)
7273
if not match_positive:
7374
return False, prompt, negative_prompt
7475

75-
match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
76+
match_negative, extracted_negative = extract_style_text_from_prompt(
77+
style.negative_prompt, negative_prompt
78+
)
7679
if not match_negative:
7780
return False, prompt, negative_prompt
7881

@@ -89,20 +92,28 @@ def _format_divider(file: str) -> str:
8992
return divider
9093

9194

95+
def _expand_path(path: list[str | Path] | str | Path) -> list[str]:
96+
if isinstance(path, (str, Path)):
97+
return [str(Path(path))]
98+
99+
paths = []
100+
for pattern in path:
101+
folder, file = os.path.split(pattern)
102+
if "*" in file or "?" in file:
103+
matching_files = Path(folder).glob(file)
104+
[paths.append(str(file)) for file in matching_files]
105+
else:
106+
paths.append(str(Path(pattern)))
107+
108+
return paths
109+
110+
92111
class StyleDatabase:
93-
def __init__(self, paths: list[str | Path]):
112+
def __init__(self, path: str | Path):
94113
self.no_style = PromptStyle("None", "", "", None)
95114
self.styles = {}
96115
self.path = path
97116
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
98-
99-
# The default path will be self.path with any wildcard removed. If it
100-
# doesn't exist, the reload() method updates this to be 'styles.csv'.
101-
self.default_file = "styles.csv"
102-
folder, file = os.path.split(self.path)
103-
filename, _, ext = file.partition('*')
104-
self.default_path = os.path.join(folder, filename + ext)
105-
106117
self.reload()
107118

108119
def reload(self):
@@ -112,50 +123,18 @@ def reload(self):
112123
"""
113124
self.styles.clear()
114125

115-
# scans for all styles files
116-
all_styles_files = []
117-
for pattern in self.paths:
118-
folder, file = os.path.split(pattern)
119-
if '*' in file or '?' in file:
120-
found_files = Path(folder).glob(file)
121-
[all_styles_files.append(file) for file in found_files]
122-
else:
123-
# if os.path.exists(pattern):
124-
all_styles_files.append(Path(pattern))
125-
126-
if "*" in filename:
127-
fileglob = filename.split("*")[0] + "*.csv"
128-
filelist = []
129-
for file in os.listdir(path):
130-
if fnmatch.fnmatch(file, fileglob):
131-
filelist.append(file)
132-
# Add a visible divider to the style list
133-
divider = _format_divider(file)
134-
self.styles[divider] = PromptStyle(
135-
f"{divider}", None, None, "do_not_save"
136-
)
137-
# Add styles from this CSV file
138-
self.load_from_csv(os.path.join(path, file))
139-
140-
# Ensure the default file is loaded, else its contents may be lost:
141-
if os.path.split(self.default_path)[1] not in filelist:
142-
self.default_path = os.path.join(path, self.default_file)
143-
divider = _format_divider(self.default_file)
144-
self.styles[divider] = PromptStyle(
145-
f"{divider}", None, None, "do_not_save"
146-
)
147-
self.load_from_csv(os.path.join(path, self.default_file))
148-
149-
if len(filelist) == 0:
150-
print(f"No styles found in {path} matching {fileglob}")
151-
self.load_from_csv(self.default_path)
152-
return
126+
# Expand the path to a list of full paths, expanding any wildcards. The
127+
# default path will be the first of these:
128+
style_files = _expand_path(self.path)
129+
self.default_path = style_files[0]
153130

154-
elif not os.path.exists(self.path):
155-
print(f"Style database not found: {self.path}")
156-
return
157-
else:
158-
self.load_from_csv(self.path)
131+
for file in style_files:
132+
_, filename = os.path.split(file)
133+
# Add a visible divider to the style list
134+
divider = _format_divider(filename)
135+
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
136+
# Add styles from this CSV file
137+
self.load_from_csv(file)
159138

160139
def load_from_csv(self, path: str):
161140
with open(path, "r", encoding="utf-8-sig", newline="") as file:
@@ -173,7 +152,10 @@ def load_from_csv(self, path: str):
173152
)
174153

175154
def get_style_paths(self) -> set:
176-
"""Returns a set of all distinct paths of files that styles are loaded from."""
155+
"""
156+
Using the collection of styles in the StyleDatabase, returns a set of
157+
all distinct files that styles are loaded from.
158+
"""
177159
# Update any styles without a path to the default path
178160
for style in list(self.styles.values()):
179161
if not style.path:
@@ -224,14 +206,17 @@ def save_styles(self, path: str = None) -> None:
224206
writer = csv.DictWriter(file, fieldnames=self.prompt_fields)
225207
writer.writeheader()
226208
for style in (s for s in self.styles.values() if s.path == style_path):
227-
# Skip style list dividers, e.g. "STYLES.CSV"
209+
# Skip style list divider entries, e.g. "## STYLES.CSV ##"
228210
if style.name.lower().strip("# ") in csv_names:
229211
continue
230212
# Write style fields, ignoring the path field
231213
writer.writerow(
232214
{k: v for k, v in style._asdict().items() if k != "path"}
233215
)
234216

217+
# Reloading the styles to re-order the drop-down lists
218+
self.reload()
219+
235220
def extract_styles_from_prompt(self, prompt, negative_prompt):
236221
extracted = []
237222

0 commit comments

Comments
 (0)