1
1
from __future__ import annotations
2
2
from pathlib import Path
3
- from modules import errors
4
3
import csv
5
4
import os
6
5
import typing
@@ -45,13 +44,13 @@ def extract_style_text_from_prompt(style_text, prompt):
45
44
if "{prompt}" in stripped_style_text :
46
45
left , _ , right = stripped_style_text .partition ("{prompt}" )
47
46
if stripped_prompt .startswith (left ) and stripped_prompt .endswith (right ):
48
- prompt = stripped_prompt [len (left ): len (stripped_prompt )- len (right )]
47
+ prompt = stripped_prompt [len (left ) : len (stripped_prompt ) - len (right )]
49
48
return True , prompt
50
49
else :
51
50
if stripped_prompt .endswith (stripped_style_text ):
52
- prompt = stripped_prompt [:len (stripped_prompt )- len (stripped_style_text )]
51
+ prompt = stripped_prompt [: len (stripped_prompt ) - len (stripped_style_text )]
53
52
54
- if prompt .endswith (', ' ):
53
+ if prompt .endswith (", " ):
55
54
prompt = prompt [:- 2 ]
56
55
57
56
return True , prompt
@@ -68,11 +67,15 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):
68
67
if not style .prompt and not style .negative_prompt :
69
68
return False , prompt , negative_prompt
70
69
71
- match_positive , extracted_positive = extract_style_text_from_prompt (style .prompt , prompt )
70
+ match_positive , extracted_positive = extract_style_text_from_prompt (
71
+ style .prompt , prompt
72
+ )
72
73
if not match_positive :
73
74
return False , prompt , negative_prompt
74
75
75
- match_negative , extracted_negative = extract_style_text_from_prompt (style .negative_prompt , negative_prompt )
76
+ match_negative , extracted_negative = extract_style_text_from_prompt (
77
+ style .negative_prompt , negative_prompt
78
+ )
76
79
if not match_negative :
77
80
return False , prompt , negative_prompt
78
81
@@ -89,20 +92,28 @@ def _format_divider(file: str) -> str:
89
92
return divider
90
93
91
94
95
+ def _expand_path (path : list [str | Path ] | str | Path ) -> list [str ]:
96
+ if isinstance (path , (str , Path )):
97
+ return [str (Path (path ))]
98
+
99
+ paths = []
100
+ for pattern in path :
101
+ folder , file = os .path .split (pattern )
102
+ if "*" in file or "?" in file :
103
+ matching_files = Path (folder ).glob (file )
104
+ [paths .append (str (file )) for file in matching_files ]
105
+ else :
106
+ paths .append (str (Path (pattern )))
107
+
108
+ return paths
109
+
110
+
92
111
class StyleDatabase :
93
- def __init__ (self , paths : list [ str | Path ] ):
112
+ def __init__ (self , path : str | Path ):
94
113
self .no_style = PromptStyle ("None" , "" , "" , None )
95
114
self .styles = {}
96
115
self .path = path
97
116
self .prompt_fields = [field for field in PromptStyle ._fields if field != "path" ]
98
-
99
- # The default path will be self.path with any wildcard removed. If it
100
- # doesn't exist, the reload() method updates this to be 'styles.csv'.
101
- self .default_file = "styles.csv"
102
- folder , file = os .path .split (self .path )
103
- filename , _ , ext = file .partition ('*' )
104
- self .default_path = os .path .join (folder , filename + ext )
105
-
106
117
self .reload ()
107
118
108
119
def reload (self ):
@@ -112,50 +123,18 @@ def reload(self):
112
123
"""
113
124
self .styles .clear ()
114
125
115
- # scans for all styles files
116
- all_styles_files = []
117
- for pattern in self .paths :
118
- folder , file = os .path .split (pattern )
119
- if '*' in file or '?' in file :
120
- found_files = Path (folder ).glob (file )
121
- [all_styles_files .append (file ) for file in found_files ]
122
- else :
123
- # if os.path.exists(pattern):
124
- all_styles_files .append (Path (pattern ))
125
-
126
- if "*" in filename :
127
- fileglob = filename .split ("*" )[0 ] + "*.csv"
128
- filelist = []
129
- for file in os .listdir (path ):
130
- if fnmatch .fnmatch (file , fileglob ):
131
- filelist .append (file )
132
- # Add a visible divider to the style list
133
- divider = _format_divider (file )
134
- self .styles [divider ] = PromptStyle (
135
- f"{ divider } " , None , None , "do_not_save"
136
- )
137
- # Add styles from this CSV file
138
- self .load_from_csv (os .path .join (path , file ))
139
-
140
- # Ensure the default file is loaded, else its contents may be lost:
141
- if os .path .split (self .default_path )[1 ] not in filelist :
142
- self .default_path = os .path .join (path , self .default_file )
143
- divider = _format_divider (self .default_file )
144
- self .styles [divider ] = PromptStyle (
145
- f"{ divider } " , None , None , "do_not_save"
146
- )
147
- self .load_from_csv (os .path .join (path , self .default_file ))
148
-
149
- if len (filelist ) == 0 :
150
- print (f"No styles found in { path } matching { fileglob } " )
151
- self .load_from_csv (self .default_path )
152
- return
126
+ # Expand the path to a list of full paths, expanding any wildcards. The
127
+ # default path will be the first of these:
128
+ style_files = _expand_path (self .path )
129
+ self .default_path = style_files [0 ]
153
130
154
- elif not os .path .exists (self .path ):
155
- print (f"Style database not found: { self .path } " )
156
- return
157
- else :
158
- self .load_from_csv (self .path )
131
+ for file in style_files :
132
+ _ , filename = os .path .split (file )
133
+ # Add a visible divider to the style list
134
+ divider = _format_divider (filename )
135
+ self .styles [divider ] = PromptStyle (f"{ divider } " , None , None , "do_not_save" )
136
+ # Add styles from this CSV file
137
+ self .load_from_csv (file )
159
138
160
139
def load_from_csv (self , path : str ):
161
140
with open (path , "r" , encoding = "utf-8-sig" , newline = "" ) as file :
@@ -173,7 +152,10 @@ def load_from_csv(self, path: str):
173
152
)
174
153
175
154
def get_style_paths (self ) -> set :
176
- """Returns a set of all distinct paths of files that styles are loaded from."""
155
+ """
156
+ Using the collection of styles in the StyleDatabase, returns a set of
157
+ all distinct files that styles are loaded from.
158
+ """
177
159
# Update any styles without a path to the default path
178
160
for style in list (self .styles .values ()):
179
161
if not style .path :
@@ -224,14 +206,17 @@ def save_styles(self, path: str = None) -> None:
224
206
writer = csv .DictWriter (file , fieldnames = self .prompt_fields )
225
207
writer .writeheader ()
226
208
for style in (s for s in self .styles .values () if s .path == style_path ):
227
- # Skip style list dividers , e.g. "STYLES.CSV"
209
+ # Skip style list divider entries , e.g. "## STYLES.CSV ## "
228
210
if style .name .lower ().strip ("# " ) in csv_names :
229
211
continue
230
212
# Write style fields, ignoring the path field
231
213
writer .writerow (
232
214
{k : v for k , v in style ._asdict ().items () if k != "path" }
233
215
)
234
216
217
+ # Reloading the styles to re-order the drop-down lists
218
+ self .reload ()
219
+
235
220
def extract_styles_from_prompt (self , prompt , negative_prompt ):
236
221
extracted = []
237
222
0 commit comments