Skip to content

Commit 4ea5403

Browse files
authored
Merge pull request AUTOMATIC1111#3 from Kittensx/Kittensx-patch-Simple-KES
Update simple_karras_exponential_scheduler.py
2 parents 6ffb728 + da2e709 commit 4ea5403

File tree

1 file changed

+46
-50
lines changed

1 file changed

+46
-50
lines changed

modules/simple_karras_exponential_scheduler.py

+46-50
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,26 @@
1313
import logging
1414
from datetime import datetime
1515

16+
def get_random_or_default(scheduler_config, key_prefix, default_value, global_randomize):
17+
"""Helper function to either randomize a value based on conditions or return the default."""
18+
19+
# Determine if we should randomize based on global and individual flags
20+
randomize_flag = global_randomize or scheduler_config.get(f'{key_prefix}_rand', False)
21+
22+
if randomize_flag:
23+
# Use specified min/max values for randomization if they exist, else use default range
24+
rand_min = scheduler_config.get(f'{key_prefix}_rand_min', default_value * 0.8)
25+
rand_max = scheduler_config.get(f'{key_prefix}_rand_max', default_value * 1.2)
26+
value = random.uniform(rand_min, rand_max)
27+
custom_logger.info(f"Randomized {key_prefix}: {value}")
28+
else:
29+
# Use default value if no randomization is applied
30+
value = default_value
31+
custom_logger.info(f"Using default {key_prefix}: {value}")
32+
33+
return value
34+
35+
1636
class CustomLogger:
1737
def __init__(self, log_name, print_to_console=False, debug_enabled=False):
1838
self.print_to_console = print_to_console #prints to console
@@ -149,38 +169,8 @@ def start_config_watcher(config_manager, config_path):
149169

150170
# Start watching for config changes
151171
observer = start_config_watcher(config_manager, config_path)
152-
'''
153-
def get_random_or_default(config, key_prefix, default_value):
154-
"""Helper function to either randomize a value or return the default."""
155-
randomize_flag = config['scheduler'].get(f'{key_prefix}_rand', False)
156-
if randomize_flag:
157-
rand_min = config['scheduler'].get(f'{key_prefix}_rand_min', default_value * 0.8)
158-
rand_max = config['scheduler'].get(f'{key_prefix}_rand_max', default_value * 1.2)
159-
value = random.uniform(rand_min, rand_max)
160-
custom_logger.info(f"Randomized {key_prefix}: {value}" )
161-
else:
162-
value = default_value
163-
custom_logger.info(f"Using default {key_prefix}: {value}")
164-
return value
165-
'''
166-
def get_random_or_default(config, key_prefix, default_value, global_randomize):
167-
"""Helper function to either randomize a value based on conditions or return the default."""
168-
# Check if global randomize is on or the individual flag is on
169-
randomize_flag = global_randomize or config['scheduler'].get(f'{key_prefix}_rand', False)
170-
171-
if randomize_flag:
172-
# Use specified min/max for randomization if the individual flag is set or global randomize is on
173-
rand_min = config['scheduler'].get(f'{key_prefix}_rand_min', default_value * 0.8)
174-
rand_max = config['scheduler'].get(f'{key_prefix}_rand_max', default_value * 1.2)
175-
value = random.uniform(rand_min, rand_max)
176-
custom_logger.info(f"Randomized {key_prefix}: {value}")
177-
else:
178-
value = default_value
179-
custom_logger.info(f"Using default {key_prefix}: {value}")
180-
181-
return value
182172

183-
173+
184174
def simple_karras_exponential_scheduler(
185175
n, device, sigma_min=0.01, sigma_max=50, start_blend=0.1, end_blend=0.5,
186176
sharpness=0.95, early_stopping_threshold=0.01, update_interval=10, initial_step_size=0.9,
@@ -209,6 +199,15 @@ def simple_karras_exponential_scheduler(
209199
Returns:
210200
torch.Tensor: A tensor of blended sigma values.
211201
"""
202+
config_path = os.path.join(os.path.dirname(__file__), 'simple_kes_scheduler.yaml')
203+
config = config_manager.load_config()
204+
scheduler_config = config.get('scheduler', {})
205+
if not scheduler_config:
206+
raise ValueError("Scheduler configuration is missing from the config file.")
207+
208+
# Global randomization flag
209+
global_randomize = scheduler_config.get('randomize', False)
210+
212211
#debug_log("Entered simple_karras_exponential_scheduler function")
213212
default_config = {
214213
"debug": False,
@@ -272,30 +271,27 @@ def simple_karras_exponential_scheduler(
272271
"noise_scale_factor_rand_max": 0.95,
273272
}
274273
custom_logger.info(f"Default Config create {default_config}")
275-
for key, value in default_config.items():
276-
custom_logger.info(f"Default Config - {key}: {value}")
277-
278-
#config = config_manager.load_config()
279274
config = config_manager.load_config().get('scheduler', {})
280-
global_randomize = config.get('randomize', randomize)
281-
282-
custom_logger.info(f"Config loaded from yaml {config}")
275+
if not config:
276+
raise ValueError("Scheduler configuration is missing from the config file.")
277+
278+
# Log loaded YAML configuration
279+
custom_logger.info(f"Configuration loaded from YAML: {config}")
280+
283281
for key, value in config.items():
284-
custom_logger.info(f"Config - {key}: {value}")
285-
286-
# Check if the scheduler config is available in the YAML file
287-
scheduler_config = config.get('scheduler', {})
288-
if not scheduler_config:
289-
raise ValueError("Scheduler configuration is missing from the config file.")
290-
291-
for key, value in scheduler_config.items():
292-
custom_logger.info(f"Scheduler Config before update - {key}: {value}")
293-
for key, value in scheduler_config.items():
294282
if key in default_config:
295-
default_config[key] = value
283+
default_config[key] = value # Override default with YAML value
296284
custom_logger.info(f"Overriding default config: {key} = {value}")
297285
else:
298-
debug.log(f"Ignoring unknown config option: {key}")
286+
custom_logger.info(f"Ignoring unknown config option: {key}")
287+
288+
custom_logger.info(f"Final configuration after merging with YAML: {default_config}")
289+
290+
global_randomize = default_config.get('randomize', False)
291+
custom_logger.info(f"Global randomization flag set to: {global_randomize}")
292+
293+
custom_logger.info(f"Config loaded from yaml {config}")
294+
299295
# Now using default_config, updated with valid YAML values
300296
custom_logger.info(f"Final Config after overriding: {default_config}")
301297

0 commit comments

Comments
 (0)