Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extension / Script load and execution order system #13943

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion modules/extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import json
import threading

from modules import shared, errors, cache, scripts
Expand Down Expand Up @@ -94,9 +95,18 @@ def list_files(self, subdir, extension):
if not os.path.isdir(dirpath):
return []

load_order = {}
try:
with open(os.path.join(self.path, 'webui-extension-properties.json'), 'r', encoding='utf-8') as file:
load_order = json.load(file).get('load_order')
except FileNotFoundError:
pass
except Exception as e:
print(e)

res = []
for filename in sorted(os.listdir(dirpath)):
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename), load_order.get(filename)))

res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]

Expand Down
157 changes: 127 additions & 30 deletions modules/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import sys
import inspect
from pathlib import Path
from collections import namedtuple
from dataclasses import dataclass

Expand Down Expand Up @@ -115,7 +116,6 @@ def setup(self, p, *args):
"""
pass


def before_process(self, p, *args):
"""
This function is called very early during processing begins for AlwaysVisible scripts.
Expand Down Expand Up @@ -242,10 +242,10 @@ def on_before_component(self, callback, *, elem_id):
"""
Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.

May be called in show() or ui() - but it may be too late in latter as some components may already be created.
May be called in show() or ui() - but it may be too late in the latter as some components may already be created.

This function is an alternative to before_component in that it also cllows to run before a component is created, but
it doesn't require to be called for every created component - just for the one you need.
This function is an alternative to before_component in that it is also called before a component is created, but
it doesn't require being called for every created component - just for the one you need.
"""
if self.on_before_component_elem_id is None:
self.on_before_component_elem_id = []
Expand Down Expand Up @@ -305,7 +305,7 @@ def basedir():
return current_basedir


ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path", "load_order"])

scripts_data = []
postprocessing_scripts_data = []
Expand All @@ -318,7 +318,7 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True):
basedir = os.path.join(paths.script_path, scriptdirname)
if os.path.exists(basedir):
for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename), None))

if include_extensions:
for ext in extensions.active():
Expand All @@ -345,6 +345,42 @@ def list_files_with_name(filename):
return res


def get_short_path(path: Path):
"""Returns the relative path if input path is sub-dir of data_path or script_path"""
try:
return path.relative_to(paths.data_path)
except ValueError:
try:
return path.relative_to(paths.script_path)
except ValueError:
return path


script_default_order = {
paths.extensions_dir: 70000,
paths.extensions_builtin_dir: 50000,
os.path.join(paths.script_path, 'scripts'): 30000,
paths.script_path: 1000000
}


def assign_script_default_order(file_path):
"""Assign load order base on script type
Args:
file_path: path to script file
Returns: order number
internal webui scripts: 10000
built-in scripts: 30000
built-in extensions: 50000
extension: 70000
other: 1000000 (should never reach)
"""
for key in script_default_order:
if file_path.is_relative_to(key):
return script_default_order[key]
return 1000000


def load_scripts():
global current_basedir
scripts_data.clear()
Expand All @@ -365,15 +401,24 @@ def register_scripts_from_module(module):
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))

def orderby(basedir):
# 1st webui, 2nd extensions-builtin, 3rd extensions
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
for key in priority:
if basedir.startswith(key):
return priority[key]
return 9999
def get_script_load_order(script_file):
"""get script file load order based on the following priority
1. User defined load order shared.opts.script_order_override {key: order}
key: object_id : "short_path"
order: number
2. Default load order specified by extension properties
3. Assign order using assign_script_default_order()

if multiple scripts file have the same load order then the object_id will be used for comparison
"""
path = Path(script_file.path)
object_id = get_short_path(path)
order = shared.opts.script_order_override.get(object_id) or script_file.load_order
if order is None:
return [assign_script_default_order(path), object_id]
return [order, object_id]

for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
for scriptfile in sorted(scripts_list, key=get_script_load_order):
try:
if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path
Expand Down Expand Up @@ -406,6 +451,23 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
return default


def get_function_execution_order(script_path, script_class, function_name):
"""get execution order of function_name for class based on the following priority
1. User defined execution order shared.opts.script_order_override {key: order}
key: object_id: "short_path>ScriptClassName>function_name"
order: number
2. Default execution order specified by function attribute function_name.order
3. Assign order using assign_script_default_order()

if multiple scripts file have the same load order then the short_path will be used for comparison
"""
path = Path(script_path)
object_id = f'{get_short_path(path)}>{script_class.__name__}>{function_name}'
order = shared.opts.script_order_override.get(object_id) or getattr(getattr(script_class, function_name), 'order', None)
if order is None:
return [assign_script_default_order(path), object_id]
return [order, object_id]

class ScriptRunner:
def __init__(self):
self.scripts = []
Expand All @@ -416,13 +478,46 @@ def __init__(self):
self.infotext_fields = []
self.paste_field_names = []
self.inputs = [None]
self.script_callback_map = {}

self.on_before_component_elem_id = {}
"""dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""

self.on_after_component_elem_id = {}
"""dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""

alwayson_scripts_callbacks = [
'ui',
'setup',
'before_process',
'process',
'before_process_batch',
'after_extra_networks_activate',
'process_batch',
'before_hr',
'postprocess',
'postprocess_batch',
'postprocess_batch_list',
'postprocess_image',
]

