diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e37852bcc..f6026f1373 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#2479](https://github.com/plotly/dash/pull/2479) Fix `KeyError` "Callback function not found for output [...], , perhaps you forgot to prepend the '@'?" issue when using duplicate callbacks targeting the same output. This issue would occur when the app is restarted or when running with multiple `gunicorn` workers. - [#2471](https://github.com/plotly/dash/pull/2471) Fix `allow_duplicate` output with clientside callback, fix [#2467](https://github.com/plotly/dash/issues/2467) +- [#2473](https://github.com/plotly/dash/pull/2473) Fix background callbacks with different outputs but same function, fix [#2221](https://github.com/plotly/dash/issues/2221) ## [2.9.1] - 2023-03-17 diff --git a/dash/_callback.py b/dash/_callback.py index 666a3dccef..879d16f43b 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -316,7 +316,9 @@ def wrap_func(func): if long is not None: long_key = BaseLongCallbackManager.register_func( - func, long.get("progress") is not None + func, + long.get("progress") is not None, + callback_id, ) @wraps(func) diff --git a/dash/long_callback/managers/__init__.py b/dash/long_callback/managers/__init__.py index b7bf175c4f..cf5dcd2182 100644 --- a/dash/long_callback/managers/__init__.py +++ b/dash/long_callback/managers/__init__.py @@ -36,7 +36,7 @@ def terminate_unhealthy_job(self, job): def job_running(self, job): raise NotImplementedError - def make_job_fn(self, fn, progress): + def make_job_fn(self, fn, progress, key=None): raise NotImplementedError def call_job_fn(self, key, job_fn, args, context): @@ -76,11 +76,11 @@ def build_cache_key(self, fn, args, cache_args_to_ignore): return hashlib.sha1(str(hash_dict).encode("utf-8")).hexdigest() def register(self, key, fn, progress): - self.func_registry[key] = self.make_job_fn(fn, progress) + self.func_registry[key] = self.make_job_fn(fn, progress, key) @staticmethod - def register_func(fn, progress): - key = BaseLongCallbackManager.hash_function(fn) + def register_func(fn, progress, callback_id): + key = BaseLongCallbackManager.hash_function(fn, callback_id) BaseLongCallbackManager.functions.append( ( key, @@ -99,7 +99,9 @@ def _make_progress_key(key): return key + "-progress" @staticmethod - def hash_function(fn): + def hash_function(fn, callback_id=""): fn_source = inspect.getsource(fn) fn_str = fn_source - return hashlib.sha1(fn_str.encode("utf-8")).hexdigest() + return hashlib.sha1( + callback_id.encode("utf-8") + fn_str.encode("utf-8") + ).hexdigest() diff --git a/dash/long_callback/managers/celery_manager.py b/dash/long_callback/managers/celery_manager.py index b4e0165c8d..5090d42178 100644 --- a/dash/long_callback/managers/celery_manager.py +++ b/dash/long_callback/managers/celery_manager.py @@ -1,6 +1,4 @@ import json -import inspect -import hashlib import traceback from contextvars import copy_context @@ -78,8 +76,8 @@ def job_running(self, job): "PROGRESS", ) - def make_job_fn(self, fn, progress): - return _make_job_fn(fn, self.handle, progress) + def make_job_fn(self, fn, progress, key=None): + return _make_job_fn(fn, self.handle, progress, key) def get_task(self, job): if job: @@ -127,15 +125,10 @@ def get_result(self, key, job): return result -def _make_job_fn(fn, celery_app, progress): +def _make_job_fn(fn, celery_app, progress, key): cache = celery_app.backend - # Hash function source and module to create a unique (but stable) celery task name - fn_source = inspect.getsource(fn) - fn_str = fn_source - fn_hash = hashlib.sha1(fn_str.encode("utf-8")).hexdigest() - - @celery_app.task(name=f"long_callback_{fn_hash}") + @celery_app.task(name=f"long_callback_{key}") def job_fn(result_key, progress_key, user_callback_args, context=None): def _set_progress(progress_value): if not isinstance(progress_value, (list, tuple)): diff --git a/dash/long_callback/managers/diskcache_manager.py b/dash/long_callback/managers/diskcache_manager.py index fea69f37f5..22fb7bd5c5 100644 --- a/dash/long_callback/managers/diskcache_manager.py +++ b/dash/long_callback/managers/diskcache_manager.py @@ -107,7 +107,7 @@ def job_running(self, job): return proc.status() != psutil.STATUS_ZOMBIE return False - def make_job_fn(self, fn, progress): + def make_job_fn(self, fn, progress, key=None): return _make_job_fn(fn, self.handle, progress) def clear_cache_entry(self, key): diff --git a/tests/integration/long_callback/app_diff_outputs.py b/tests/integration/long_callback/app_diff_outputs.py new file mode 100644 index 0000000000..4294394747 --- /dev/null +++ b/tests/integration/long_callback/app_diff_outputs.py @@ -0,0 +1,36 @@ +from dash import Dash, Input, Output, html + +from tests.integration.long_callback.utils import get_long_callback_manager + +long_callback_manager = get_long_callback_manager() +handle = long_callback_manager.handle + +app = Dash(__name__, long_callback_manager=long_callback_manager) + +app.layout = html.Div( + [ + html.Button("click 1", id="button-1"), + html.Button("click 2", id="button-2"), + html.Div(id="output-1"), + html.Div(id="output-2"), + ] +) + + +def gen_callback(index): + @app.callback( + Output(f"output-{index}", "children"), + Input(f"button-{index}", "n_clicks"), + background=True, + prevent_initial_call=True, + ) + def callback_name(_): + return f"Clicked on {index}" + + +for i in range(1, 3): + gen_callback(i) + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/tests/integration/long_callback/test_basic_long_callback.py b/tests/integration/long_callback/test_basic_long_callback.py index e82e4dc628..93cc56c04d 100644 --- a/tests/integration/long_callback/test_basic_long_callback.py +++ b/tests/integration/long_callback/test_basic_long_callback.py @@ -547,3 +547,12 @@ def test_lcbc014_progress_delete(dash_duo, manager): dash_duo.wait_for_text_to_equal("#output", "done") assert dash_duo.find_element("#progress-counter").text == "2" + + +def test_lcbc015_diff_outputs_same_func(dash_duo, manager): + with setup_long_callback_app(manager, "app_diff_outputs") as app: + dash_duo.start_server(app) + + for i in range(1, 3): + dash_duo.find_element(f"#button-{i}").click() + dash_duo.wait_for_text_to_equal(f"#output-{i}", f"Clicked on {i}")