base_scripts_callbacks = [
'before_component',
'after_component',
]

def init_script_callback_map(self):
self.alwayson_scripts.sort(key=lambda x: get_function_execution_order(x.filename, x.__class__, 'title'))
self.selectable_scripts.sort(key=lambda x: get_function_execution_order(x.filename, x.__class__, 'title'))
self.scripts.sort(key=lambda x: get_function_execution_order(x.filename, x.__class__, 'title'))

self.script_callback_map = {
key: sorted(self.alwayson_scripts, key=lambda x: get_function_execution_order(x.filename, x.__class__, key)) for key in self.alwayson_scripts_callbacks
}
self.script_callback_map.update({
key: sorted(self.scripts, key=lambda x: get_function_execution_order(x.filename, x.__class__, key)) for key in self.base_scripts_callbacks
})

def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing

Expand All @@ -432,7 +527,8 @@ def initialize_scripts(self, is_img2img):

auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()

for script_data in auto_processing_scripts + scripts_data:
script_list = auto_processing_scripts + scripts_data
for script_data in sorted(script_list, key=lambda x: get_function_execution_order(x.path, x.script_class, 'show')):
script = script_data.script_class()
script.filename = script_data.path
script.is_txt2img = not is_img2img
Expand All @@ -450,6 +546,7 @@ def initialize_scripts(self, is_img2img):
self.scripts.append(script)
self.selectable_scripts.append(script)

self.init_script_callback_map()
self.apply_on_before_component_callbacks()

def apply_on_before_component_callbacks(self):
Expand Down Expand Up @@ -520,7 +617,7 @@ def create_script_ui(self, script):

def setup_ui_for_section(self, section, scriptlist=None):
if scriptlist is None:
scriptlist = self.alwayson_scripts
scriptlist = self.script_callback_map['ui']

for script in scriptlist:
if script.alwayson and script.section != section:
Expand Down Expand Up @@ -550,7 +647,7 @@ def setup_ui(self):
self.setup_ui_for_section(None, self.selectable_scripts)

def select_script(script_index):
selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
selected_script = self.selectable_scripts[script_index - 1] if script_index > 0 else None

return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]

Expand Down Expand Up @@ -609,71 +706,71 @@ def run(self, p, *args):
return processed

def before_process(self, p):
for script in self.alwayson_scripts:
for script in self.script_callback_map['before_process']:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process(p, *script_args)
except Exception:
errors.report(f"Error running before_process: {script.filename}", exc_info=True)

def process(self, p):
for script in self.alwayson_scripts:
for script in self.script_callback_map['process']:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
errors.report(f"Error running process: {script.filename}", exc_info=True)

def before_process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
for script in self.script_callback_map['before_process_batch']:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.before_process_batch(p, *script_args, **kwargs)
except Exception:
errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True)

def after_extra_networks_activate(self, p, **kwargs):
for script in self.alwayson_scripts:
for script in self.script_callback_map['after_extra_networks_activate']:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.after_extra_networks_activate(p, *script_args, **kwargs)
except Exception:
errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True)

def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
for script in self.script_callback_map['process_batch']:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
errors.report(f"Error running process_batch: {script.filename}", exc_info=True)

def postprocess(self, p, processed):
for script in self.alwayson_scripts:
for script in self.script_callback_map['postprocess']:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args)
except Exception:
errors.report(f"Error running postprocess: {script.filename}", exc_info=True)

def postprocess_batch(self, p, images, **kwargs):
for script in self.alwayson_scripts:
for script in self.script_callback_map['postprocess_batch']:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch(p, *script_args, images=images, **kwargs)
except Exception:
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)

def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
for script in self.alwayson_scripts:
for script in self.script_callback_map['postprocess_batch_list']:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch_list(p, pp, *script_args, **kwargs)
except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)

def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
for script in self.script_callback_map['postprocess_image']:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_image(p, pp, *script_args)
Expand All @@ -687,7 +784,7 @@ def before_component(self, component, **kwargs):
except Exception:
errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)

for script in self.scripts:
for script in self.script_callback_map['before_component']:
try:
script.before_component(component, **kwargs)
except Exception:
Expand All @@ -700,7 +797,7 @@ def after_component(self, component, **kwargs):
except Exception:
errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)

for script in self.scripts:
for script in self.script_callback_map['after_component']:
try:
script.after_component(component, **kwargs)
except Exception:
Expand Down Expand Up @@ -728,15 +825,15 @@ def reload_sources(self, cache):
self.scripts[si].args_to = args_to

def before_hr(self, p):
for script in self.alwayson_scripts:
for script in self.script_callback_map['before_hr']:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.before_hr(p, *script_args)
except Exception:
errors.report(f"Error running before_hr: {script.filename}", exc_info=True)

def setup_scrips(self, p, *, is_ui=True):
for script in self.alwayson_scripts:
for script in self.script_callback_map['setup']:
if not is_ui and script.setup_for_ui_only:
continue

Expand Down
1 change: 1 addition & 0 deletions modules/shared_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,4 +338,5 @@
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
"script_order_override": OptionInfo({}, "Override default load and callback order of scripts/extensions"),
}))