From c457ab7bee35457a2ecc1536beed49ad6f033bb7 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 08:24:23 -0400 Subject: [PATCH 001/166] initial push --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a6a2224907..4120ea8f93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ All notable changes to `dash` will be documented in this file. This project adheres to [Semantic Versioning](https://semver.org/). +## [bringyourownserver] +- [#3430] Adds support to bring your own server, eg (Quart, FastAPI, etc). + ## [UNRELEASED] ## Added From 4ebc657a49e6440d9977bc08c1afcc8916945230 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:00:32 -0400 Subject: [PATCH 002/166] work to modularize the dash eco-system and decouple from Flask --- dash/_callback.py | 19 +- dash/dash.py | 257 ++++++----------------- dash/server_factories/__init__.py | 10 + dash/server_factories/base_factory.py | 47 +++++ dash/server_factories/fastapi_factory.py | 226 ++++++++++++++++++++ dash/server_factories/flask_factory.py | 188 +++++++++++++++++ 6 files changed, 550 insertions(+), 197 deletions(-) create mode 100644 dash/server_factories/__init__.py create mode 100644 dash/server_factories/base_factory.py create mode 100644 dash/server_factories/fastapi_factory.py create mode 100644 dash/server_factories/flask_factory.py diff --git a/dash/_callback.py b/dash/_callback.py index aacb8dbdde..bca8027fdd 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -6,7 +6,7 @@ import asyncio -import flask +from dash.server_factories import get_request_adapter from .dependencies import ( handle_callback_args, @@ -376,7 +376,7 @@ def _get_callback_manager( " and store results on redis.\n" ) - old_job = flask.request.args.getlist("oldJob") + old_job = get_request_adapter().get_args().getlist("oldJob") if old_job: for job in old_job: @@ -436,7 +436,7 @@ def _setup_background_callback( def _progress_background_callback(response, callback_manager, background): progress_outputs = background.get("progress") - cache_key = flask.request.args.get("cacheKey") + cache_key = get_request_adapter().get_args().get("cacheKey") if progress_outputs: # Get the progress before the result as it would be erased after the results. @@ -453,8 +453,8 @@ def _update_background_callback( """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) - cache_key = flask.request.args.get("cacheKey") - job_id = flask.request.args.get("job") + cache_key = get_request_adapter().get_args().get("cacheKey") + job_id = get_request_adapter().get_args().get("job") _progress_background_callback(response, callback_manager, background) @@ -474,8 +474,8 @@ def _handle_rest_background_callback( multi, has_update=False, ): - cache_key = flask.request.args.get("cacheKey") - job_id = flask.request.args.get("job") + cache_key = get_request_adapter().get_args().get("cacheKey") + job_id = get_request_adapter().get_args().get("job") # Must get job_running after get_result since get_results terminates it. job_running = callback_manager.job_running(job_id) if not job_running and output_value is callback_manager.UNDEFINED: @@ -688,11 +688,10 @@ def add_context(*args, **kwargs): ) response: dict = {"multi": True} - jsonResponse = None try: if background is not None: - if not flask.request.args.get("cacheKey"): + if not get_request_adapter().get_args().get("cacheKey"): return _setup_background_callback( kwargs, background, @@ -763,7 +762,7 @@ async def async_add_context(*args, **kwargs): try: if background is not None: - if not flask.request.args.get("cacheKey"): + if not get_request_adapter().get_args().get("cacheKey"): return _setup_background_callback( kwargs, background, diff --git a/dash/dash.py b/dash/dash.py index 8430259c27..56bf65c9e6 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -17,6 +17,7 @@ import hashlib import base64 import traceback +import inspect from urllib.parse import urlparse from typing import Any, Callable, Dict, Optional, Union, Sequence, Literal, List @@ -67,6 +68,8 @@ from . import _validate from . import _watch from . import _get_app +from .server_factories.flask_factory import FlaskServerFactory +from .server_factories.base_factory import BaseServerFactory from ._get_app import with_app_context, with_app_context_async, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group @@ -421,7 +424,7 @@ class Dash(ObsoleteChecker): _plotlyjs_url: str STARTUP_ROUTES: list = [] - server: flask.Flask + server: Any # Layout is a complex type which can be many things _layout: Any @@ -430,7 +433,7 @@ class Dash(ObsoleteChecker): def __init__( # pylint: disable=too-many-statements self, name: Optional[str] = None, - server: Union[bool, flask.Flask] = True, + server: Union[bool, Callable[[], Any]] = True, assets_folder: str = "assets", pages_folder: str = "pages", use_pages: Optional[bool] = None, @@ -466,6 +469,7 @@ def __init__( # pylint: disable=too-many-statements description: Optional[str] = None, on_error: Optional[Callable[[Exception], Any]] = None, use_async: Optional[bool] = None, + server_factory: Optional[BaseServerFactory] = None, **obsolete, ): @@ -488,16 +492,23 @@ def __init__( # pylint: disable=too-many-statements caller_name: str = name if name is not None else get_caller_name() + self.server_factory = server_factory or FlaskServerFactory() + # We have 3 cases: server is either True (we create the server), False # (defer server creation) or a Flask app instance (we use their server) - if isinstance(server, flask.Flask): + if callable(server) and not (hasattr(server, 'route') and hasattr(server, 'run')): + # Server factory function + self.server = server() + if name is None: + caller_name = getattr(self.server, "name", caller_name) + elif hasattr(server, 'route') and hasattr(server, 'run'): self.server = server if name is None: caller_name = getattr(server, "name", caller_name) elif isinstance(server, bool): - self.server = flask.Flask(caller_name) if server else None # type: ignore + self.server = self.server_factory.create_app(caller_name) if server else None else: - raise ValueError("server must be a Flask app or a boolean") + raise ValueError("server must be a Flask app, a boolean, or a server factory function") base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix @@ -671,11 +682,8 @@ def _setup_hooks(self): if self._hooks.get_hooks("error"): self._on_error = self._hooks.HookErrorHandler(self._on_error) - def init_app(self, app: Optional[flask.Flask] = None, **kwargs) -> None: - """Initialize the parts of Dash that require a flask app.""" - + def init_app(self, app: Optional[Any] = None, **kwargs) -> None: config = self.config - config.update(kwargs) config.set_read_only( [ @@ -685,89 +693,58 @@ def init_app(self, app: Optional[flask.Flask] = None, **kwargs) -> None: ], "Read-only: can only be set in the Dash constructor or during init_app()", ) - if app is not None: self.server = app - bp_prefix = config.routes_pathname_prefix.replace("/", "_").replace(".", "_") assets_blueprint_name = f"{bp_prefix}dash_assets" - - self.server.register_blueprint( - flask.Blueprint( - assets_blueprint_name, - config.name, - static_folder=self.config.assets_folder, - static_url_path=config.routes_pathname_prefix - + self.config.assets_url_path.lstrip("/"), - ) + self.server_factory.register_assets_blueprint( + self.server, + assets_blueprint_name, + config.routes_pathname_prefix + self.config.assets_url_path.lstrip("/"), + self.config.assets_folder, ) - if config.compress: try: - # pylint: disable=import-outside-toplevel - from flask_compress import Compress # type: ignore[reportMissingImports] - - # gzip + from flask_compress import Compress Compress(self.server) - _flask_compress_version = parse_version( _get_distribution_version("flask_compress") ) - if not hasattr( self.server.config, "COMPRESS_ALGORITHM" ) and _flask_compress_version >= parse_version("1.6.0"): - # flask-compress==1.6.0 changed default to ['br', 'gzip'] - # and non-overridable default compression with Brotli is - # causing performance issues self.server.config["COMPRESS_ALGORITHM"] = ["gzip"] except ImportError as error: raise ImportError( "To use the compress option, you need to install dash[compress]" ) from error - - @self.server.errorhandler(PreventUpdate) - def _handle_error(_): - """Handle a halted callback and return an empty 204 response.""" - return "", 204 - - self.server.before_request(self._setup_server) - - # add a handler for components suites errors to return 404 - self.server.errorhandler(InvalidResourceError)(self._invalid_resources_handler) - + self.server_factory.register_error_handlers(self.server) + self.server_factory.before_request(self.server, self._setup_server) self._setup_routes() - _get_app.APP = self self.enable_pages() - self._setup_plotlyjs() def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> None: full_name = self.config.routes_pathname_prefix + name - - self.server.add_url_rule( - full_name, view_func=view_func, endpoint=full_name, methods=list(methods) + self.server_factory.add_url_rule( + self.server, + full_name, + view_func=view_func, + endpoint=full_name, + methods=list(methods), ) - - # record the url in Dash.routes so that it can be accessed later - # e.g. for adding authentication with flask_login self.routes.append(full_name) def _setup_routes(self): - self._add_url( - "_dash-component-suites//", - self.serve_component_suites, - ) + self.server_factory.setup_component_suites(self.server, self) self._add_url("_dash-layout", self.serve_layout) self._add_url("_dash-dependencies", self.dependencies) - if self._use_async: - self._add_url("_dash-update-component", self.async_dispatch, ["POST"]) - else: - self._add_url("_dash-update-component", self.dispatch, ["POST"]) + self._add_url("_dash-update-component", self.server_factory.dispatch(self.server, self, self._use_async), ["POST"]) self._add_url("_reload-hash", self.serve_reload_hash) - self._add_url("_favicon.ico", self._serve_default_favicon) - self._add_url("", self.index) + self._add_url("_favicon.ico", self.server_factory._serve_default_favicon) + self.server_factory.setup_index(self.server, self) + self.server_factory.setup_catchall(self.server, self) if jupyter_dash.active: self._add_url( @@ -781,8 +758,6 @@ def _setup_routes(self): hook.data["methods"], ) - # catch-all for front-end routes, used by dcc.Location - self._add_url("", self.index) def setup_apis(self): """ @@ -902,7 +877,7 @@ def serve_layout(self): layout = hook(layout) # TODO - Set browser cache limit - pass hash into frontend - return flask.Response( + return self.server_factory.make_response( to_json(layout), mimetype="application/json", ) @@ -966,7 +941,7 @@ def serve_reload_hash(self): _reload.hard = False _reload.changed_assets = [] - return flask.jsonify( + return self.server_factory.jsonify( { "reloadHash": _hash, "hard": hard, @@ -1159,54 +1134,12 @@ def _generate_meta(self): return meta_tags + self.config.meta_tags - # Serve the JS bundles for each package - def serve_component_suites(self, package_name, fingerprinted_path): - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) - - _validate.validate_js_path(self.registered_paths, package_name, path_in_pkg) - - extension = "." + path_in_pkg.split(".")[-1] - mimetype = mimetypes.types_map.get(extension, "application/octet-stream") - - package = sys.modules[package_name] - self.logger.debug( - "serving -- package: %s[%s] resource: %s => location: %s", - package_name, - package.__version__, - path_in_pkg, - package.__path__, - ) - - response = flask.Response( - pkgutil.get_data(package_name, path_in_pkg), mimetype=mimetype - ) - - if has_fingerprint: - # Fingerprinted resources are good forever (1 year) - # No need for ETag as the fingerprint changes with each build - response.cache_control.max_age = 31536000 # 1 year - else: - # Non-fingerprinted resources are given an ETag that - # will be used / check on future requests - response.add_etag() - tag = response.get_etag()[0] - - request_etag = flask.request.headers.get("If-None-Match") - - if f'"{tag}"' == request_etag: - response = flask.Response(None, status=304) - - return response - - @with_app_context - def index(self, *args, **kwargs): # pylint: disable=unused-argument + def render_index(self, *args, **kwargs): scripts = self._generate_scripts_html() css = self._generate_css_dist_html() config = self._generate_config_html() metas = self._generate_meta() renderer = self._generate_renderer() - - # use self.title instead of app.config.title for backwards compatibility title = self.title if self.use_pages and self.config.include_pages_meta: @@ -1314,7 +1247,7 @@ def interpolate_index(self, **kwargs): @with_app_context def dependencies(self): - return flask.Response( + return self.server_factory.make_response( to_json(self._callback_list), content_type="application/json", ) @@ -1417,8 +1350,11 @@ def callback(self, *_args, **_kwargs) -> Callable[..., Any]: **_kwargs, ) + def _inputs_to_vals(self, inputs): + return inputs_to_vals(inputs) + # pylint: disable=R0915 - def _initialize_context(self, body): + def _initialize_context(self, body, adapter): """Initialize the global context for the request.""" g = AttributeDict({}) g.inputs_list = body.get("inputs", []) @@ -1430,12 +1366,12 @@ def _initialize_context(self, body): {"prop_id": x, "value": g.input_values.get(x)} for x in body.get("changedPropIds", []) ] - g.dash_response = flask.Response(mimetype="application/json") - g.cookies = dict(**flask.request.cookies) - g.headers = dict(**flask.request.headers) - g.path = flask.request.full_path - g.remote = flask.request.remote_addr - g.origin = flask.request.origin + g.dash_response = self.server_factory.make_response(mimetype="application/json", data=None) + g.cookies = dict(adapter.get_cookies()) + g.headers = dict(adapter.get_headers()) + g.path = adapter.get_full_path() + g.remote = adapter.get_remote_addr() + g.origin = adapter.get_origin() g.updated_props = {} return g @@ -1499,11 +1435,6 @@ def _prepare_grouping(self, data_list, indices): def _execute_callback(self, func, args, outputs_list, g): """Execute the callback with the prepared arguments.""" - g.cookies = dict(**flask.request.cookies) - g.headers = dict(**flask.request.headers) - g.path = flask.request.full_path - g.remote = flask.request.remote_addr - g.origin = flask.request.origin g.custom_data = AttributeDict({}) for hook in self._hooks.get_hooks("custom_data"): @@ -1522,47 +1453,6 @@ def _execute_callback(self, func, args, outputs_list, g): ) return partial_func - @with_app_context_async - async def async_dispatch(self): - body = flask.request.get_json() - g = self._initialize_context(body) - func = self._prepare_callback(g, body) - args = inputs_to_vals(g.inputs_list + g.states_list) - - ctx = copy_context() - partial_func = self._execute_callback(func, args, g.outputs_list, g) - if asyncio.iscoroutine(func): - response_data = await ctx.run(partial_func) - else: - response_data = ctx.run(partial_func) - - if asyncio.iscoroutine(response_data): - response_data = await response_data - - g.dash_response.set_data(response_data) - return g.dash_response - - @with_app_context - def dispatch(self): - body = flask.request.get_json() - g = self._initialize_context(body) - func = self._prepare_callback(g, body) - args = inputs_to_vals(g.inputs_list + g.states_list) - - ctx = copy_context() - partial_func = self._execute_callback(func, args, g.outputs_list, g) - response_data = ctx.run(partial_func) - - if asyncio.iscoroutine(response_data): - raise Exception( - "You are trying to use a coroutine without dash[async]. " - "Please install the dependencies via `pip install dash[async]` and ensure " - "that `use_async=False` is not being passed to the app." - ) - - g.dash_response.set_data(response_data) - return g.dash_response - def _setup_server(self): if self._got_first_request["setup_server"]: return @@ -1695,12 +1585,6 @@ def _walk_assets_directory(self): def _invalid_resources_handler(err): return err.args[0], 404 - @staticmethod - def _serve_default_favicon(): - return flask.Response( - pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" - ) - def csp_hashes(self, hash_algorithm="sha256") -> Sequence[str]: """Calculates CSP hashes (sha + base64) of all inline scripts, such that one of the biggest benefits of CSP (disallowing general inline scripts) @@ -2112,14 +1996,19 @@ def enable_dev_tools( elif dev_tools.prune_errors: secret = gen_salt(20) - @self.server.errorhandler(Exception) - def _wrap_errors(error): - # find the callback invocation, if the error is from a callback - # and skip the traceback up to that point - # if the error didn't come from inside a callback, we won't - # skip anything. - tb = _get_traceback(secret, error) - return tb, 500 + if hasattr(self.server, "errorhandler"): + # Flask + @self.server.errorhandler(Exception) + def _wrap_errors(error): + tb = _get_traceback(secret, error) + return tb, 500 + elif hasattr(self.server, "exception_handler"): + # FastAPI + @self.server.exception_handler(Exception) + async def _wrap_errors(request, error): + tb = _get_traceback(secret, error) + from fastapi.responses import PlainTextResponse + return PlainTextResponse(tb, status_code=500) if debug and dev_tools.ui: @@ -2149,9 +2038,8 @@ def _after_request(response): return response - self.server.before_request(_before_request) - - self.server.after_request(_after_request) + self.server_factory.before_request(self.server, _before_request) + self.server_factory.after_request(self.server, _after_request) if ( debug @@ -2435,7 +2323,7 @@ def verify_url_part(served_part, url_part, part_name): server_url=jupyter_server_url, ) else: - self.server.run(host=host, port=port, debug=debug, **flask_run_options) + self.server_factory.run(self.server, host=host, port=port, debug=debug, **flask_run_options) def enable_pages(self) -> None: if not self.use_pages: @@ -2495,7 +2383,7 @@ async def update(pathname_, search_, **states): ) if callable(title): title = await execute_async_function( - title, **(path_variables or {}) + title, **{**(path_variables or {})} ) return layout, {"title": title} @@ -2559,7 +2447,7 @@ def update(pathname_, search_, **states): **{**(path_variables or {}), **query_parameters, **states} ) if callable(title): - title = title(**(path_variables or {})) + title = title(**{**(path_variables or {})}) return layout, {"title": title} @@ -2599,10 +2487,5 @@ def update(pathname_, search_, **states): Input(_ID_STORE, "data"), ) - def __call__(self, environ, start_response): - """ - This method makes instances of Dash WSGI-compliant callables. - It delegates the actual WSGI handling to the internal Flask app's - __call__ method. - """ - return self.server(environ, start_response) + def __call__(self, *args, **kwargs): + return self.server_factory.__call__(self.server, *args, **kwargs) diff --git a/dash/server_factories/__init__.py b/dash/server_factories/__init__.py new file mode 100644 index 0000000000..7d9874ec7a --- /dev/null +++ b/dash/server_factories/__init__.py @@ -0,0 +1,10 @@ +# python +import contextvars + +_request_adapter_var = contextvars.ContextVar("request_adapter") + +def set_request_adapter(adapter): + _request_adapter_var.set(adapter) + +def get_request_adapter(): + return _request_adapter_var.get() diff --git a/dash/server_factories/base_factory.py b/dash/server_factories/base_factory.py new file mode 100644 index 0000000000..f429357e03 --- /dev/null +++ b/dash/server_factories/base_factory.py @@ -0,0 +1,47 @@ +from abc import ABC, abstractmethod + +class BaseServerFactory(ABC): + def __call__(self, server, *args, **kwargs): + # Default: WSGI + return server(*args, **kwargs) + + @abstractmethod + def create_app(self, name="__main__", config=None): + pass + + @abstractmethod + def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + pass + + @abstractmethod + def register_error_handlers(self, app): + pass + + @abstractmethod + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + pass + + @abstractmethod + def before_request(self, app, func): + pass + + @abstractmethod + def after_request(self, app, func): + pass + + @abstractmethod + def run(self, app, host, port, debug, **kwargs): + pass + + @abstractmethod + def make_response(self, data, mimetype=None, content_type=None): + pass + + @abstractmethod + def jsonify(self, obj): + pass + + @abstractmethod + def get_request_adapter(self): + pass + diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py new file mode 100644 index 0000000000..7592a51ce3 --- /dev/null +++ b/dash/server_factories/fastapi_factory.py @@ -0,0 +1,226 @@ +import traceback + +from fastapi import FastAPI, Request, Response, APIRouter +from fastapi.responses import JSONResponse +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter, get_request_adapter +from .base_factory import BaseServerFactory +import inspect +import pkgutil +from contextvars import copy_context + +class FastAPIServerFactory(BaseServerFactory): + def __call__(self, server, *args, **kwargs): + # ASGI: (scope, receive, send) + if ( + len(args) == 3 + and isinstance(args[0], dict) + and "type" in args[0] + ): + return server(*args, **kwargs) + raise TypeError("FastAPI app must be called with (scope, receive, send)") + + + def create_app(self, name="__main__", config=None): + app = FastAPI() + if config: + for key, value in config.items(): + setattr(app.state, key, value) + return app + + def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + from fastapi.staticfiles import StaticFiles + try: + app.mount(assets_url_path, StaticFiles(directory=assets_folder), name=blueprint_name) + except RuntimeError: + # directory doesnt exist + pass + + def register_error_handlers(self, app): + @app.exception_handler(PreventUpdate) + async def _handle_error(request: Request, exc: PreventUpdate): + return Response(status_code=204) + + @app.exception_handler(InvalidResourceError) + async def _invalid_resources_handler(request: Request, exc: InvalidResourceError): + return Response(content=exc.args[0], status_code=404) + + def _html_response_wrapper(self, view_func): + async def wrapped(*args, **kwargs): + # If view_func is a function, call it; if it's a string, use it directly + html = view_func() if callable(view_func) else view_func + return Response(content=html, media_type="text/html") + + return wrapped + + def setup_index(self, app, dash_app): + async def index(): + return Response(content=dash_app.render_index(), media_type="text/html") + self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + + def setup_catchall(self, app, dash_app): + async def catchall(path: str): + return Response(content=dash_app.render_index(), media_type="text/html") + + # self.add_url_rule(app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"]) + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + if rule == "": + rule = "/" + if isinstance(view_func, str): + # Wrap string or sync function to async FastAPI handler + view_func = self._html_response_wrapper(view_func) + app.add_api_route(rule, view_func, methods=methods or ["GET"], name=endpoint, include_in_schema=False) + + def before_request(self, app, func): + # FastAPI does not have before_request, but we can use middleware + app.middleware("http")(self._make_before_middleware(func)) + + def after_request(self, app, func): + # FastAPI does not have after_request, but we can use middleware + app.middleware("http")(self._make_after_middleware(func)) + + def run(self, app, host, port, debug, **kwargs): + import uvicorn + reload = debug + if reload: + # Assume app is created in 'main.py' as 'app' + # Adjust 'main:app' as needed for your project structure + uvicorn.run("app:app", host=host, port=port, reload=reload, **kwargs) + else: + uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) + + def make_response(self, data, mimetype=None, content_type=None): + headers = {} + if mimetype: + headers["content-type"] = mimetype + if content_type: + headers["content-type"] = content_type + return Response(content=data, headers=headers) + + def jsonify(self, obj): + return JSONResponse(content=obj) + + def get_request_adapter(self): + return FastAPIRequestAdapter + + def _make_before_middleware(self, func): + pass + async def middleware(request, call_next): + if func is not None: + if inspect.iscoroutinefunction(func): + await func() + else: + func() + response = await call_next(request) + return response + + return middleware + + def _make_after_middleware(self, func): + pass + async def middleware(request, call_next): + response = await call_next(request) + await func() + return response + return middleware + + def serve_component_suites(self, dash_app, package_name, fingerprinted_path, request): + import sys + import mimetypes + import pkgutil + from dash.fingerprint import check_fingerprint + from dash import _validate + + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + from starlette.responses import Response as StarletteResponse + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + else: + import hashlib + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if request.headers.get("if-none-match") == etag: + return StarletteResponse(status_code=304) + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + + def setup_component_suites(self, app, dash_app): + from fastapi import Request + async def serve(request: Request, package_name: str, fingerprinted_path: str): + return self.serve_component_suites(dash_app, package_name, fingerprinted_path, request) + + self.add_url_rule( + app, + "/_dash-component-suites/{package_name}/{fingerprinted_path:path}", + serve, + ) + + def dispatch(self, app, dash_app, use_async): + async def _dispatch(request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) + body = await request.json() + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): + response_data = await response_data + # Instead of set_data, return a new Response + return Response(content=response_data, media_type="application/json") + + return _dispatch + + def _serve_default_favicon(self): + return Response( + content=pkgutil.get_data("dash", "favicon.ico"), + media_type="image/x-icon" + ) + +class FastAPIRequestAdapter: + def __init__(self): + self._request = None + + def set_request(self, request: Request): + self._request = request + + def get_args(self): + return self._request.query_params + + async def get_json(self): + return await self._request.json() + + def is_json(self): + return self._request.headers.get("content-type", "").startswith("application/json") + + def get_cookies(self, request=None): + return self._request.cookies + + def get_headers(self): + return self._request.headers + + def get_full_path(self): + return str(self._request.url) + + def get_remote_addr(self): + return self._request.client.host if self._request.client else None + + def get_origin(self): + return self._request.headers.get("origin") diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py new file mode 100644 index 0000000000..82b6b266a8 --- /dev/null +++ b/dash/server_factories/flask_factory.py @@ -0,0 +1,188 @@ +import flask +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter, get_request_adapter +from .base_factory import BaseServerFactory +from contextvars import copy_context +import asyncio +import pkgutil + +class FlaskServerFactory(BaseServerFactory): + def __call__(self, server, *args, **kwargs): + # Always WSGI + return server(*args, **kwargs) + + def create_app(self, name="__main__", config=None): + app = flask.Flask(name) + if config: + app.config.update(config) + return app + + def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + bp = flask.Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + app.register_blueprint(bp) + + def register_error_handlers(self, app): + @app.errorhandler(PreventUpdate) + def _handle_error(_): + return "", 204 + + @app.errorhandler(InvalidResourceError) + def _invalid_resources_handler(err): + return err.args[0], 404 + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + app.add_url_rule(rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"]) + + def before_request(self, app, func): + app.before_request(func) + + def after_request(self, app, func): + app.after_request(func) + + def run(self, app, host, port, debug, **kwargs): + app.run(host=host, port=port, debug=debug, **kwargs) + + def make_response(self, data, mimetype=None, content_type=None): + return flask.Response(data, mimetype=mimetype, content_type=content_type) + + def jsonify(self, obj): + return flask.jsonify(obj) + + def get_request_adapter(self): + return FlaskRequestAdapter + + def setup_catchall(self, app, dash_app): + def catchall(path, *args, **kwargs): + return dash_app.index(*args, **kwargs) + self.add_url_rule(app, "/", catchall, endpoint="catchall", methods=["GET"]) + + def setup_index(self, app, dash_app): + def index(*args, **kwargs): + return dash_app.render_index(dash_app, *args, **kwargs) + + self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + + def serve_component_suites(self, dash_app, package_name, fingerprinted_path, request=None): + import sys + import mimetypes + import pkgutil + from dash.fingerprint import check_fingerprint + from dash import _validate + import flask + + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + response = flask.Response(data, mimetype=mimetype) + if has_fingerprint: + response.cache_control.max_age = 31536000 # 1 year + else: + response.add_etag() + tag = response.get_etag()[0] + request_etag = flask.request.headers.get("If-None-Match") + if f'"{tag}"' == request_etag: + response = flask.Response(None, status=304) + return response + + def setup_component_suites(self, app, dash_app): + def serve(package_name, fingerprinted_path): + return self.serve_component_suites(dash_app, package_name, fingerprinted_path, flask.request) + + self.add_url_rule( + app, + "/_dash-component-suites//", + serve, + ) + + def dispatch(self, app, dash_app, use_async=False): + def _dispatch(): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + body = flask.request.get_json() + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + raise Exception( + "You are trying to use a coroutine without dash[async]. " + "Please install the dependencies via `pip install dash[async]` and ensure " + "that `use_async=False` is not being passed to the app." + ) + g.dash_response.set_data(response_data) + return g.dash_response + + async def _dispatch_async(): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + body = flask.request.get_json() + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + response_data = await response_data + g.dash_response.set_data(response_data) + return g.dash_response + + if use_async: + _dispatch = _dispatch_async + return _dispatch + + def _serve_default_favicon(): + import flask + return flask.Response( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" + ) + +class FlaskRequestAdapter: + @staticmethod + def get_args(): + return flask.request.args + + @staticmethod + def get_json(): + return flask.request.get_json() + + @staticmethod + def is_json(): + return flask.request.is_json + + @staticmethod + def get_cookies(): + return flask.request.cookies + + @staticmethod + def get_headers(): + return flask.request.headers + + @staticmethod + def get_full_path(): + return flask.request.full_path + + @staticmethod + def get_remote_addr(): + return flask.request.remote_addr + + @staticmethod + def get_origin(): + return getattr(flask.request, 'origin', None) From 9dff79140b93e65b2076e61ff821bc324a936f00 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:28:31 -0400 Subject: [PATCH 003/166] fix favicon --- dash/server_factories/flask_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 82b6b266a8..bdcc6aef87 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -148,7 +148,7 @@ async def _dispatch_async(): _dispatch = _dispatch_async return _dispatch - def _serve_default_favicon(): + def _serve_default_favicon(self): import flask return flask.Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" From c319b18a18044bbbf1c2081731feeafbeff5fd2a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:29:37 -0400 Subject: [PATCH 004/166] removing changelog entry --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4120ea8f93..a6a2224907 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,6 @@ All notable changes to `dash` will be documented in this file. This project adheres to [Semantic Versioning](https://semver.org/). -## [bringyourownserver] -- [#3430] Adds support to bring your own server, eg (Quart, FastAPI, etc). - ## [UNRELEASED] ## Added From 7de2a41017a37ee54a7a2d6219cb4211e71b3a11 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 15:51:45 -0400 Subject: [PATCH 005/166] fixing issue with debug true for FastAPI --- dash/server_factories/fastapi_factory.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 7592a51ce3..f893e61bc6 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -8,6 +8,7 @@ import inspect import pkgutil from contextvars import copy_context +import importlib.util class FastAPIServerFactory(BaseServerFactory): def __call__(self, server, *args, **kwargs): @@ -81,12 +82,15 @@ def after_request(self, app, func): app.middleware("http")(self._make_after_middleware(func)) def run(self, app, host, port, debug, **kwargs): + frame = inspect.stack()[2] import uvicorn + reload = debug if reload: - # Assume app is created in 'main.py' as 'app' - # Adjust 'main:app' as needed for your project structure - uvicorn.run("app:app", host=host, port=port, reload=reload, **kwargs) + # Dynamically determine the module name from the file path + file_path = frame.filename + module_name = importlib.util.spec_from_file_location("app", file_path).name + uvicorn.run(f"{module_name}:app.server", host=host, port=port, reload=reload, **kwargs) else: uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) From 2cd769e51f2050aedcfd030feb3a2c4bed09938e Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 16:04:10 -0400 Subject: [PATCH 006/166] fixing `catchall` for path routes --- dash/server_factories/flask_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index bdcc6aef87..4748fa317e 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -58,7 +58,7 @@ def get_request_adapter(self): def setup_catchall(self, app, dash_app): def catchall(path, *args, **kwargs): - return dash_app.index(*args, **kwargs) + return dash_app.render_index(*args, **kwargs) self.add_url_rule(app, "/", catchall, endpoint="catchall", methods=["GET"]) def setup_index(self, app, dash_app): From 686f32f64e45904ab13059dbf1b352df28a02601 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 9 Sep 2025 16:37:30 -0400 Subject: [PATCH 007/166] fixing pages for use with `fastapi` --- dash/_pages.py | 10 +++++----- dash/dash.py | 7 +++++-- dash/server_factories/fastapi_factory.py | 23 +++++++++++++++++++---- dash/server_factories/flask_factory.py | 10 ++++++++++ 4 files changed, 39 insertions(+), 11 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index 45538546e8..b1cd0cbe69 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -389,15 +389,15 @@ def _path_to_page(path_id): return {}, None -def _page_meta_tags(app): - start_page, path_variables = _path_to_page(flask.request.path.strip("/")) +def _page_meta_tags(app, request): + request_url = request.get_path() + start_page, path_variables = _path_to_page(request_url.strip("/")) - # use the supplied image_url or create url based on image in the assets folder image = start_page.get("image", "") if image: image = app.get_asset_url(image) assets_image_url = ( - "".join([flask.request.url_root, image.lstrip("/")]) if image else None + "".join([request.url_root, image.lstrip("/")]) if image else None ) supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url @@ -413,7 +413,7 @@ def _page_meta_tags(app): return [ {"name": "description", "content": description}, {"property": "twitter:card", "content": "summary_large_image"}, - {"property": "twitter:url", "content": flask.request.url}, + {"property": "twitter:url", "content": request_url}, {"property": "twitter:title", "content": title}, {"property": "twitter:description", "content": description}, {"property": "twitter:image", "content": image_url or ""}, diff --git a/dash/dash.py b/dash/dash.py index 56bf65c9e6..18fe56acca 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -44,6 +44,7 @@ ProxyError, DuplicateCallback, ) +from .server_factories import get_request_adapter from .version import __version__ from ._configs import get_combined_config, pathname_configs, pages_folder_config from ._utils import ( @@ -1141,9 +1142,10 @@ def render_index(self, *args, **kwargs): metas = self._generate_meta() renderer = self._generate_renderer() title = self.title + request = get_request_adapter() if self.use_pages and self.config.include_pages_meta: - metas = _page_meta_tags(self) + metas + metas = _page_meta_tags(self, request) + metas if self._favicon: favicon_mod_time = os.path.getmtime( @@ -2331,7 +2333,7 @@ def enable_pages(self) -> None: if self.pages_folder: _import_layouts_from_pages(self.config.pages_folder) - @self.server.before_request + def router(): if self._got_first_request["pages"]: return @@ -2487,5 +2489,6 @@ def update(pathname_, search_, **states): Input(_ID_STORE, "data"), ) + self.server_factory.before_request(self.server, router) def __call__(self, *args, **kwargs): return self.server_factory.__call__(self.server, *args, **kwargs) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index f893e61bc6..aa4ff5d523 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -55,15 +55,27 @@ async def wrapped(*args, **kwargs): return wrapped def setup_index(self, app, dash_app): - async def index(): + async def index(request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) return Response(content=dash_app.render_index(), media_type="text/html") self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) def setup_catchall(self, app, dash_app): - async def catchall(path: str): - return Response(content=dash_app.render_index(), media_type="text/html") + @dash_app.server.on_event("startup") + def _setup_catchall(): + from fastapi import Request, Response - # self.add_url_rule(app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"]) + async def catchall(path: str, request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) + return Response(content=dash_app.render_index(), media_type="text/html") + + self.add_url_rule(app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"]) + + pass # catchall needs to be last to not override other routes def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): if rule == "": @@ -228,3 +240,6 @@ def get_remote_addr(self): def get_origin(self): return self._request.headers.get("origin") + + def get_path(self): + return self._request.url.path # <-- Add this method diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 4748fa317e..1c748e01ed 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -58,11 +58,17 @@ def get_request_adapter(self): def setup_catchall(self, app, dash_app): def catchall(path, *args, **kwargs): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(flask.request) return dash_app.render_index(*args, **kwargs) self.add_url_rule(app, "/", catchall, endpoint="catchall", methods=["GET"]) def setup_index(self, app, dash_app): def index(*args, **kwargs): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(flask.request) return dash_app.render_index(dash_app, *args, **kwargs) self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) @@ -186,3 +192,7 @@ def get_remote_addr(): @staticmethod def get_origin(): return getattr(flask.request, 'origin', None) + + @staticmethod + def get_path(): + return flask.request.path From 660e257604bc5e95681e6b7d495830c5ed5686ac Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 10 Sep 2025 10:04:52 -0400 Subject: [PATCH 008/166] fixing issue with flask pages --- dash/server_factories/flask_factory.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 1c748e01ed..9bc7929685 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -60,7 +60,6 @@ def setup_catchall(self, app, dash_app): def catchall(path, *args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) - adapter.set_request(flask.request) return dash_app.render_index(*args, **kwargs) self.add_url_rule(app, "/", catchall, endpoint="catchall", methods=["GET"]) @@ -68,7 +67,6 @@ def setup_index(self, app, dash_app): def index(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) - adapter.set_request(flask.request) return dash_app.render_index(dash_app, *args, **kwargs) self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) From 0fa5c99de789f9161be40bfe72c05e4140906281 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 10:20:21 -0400 Subject: [PATCH 009/166] fixing for lint --- dash/_pages.py | 4 +- dash/dash.py | 34 ++++++++---- dash/server_factories/__init__.py | 2 + dash/server_factories/base_factory.py | 6 ++- dash/server_factories/fastapi_factory.py | 68 +++++++++++++++++------- dash/server_factories/flask_factory.py | 26 ++++++--- 6 files changed, 102 insertions(+), 38 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index b1cd0cbe69..2a3a116324 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -396,9 +396,7 @@ def _page_meta_tags(app, request): image = start_page.get("image", "") if image: image = app.get_asset_url(image) - assets_image_url = ( - "".join([request.url_root, image.lstrip("/")]) if image else None - ) + assets_image_url = "".join([request.url_root, image.lstrip("/")]) if image else None supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url diff --git a/dash/dash.py b/dash/dash.py index 18fe56acca..f6f6e76e01 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -497,19 +497,25 @@ def __init__( # pylint: disable=too-many-statements # We have 3 cases: server is either True (we create the server), False # (defer server creation) or a Flask app instance (we use their server) - if callable(server) and not (hasattr(server, 'route') and hasattr(server, 'run')): + if callable(server) and not ( + hasattr(server, "route") and hasattr(server, "run") + ): # Server factory function self.server = server() if name is None: caller_name = getattr(self.server, "name", caller_name) - elif hasattr(server, 'route') and hasattr(server, 'run'): + elif hasattr(server, "route") and hasattr(server, "run"): self.server = server if name is None: caller_name = getattr(server, "name", caller_name) elif isinstance(server, bool): - self.server = self.server_factory.create_app(caller_name) if server else None + self.server = ( + self.server_factory.create_app(caller_name) if server else None + ) else: - raise ValueError("server must be a Flask app, a boolean, or a server factory function") + raise ValueError( + "server must be a Flask app, a boolean, or a server factory function" + ) base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix @@ -707,6 +713,7 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: if config.compress: try: from flask_compress import Compress + Compress(self.server) _flask_compress_version = parse_version( _get_distribution_version("flask_compress") @@ -741,7 +748,11 @@ def _setup_routes(self): self.server_factory.setup_component_suites(self.server, self) self._add_url("_dash-layout", self.serve_layout) self._add_url("_dash-dependencies", self.dependencies) - self._add_url("_dash-update-component", self.server_factory.dispatch(self.server, self, self._use_async), ["POST"]) + self._add_url( + "_dash-update-component", + self.server_factory.dispatch(self.server, self, self._use_async), + ["POST"], + ) self._add_url("_reload-hash", self.serve_reload_hash) self._add_url("_favicon.ico", self.server_factory._serve_default_favicon) self.server_factory.setup_index(self.server, self) @@ -759,7 +770,6 @@ def _setup_routes(self): hook.data["methods"], ) - def setup_apis(self): """ Register API endpoints for all callbacks defined using `dash.callback`. @@ -1368,7 +1378,9 @@ def _initialize_context(self, body, adapter): {"prop_id": x, "value": g.input_values.get(x)} for x in body.get("changedPropIds", []) ] - g.dash_response = self.server_factory.make_response(mimetype="application/json", data=None) + g.dash_response = self.server_factory.make_response( + mimetype="application/json", data=None + ) g.cookies = dict(adapter.get_cookies()) g.headers = dict(adapter.get_headers()) g.path = adapter.get_full_path() @@ -2004,12 +2016,14 @@ def enable_dev_tools( def _wrap_errors(error): tb = _get_traceback(secret, error) return tb, 500 + elif hasattr(self.server, "exception_handler"): # FastAPI @self.server.exception_handler(Exception) async def _wrap_errors(request, error): tb = _get_traceback(secret, error) from fastapi.responses import PlainTextResponse + return PlainTextResponse(tb, status_code=500) if debug and dev_tools.ui: @@ -2325,7 +2339,9 @@ def verify_url_part(served_part, url_part, part_name): server_url=jupyter_server_url, ) else: - self.server_factory.run(self.server, host=host, port=port, debug=debug, **flask_run_options) + self.server_factory.run( + self.server, host=host, port=port, debug=debug, **flask_run_options + ) def enable_pages(self) -> None: if not self.use_pages: @@ -2333,7 +2349,6 @@ def enable_pages(self) -> None: if self.pages_folder: _import_layouts_from_pages(self.config.pages_folder) - def router(): if self._got_first_request["pages"]: return @@ -2490,5 +2505,6 @@ def update(pathname_, search_, **states): ) self.server_factory.before_request(self.server, router) + def __call__(self, *args, **kwargs): return self.server_factory.__call__(self.server, *args, **kwargs) diff --git a/dash/server_factories/__init__.py b/dash/server_factories/__init__.py index 7d9874ec7a..1bfd497935 100644 --- a/dash/server_factories/__init__.py +++ b/dash/server_factories/__init__.py @@ -3,8 +3,10 @@ _request_adapter_var = contextvars.ContextVar("request_adapter") + def set_request_adapter(adapter): _request_adapter_var.set(adapter) + def get_request_adapter(): return _request_adapter_var.get() diff --git a/dash/server_factories/base_factory.py b/dash/server_factories/base_factory.py index f429357e03..b44f6888cb 100644 --- a/dash/server_factories/base_factory.py +++ b/dash/server_factories/base_factory.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod + class BaseServerFactory(ABC): def __call__(self, server, *args, **kwargs): # Default: WSGI @@ -10,7 +11,9 @@ def create_app(self, name="__main__", config=None): pass @abstractmethod - def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): pass @abstractmethod @@ -44,4 +47,3 @@ def jsonify(self, obj): @abstractmethod def get_request_adapter(self): pass - diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index aa4ff5d523..8d9efb2416 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -10,18 +10,14 @@ from contextvars import copy_context import importlib.util + class FastAPIServerFactory(BaseServerFactory): def __call__(self, server, *args, **kwargs): # ASGI: (scope, receive, send) - if ( - len(args) == 3 - and isinstance(args[0], dict) - and "type" in args[0] - ): + if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: return server(*args, **kwargs) raise TypeError("FastAPI app must be called with (scope, receive, send)") - def create_app(self, name="__main__", config=None): app = FastAPI() if config: @@ -29,10 +25,17 @@ def create_app(self, name="__main__", config=None): setattr(app.state, key, value) return app - def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): from fastapi.staticfiles import StaticFiles + try: - app.mount(assets_url_path, StaticFiles(directory=assets_folder), name=blueprint_name) + app.mount( + assets_url_path, + StaticFiles(directory=assets_folder), + name=blueprint_name, + ) except RuntimeError: # directory doesnt exist pass @@ -43,7 +46,9 @@ async def _handle_error(request: Request, exc: PreventUpdate): return Response(status_code=204) @app.exception_handler(InvalidResourceError) - async def _invalid_resources_handler(request: Request, exc: InvalidResourceError): + async def _invalid_resources_handler( + request: Request, exc: InvalidResourceError + ): return Response(content=exc.args[0], status_code=404) def _html_response_wrapper(self, view_func): @@ -60,6 +65,7 @@ async def index(request: Request): set_request_adapter(adapter) adapter.set_request(request) return Response(content=dash_app.render_index(), media_type="text/html") + self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) def setup_catchall(self, app, dash_app): @@ -73,9 +79,11 @@ async def catchall(path: str, request: Request): adapter.set_request(request) return Response(content=dash_app.render_index(), media_type="text/html") - self.add_url_rule(app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"]) + self.add_url_rule( + app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"] + ) - pass # catchall needs to be last to not override other routes + pass # catchall needs to be last to not override other routes def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): if rule == "": @@ -83,7 +91,13 @@ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): if isinstance(view_func, str): # Wrap string or sync function to async FastAPI handler view_func = self._html_response_wrapper(view_func) - app.add_api_route(rule, view_func, methods=methods or ["GET"], name=endpoint, include_in_schema=False) + app.add_api_route( + rule, + view_func, + methods=methods or ["GET"], + name=endpoint, + include_in_schema=False, + ) def before_request(self, app, func): # FastAPI does not have before_request, but we can use middleware @@ -102,7 +116,13 @@ def run(self, app, host, port, debug, **kwargs): # Dynamically determine the module name from the file path file_path = frame.filename module_name = importlib.util.spec_from_file_location("app", file_path).name - uvicorn.run(f"{module_name}:app.server", host=host, port=port, reload=reload, **kwargs) + uvicorn.run( + f"{module_name}:app.server", + host=host, + port=port, + reload=reload, + **kwargs, + ) else: uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) @@ -122,6 +142,7 @@ def get_request_adapter(self): def _make_before_middleware(self, func): pass + async def middleware(request, call_next): if func is not None: if inspect.iscoroutinefunction(func): @@ -135,13 +156,17 @@ async def middleware(request, call_next): def _make_after_middleware(self, func): pass + async def middleware(request, call_next): response = await call_next(request) await func() return response + return middleware - def serve_component_suites(self, dash_app, package_name, fingerprinted_path, request): + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path, request + ): import sys import mimetypes import pkgutil @@ -162,12 +187,14 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req ) data = pkgutil.get_data(package_name, path_in_pkg) from starlette.responses import Response as StarletteResponse + headers = {} if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" return StarletteResponse(content=data, media_type=mimetype, headers=headers) else: import hashlib + etag = hashlib.md5(data).hexdigest() if data else "" headers["ETag"] = etag if request.headers.get("if-none-match") == etag: @@ -176,8 +203,11 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req def setup_component_suites(self, app, dash_app): from fastapi import Request + async def serve(request: Request, package_name: str, fingerprinted_path: str): - return self.serve_component_suites(dash_app, package_name, fingerprinted_path, request) + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, request + ) self.add_url_rule( app, @@ -206,10 +236,10 @@ async def _dispatch(request: Request): def _serve_default_favicon(self): return Response( - content=pkgutil.get_data("dash", "favicon.ico"), - media_type="image/x-icon" + content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" ) + class FastAPIRequestAdapter: def __init__(self): self._request = None @@ -224,7 +254,9 @@ async def get_json(self): return await self._request.json() def is_json(self): - return self._request.headers.get("content-type", "").startswith("application/json") + return self._request.headers.get("content-type", "").startswith( + "application/json" + ) def get_cookies(self, request=None): return self._request.cookies diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 9bc7929685..5eaaf44a36 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -6,6 +6,7 @@ import asyncio import pkgutil + class FlaskServerFactory(BaseServerFactory): def __call__(self, server, *args, **kwargs): # Always WSGI @@ -17,7 +18,9 @@ def create_app(self, name="__main__", config=None): app.config.update(config) return app - def register_assets_blueprint(self, app, blueprint_name, assets_url_path, assets_folder): + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): bp = flask.Blueprint( blueprint_name, __name__, @@ -36,7 +39,9 @@ def _invalid_resources_handler(err): return err.args[0], 404 def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - app.add_url_rule(rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"]) + app.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) def before_request(self, app, func): app.before_request(func) @@ -61,7 +66,10 @@ def catchall(path, *args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) return dash_app.render_index(*args, **kwargs) - self.add_url_rule(app, "/", catchall, endpoint="catchall", methods=["GET"]) + + self.add_url_rule( + app, "/", catchall, endpoint="catchall", methods=["GET"] + ) def setup_index(self, app, dash_app): def index(*args, **kwargs): @@ -71,7 +79,9 @@ def index(*args, **kwargs): self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) - def serve_component_suites(self, dash_app, package_name, fingerprinted_path, request=None): + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path, request=None + ): import sys import mimetypes import pkgutil @@ -105,7 +115,9 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req def setup_component_suites(self, app, dash_app): def serve(package_name, fingerprinted_path): - return self.serve_component_suites(dash_app, package_name, fingerprinted_path, flask.request) + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, flask.request + ) self.add_url_rule( app, @@ -154,10 +166,12 @@ async def _dispatch_async(): def _serve_default_favicon(self): import flask + return flask.Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) + class FlaskRequestAdapter: @staticmethod def get_args(): @@ -189,7 +203,7 @@ def get_remote_addr(): @staticmethod def get_origin(): - return getattr(flask.request, 'origin', None) + return getattr(flask.request, "origin", None) @staticmethod def get_path(): From 1088331323380bd58c27f3776d21d4e4132bfd3d Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 13:33:16 -0400 Subject: [PATCH 010/166] fixing issue with failing test due to `endpoint` name --- dash/server_factories/flask_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 5eaaf44a36..69173516b2 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -77,7 +77,7 @@ def index(*args, **kwargs): set_request_adapter(adapter) return dash_app.render_index(dash_app, *args, **kwargs) - self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + self.add_url_rule(app, "/", index, endpoint="/", methods=["GET"]) def serve_component_suites( self, dash_app, package_name, fingerprinted_path, request=None From 4920e33cf68651bebc699f72f6da1e8eadbf925e Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 13:33:43 -0400 Subject: [PATCH 011/166] fixing `run` command to trigger `devtools` properly --- dash/server_factories/fastapi_factory.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 8d9efb2416..1fa07a6ac6 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -12,6 +12,10 @@ class FastAPIServerFactory(BaseServerFactory): + def __init__(self): + self.config = {} + super().__init__() + def __call__(self, server, *args, **kwargs): # ASGI: (scope, receive, send) if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: @@ -71,6 +75,7 @@ async def index(request: Request): def setup_catchall(self, app, dash_app): @dash_app.server.on_event("startup") def _setup_catchall(): + dash_app.enable_dev_tools(**self.config) # do this to make sure dev tools are enabled from fastapi import Request, Response async def catchall(path: str, request: Request): @@ -111,6 +116,9 @@ def run(self, app, host, port, debug, **kwargs): frame = inspect.stack()[2] import uvicorn + self.config = dict({'debug': debug} if debug else {}, **kwargs) + + reload = debug if reload: # Dynamically determine the module name from the file path From 9ffba5a58652cc7b28d253e46662ad8cbe0fb8bd Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 14:51:02 -0400 Subject: [PATCH 012/166] fixing issue with lint and debug ui --- dash/dash.py | 69 +++--------- dash/server_factories/fastapi_factory.py | 136 ++++++++++++++--------- dash/server_factories/flask_factory.py | 68 ++++++++---- 3 files changed, 144 insertions(+), 129 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index f6f6e76e01..2151f31f77 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -4,7 +4,6 @@ import collections import importlib import warnings -from contextvars import copy_context from importlib.machinery import ModuleSpec from importlib.util import find_spec from importlib import metadata @@ -12,12 +11,10 @@ import threading import re import logging -import time import mimetypes import hashlib import base64 import traceback -import inspect from urllib.parse import urlparse from typing import Any, Callable, Dict, Optional, Union, Sequence, Literal, List @@ -30,7 +27,7 @@ from dash import html from dash import dash_table -from .fingerprint import build_fingerprint, check_fingerprint +from .fingerprint import build_fingerprint from .resources import Scripts, Css from .dependencies import ( Input, @@ -39,8 +36,6 @@ ) from .development.base_component import ComponentRegistry from .exceptions import ( - PreventUpdate, - InvalidResourceError, ProxyError, DuplicateCallback, ) @@ -72,7 +67,7 @@ from .server_factories.flask_factory import FlaskServerFactory from .server_factories.base_factory import BaseServerFactory -from ._get_app import with_app_context, with_app_context_async, with_app_context_factory +from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group from ._obsolete import ObsoleteChecker @@ -712,8 +707,9 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: ) if config.compress: try: - from flask_compress import Compress + import flask_compress # pylint: disable=import-outside-toplevel + Compress = flask_compress.Compress Compress(self.server) _flask_compress_version = parse_version( _get_distribution_version("flask_compress") @@ -754,7 +750,10 @@ def _setup_routes(self): ["POST"], ) self._add_url("_reload-hash", self.serve_reload_hash) - self._add_url("_favicon.ico", self.server_factory._serve_default_favicon) + self._add_url( + "_favicon.ico", + self.server_factory._serve_default_favicon, # pylint: disable=protected-access + ) self.server_factory.setup_index(self.server, self) self.server_factory.setup_catchall(self.server, self) @@ -1145,7 +1144,7 @@ def _generate_meta(self): return meta_tags + self.config.meta_tags - def render_index(self, *args, **kwargs): + def render_index(self, *_args, **_kwargs): scripts = self._generate_scripts_html() css = self._generate_css_dist_html() config = self._generate_config_html() @@ -1845,6 +1844,7 @@ def enable_dev_tools( dev_tools_silence_routes_logging: Optional[bool] = None, dev_tools_disable_version_check: Optional[bool] = None, dev_tools_prune_errors: Optional[bool] = None, + first_run: bool = True, ) -> bool: """Activate the dev tools, called by `run`. If your application is served by wsgi and you want to activate the dev tools, you can call @@ -2009,53 +2009,12 @@ def enable_dev_tools( ) elif dev_tools.prune_errors: secret = gen_salt(20) - - if hasattr(self.server, "errorhandler"): - # Flask - @self.server.errorhandler(Exception) - def _wrap_errors(error): - tb = _get_traceback(secret, error) - return tb, 500 - - elif hasattr(self.server, "exception_handler"): - # FastAPI - @self.server.exception_handler(Exception) - async def _wrap_errors(request, error): - tb = _get_traceback(secret, error) - from fastapi.responses import PlainTextResponse - - return PlainTextResponse(tb, status_code=500) + self.server_factory.register_prune_error_handler( + self.server, secret, _get_traceback + ) if debug and dev_tools.ui: - - def _before_request(): - flask.g.timing_information = { # pylint: disable=assigning-non-slot - "__dash_server": {"dur": time.time(), "desc": None} - } - - def _after_request(response): - timing_information = flask.g.get("timing_information", None) - if timing_information is None: - return response - - dash_total = timing_information.get("__dash_server", None) - if dash_total is not None: - dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) - - for name, info in timing_information.items(): - value = name - if info.get("desc") is not None: - value += f';desc="{info["desc"]}"' - - if info.get("dur") is not None: - value += f";dur={info['dur']}" - - response.headers.add("Server-Timing", value) - - return response - - self.server_factory.before_request(self.server, _before_request) - self.server_factory.after_request(self.server, _after_request) + self.server_factory.register_timing_hooks(self.server, first_run) if ( debug diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 1fa07a6ac6..918ca2175f 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -1,14 +1,22 @@ -import traceback - -from fastapi import FastAPI, Request, Response, APIRouter -from fastapi.responses import JSONResponse -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.server_factories import set_request_adapter, get_request_adapter -from .base_factory import BaseServerFactory +import sys +import mimetypes +import hashlib import inspect import pkgutil from contextvars import copy_context import importlib.util +import time +import uvicorn +from fastapi import FastAPI, Request, Response +from fastapi.responses import JSONResponse, PlainTextResponse +from fastapi.staticfiles import StaticFiles +from starlette.responses import Response as StarletteResponse +from starlette.datastructures import MutableHeaders +from dash.fingerprint import check_fingerprint +from dash import _validate +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter +from .base_factory import BaseServerFactory class FastAPIServerFactory(BaseServerFactory): @@ -32,8 +40,6 @@ def create_app(self, name="__main__", config=None): def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - from fastapi.staticfiles import StaticFiles - try: app.mount( assets_url_path, @@ -46,17 +52,21 @@ def register_assets_blueprint( def register_error_handlers(self, app): @app.exception_handler(PreventUpdate) - async def _handle_error(request: Request, exc: PreventUpdate): + async def _handle_error(_request, _exc): return Response(status_code=204) @app.exception_handler(InvalidResourceError) - async def _invalid_resources_handler( - request: Request, exc: InvalidResourceError - ): + async def _invalid_resources_handler(_request, exc): return Response(content=exc.args[0], status_code=404) + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.exception_handler(Exception) + async def _wrap_errors(_error_request, error): + tb = get_traceback_func(secret, error) + return PlainTextResponse(tb, status_code=500) + def _html_response_wrapper(self, view_func): - async def wrapped(*args, **kwargs): + async def wrapped(*_args, **_kwargs): # If view_func is a function, call it; if it's a string, use it directly html = view_func() if callable(view_func) else view_func return Response(content=html, media_type="text/html") @@ -75,10 +85,11 @@ async def index(request: Request): def setup_catchall(self, app, dash_app): @dash_app.server.on_event("startup") def _setup_catchall(): - dash_app.enable_dev_tools(**self.config) # do this to make sure dev tools are enabled - from fastapi import Request, Response + dash_app.enable_dev_tools( + **self.config, first_run=False + ) # do this to make sure dev tools are enabled - async def catchall(path: str, request: Request): + async def catchall(_path: str, request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) @@ -88,8 +99,6 @@ async def catchall(path: str, request: Request): app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"] ) - pass # catchall needs to be last to not override other routes - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): if rule == "": rule = "/" @@ -114,11 +123,7 @@ def after_request(self, app, func): def run(self, app, host, port, debug, **kwargs): frame = inspect.stack()[2] - import uvicorn - - self.config = dict({'debug': debug} if debug else {}, **kwargs) - - + self.config = dict({"debug": debug} if debug else {}, **kwargs) reload = debug if reload: # Dynamically determine the module name from the file path @@ -149,8 +154,6 @@ def get_request_adapter(self): return FastAPIRequestAdapter def _make_before_middleware(self, func): - pass - async def middleware(request, call_next): if func is not None: if inspect.iscoroutinefunction(func): @@ -163,11 +166,13 @@ async def middleware(request, call_next): return middleware def _make_after_middleware(self, func): - pass - async def middleware(request, call_next): response = await call_next(request) - await func() + if func is not None: + if inspect.iscoroutinefunction(func): + await func() + else: + func() return response return middleware @@ -175,12 +180,6 @@ async def middleware(request, call_next): def serve_component_suites( self, dash_app, package_name, fingerprinted_path, request ): - import sys - import mimetypes - import pkgutil - from dash.fingerprint import check_fingerprint - from dash import _validate - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -194,24 +193,17 @@ def serve_component_suites( package.__path__, ) data = pkgutil.get_data(package_name, path_in_pkg) - from starlette.responses import Response as StarletteResponse - headers = {} if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" return StarletteResponse(content=data, media_type=mimetype, headers=headers) - else: - import hashlib - - etag = hashlib.md5(data).hexdigest() if data else "" - headers["ETag"] = etag - if request.headers.get("if-none-match") == etag: - return StarletteResponse(status_code=304) - return StarletteResponse(content=data, media_type=mimetype, headers=headers) + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if request.headers.get("if-none-match") == etag: + return StarletteResponse(status_code=304) + return StarletteResponse(content=data, media_type=mimetype, headers=headers) def setup_component_suites(self, app, dash_app): - from fastapi import Request - async def serve(request: Request, package_name: str, fingerprinted_path: str): return self.serve_component_suites( dash_app, package_name, fingerprinted_path, request @@ -223,17 +215,26 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): serve, ) - def dispatch(self, app, dash_app, use_async): + def dispatch(self, _app, dash_app, _use_async): async def _dispatch(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) + # pylint: disable=protected-access body = await request.json() - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + g = dash_app._initialize_context( + body, adapter + ) # pylint: disable=protected-access + func = dash_app._prepare_callback( + g, body + ) # pylint: disable=protected-access + args = dash_app._inputs_to_vals( + g.inputs_list + g.states_list + ) # pylint: disable=protected-access ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + partial_func = dash_app._execute_callback( + func, args, g.outputs_list, g + ) # pylint: disable=protected-access response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): response_data = await response_data @@ -247,6 +248,33 @@ def _serve_default_favicon(self): content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" ) + def register_timing_hooks(self, app, first_run): + if not first_run: + return + + @app.middleware("http") + async def timing_middleware(request, call_next): + # Before request + request.state.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } + response = await call_next(request) + # After request + timing_information = getattr(request.state, "timing_information", None) + if timing_information is not None: + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + headers = MutableHeaders(response.headers) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + headers.append("Server-Timing", value) + return response + class FastAPIRequestAdapter: def __init__(self): @@ -266,7 +294,7 @@ def is_json(self): "application/json" ) - def get_cookies(self, request=None): + def get_cookies(self, _request=None): return self._request.cookies def get_headers(self): @@ -282,4 +310,4 @@ def get_origin(self): return self._request.headers.get("origin") def get_path(self): - return self._request.url.path # <-- Add this method + return self._request.url.path diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 69173516b2..dafa4b24b4 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -1,10 +1,15 @@ -import flask -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.server_factories import set_request_adapter, get_request_adapter -from .base_factory import BaseServerFactory from contextvars import copy_context import asyncio import pkgutil +import sys +import mimetypes +import time +import flask +from dash.fingerprint import check_fingerprint +from dash import _validate +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter +from .base_factory import BaseServerFactory class FlaskServerFactory(BaseServerFactory): @@ -38,6 +43,12 @@ def _handle_error(_): def _invalid_resources_handler(err): return err.args[0], 404 + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.errorhandler(Exception) + def _wrap_errors(error): + tb = get_traceback_func(secret, error) + return tb, 500 + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): app.add_url_rule( rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] @@ -62,7 +73,7 @@ def get_request_adapter(self): return FlaskRequestAdapter def setup_catchall(self, app, dash_app): - def catchall(path, *args, **kwargs): + def catchall(_path, *args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) return dash_app.render_index(*args, **kwargs) @@ -75,20 +86,11 @@ def setup_index(self, app, dash_app): def index(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) - return dash_app.render_index(dash_app, *args, **kwargs) + return dash_app.render_index(*args, **kwargs) self.add_url_rule(app, "/", index, endpoint="/", methods=["GET"]) - def serve_component_suites( - self, dash_app, package_name, fingerprinted_path, request=None - ): - import sys - import mimetypes - import pkgutil - from dash.fingerprint import check_fingerprint - from dash import _validate - import flask - + def serve_component_suites(self, dash_app, package_name, fingerprinted_path): path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -116,7 +118,7 @@ def serve_component_suites( def setup_component_suites(self, app, dash_app): def serve(package_name, fingerprinted_path): return self.serve_component_suites( - dash_app, package_name, fingerprinted_path, flask.request + dash_app, package_name, fingerprinted_path ) self.add_url_rule( @@ -125,11 +127,12 @@ def serve(package_name, fingerprinted_path): serve, ) - def dispatch(self, app, dash_app, use_async=False): + def dispatch(self, _app, dash_app, use_async=False): def _dispatch(): adapter = FlaskRequestAdapter() set_request_adapter(adapter) body = flask.request.get_json() + # pylint: disable=protected-access g = dash_app._initialize_context(body, adapter) func = dash_app._prepare_callback(g, body) args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) @@ -149,6 +152,7 @@ async def _dispatch_async(): adapter = FlaskRequestAdapter() set_request_adapter(adapter) body = flask.request.get_json() + # pylint: disable=protected-access g = dash_app._initialize_context(body, adapter) func = dash_app._prepare_callback(g, body) args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) @@ -161,16 +165,40 @@ async def _dispatch_async(): return g.dash_response if use_async: - _dispatch = _dispatch_async + return _dispatch_async return _dispatch def _serve_default_favicon(self): - import flask return flask.Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) + def register_timing_hooks(self, app, _first_run): + def _before_request(): + flask.g.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } + + def _after_request(response): + timing_information = flask.g.get("timing_information", None) + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + response.headers.add("Server-Timing", value) + return response + + self.before_request(app, _before_request) + self.after_request(app, _after_request) + class FlaskRequestAdapter: @staticmethod From 908aacd729695fd2ef8d79a6343d0ef21b6cea84 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:39:56 -0400 Subject: [PATCH 013/166] fixing issue with `_app` when using dispatch, need to keep in context --- dash/server_factories/fastapi_factory.py | 4 +++- dash/server_factories/flask_factory.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 918ca2175f..ff21e61d72 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -215,7 +215,9 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): serve, ) - def dispatch(self, _app, dash_app, _use_async): + def dispatch( + self, app, dash_app, use_async=False + ): # pylint: disable=unused-argument async def _dispatch(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index dafa4b24b4..b16135cfff 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -127,7 +127,9 @@ def serve(package_name, fingerprinted_path): serve, ) - def dispatch(self, _app, dash_app, use_async=False): + def dispatch( + self, app, dash_app, use_async=False + ): # pylint: disable=unused-argument def _dispatch(): adapter = FlaskRequestAdapter() set_request_adapter(adapter) From 9491c7fbfbc637029092413ffee155f56bcf4988 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 15:58:55 -0400 Subject: [PATCH 014/166] fixing issue with catchall --- dash/server_factories/fastapi_factory.py | 2 +- dash/server_factories/flask_factory.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index ff21e61d72..0853972d1f 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -89,7 +89,7 @@ def _setup_catchall(): **self.config, first_run=False ) # do this to make sure dev tools are enabled - async def catchall(_path: str, request: Request): + async def catchall(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index b16135cfff..bb1204af19 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -73,7 +73,7 @@ def get_request_adapter(self): return FlaskRequestAdapter def setup_catchall(self, app, dash_app): - def catchall(_path, *args, **kwargs): + def catchall(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) return dash_app.render_index(*args, **kwargs) From 39ad7bd9c837699393e05ebdcc12c0c95119bc8f Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:18:17 -0400 Subject: [PATCH 015/166] fixing issue with args and cancelling callbacks --- dash/_callback_context.py | 8 ++++++++ dash/dash.py | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/dash/_callback_context.py b/dash/_callback_context.py index f64865c464..72b92e09e2 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -288,6 +288,14 @@ def path(self): """ return _get_from_context("path", "") + @property + @has_context + def args(self): + """ + Query parameters of the callback request as a dictionary-like object. + """ + return _get_from_context("args", "") + @property @has_context def remote(self): diff --git a/dash/dash.py b/dash/dash.py index 2151f31f77..d20672453c 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -70,6 +70,7 @@ from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group from ._obsolete import ObsoleteChecker +from ._callback_context import callback_context from . import _pages from ._pages import ( @@ -1382,6 +1383,7 @@ def _initialize_context(self, body, adapter): ) g.cookies = dict(adapter.get_cookies()) g.headers = dict(adapter.get_headers()) + g.args = adapter.get_args() g.path = adapter.get_full_path() g.remote = adapter.get_remote_addr() g.origin = adapter.get_origin() @@ -1529,7 +1531,7 @@ def _setup_server(self): manager=manager, ) def cancel_call(*_): - job_ids = flask.request.args.getlist("cancelJob") + job_ids = callback_context.args.getlist("cancelJob") executor = _callback.context_value.get().background_callback_manager if job_ids: for job_id in job_ids: From 7bf69a7583e1b216de132d6a622ded05d85f1ce8 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:33:33 -0400 Subject: [PATCH 016/166] fixing issues with pages metadata and flaky tests --- dash/_pages.py | 4 +-- dash/server_factories/fastapi_factory.py | 28 +++++++++++++++---- dash/server_factories/flask_factory.py | 8 ++++++ .../multi_page/test_pages_relative_path.py | 3 +- 4 files changed, 34 insertions(+), 9 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index 2a3a116324..3fab86eb99 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -396,7 +396,7 @@ def _page_meta_tags(app, request): image = start_page.get("image", "") if image: image = app.get_asset_url(image) - assets_image_url = "".join([request.url_root, image.lstrip("/")]) if image else None + assets_image_url = "".join([request.get_root(), image.lstrip("/")]) if image else None supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url @@ -411,7 +411,7 @@ def _page_meta_tags(app, request): return [ {"name": "description", "content": description}, {"property": "twitter:card", "content": "summary_large_image"}, - {"property": "twitter:url", "content": request_url}, + {"property": "twitter:url", "content": request.get_url()}, {"property": "twitter:title", "content": title}, {"property": "twitter:description", "content": description}, {"property": "twitter:image", "content": image_url or ""}, diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 0853972d1f..19d70b022e 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -6,12 +6,22 @@ from contextvars import copy_context import importlib.util import time -import uvicorn -from fastapi import FastAPI, Request, Response -from fastapi.responses import JSONResponse, PlainTextResponse -from fastapi.staticfiles import StaticFiles -from starlette.responses import Response as StarletteResponse -from starlette.datastructures import MutableHeaders + +try: + import uvicorn + from fastapi import FastAPI, Request, Response + from fastapi.responses import JSONResponse, PlainTextResponse + from fastapi.staticfiles import StaticFiles + from starlette.responses import Response as StarletteResponse + from starlette.datastructures import MutableHeaders +except ImportError: + uvicorn = None + FastAPI = Request = Response = None + JSONResponse = PlainTextResponse = None + StaticFiles = None + StarletteResponse = None + MutableHeaders = None + from dash.fingerprint import check_fingerprint from dash import _validate from dash.exceptions import PreventUpdate, InvalidResourceError @@ -285,6 +295,9 @@ def __init__(self): def set_request(self, request: Request): self._request = request + def get_root(self): + return str(self._request.base_url) + def get_args(self): return self._request.query_params @@ -305,6 +318,9 @@ def get_headers(self): def get_full_path(self): return str(self._request.url) + def get_url(self): + return str(self._request.url) + def get_remote_addr(self): return self._request.client.host if self._request.client else None diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index bb1204af19..8153ec4f92 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -207,6 +207,10 @@ class FlaskRequestAdapter: def get_args(): return flask.request.args + @staticmethod + def get_root(): + return flask.request.url_root + @staticmethod def get_json(): return flask.request.get_json() @@ -223,6 +227,10 @@ def get_cookies(): def get_headers(): return flask.request.headers + @staticmethod + def get_url(): + return flask.request.url + @staticmethod def get_full_path(): return flask.request.full_path diff --git a/tests/integration/multi_page/test_pages_relative_path.py b/tests/integration/multi_page/test_pages_relative_path.py index 6c505ac3f5..6fcbb6c6e0 100644 --- a/tests/integration/multi_page/test_pages_relative_path.py +++ b/tests/integration/multi_page/test_pages_relative_path.py @@ -2,6 +2,7 @@ import dash from dash import Dash, dcc, html +from dash.testing.wait import until def get_app(app): @@ -83,6 +84,6 @@ def test_pare003_absolute_path(dash_duo, clear_pages_state): for page in dash.page_registry.values(): dash_duo.find_element("#" + page["id"]).click() dash_duo.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"]) - assert dash_duo.driver.title == page["title"], "check that page title updates" + until(lambda: dash_duo.driver.title == page["title"],timeout=3) assert dash_duo.get_logs() == [], "browser console should contain no error" From 10681dccfef11be7425b8fe11c2b8a21c148f20f Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 11 Sep 2025 16:53:27 -0400 Subject: [PATCH 017/166] fixing issues with relativate paths --- dash/server_factories/fastapi_factory.py | 5 ++--- dash/server_factories/flask_factory.py | 3 +-- tests/integration/multi_page/test_pages_relative_path.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 19d70b022e..914f591e17 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -219,9 +219,8 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): dash_app, package_name, fingerprinted_path, request ) - self.add_url_rule( - app, - "/_dash-component-suites/{package_name}/{fingerprinted_path:path}", + dash_app._add_url( + "/_dash-component-suites//", serve, ) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 8153ec4f92..684596ac23 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -121,8 +121,7 @@ def serve(package_name, fingerprinted_path): dash_app, package_name, fingerprinted_path ) - self.add_url_rule( - app, + dash_app._add_url( "/_dash-component-suites//", serve, ) diff --git a/tests/integration/multi_page/test_pages_relative_path.py b/tests/integration/multi_page/test_pages_relative_path.py index 6fcbb6c6e0..696ecc39a4 100644 --- a/tests/integration/multi_page/test_pages_relative_path.py +++ b/tests/integration/multi_page/test_pages_relative_path.py @@ -71,7 +71,7 @@ def test_pare002_relative_path_with_url_base_pathname( for page in dash.page_registry.values(): dash_br.find_element("#" + page["id"]).click() dash_br.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"]) - assert dash_br.driver.title == page["title"], "check that page title updates" + until(lambda: dash_br.driver.title == page["title"], timeout=3) assert dash_br.get_logs() == [], "browser console should contain no error" From 4944d6d2b3060d43489f71c82bb7bdaf69242471 Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Thu, 11 Sep 2025 21:24:55 +0200 Subject: [PATCH 018/166] =?UTF-8?q?=E2=88=99=20-=20initial=20quart=20facto?= =?UTF-8?q?ry=20=E2=88=99=20-=20added=20types=20to=20BaseFactory=20to=20re?= =?UTF-8?q?move=20linting=20errors=20on=20create=20app=20in=20Flask=20and?= =?UTF-8?q?=20Quart=20Factory?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dash/server_factories/base_factory.py | 25 +-- dash/server_factories/quart_factory.py | 238 +++++++++++++++++++++++++ 2 files changed, 251 insertions(+), 12 deletions(-) create mode 100644 dash/server_factories/quart_factory.py diff --git a/dash/server_factories/base_factory.py b/dash/server_factories/base_factory.py index b44f6888cb..12088947c2 100644 --- a/dash/server_factories/base_factory.py +++ b/dash/server_factories/base_factory.py @@ -1,49 +1,50 @@ from abc import ABC, abstractmethod +from typing import Any class BaseServerFactory(ABC): - def __call__(self, server, *args, **kwargs): + def __call__(self, server, *args, **kwargs) -> Any: # Default: WSGI return server(*args, **kwargs) @abstractmethod - def create_app(self, name="__main__", config=None): + def create_app(self, name: str = "__main__", config=None) -> Any: # pragma: no cover - interface pass @abstractmethod def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder - ): + self, app, blueprint_name: str, assets_url_path: str, assets_folder: str + ) -> None: # pragma: no cover - interface pass @abstractmethod - def register_error_handlers(self, app): + def register_error_handlers(self, app) -> None: # pragma: no cover - interface pass @abstractmethod - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + def add_url_rule(self, app, rule: str, view_func, endpoint=None, methods=None) -> None: # pragma: no cover - interface pass @abstractmethod - def before_request(self, app, func): + def before_request(self, app, func) -> None: # pragma: no cover - interface pass @abstractmethod - def after_request(self, app, func): + def after_request(self, app, func) -> None: # pragma: no cover - interface pass @abstractmethod - def run(self, app, host, port, debug, **kwargs): + def run(self, app, host: str, port: int, debug: bool, **kwargs) -> None: # pragma: no cover - interface pass @abstractmethod - def make_response(self, data, mimetype=None, content_type=None): + def make_response(self, data, mimetype=None, content_type=None) -> Any: # pragma: no cover - interface pass @abstractmethod - def jsonify(self, obj): + def jsonify(self, obj) -> Any: # pragma: no cover - interface pass @abstractmethod - def get_request_adapter(self): + def get_request_adapter(self) -> Any: # pragma: no cover - interface pass diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py new file mode 100644 index 0000000000..977c9aea4c --- /dev/null +++ b/dash/server_factories/quart_factory.py @@ -0,0 +1,238 @@ +from .base_factory import BaseServerFactory +from quart import Quart, request, Response as QuartResponse, jsonify, send_from_directory +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter +from dash.fingerprint import check_fingerprint +from dash import _validate +from contextvars import copy_context +import inspect +import os +import pkgutil +import mimetypes +import hashlib +import sys + + +class QuartAPIServerFactory(BaseServerFactory): + """Quart implementation of the Dash server factory. + + All Quart/async specific imports are at the top-level (per user request) so + Quart must be installed when this module is imported. + """ + + def __init__(self) -> None: + self.config = {} + super().__init__() + + def __call__(self, server, *args, **kwargs): + # ASGI style (scope, receive, send) or standard call-through handled by BaseServerFactory + return super().__call__(server, *args, **kwargs) + + def create_app(self, name="__main__", config=None): + app = Quart(name) + if config: + for key, value in config.items(): + # Mirror Flask usage of config dict + app.config[key] = value + return app + + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): + if os.path.isdir(assets_folder): + route = f"{assets_url_path}/" + + @app.route(route) + async def serve_asset(filename): # pragma: no cover - simple passthrough + return await send_from_directory(assets_folder, filename) + + def register_error_handlers(self, app): + @app.errorhandler(PreventUpdate) + async def _prevent_update(_): + return "", 204 + + @app.errorhandler(InvalidResourceError) + async def _invalid_resource(err): + return err.args[0], 404 + + def _html_response_wrapper(self, view_func): + async def wrapped(*args, **kwargs): + html_val = view_func() if callable(view_func) else view_func + if inspect.iscoroutine(html_val): # handle async function returning html + html_val = await html_val + html = str(html_val) + return QuartResponse(html, content_type="text/html") + + return wrapped + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + if rule == "": + rule = "/" + if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): + # Wrap plain strings or sync callables in async handler returning HTML + if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): + view_func = self._html_response_wrapper(view_func) + app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) + + # ---- Index & Catchall ------------------------------------------------ + def setup_index(self, app, dash_app): + async def index(): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + return QuartResponse(dash_app.render_index(), content_type="text/html") + + self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + + def setup_catchall(self, app, dash_app): + @app.before_serving + async def _enable_dev_tools(): # pragma: no cover - environmental + dash_app.enable_dev_tools(**self.config) + + async def catchall(path): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + return QuartResponse(dash_app.render_index(), content_type="text/html") + + # Must be added after other routes + self.add_url_rule( + app, "/", catchall, endpoint="catchall", methods=["GET"] + ) + + # ---- Middleware-esque hooks ----------------------------------------- + def before_request(self, app, func): + app.before_request(func) + + def after_request(self, app, func): + # Quart after_request expects (response) -> response + @app.after_request + async def _after(response): + if func is not None: + result = func() + if inspect.iscoroutine(result): # Allow async hooks + await result + return response + + # ---- Running --------------------------------------------------------- + def run(self, app, host, port, debug, **kwargs): + self.config = dict({'debug': debug} if debug else {}, **kwargs) + app.run(host=host, port=port, debug=debug, **kwargs) + + # ---- Responses / JSON ------------------------------------------------ + def make_response(self, data, mimetype=None, content_type=None): + headers = {} + if mimetype: + headers["Content-Type"] = mimetype + if content_type: + headers["Content-Type"] = content_type + return QuartResponse(data, headers=headers) + + def jsonify(self, obj): + return jsonify(obj) + + def get_request_adapter(self): + return QuartRequestAdapter + + # ---- Component Suites ------------------------------------------------ + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path, req + ): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + getattr(package, "__version__", "unknown"), + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + return QuartResponse(data, content_type=mimetype, headers=headers) + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if req.headers.get("If-None-Match") == etag: + return QuartResponse(None, status=304) + return QuartResponse(data, content_type=mimetype, headers=headers) + + def setup_component_suites(self, app, dash_app): + async def serve(package_name, fingerprinted_path): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, request + ) + + self.add_url_rule( + app, + "/_dash-component-suites//", + serve, + methods=["GET"], + ) + + # ---- Dispatch (Callbacks) ------------------------------------------- + def dispatch(self, app, dash_app, use_async=True): # Quart always async + async def _dispatch(): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + body = await request.get_json() + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): # if user callback is async + response_data = await response_data + return QuartResponse(response_data, content_type="application/json") + + return _dispatch + + # ---- Favicon --------------------------------------------------------- + def _serve_default_favicon(self): + return QuartResponse( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" + ) + + +class QuartRequestAdapter: + """Adapter that normalizes Quart's request API to what Dash expects.""" + + @staticmethod + def get_args(): + return request.args + + @staticmethod + async def get_json(): + return await request.get_json() + + @staticmethod + def is_json(): + return request.is_json + + @staticmethod + def get_cookies(): + return request.cookies + + @staticmethod + def get_headers(): + return request.headers + + @staticmethod + def get_full_path(): + return request.full_path + + @staticmethod + def get_remote_addr(): + return request.remote_addr + + @staticmethod + def get_origin(): + return request.headers.get("Origin") + + @staticmethod + def get_path(): + return request.path + From 3b0f47e37d465a01ce6acd35a779458509d53aa1 Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Fri, 12 Sep 2025 15:28:49 +0200 Subject: [PATCH 019/166] Quart factory ready --- dash/server_factories/quart_factory.py | 154 +++++++++++++++++-------- 1 file changed, 107 insertions(+), 47 deletions(-) diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py index 977c9aea4c..75376fcd7a 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/server_factories/quart_factory.py @@ -1,5 +1,5 @@ from .base_factory import BaseServerFactory -from quart import Quart, request, Response as QuartResponse, jsonify, send_from_directory +from quart import Quart, request, Response, jsonify, send_from_directory from dash.exceptions import PreventUpdate, InvalidResourceError from dash.server_factories import set_request_adapter from dash.fingerprint import check_fingerprint @@ -11,6 +11,7 @@ import mimetypes import hashlib import sys +import time class QuartAPIServerFactory(BaseServerFactory): @@ -39,12 +40,50 @@ def create_app(self, name="__main__", config=None): def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - if os.path.isdir(assets_folder): - route = f"{assets_url_path}/" + # Mirror Flask implementation using a blueprint serving static files + from quart import Blueprint + + bp = Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + app.register_blueprint(bp) + + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.errorhandler(Exception) + async def _wrap_errors(_error_request, error): + tb = get_traceback_func(secret, error) + return tb, 500 + + def register_timing_hooks(self, app, _first_run): # parity with Flask factory + from quart import g + + @app.before_request + async def _before_request(): # pragma: no cover - timing infra + g.timing_information = {"__dash_server": {"dur": time.time(), "desc": None}} - @app.route(route) - async def serve_asset(filename): # pragma: no cover - simple passthrough - return await send_from_directory(assets_folder, filename) + @app.after_request + async def _after_request(response): # pragma: no cover - timing infra + timing_information = getattr(g, "timing_information", None) + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + # Quart/Werkzeug headers expose 'add' (not 'append') + if hasattr(response.headers, "add"): + response.headers.add("Server-Timing", value) + else: # fallback just in case + response.headers["Server-Timing"] = value + return response def register_error_handlers(self, app): @app.errorhandler(PreventUpdate) @@ -61,49 +100,72 @@ async def wrapped(*args, **kwargs): if inspect.iscoroutine(html_val): # handle async function returning html html_val = await html_val html = str(html_val) - return QuartResponse(html, content_type="text/html") + return Response(html, content_type="text/html") return wrapped def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - if rule == "": - rule = "/" - if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): - # Wrap plain strings or sync callables in async handler returning HTML - if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): - view_func = self._html_response_wrapper(view_func) - app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) - - # ---- Index & Catchall ------------------------------------------------ + app.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) + + # def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + # if rule == "": + # rule = "/" + # if isinstance(view_func, str): + # # Literal HTML content + # view_func = self._html_response_wrapper(view_func) + # elif not inspect.iscoroutinefunction(view_func): + # # Sync function: wrap to make async but preserve Response objects + # original = view_func + + # async def _async_adapter(*args, **kwargs): + # result = original(*args, **kwargs) + # # Pass through existing Response (Quart/Flask style) + # if isinstance(result, Response) or ( + # hasattr(result, "status_code") + # and hasattr(result, "headers") + # and hasattr(result, "get_data") + # ): + # return result + # # If it's bytes or str treat as HTML + # if isinstance(result, (str, bytes)): + # return Response(result, content_type="text/html") + # # Fallback: JSON encode arbitrary python objects + # try: + # import json + + # return Response( + # json.dumps(result), content_type="application/json" + # ) + # except Exception: # pragma: no cover + # return Response(str(result), content_type="text/plain") + + # view_func = _async_adapter + # app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) + def setup_index(self, app, dash_app): async def index(): adapter = QuartRequestAdapter() set_request_adapter(adapter) - return QuartResponse(dash_app.render_index(), content_type="text/html") + return Response(dash_app.render_index(), content_type="text/html") self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) def setup_catchall(self, app, dash_app): - @app.before_serving - async def _enable_dev_tools(): # pragma: no cover - environmental - dash_app.enable_dev_tools(**self.config) - async def catchall(path): adapter = QuartRequestAdapter() set_request_adapter(adapter) - return QuartResponse(dash_app.render_index(), content_type="text/html") + return Response(dash_app.render_index(), content_type="text/html") - # Must be added after other routes self.add_url_rule( app, "/", catchall, endpoint="catchall", methods=["GET"] ) - # ---- Middleware-esque hooks ----------------------------------------- def before_request(self, app, func): app.before_request(func) def after_request(self, app, func): - # Quart after_request expects (response) -> response @app.after_request async def _after(response): if func is not None: @@ -112,19 +174,24 @@ async def _after(response): await result return response - # ---- Running --------------------------------------------------------- def run(self, app, host, port, debug, **kwargs): - self.config = dict({'debug': debug} if debug else {}, **kwargs) - app.run(host=host, port=port, debug=debug, **kwargs) + # Store only dev tools related configuration (exclude server-only kwargs unsupported by Quart) + # Quart's run does NOT accept 'threaded' (Flask-specific). Drop silently (or log) if present. + unsupported = {"threaded", "processes"} + filtered_kwargs = {} + for k, v in kwargs.items(): + if k in unsupported: + continue + filtered_kwargs[k] = v + + # Keep a slim config for potential future use (dev tools already enabled in Dash.run) + self.config = {'debug': debug} + self.config.update({k: v for k, v in filtered_kwargs.items() if k.startswith('dev_tools_')}) + + app.run(host=host, port=port, debug=debug, **filtered_kwargs) - # ---- Responses / JSON ------------------------------------------------ def make_response(self, data, mimetype=None, content_type=None): - headers = {} - if mimetype: - headers["Content-Type"] = mimetype - if content_type: - headers["Content-Type"] = content_type - return QuartResponse(data, headers=headers) + return Response(data, mimetype=mimetype, content_type=content_type) def jsonify(self, obj): return jsonify(obj) @@ -132,7 +199,6 @@ def jsonify(self, obj): def get_request_adapter(self): return QuartRequestAdapter - # ---- Component Suites ------------------------------------------------ def serve_component_suites( self, dash_app, package_name, fingerprinted_path, req ): @@ -152,12 +218,9 @@ def serve_component_suites( headers = {} if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" - return QuartResponse(data, content_type=mimetype, headers=headers) - etag = hashlib.md5(data).hexdigest() if data else "" - headers["ETag"] = etag - if req.headers.get("If-None-Match") == etag: - return QuartResponse(None, status=304) - return QuartResponse(data, content_type=mimetype, headers=headers) + return Response(data, content_type=mimetype, headers=headers) + + return Response(data, content_type=mimetype, headers=headers) def setup_component_suites(self, app, dash_app): async def serve(package_name, fingerprinted_path): @@ -172,7 +235,6 @@ async def serve(package_name, fingerprinted_path): methods=["GET"], ) - # ---- Dispatch (Callbacks) ------------------------------------------- def dispatch(self, app, dash_app, use_async=True): # Quart always async async def _dispatch(): adapter = QuartRequestAdapter() @@ -186,13 +248,12 @@ async def _dispatch(): response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async response_data = await response_data - return QuartResponse(response_data, content_type="application/json") + return Response(response_data, content_type="application/json") return _dispatch - # ---- Favicon --------------------------------------------------------- def _serve_default_favicon(self): - return QuartResponse( + return Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) @@ -235,4 +296,3 @@ def get_origin(): @staticmethod def get_path(): return request.path - From 1112f7743c7a91791f0720ee505959a80fd0c0cd Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:05:26 -0400 Subject: [PATCH 020/166] fixing for lint --- dash/_pages.py | 4 +++- dash/server_factories/fastapi_factory.py | 6 +++--- dash/server_factories/flask_factory.py | 6 +++--- tests/integration/multi_page/test_pages_relative_path.py | 2 +- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index 3fab86eb99..6c00e656c7 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -396,7 +396,9 @@ def _page_meta_tags(app, request): image = start_page.get("image", "") if image: image = app.get_asset_url(image) - assets_image_url = "".join([request.get_root(), image.lstrip("/")]) if image else None + assets_image_url = ( + "".join([request.get_root(), image.lstrip("/")]) if image else None + ) supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index 914f591e17..eb4a9392f5 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -219,14 +219,14 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): dash_app, package_name, fingerprinted_path, request ) + # pylint: disable=protected-access dash_app._add_url( "/_dash-component-suites//", serve, ) - def dispatch( - self, app, dash_app, use_async=False - ): # pylint: disable=unused-argument + # pylint: disable=unused-argument + def dispatch(self, app, dash_app, use_async=False): async def _dispatch(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 684596ac23..c2221469fc 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -121,14 +121,14 @@ def serve(package_name, fingerprinted_path): dash_app, package_name, fingerprinted_path ) + # pylint: disable=protected-access dash_app._add_url( "/_dash-component-suites//", serve, ) - def dispatch( - self, app, dash_app, use_async=False - ): # pylint: disable=unused-argument + # pylint: disable=unused-argument + def dispatch(self, app, dash_app, use_async=False): def _dispatch(): adapter = FlaskRequestAdapter() set_request_adapter(adapter) diff --git a/tests/integration/multi_page/test_pages_relative_path.py b/tests/integration/multi_page/test_pages_relative_path.py index 696ecc39a4..24e7209a70 100644 --- a/tests/integration/multi_page/test_pages_relative_path.py +++ b/tests/integration/multi_page/test_pages_relative_path.py @@ -84,6 +84,6 @@ def test_pare003_absolute_path(dash_duo, clear_pages_state): for page in dash.page_registry.values(): dash_duo.find_element("#" + page["id"]).click() dash_duo.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"]) - until(lambda: dash_duo.driver.title == page["title"],timeout=3) + until(lambda: dash_duo.driver.title == page["title"], timeout=3) assert dash_duo.get_logs() == [], "browser console should contain no error" From 8c52bbb9033588df0c764d6e3fd61e6c2defebd5 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:28:06 -0400 Subject: [PATCH 021/166] fixing issue with apps overwriting other paths --- dash/dash.py | 6 +++--- dash/server_factories/fastapi_factory.py | 16 ++++++++-------- dash/server_factories/flask_factory.py | 17 +++++++++-------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index d20672453c..4a06e9216e 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -742,7 +742,7 @@ def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> Non self.routes.append(full_name) def _setup_routes(self): - self.server_factory.setup_component_suites(self.server, self) + self.server_factory.setup_component_suites(self) self._add_url("_dash-layout", self.serve_layout) self._add_url("_dash-dependencies", self.dependencies) self._add_url( @@ -755,8 +755,8 @@ def _setup_routes(self): "_favicon.ico", self.server_factory._serve_default_favicon, # pylint: disable=protected-access ) - self.server_factory.setup_index(self.server, self) - self.server_factory.setup_catchall(self.server, self) + self.server_factory.setup_index(self) + self.server_factory.setup_catchall(self) if jupyter_dash.active: self._add_url( diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index eb4a9392f5..de1caf451c 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -83,16 +83,17 @@ async def wrapped(*_args, **_kwargs): return wrapped - def setup_index(self, app, dash_app): + def setup_index(self, dash_app): async def index(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) return Response(content=dash_app.render_index(), media_type="text/html") - self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) - def setup_catchall(self, app, dash_app): + def setup_catchall(self, dash_app): @dash_app.server.on_event("startup") def _setup_catchall(): dash_app.enable_dev_tools( @@ -105,9 +106,8 @@ async def catchall(request: Request): adapter.set_request(request) return Response(content=dash_app.render_index(), media_type="text/html") - self.add_url_rule( - app, "/{path:path}", catchall, endpoint="catchall", methods=["GET"] - ) + # pylint: disable=protected-access + dash_app._add_url("{path:path}", catchall, methods=["GET"]) def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): if rule == "": @@ -213,7 +213,7 @@ def serve_component_suites( return StarletteResponse(status_code=304) return StarletteResponse(content=data, media_type=mimetype, headers=headers) - def setup_component_suites(self, app, dash_app): + def setup_component_suites(self, dash_app): async def serve(request: Request, package_name: str, fingerprinted_path: str): return self.serve_component_suites( dash_app, package_name, fingerprinted_path, request @@ -221,7 +221,7 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): # pylint: disable=protected-access dash_app._add_url( - "/_dash-component-suites//", + "_dash-component-suites//", serve, ) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index c2221469fc..1ea561b076 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -50,6 +50,7 @@ def _wrap_errors(error): return tb, 500 def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + print(rule, endpoint, methods) app.add_url_rule( rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) @@ -72,23 +73,23 @@ def jsonify(self, obj): def get_request_adapter(self): return FlaskRequestAdapter - def setup_catchall(self, app, dash_app): + def setup_catchall(self, dash_app): def catchall(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) return dash_app.render_index(*args, **kwargs) - self.add_url_rule( - app, "/", catchall, endpoint="catchall", methods=["GET"] - ) + # pylint: disable=protected-access + dash_app._add_url("", catchall, methods=["GET"]) - def setup_index(self, app, dash_app): + def setup_index(self, dash_app): def index(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) return dash_app.render_index(*args, **kwargs) - self.add_url_rule(app, "/", index, endpoint="/", methods=["GET"]) + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) def serve_component_suites(self, dash_app, package_name, fingerprinted_path): path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) @@ -115,7 +116,7 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path): response = flask.Response(None, status=304) return response - def setup_component_suites(self, app, dash_app): + def setup_component_suites(self, dash_app): def serve(package_name, fingerprinted_path): return self.serve_component_suites( dash_app, package_name, fingerprinted_path @@ -123,7 +124,7 @@ def serve(package_name, fingerprinted_path): # pylint: disable=protected-access dash_app._add_url( - "/_dash-component-suites//", + "_dash-component-suites//", serve, ) From aabeeb7801f47a4c2f1e54bb3278c20ec308f417 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 10:35:10 -0400 Subject: [PATCH 022/166] removing print --- dash/server_factories/flask_factory.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 1ea561b076..6eebd735ac 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -50,7 +50,6 @@ def _wrap_errors(error): return tb, 500 def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - print(rule, endpoint, methods) app.add_url_rule( rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) From 5659cd73e98bda4ebb907480a56cf24207fb5f3a Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Fri, 12 Sep 2025 16:37:39 +0200 Subject: [PATCH 023/166] cleanup --- dash/server_factories/quart_factory.py | 58 ++------------------------ 1 file changed, 3 insertions(+), 55 deletions(-) diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py index 75376fcd7a..685b8d70e4 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/server_factories/quart_factory.py @@ -1,15 +1,13 @@ from .base_factory import BaseServerFactory -from quart import Quart, request, Response, jsonify, send_from_directory +from quart import Quart, request, Response, jsonify from dash.exceptions import PreventUpdate, InvalidResourceError from dash.server_factories import set_request_adapter from dash.fingerprint import check_fingerprint from dash import _validate from contextvars import copy_context import inspect -import os import pkgutil import mimetypes -import hashlib import sys import time @@ -26,21 +24,18 @@ def __init__(self) -> None: super().__init__() def __call__(self, server, *args, **kwargs): - # ASGI style (scope, receive, send) or standard call-through handled by BaseServerFactory return super().__call__(server, *args, **kwargs) def create_app(self, name="__main__", config=None): app = Quart(name) if config: for key, value in config.items(): - # Mirror Flask usage of config dict app.config[key] = value return app def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - # Mirror Flask implementation using a blueprint serving static files from quart import Blueprint bp = Blueprint( @@ -109,41 +104,6 @@ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) - # def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - # if rule == "": - # rule = "/" - # if isinstance(view_func, str): - # # Literal HTML content - # view_func = self._html_response_wrapper(view_func) - # elif not inspect.iscoroutinefunction(view_func): - # # Sync function: wrap to make async but preserve Response objects - # original = view_func - - # async def _async_adapter(*args, **kwargs): - # result = original(*args, **kwargs) - # # Pass through existing Response (Quart/Flask style) - # if isinstance(result, Response) or ( - # hasattr(result, "status_code") - # and hasattr(result, "headers") - # and hasattr(result, "get_data") - # ): - # return result - # # If it's bytes or str treat as HTML - # if isinstance(result, (str, bytes)): - # return Response(result, content_type="text/html") - # # Fallback: JSON encode arbitrary python objects - # try: - # import json - - # return Response( - # json.dumps(result), content_type="application/json" - # ) - # except Exception: # pragma: no cover - # return Response(str(result), content_type="text/plain") - - # view_func = _async_adapter - # app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) - def setup_index(self, app, dash_app): async def index(): adapter = QuartRequestAdapter() @@ -175,20 +135,8 @@ async def _after(response): return response def run(self, app, host, port, debug, **kwargs): - # Store only dev tools related configuration (exclude server-only kwargs unsupported by Quart) - # Quart's run does NOT accept 'threaded' (Flask-specific). Drop silently (or log) if present. - unsupported = {"threaded", "processes"} - filtered_kwargs = {} - for k, v in kwargs.items(): - if k in unsupported: - continue - filtered_kwargs[k] = v - - # Keep a slim config for potential future use (dev tools already enabled in Dash.run) - self.config = {'debug': debug} - self.config.update({k: v for k, v in filtered_kwargs.items() if k.startswith('dev_tools_')}) - - app.run(host=host, port=port, debug=debug, **filtered_kwargs) + self.config = dict({'debug': debug} if debug else {}, **kwargs) + app.run(host=host, port=port, debug=debug, **kwargs) def make_response(self, data, mimetype=None, content_type=None): return Response(data, mimetype=mimetype, content_type=content_type) From b05e37654f0f36eaa709514effe5900c6d94cf5a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 11:07:10 -0400 Subject: [PATCH 024/166] reverting `render_index` -> `index` and making catch for outside of a `request` context --- dash/dash.py | 10 +++++++--- dash/server_factories/fastapi_factory.py | 4 ++-- dash/server_factories/flask_factory.py | 4 ++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 4a06e9216e..973f1ec579 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -1145,16 +1145,20 @@ def _generate_meta(self): return meta_tags + self.config.meta_tags - def render_index(self, *_args, **_kwargs): + def index(self, *_args, **_kwargs): scripts = self._generate_scripts_html() css = self._generate_css_dist_html() config = self._generate_config_html() metas = self._generate_meta() renderer = self._generate_renderer() title = self.title - request = get_request_adapter() + try: + request = get_request_adapter() + except LookupError: + # no request context + request = None - if self.use_pages and self.config.include_pages_meta: + if self.use_pages and self.config.include_pages_meta and request: metas = _page_meta_tags(self, request) + metas if self._favicon: diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index de1caf451c..cf08f85d7f 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -88,7 +88,7 @@ async def index(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) - return Response(content=dash_app.render_index(), media_type="text/html") + return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) @@ -104,7 +104,7 @@ async def catchall(request: Request): adapter = FastAPIRequestAdapter() set_request_adapter(adapter) adapter.set_request(request) - return Response(content=dash_app.render_index(), media_type="text/html") + return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access dash_app._add_url("{path:path}", catchall, methods=["GET"]) diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 6eebd735ac..9ba8a5017c 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -76,7 +76,7 @@ def setup_catchall(self, dash_app): def catchall(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) - return dash_app.render_index(*args, **kwargs) + return dash_app.index(*args, **kwargs) # pylint: disable=protected-access dash_app._add_url("", catchall, methods=["GET"]) @@ -85,7 +85,7 @@ def setup_index(self, dash_app): def index(*args, **kwargs): adapter = FlaskRequestAdapter() set_request_adapter(adapter) - return dash_app.render_index(*args, **kwargs) + return dash_app.index(*args, **kwargs) # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) From ed0dc3b4fbdf78ea68ca1d21e510aca1bf4a3320 Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Thu, 11 Sep 2025 21:24:55 +0200 Subject: [PATCH 025/166] =?UTF-8?q?=E2=88=99=20-=20initial=20quart=20facto?= =?UTF-8?q?ry=20=E2=88=99=20-=20added=20types=20to=20BaseFactory=20to=20re?= =?UTF-8?q?move=20linting=20errors=20on=20create=20app=20in=20Flask=20and?= =?UTF-8?q?=20Quart=20Factory?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dash/server_factories/base_factory.py | 25 +-- dash/server_factories/quart_factory.py | 238 +++++++++++++++++++++++++ 2 files changed, 251 insertions(+), 12 deletions(-) create mode 100644 dash/server_factories/quart_factory.py diff --git a/dash/server_factories/base_factory.py b/dash/server_factories/base_factory.py index b44f6888cb..12088947c2 100644 --- a/dash/server_factories/base_factory.py +++ b/dash/server_factories/base_factory.py @@ -1,49 +1,50 @@ from abc import ABC, abstractmethod +from typing import Any class BaseServerFactory(ABC): - def __call__(self, server, *args, **kwargs): + def __call__(self, server, *args, **kwargs) -> Any: # Default: WSGI return server(*args, **kwargs) @abstractmethod - def create_app(self, name="__main__", config=None): + def create_app(self, name: str = "__main__", config=None) -> Any: # pragma: no cover - interface pass @abstractmethod def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder - ): + self, app, blueprint_name: str, assets_url_path: str, assets_folder: str + ) -> None: # pragma: no cover - interface pass @abstractmethod - def register_error_handlers(self, app): + def register_error_handlers(self, app) -> None: # pragma: no cover - interface pass @abstractmethod - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + def add_url_rule(self, app, rule: str, view_func, endpoint=None, methods=None) -> None: # pragma: no cover - interface pass @abstractmethod - def before_request(self, app, func): + def before_request(self, app, func) -> None: # pragma: no cover - interface pass @abstractmethod - def after_request(self, app, func): + def after_request(self, app, func) -> None: # pragma: no cover - interface pass @abstractmethod - def run(self, app, host, port, debug, **kwargs): + def run(self, app, host: str, port: int, debug: bool, **kwargs) -> None: # pragma: no cover - interface pass @abstractmethod - def make_response(self, data, mimetype=None, content_type=None): + def make_response(self, data, mimetype=None, content_type=None) -> Any: # pragma: no cover - interface pass @abstractmethod - def jsonify(self, obj): + def jsonify(self, obj) -> Any: # pragma: no cover - interface pass @abstractmethod - def get_request_adapter(self): + def get_request_adapter(self) -> Any: # pragma: no cover - interface pass diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py new file mode 100644 index 0000000000..977c9aea4c --- /dev/null +++ b/dash/server_factories/quart_factory.py @@ -0,0 +1,238 @@ +from .base_factory import BaseServerFactory +from quart import Quart, request, Response as QuartResponse, jsonify, send_from_directory +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.server_factories import set_request_adapter +from dash.fingerprint import check_fingerprint +from dash import _validate +from contextvars import copy_context +import inspect +import os +import pkgutil +import mimetypes +import hashlib +import sys + + +class QuartAPIServerFactory(BaseServerFactory): + """Quart implementation of the Dash server factory. + + All Quart/async specific imports are at the top-level (per user request) so + Quart must be installed when this module is imported. + """ + + def __init__(self) -> None: + self.config = {} + super().__init__() + + def __call__(self, server, *args, **kwargs): + # ASGI style (scope, receive, send) or standard call-through handled by BaseServerFactory + return super().__call__(server, *args, **kwargs) + + def create_app(self, name="__main__", config=None): + app = Quart(name) + if config: + for key, value in config.items(): + # Mirror Flask usage of config dict + app.config[key] = value + return app + + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): + if os.path.isdir(assets_folder): + route = f"{assets_url_path}/" + + @app.route(route) + async def serve_asset(filename): # pragma: no cover - simple passthrough + return await send_from_directory(assets_folder, filename) + + def register_error_handlers(self, app): + @app.errorhandler(PreventUpdate) + async def _prevent_update(_): + return "", 204 + + @app.errorhandler(InvalidResourceError) + async def _invalid_resource(err): + return err.args[0], 404 + + def _html_response_wrapper(self, view_func): + async def wrapped(*args, **kwargs): + html_val = view_func() if callable(view_func) else view_func + if inspect.iscoroutine(html_val): # handle async function returning html + html_val = await html_val + html = str(html_val) + return QuartResponse(html, content_type="text/html") + + return wrapped + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + if rule == "": + rule = "/" + if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): + # Wrap plain strings or sync callables in async handler returning HTML + if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): + view_func = self._html_response_wrapper(view_func) + app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) + + # ---- Index & Catchall ------------------------------------------------ + def setup_index(self, app, dash_app): + async def index(): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + return QuartResponse(dash_app.render_index(), content_type="text/html") + + self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + + def setup_catchall(self, app, dash_app): + @app.before_serving + async def _enable_dev_tools(): # pragma: no cover - environmental + dash_app.enable_dev_tools(**self.config) + + async def catchall(path): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + return QuartResponse(dash_app.render_index(), content_type="text/html") + + # Must be added after other routes + self.add_url_rule( + app, "/", catchall, endpoint="catchall", methods=["GET"] + ) + + # ---- Middleware-esque hooks ----------------------------------------- + def before_request(self, app, func): + app.before_request(func) + + def after_request(self, app, func): + # Quart after_request expects (response) -> response + @app.after_request + async def _after(response): + if func is not None: + result = func() + if inspect.iscoroutine(result): # Allow async hooks + await result + return response + + # ---- Running --------------------------------------------------------- + def run(self, app, host, port, debug, **kwargs): + self.config = dict({'debug': debug} if debug else {}, **kwargs) + app.run(host=host, port=port, debug=debug, **kwargs) + + # ---- Responses / JSON ------------------------------------------------ + def make_response(self, data, mimetype=None, content_type=None): + headers = {} + if mimetype: + headers["Content-Type"] = mimetype + if content_type: + headers["Content-Type"] = content_type + return QuartResponse(data, headers=headers) + + def jsonify(self, obj): + return jsonify(obj) + + def get_request_adapter(self): + return QuartRequestAdapter + + # ---- Component Suites ------------------------------------------------ + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path, req + ): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + getattr(package, "__version__", "unknown"), + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + return QuartResponse(data, content_type=mimetype, headers=headers) + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if req.headers.get("If-None-Match") == etag: + return QuartResponse(None, status=304) + return QuartResponse(data, content_type=mimetype, headers=headers) + + def setup_component_suites(self, app, dash_app): + async def serve(package_name, fingerprinted_path): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, request + ) + + self.add_url_rule( + app, + "/_dash-component-suites//", + serve, + methods=["GET"], + ) + + # ---- Dispatch (Callbacks) ------------------------------------------- + def dispatch(self, app, dash_app, use_async=True): # Quart always async + async def _dispatch(): + adapter = QuartRequestAdapter() + set_request_adapter(adapter) + body = await request.get_json() + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): # if user callback is async + response_data = await response_data + return QuartResponse(response_data, content_type="application/json") + + return _dispatch + + # ---- Favicon --------------------------------------------------------- + def _serve_default_favicon(self): + return QuartResponse( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" + ) + + +class QuartRequestAdapter: + """Adapter that normalizes Quart's request API to what Dash expects.""" + + @staticmethod + def get_args(): + return request.args + + @staticmethod + async def get_json(): + return await request.get_json() + + @staticmethod + def is_json(): + return request.is_json + + @staticmethod + def get_cookies(): + return request.cookies + + @staticmethod + def get_headers(): + return request.headers + + @staticmethod + def get_full_path(): + return request.full_path + + @staticmethod + def get_remote_addr(): + return request.remote_addr + + @staticmethod + def get_origin(): + return request.headers.get("Origin") + + @staticmethod + def get_path(): + return request.path + From 141527c8b8c25c7b1a47f7dc9eded53135e2ce94 Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Fri, 12 Sep 2025 15:28:49 +0200 Subject: [PATCH 026/166] Quart factory ready --- dash/server_factories/quart_factory.py | 154 +++++++++++++++++-------- 1 file changed, 107 insertions(+), 47 deletions(-) diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py index 977c9aea4c..75376fcd7a 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/server_factories/quart_factory.py @@ -1,5 +1,5 @@ from .base_factory import BaseServerFactory -from quart import Quart, request, Response as QuartResponse, jsonify, send_from_directory +from quart import Quart, request, Response, jsonify, send_from_directory from dash.exceptions import PreventUpdate, InvalidResourceError from dash.server_factories import set_request_adapter from dash.fingerprint import check_fingerprint @@ -11,6 +11,7 @@ import mimetypes import hashlib import sys +import time class QuartAPIServerFactory(BaseServerFactory): @@ -39,12 +40,50 @@ def create_app(self, name="__main__", config=None): def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - if os.path.isdir(assets_folder): - route = f"{assets_url_path}/" + # Mirror Flask implementation using a blueprint serving static files + from quart import Blueprint + + bp = Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + app.register_blueprint(bp) + + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.errorhandler(Exception) + async def _wrap_errors(_error_request, error): + tb = get_traceback_func(secret, error) + return tb, 500 + + def register_timing_hooks(self, app, _first_run): # parity with Flask factory + from quart import g + + @app.before_request + async def _before_request(): # pragma: no cover - timing infra + g.timing_information = {"__dash_server": {"dur": time.time(), "desc": None}} - @app.route(route) - async def serve_asset(filename): # pragma: no cover - simple passthrough - return await send_from_directory(assets_folder, filename) + @app.after_request + async def _after_request(response): # pragma: no cover - timing infra + timing_information = getattr(g, "timing_information", None) + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + # Quart/Werkzeug headers expose 'add' (not 'append') + if hasattr(response.headers, "add"): + response.headers.add("Server-Timing", value) + else: # fallback just in case + response.headers["Server-Timing"] = value + return response def register_error_handlers(self, app): @app.errorhandler(PreventUpdate) @@ -61,49 +100,72 @@ async def wrapped(*args, **kwargs): if inspect.iscoroutine(html_val): # handle async function returning html html_val = await html_val html = str(html_val) - return QuartResponse(html, content_type="text/html") + return Response(html, content_type="text/html") return wrapped def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - if rule == "": - rule = "/" - if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): - # Wrap plain strings or sync callables in async handler returning HTML - if isinstance(view_func, str) or not inspect.iscoroutinefunction(view_func): - view_func = self._html_response_wrapper(view_func) - app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) - - # ---- Index & Catchall ------------------------------------------------ + app.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) + + # def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + # if rule == "": + # rule = "/" + # if isinstance(view_func, str): + # # Literal HTML content + # view_func = self._html_response_wrapper(view_func) + # elif not inspect.iscoroutinefunction(view_func): + # # Sync function: wrap to make async but preserve Response objects + # original = view_func + + # async def _async_adapter(*args, **kwargs): + # result = original(*args, **kwargs) + # # Pass through existing Response (Quart/Flask style) + # if isinstance(result, Response) or ( + # hasattr(result, "status_code") + # and hasattr(result, "headers") + # and hasattr(result, "get_data") + # ): + # return result + # # If it's bytes or str treat as HTML + # if isinstance(result, (str, bytes)): + # return Response(result, content_type="text/html") + # # Fallback: JSON encode arbitrary python objects + # try: + # import json + + # return Response( + # json.dumps(result), content_type="application/json" + # ) + # except Exception: # pragma: no cover + # return Response(str(result), content_type="text/plain") + + # view_func = _async_adapter + # app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) + def setup_index(self, app, dash_app): async def index(): adapter = QuartRequestAdapter() set_request_adapter(adapter) - return QuartResponse(dash_app.render_index(), content_type="text/html") + return Response(dash_app.render_index(), content_type="text/html") self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) def setup_catchall(self, app, dash_app): - @app.before_serving - async def _enable_dev_tools(): # pragma: no cover - environmental - dash_app.enable_dev_tools(**self.config) - async def catchall(path): adapter = QuartRequestAdapter() set_request_adapter(adapter) - return QuartResponse(dash_app.render_index(), content_type="text/html") + return Response(dash_app.render_index(), content_type="text/html") - # Must be added after other routes self.add_url_rule( app, "/", catchall, endpoint="catchall", methods=["GET"] ) - # ---- Middleware-esque hooks ----------------------------------------- def before_request(self, app, func): app.before_request(func) def after_request(self, app, func): - # Quart after_request expects (response) -> response @app.after_request async def _after(response): if func is not None: @@ -112,19 +174,24 @@ async def _after(response): await result return response - # ---- Running --------------------------------------------------------- def run(self, app, host, port, debug, **kwargs): - self.config = dict({'debug': debug} if debug else {}, **kwargs) - app.run(host=host, port=port, debug=debug, **kwargs) + # Store only dev tools related configuration (exclude server-only kwargs unsupported by Quart) + # Quart's run does NOT accept 'threaded' (Flask-specific). Drop silently (or log) if present. + unsupported = {"threaded", "processes"} + filtered_kwargs = {} + for k, v in kwargs.items(): + if k in unsupported: + continue + filtered_kwargs[k] = v + + # Keep a slim config for potential future use (dev tools already enabled in Dash.run) + self.config = {'debug': debug} + self.config.update({k: v for k, v in filtered_kwargs.items() if k.startswith('dev_tools_')}) + + app.run(host=host, port=port, debug=debug, **filtered_kwargs) - # ---- Responses / JSON ------------------------------------------------ def make_response(self, data, mimetype=None, content_type=None): - headers = {} - if mimetype: - headers["Content-Type"] = mimetype - if content_type: - headers["Content-Type"] = content_type - return QuartResponse(data, headers=headers) + return Response(data, mimetype=mimetype, content_type=content_type) def jsonify(self, obj): return jsonify(obj) @@ -132,7 +199,6 @@ def jsonify(self, obj): def get_request_adapter(self): return QuartRequestAdapter - # ---- Component Suites ------------------------------------------------ def serve_component_suites( self, dash_app, package_name, fingerprinted_path, req ): @@ -152,12 +218,9 @@ def serve_component_suites( headers = {} if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" - return QuartResponse(data, content_type=mimetype, headers=headers) - etag = hashlib.md5(data).hexdigest() if data else "" - headers["ETag"] = etag - if req.headers.get("If-None-Match") == etag: - return QuartResponse(None, status=304) - return QuartResponse(data, content_type=mimetype, headers=headers) + return Response(data, content_type=mimetype, headers=headers) + + return Response(data, content_type=mimetype, headers=headers) def setup_component_suites(self, app, dash_app): async def serve(package_name, fingerprinted_path): @@ -172,7 +235,6 @@ async def serve(package_name, fingerprinted_path): methods=["GET"], ) - # ---- Dispatch (Callbacks) ------------------------------------------- def dispatch(self, app, dash_app, use_async=True): # Quart always async async def _dispatch(): adapter = QuartRequestAdapter() @@ -186,13 +248,12 @@ async def _dispatch(): response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async response_data = await response_data - return QuartResponse(response_data, content_type="application/json") + return Response(response_data, content_type="application/json") return _dispatch - # ---- Favicon --------------------------------------------------------- def _serve_default_favicon(self): - return QuartResponse( + return Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) @@ -235,4 +296,3 @@ def get_origin(): @staticmethod def get_path(): return request.path - From 3e38d4151414bca27449d03acf323b67f958e282 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 11:15:49 -0400 Subject: [PATCH 027/166] fixing `prune_errors` test --- tests/integration/devtools/test_devtools_error_handling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/devtools/test_devtools_error_handling.py b/tests/integration/devtools/test_devtools_error_handling.py index 40d5731202..b481ef2fad 100644 --- a/tests/integration/devtools/test_devtools_error_handling.py +++ b/tests/integration/devtools/test_devtools_error_handling.py @@ -109,14 +109,14 @@ def test_dveh006_long_python_errors(dash_duo): assert "in bad_sub" not in error0 # dash and flask part of the traceback ARE included # since we set dev_tools_prune_errors=False - assert "dash.py" in error0 + assert "factory.py" in error0 assert "self.wsgi_app" in error0 error1 = get_error_html(dash_duo, 1) assert "in update_output" in error1 assert "in bad_sub" in error1 assert "ZeroDivisionError" in error1 - assert "dash.py" in error1 + assert "factory.py" in error1 assert "self.wsgi_app" in error1 From 381fb0c135b0f3b7a7106afe307d7e8a5866c65a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 15:40:34 -0400 Subject: [PATCH 028/166] adjustments for flask api_endpoint declared in callback defs --- dash/dash.py | 26 ++------------------------ dash/server_factories/flask_factory.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 973f1ec579..bed7ab43a4 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -793,30 +793,8 @@ def setup_apis(self): ) self.callback_api_paths[k] = _callback.GLOBAL_API_PATHS.pop(k) - def make_parse_body(func): - def _parse_body(): - if flask.request.is_json: - data = flask.request.get_json() - return flask.jsonify(func(**data)) - return flask.jsonify({}) - - return _parse_body - - def make_parse_body_async(func): - async def _parse_body_async(): - if flask.request.is_json: - data = flask.request.get_json() - result = await func(**data) - return flask.jsonify(result) - return flask.jsonify({}) - - return _parse_body_async - - for path, func in self.callback_api_paths.items(): - if asyncio.iscoroutinefunction(func): - self._add_url(path, make_parse_body_async(func), ["POST"]) - else: - self._add_url(path, make_parse_body(func), ["POST"]) + # Delegate to the server factory for route registration + self.server_factory.register_callback_api_routes(self.server, self.callback_api_paths) def _setup_plotlyjs(self): # pylint: disable=import-outside-toplevel diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py index 9ba8a5017c..a488a070e1 100644 --- a/dash/server_factories/flask_factory.py +++ b/dash/server_factories/flask_factory.py @@ -5,6 +5,7 @@ import mimetypes import time import flask +import inspect from dash.fingerprint import check_fingerprint from dash import _validate from dash.exceptions import PreventUpdate, InvalidResourceError @@ -200,6 +201,31 @@ def _after_request(response): self.before_request(app, _before_request) self.after_request(app, _after_request) + def register_callback_api_routes(self, app, callback_api_paths): + """ + Register callback API endpoints on the Flask app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + The view function parses the JSON body and passes it to the handler. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + + if inspect.iscoroutinefunction(handler): + async def view_func(*args, handler=handler, **kwargs): + data = flask.request.get_json() + result = await handler(**data) if data else await handler() + return flask.jsonify(result) + else: + def view_func(*args, handler=handler, **kwargs): + data = flask.request.get_json() + result = handler(**data) if data else handler() + return flask.jsonify(result) + + # Flask 2.x+ supports async views natively + app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) + class FlaskRequestAdapter: @staticmethod From a27927a2dfde473d5eb2215fbb035a4b6597c44b Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Fri, 12 Sep 2025 21:49:49 +0200 Subject: [PATCH 029/166] updated QuartRequestAdapter & QuartFactory to latest changes --- dash/server_factories/quart_factory.py | 147 ++++++++----------------- 1 file changed, 47 insertions(+), 100 deletions(-) diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py index 75376fcd7a..99c9c2e5a0 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/server_factories/quart_factory.py @@ -1,15 +1,13 @@ from .base_factory import BaseServerFactory -from quart import Quart, request, Response, jsonify, send_from_directory +from quart import Quart, Request, Response, jsonify, request from dash.exceptions import PreventUpdate, InvalidResourceError from dash.server_factories import set_request_adapter from dash.fingerprint import check_fingerprint from dash import _validate from contextvars import copy_context import inspect -import os import pkgutil import mimetypes -import hashlib import sys import time @@ -26,21 +24,18 @@ def __init__(self) -> None: super().__init__() def __call__(self, server, *args, **kwargs): - # ASGI style (scope, receive, send) or standard call-through handled by BaseServerFactory return super().__call__(server, *args, **kwargs) def create_app(self, name="__main__", config=None): app = Quart(name) if config: for key, value in config.items(): - # Mirror Flask usage of config dict app.config[key] = value return app def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - # Mirror Flask implementation using a blueprint serving static files from quart import Blueprint bp = Blueprint( @@ -109,58 +104,23 @@ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) - # def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - # if rule == "": - # rule = "/" - # if isinstance(view_func, str): - # # Literal HTML content - # view_func = self._html_response_wrapper(view_func) - # elif not inspect.iscoroutinefunction(view_func): - # # Sync function: wrap to make async but preserve Response objects - # original = view_func - - # async def _async_adapter(*args, **kwargs): - # result = original(*args, **kwargs) - # # Pass through existing Response (Quart/Flask style) - # if isinstance(result, Response) or ( - # hasattr(result, "status_code") - # and hasattr(result, "headers") - # and hasattr(result, "get_data") - # ): - # return result - # # If it's bytes or str treat as HTML - # if isinstance(result, (str, bytes)): - # return Response(result, content_type="text/html") - # # Fallback: JSON encode arbitrary python objects - # try: - # import json - - # return Response( - # json.dumps(result), content_type="application/json" - # ) - # except Exception: # pragma: no cover - # return Response(str(result), content_type="text/plain") - - # view_func = _async_adapter - # app.add_url_rule(rule, endpoint or rule, view_func, methods=methods or ["GET"]) - - def setup_index(self, app, dash_app): + def setup_index(self, dash_app): async def index(): adapter = QuartRequestAdapter() set_request_adapter(adapter) - return Response(dash_app.render_index(), content_type="text/html") + adapter.set_request(request) + return Response(dash_app.index(), content_type="text/html") - self.add_url_rule(app, "/", index, endpoint="index", methods=["GET"]) + dash_app._add_url("", index, methods=["GET"]) - def setup_catchall(self, app, dash_app): - async def catchall(path): + def setup_catchall(self, dash_app): + async def catchall(path): # noqa: ARG001 - path is unused but kept for route signature adapter = QuartRequestAdapter() set_request_adapter(adapter) - return Response(dash_app.render_index(), content_type="text/html") + adapter.set_request(request) + return Response(dash_app.index(), content_type="text/html") - self.add_url_rule( - app, "/", catchall, endpoint="catchall", methods=["GET"] - ) + dash_app._add_url("", catchall, methods=["GET"]) def before_request(self, app, func): app.before_request(func) @@ -175,20 +135,8 @@ async def _after(response): return response def run(self, app, host, port, debug, **kwargs): - # Store only dev tools related configuration (exclude server-only kwargs unsupported by Quart) - # Quart's run does NOT accept 'threaded' (Flask-specific). Drop silently (or log) if present. - unsupported = {"threaded", "processes"} - filtered_kwargs = {} - for k, v in kwargs.items(): - if k in unsupported: - continue - filtered_kwargs[k] = v - - # Keep a slim config for potential future use (dev tools already enabled in Dash.run) - self.config = {'debug': debug} - self.config.update({k: v for k, v in filtered_kwargs.items() if k.startswith('dev_tools_')}) - - app.run(host=host, port=port, debug=debug, **filtered_kwargs) + self.config = {'debug': debug, **kwargs} if debug else kwargs + app.run(host=host, port=port, debug=debug, **kwargs) def make_response(self, data, mimetype=None, content_type=None): return Response(data, mimetype=mimetype, content_type=content_type) @@ -199,9 +147,7 @@ def jsonify(self, obj): def get_request_adapter(self): return QuartRequestAdapter - def serve_component_suites( - self, dash_app, package_name, fingerprinted_path, req - ): + def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req): # noqa: ARG002 unused req preserved for interface parity path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -222,23 +168,22 @@ def serve_component_suites( return Response(data, content_type=mimetype, headers=headers) - def setup_component_suites(self, app, dash_app): + def setup_component_suites(self, dash_app): async def serve(package_name, fingerprinted_path): return self.serve_component_suites( dash_app, package_name, fingerprinted_path, request ) - self.add_url_rule( - app, - "/_dash-component-suites//", + dash_app._add_url( + "_dash-component-suites//", serve, - methods=["GET"], ) def dispatch(self, app, dash_app, use_async=True): # Quart always async async def _dispatch(): adapter = QuartRequestAdapter() set_request_adapter(adapter) + adapter.set_request(request) body = await request.get_json() g = dash_app._initialize_context(body, adapter) func = dash_app._prepare_callback(g, body) @@ -259,40 +204,42 @@ def _serve_default_favicon(self): class QuartRequestAdapter: - """Adapter that normalizes Quart's request API to what Dash expects.""" + def __init__(self) -> None: + self._request = None + + def set_request(self, request: Request) -> None: + self._request = request + + # Accessors (instance-based) + def get_root(self): + return self._request.root_url + + def get_args(self): + return self._request.args - @staticmethod - def get_args(): - return request.args + async def get_json(self): + return await self._request.get_json() - @staticmethod - async def get_json(): - return await request.get_json() + def is_json(self): + return self._request.is_json - @staticmethod - def is_json(): - return request.is_json + def get_cookies(self): + return self._request.cookies - @staticmethod - def get_cookies(): - return request.cookies + def get_headers(self): + return self._request.headers - @staticmethod - def get_headers(): - return request.headers + def get_full_path(self): + return self._request.full_path - @staticmethod - def get_full_path(): - return request.full_path + def get_url(self): + return str(self._request.url) - @staticmethod - def get_remote_addr(): - return request.remote_addr + def get_remote_addr(self): + return self._request.remote_addr - @staticmethod - def get_origin(): - return request.headers.get("Origin") + def get_origin(self): + return self._request.headers.get("origin") - @staticmethod - def get_path(): - return request.path + def get_path(self): + return self._request.path From 1824e110327740252dd944988763e0972371f7e9 Mon Sep 17 00:00:00 2001 From: Christian Giessel Date: Fri, 12 Sep 2025 22:06:43 +0200 Subject: [PATCH 030/166] Removed redundant Response return --- dash/server_factories/quart_factory.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dash/server_factories/quart_factory.py b/dash/server_factories/quart_factory.py index 99c9c2e5a0..53fa1cc469 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/server_factories/quart_factory.py @@ -164,7 +164,6 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req headers = {} if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" - return Response(data, content_type=mimetype, headers=headers) return Response(data, content_type=mimetype, headers=headers) From b14f6d276f039239725c7554decabd250c0f8975 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 16:21:18 -0400 Subject: [PATCH 031/166] fix for fastapi `api_endpoint` registering --- dash/server_factories/fastapi_factory.py | 48 ++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py index cf08f85d7f..1090c85050 100644 --- a/dash/server_factories/fastapi_factory.py +++ b/dash/server_factories/fastapi_factory.py @@ -14,13 +14,21 @@ from fastapi.staticfiles import StaticFiles from starlette.responses import Response as StarletteResponse from starlette.datastructures import MutableHeaders + from pydantic import create_model + from typing import Any, Optional except ImportError: uvicorn = None - FastAPI = Request = Response = None - JSONResponse = PlainTextResponse = None + FastAPI = None + Request = None + Response = None + JSONResponse = None + PlainTextResponse = None StaticFiles = None StarletteResponse = None MutableHeaders = None + create_model = None + Any = None + Optional = None from dash.fingerprint import check_fingerprint from dash import _validate @@ -109,7 +117,7 @@ async def catchall(request: Request): # pylint: disable=protected-access dash_app._add_url("{path:path}", catchall, methods=["GET"]) - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False): if rule == "": rule = "/" if isinstance(view_func, str): @@ -120,7 +128,7 @@ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): view_func, methods=methods or ["GET"], name=endpoint, - include_in_schema=False, + include_in_schema=include_in_schema, ) def before_request(self, app, func): @@ -286,6 +294,38 @@ async def timing_middleware(request, call_next): headers.append("Server-Timing", value) return response + def register_callback_api_routes(self, app, callback_api_paths): + """ + Register callback API endpoints on the FastAPI app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + Dynamically creates a Pydantic model for the handler's parameters and uses it as the body parameter. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + sig = inspect.signature(handler) + param_names = list(sig.parameters.keys()) + fields = {name: (Optional[Any], None) for name in param_names} + Model = create_model(f"Payload_{endpoint}", **fields) + + async def view_func(request: Request, body: Model): + kwargs = body.dict(exclude_unset=True) + if inspect.iscoroutinefunction(handler): + result = await handler(**kwargs) + else: + result = handler(**kwargs) + return JSONResponse(content=result) + + + app.add_api_route( + route, + view_func, + methods=methods, + name=endpoint, + include_in_schema=True, + ) + class FastAPIRequestAdapter: def __init__(self): From 5ef796bf7614bdf5aeb4a5bb3ad61c78f2b4bfb4 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:08:00 -0400 Subject: [PATCH 032/166] shifting from `server_factory` to `backend` --- dash/_callback.py | 2 +- .../quart_factory.py => backend/quart.py} | 30 +- dash/dash.py | 120 +++--- dash/server_factories/__init__.py | 12 - dash/server_factories/base_factory.py | 50 --- dash/server_factories/fastapi_factory.py | 370 ------------------ dash/server_factories/flask_factory.py | 273 ------------- 7 files changed, 103 insertions(+), 754 deletions(-) rename dash/{server_factories/quart_factory.py => backend/quart.py} (86%) delete mode 100644 dash/server_factories/__init__.py delete mode 100644 dash/server_factories/base_factory.py delete mode 100644 dash/server_factories/fastapi_factory.py delete mode 100644 dash/server_factories/flask_factory.py diff --git a/dash/_callback.py b/dash/_callback.py index bca8027fdd..6cc55b9162 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -6,7 +6,7 @@ import asyncio -from dash.server_factories import get_request_adapter +from dash.backend import get_request_adapter from .dependencies import ( handle_callback_args, diff --git a/dash/server_factories/quart_factory.py b/dash/backend/quart.py similarity index 86% rename from dash/server_factories/quart_factory.py rename to dash/backend/quart.py index 53fa1cc469..a2437811a4 100644 --- a/dash/server_factories/quart_factory.py +++ b/dash/backend/quart.py @@ -1,7 +1,7 @@ -from .base_factory import BaseServerFactory +from .base_server import BaseDashServer from quart import Quart, Request, Response, jsonify, request from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.server_factories import set_request_adapter +from dash.backend import set_request_adapter from dash.fingerprint import check_fingerprint from dash import _validate from contextvars import copy_context @@ -12,7 +12,7 @@ import time -class QuartAPIServerFactory(BaseServerFactory): +class QuartDashServer(BaseDashServer): """Quart implementation of the Dash server factory. All Quart/async specific imports are at the top-level (per user request) so @@ -196,6 +196,30 @@ async def _dispatch(): return _dispatch + def register_callback_api_routes(self, app, callback_api_paths): + """ + Register callback API endpoints on the Quart app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + The view function parses the JSON body and passes it to the handler. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + + if inspect.iscoroutinefunction(handler): + async def view_func(*args, handler=handler, **kwargs): + data = await request.get_json() + result = await handler(**data) if data else await handler() + return jsonify(result) + else: + async def view_func(*args, handler=handler, **kwargs): + data = await request.get_json() + result = handler(**data) if data else handler() + return jsonify(result) + + app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) + def _serve_default_favicon(self): return Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" diff --git a/dash/dash.py b/dash/dash.py index bed7ab43a4..0e7cbb25fa 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -39,7 +39,7 @@ ProxyError, DuplicateCallback, ) -from .server_factories import get_request_adapter +from .backend import get_request_adapter, get_backend from .version import __version__ from ._configs import get_combined_config, pathname_configs, pages_folder_config from ._utils import ( @@ -64,8 +64,7 @@ from . import _validate from . import _watch from . import _get_app -from .server_factories.flask_factory import FlaskServerFactory -from .server_factories.base_factory import BaseServerFactory +from .backend.flask import FlaskDashServer from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group @@ -156,6 +155,27 @@ except: # noqa: E722 page_container = None +def _is_flask_instance(obj): + try: + from flask import Flask + return isinstance(obj, Flask) + except ImportError: + return False + +def _is_fastapi_instance(obj): + try: + from fastapi import FastAPI + return isinstance(obj, FastAPI) + except ImportError: + return False + +def _is_quart_instance(obj): + try: + from quart import Quart + return isinstance(obj, Quart) + except ImportError: + return False + def _get_traceback(secret, error: Exception): try: @@ -249,6 +269,12 @@ class Dash(ObsoleteChecker): ``flask.Flask``: use this pre-existing Flask server. :type server: boolean or flask.Flask + :param backend: The backend to use for the Dash app. Can be a string + (name of the backend) or a backend class. Default is None, which + selects the Flask backend. Currently, "flask" and "fastapi" backends + are supported. + :type backend: string or type + :param assets_folder: a path, relative to the current working directory, for extra files to be used in the browser. Default ``'assets'``. All .js and .css files will be loaded immediately unless excluded by @@ -431,6 +457,7 @@ def __init__( # pylint: disable=too-many-statements self, name: Optional[str] = None, server: Union[bool, Callable[[], Any]] = True, + backend: Union[str, type, None] = None, assets_folder: str = "assets", pages_folder: str = "pages", use_pages: Optional[bool] = None, @@ -466,7 +493,6 @@ def __init__( # pylint: disable=too-many-statements description: Optional[str] = None, on_error: Optional[Callable[[Exception], Any]] = None, use_async: Optional[bool] = None, - server_factory: Optional[BaseServerFactory] = None, **obsolete, ): @@ -489,29 +515,33 @@ def __init__( # pylint: disable=too-many-statements caller_name: str = name if name is not None else get_caller_name() - self.server_factory = server_factory or FlaskServerFactory() - - # We have 3 cases: server is either True (we create the server), False - # (defer server creation) or a Flask app instance (we use their server) - if callable(server) and not ( - hasattr(server, "route") and hasattr(server, "run") - ): - # Server factory function - self.server = server() - if name is None: - caller_name = getattr(self.server, "name", caller_name) - elif hasattr(server, "route") and hasattr(server, "run"): + # Determine backend + if backend is None: + backend_cls = FlaskDashServer + elif isinstance(backend, str): + backend_cls = get_backend(backend) + elif isinstance(backend, type): + backend_cls = backend + else: + raise ValueError("Invalid backend argument") + + # Determine server and backend instance + if server is not None and server is not True and server is not False: + # User provided a server instance (e.g., Flask, Quart, FastAPI) + if _is_flask_instance(server): + backend_cls = get_backend("flask") + elif _is_quart_instance(server): + backend_cls = get_backend("quart") + elif _is_fastapi_instance(server): + backend_cls = get_backend("fastapi") + else: + raise ValueError("Unsupported server type") + self.backend = backend_cls() self.server = server - if name is None: - caller_name = getattr(server, "name", caller_name) - elif isinstance(server, bool): - self.server = ( - self.server_factory.create_app(caller_name) if server else None - ) else: - raise ValueError( - "server must be a Flask app, a boolean, or a server factory function" - ) + # No server instance provided, create backend and let backend create server + self.backend = backend_cls() + self.server = server base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix @@ -700,7 +730,7 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: self.server = app bp_prefix = config.routes_pathname_prefix.replace("/", "_").replace(".", "_") assets_blueprint_name = f"{bp_prefix}dash_assets" - self.server_factory.register_assets_blueprint( + self.backend.register_assets_blueprint( self.server, assets_blueprint_name, config.routes_pathname_prefix + self.config.assets_url_path.lstrip("/"), @@ -723,8 +753,8 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: raise ImportError( "To use the compress option, you need to install dash[compress]" ) from error - self.server_factory.register_error_handlers(self.server) - self.server_factory.before_request(self.server, self._setup_server) + self.backend.register_error_handlers(self.server) + self.backend.before_request(self.server, self._setup_server) self._setup_routes() _get_app.APP = self self.enable_pages() @@ -732,7 +762,7 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> None: full_name = self.config.routes_pathname_prefix + name - self.server_factory.add_url_rule( + self.backend.add_url_rule( self.server, full_name, view_func=view_func, @@ -742,21 +772,21 @@ def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> Non self.routes.append(full_name) def _setup_routes(self): - self.server_factory.setup_component_suites(self) + self.backend.setup_component_suites(self) self._add_url("_dash-layout", self.serve_layout) self._add_url("_dash-dependencies", self.dependencies) self._add_url( "_dash-update-component", - self.server_factory.dispatch(self.server, self, self._use_async), + self.backend.dispatch(self.server, self, self._use_async), ["POST"], ) self._add_url("_reload-hash", self.serve_reload_hash) self._add_url( "_favicon.ico", - self.server_factory._serve_default_favicon, # pylint: disable=protected-access + self.backend._serve_default_favicon, # pylint: disable=protected-access ) - self.server_factory.setup_index(self) - self.server_factory.setup_catchall(self) + self.backend.setup_index(self) + self.backend.setup_catchall(self) if jupyter_dash.active: self._add_url( @@ -794,7 +824,7 @@ def setup_apis(self): self.callback_api_paths[k] = _callback.GLOBAL_API_PATHS.pop(k) # Delegate to the server factory for route registration - self.server_factory.register_callback_api_routes(self.server, self.callback_api_paths) + self.backend.register_callback_api_routes(self.server, self.callback_api_paths) def _setup_plotlyjs(self): # pylint: disable=import-outside-toplevel @@ -866,7 +896,7 @@ def serve_layout(self): layout = hook(layout) # TODO - Set browser cache limit - pass hash into frontend - return self.server_factory.make_response( + return self.backend.make_response( to_json(layout), mimetype="application/json", ) @@ -930,7 +960,7 @@ def serve_reload_hash(self): _reload.hard = False _reload.changed_assets = [] - return self.server_factory.jsonify( + return self.backend.jsonify( { "reloadHash": _hash, "hard": hard, @@ -1241,7 +1271,7 @@ def interpolate_index(self, **kwargs): @with_app_context def dependencies(self): - return self.server_factory.make_response( + return self.backend.make_response( to_json(self._callback_list), content_type="application/json", ) @@ -1360,7 +1390,7 @@ def _initialize_context(self, body, adapter): {"prop_id": x, "value": g.input_values.get(x)} for x in body.get("changedPropIds", []) ] - g.dash_response = self.server_factory.make_response( + g.dash_response = self.backend.make_response( mimetype="application/json", data=None ) g.cookies = dict(adapter.get_cookies()) @@ -1736,7 +1766,7 @@ def display_content(path): For nested URLs, slashes are still included: `app.strip_relative_path('/page-1/sub-page-1/')` will return - `page-1/sub-page-1` + `page-1/sub-page-1 ``` """ return _get_paths.app_strip_relative_path( @@ -1993,12 +2023,12 @@ def enable_dev_tools( ) elif dev_tools.prune_errors: secret = gen_salt(20) - self.server_factory.register_prune_error_handler( + self.backend.register_prune_error_handler( self.server, secret, _get_traceback ) if debug and dev_tools.ui: - self.server_factory.register_timing_hooks(self.server, first_run) + self.backend.register_timing_hooks(self.server, first_run) if ( debug @@ -2282,7 +2312,7 @@ def verify_url_part(served_part, url_part, part_name): server_url=jupyter_server_url, ) else: - self.server_factory.run( + self.backend.run( self.server, host=host, port=port, debug=debug, **flask_run_options ) @@ -2447,7 +2477,7 @@ def update(pathname_, search_, **states): Input(_ID_STORE, "data"), ) - self.server_factory.before_request(self.server, router) + self.backend.before_request(self.server, router) def __call__(self, *args, **kwargs): - return self.server_factory.__call__(self.server, *args, **kwargs) + return self.backend.__call__(self.server, *args, **kwargs) diff --git a/dash/server_factories/__init__.py b/dash/server_factories/__init__.py deleted file mode 100644 index 1bfd497935..0000000000 --- a/dash/server_factories/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# python -import contextvars - -_request_adapter_var = contextvars.ContextVar("request_adapter") - - -def set_request_adapter(adapter): - _request_adapter_var.set(adapter) - - -def get_request_adapter(): - return _request_adapter_var.get() diff --git a/dash/server_factories/base_factory.py b/dash/server_factories/base_factory.py deleted file mode 100644 index 12088947c2..0000000000 --- a/dash/server_factories/base_factory.py +++ /dev/null @@ -1,50 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class BaseServerFactory(ABC): - def __call__(self, server, *args, **kwargs) -> Any: - # Default: WSGI - return server(*args, **kwargs) - - @abstractmethod - def create_app(self, name: str = "__main__", config=None) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def register_assets_blueprint( - self, app, blueprint_name: str, assets_url_path: str, assets_folder: str - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def register_error_handlers(self, app) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def add_url_rule(self, app, rule: str, view_func, endpoint=None, methods=None) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def before_request(self, app, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def after_request(self, app, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def run(self, app, host: str, port: int, debug: bool, **kwargs) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def make_response(self, data, mimetype=None, content_type=None) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def jsonify(self, obj) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def get_request_adapter(self) -> Any: # pragma: no cover - interface - pass diff --git a/dash/server_factories/fastapi_factory.py b/dash/server_factories/fastapi_factory.py deleted file mode 100644 index 1090c85050..0000000000 --- a/dash/server_factories/fastapi_factory.py +++ /dev/null @@ -1,370 +0,0 @@ -import sys -import mimetypes -import hashlib -import inspect -import pkgutil -from contextvars import copy_context -import importlib.util -import time - -try: - import uvicorn - from fastapi import FastAPI, Request, Response - from fastapi.responses import JSONResponse, PlainTextResponse - from fastapi.staticfiles import StaticFiles - from starlette.responses import Response as StarletteResponse - from starlette.datastructures import MutableHeaders - from pydantic import create_model - from typing import Any, Optional -except ImportError: - uvicorn = None - FastAPI = None - Request = None - Response = None - JSONResponse = None - PlainTextResponse = None - StaticFiles = None - StarletteResponse = None - MutableHeaders = None - create_model = None - Any = None - Optional = None - -from dash.fingerprint import check_fingerprint -from dash import _validate -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.server_factories import set_request_adapter -from .base_factory import BaseServerFactory - - -class FastAPIServerFactory(BaseServerFactory): - def __init__(self): - self.config = {} - super().__init__() - - def __call__(self, server, *args, **kwargs): - # ASGI: (scope, receive, send) - if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: - return server(*args, **kwargs) - raise TypeError("FastAPI app must be called with (scope, receive, send)") - - def create_app(self, name="__main__", config=None): - app = FastAPI() - if config: - for key, value in config.items(): - setattr(app.state, key, value) - return app - - def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder - ): - try: - app.mount( - assets_url_path, - StaticFiles(directory=assets_folder), - name=blueprint_name, - ) - except RuntimeError: - # directory doesnt exist - pass - - def register_error_handlers(self, app): - @app.exception_handler(PreventUpdate) - async def _handle_error(_request, _exc): - return Response(status_code=204) - - @app.exception_handler(InvalidResourceError) - async def _invalid_resources_handler(_request, exc): - return Response(content=exc.args[0], status_code=404) - - def register_prune_error_handler(self, app, secret, get_traceback_func): - @app.exception_handler(Exception) - async def _wrap_errors(_error_request, error): - tb = get_traceback_func(secret, error) - return PlainTextResponse(tb, status_code=500) - - def _html_response_wrapper(self, view_func): - async def wrapped(*_args, **_kwargs): - # If view_func is a function, call it; if it's a string, use it directly - html = view_func() if callable(view_func) else view_func - return Response(content=html, media_type="text/html") - - return wrapped - - def setup_index(self, dash_app): - async def index(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) - return Response(content=dash_app.index(), media_type="text/html") - - # pylint: disable=protected-access - dash_app._add_url("", index, methods=["GET"]) - - def setup_catchall(self, dash_app): - @dash_app.server.on_event("startup") - def _setup_catchall(): - dash_app.enable_dev_tools( - **self.config, first_run=False - ) # do this to make sure dev tools are enabled - - async def catchall(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) - return Response(content=dash_app.index(), media_type="text/html") - - # pylint: disable=protected-access - dash_app._add_url("{path:path}", catchall, methods=["GET"]) - - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False): - if rule == "": - rule = "/" - if isinstance(view_func, str): - # Wrap string or sync function to async FastAPI handler - view_func = self._html_response_wrapper(view_func) - app.add_api_route( - rule, - view_func, - methods=methods or ["GET"], - name=endpoint, - include_in_schema=include_in_schema, - ) - - def before_request(self, app, func): - # FastAPI does not have before_request, but we can use middleware - app.middleware("http")(self._make_before_middleware(func)) - - def after_request(self, app, func): - # FastAPI does not have after_request, but we can use middleware - app.middleware("http")(self._make_after_middleware(func)) - - def run(self, app, host, port, debug, **kwargs): - frame = inspect.stack()[2] - self.config = dict({"debug": debug} if debug else {}, **kwargs) - reload = debug - if reload: - # Dynamically determine the module name from the file path - file_path = frame.filename - module_name = importlib.util.spec_from_file_location("app", file_path).name - uvicorn.run( - f"{module_name}:app.server", - host=host, - port=port, - reload=reload, - **kwargs, - ) - else: - uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) - - def make_response(self, data, mimetype=None, content_type=None): - headers = {} - if mimetype: - headers["content-type"] = mimetype - if content_type: - headers["content-type"] = content_type - return Response(content=data, headers=headers) - - def jsonify(self, obj): - return JSONResponse(content=obj) - - def get_request_adapter(self): - return FastAPIRequestAdapter - - def _make_before_middleware(self, func): - async def middleware(request, call_next): - if func is not None: - if inspect.iscoroutinefunction(func): - await func() - else: - func() - response = await call_next(request) - return response - - return middleware - - def _make_after_middleware(self, func): - async def middleware(request, call_next): - response = await call_next(request) - if func is not None: - if inspect.iscoroutinefunction(func): - await func() - else: - func() - return response - - return middleware - - def serve_component_suites( - self, dash_app, package_name, fingerprinted_path, request - ): - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) - _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) - extension = "." + path_in_pkg.split(".")[-1] - mimetype = mimetypes.types_map.get(extension, "application/octet-stream") - package = sys.modules[package_name] - dash_app.logger.debug( - "serving -- package: %s[%s] resource: %s => location: %s", - package_name, - package.__version__, - path_in_pkg, - package.__path__, - ) - data = pkgutil.get_data(package_name, path_in_pkg) - headers = {} - if has_fingerprint: - headers["Cache-Control"] = "public, max-age=31536000" - return StarletteResponse(content=data, media_type=mimetype, headers=headers) - etag = hashlib.md5(data).hexdigest() if data else "" - headers["ETag"] = etag - if request.headers.get("if-none-match") == etag: - return StarletteResponse(status_code=304) - return StarletteResponse(content=data, media_type=mimetype, headers=headers) - - def setup_component_suites(self, dash_app): - async def serve(request: Request, package_name: str, fingerprinted_path: str): - return self.serve_component_suites( - dash_app, package_name, fingerprinted_path, request - ) - - # pylint: disable=protected-access - dash_app._add_url( - "_dash-component-suites//", - serve, - ) - - # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=False): - async def _dispatch(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) - # pylint: disable=protected-access - body = await request.json() - g = dash_app._initialize_context( - body, adapter - ) # pylint: disable=protected-access - func = dash_app._prepare_callback( - g, body - ) # pylint: disable=protected-access - args = dash_app._inputs_to_vals( - g.inputs_list + g.states_list - ) # pylint: disable=protected-access - ctx = copy_context() - partial_func = dash_app._execute_callback( - func, args, g.outputs_list, g - ) # pylint: disable=protected-access - response_data = ctx.run(partial_func) - if inspect.iscoroutine(response_data): - response_data = await response_data - # Instead of set_data, return a new Response - return Response(content=response_data, media_type="application/json") - - return _dispatch - - def _serve_default_favicon(self): - return Response( - content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" - ) - - def register_timing_hooks(self, app, first_run): - if not first_run: - return - - @app.middleware("http") - async def timing_middleware(request, call_next): - # Before request - request.state.timing_information = { - "__dash_server": {"dur": time.time(), "desc": None} - } - response = await call_next(request) - # After request - timing_information = getattr(request.state, "timing_information", None) - if timing_information is not None: - dash_total = timing_information.get("__dash_server", None) - if dash_total is not None: - dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) - headers = MutableHeaders(response.headers) - for name, info in timing_information.items(): - value = name - if info.get("desc") is not None: - value += f';desc="{info["desc"]}"' - if info.get("dur") is not None: - value += f";dur={info['dur']}" - headers.append("Server-Timing", value) - return response - - def register_callback_api_routes(self, app, callback_api_paths): - """ - Register callback API endpoints on the FastAPI app. - Each key in callback_api_paths is a route, each value is a handler (sync or async). - Dynamically creates a Pydantic model for the handler's parameters and uses it as the body parameter. - """ - for path, handler in callback_api_paths.items(): - endpoint = f"dash_callback_api_{path}" - route = path if path.startswith("/") else f"/{path}" - methods = ["POST"] - sig = inspect.signature(handler) - param_names = list(sig.parameters.keys()) - fields = {name: (Optional[Any], None) for name in param_names} - Model = create_model(f"Payload_{endpoint}", **fields) - - async def view_func(request: Request, body: Model): - kwargs = body.dict(exclude_unset=True) - if inspect.iscoroutinefunction(handler): - result = await handler(**kwargs) - else: - result = handler(**kwargs) - return JSONResponse(content=result) - - - app.add_api_route( - route, - view_func, - methods=methods, - name=endpoint, - include_in_schema=True, - ) - - -class FastAPIRequestAdapter: - def __init__(self): - self._request = None - - def set_request(self, request: Request): - self._request = request - - def get_root(self): - return str(self._request.base_url) - - def get_args(self): - return self._request.query_params - - async def get_json(self): - return await self._request.json() - - def is_json(self): - return self._request.headers.get("content-type", "").startswith( - "application/json" - ) - - def get_cookies(self, _request=None): - return self._request.cookies - - def get_headers(self): - return self._request.headers - - def get_full_path(self): - return str(self._request.url) - - def get_url(self): - return str(self._request.url) - - def get_remote_addr(self): - return self._request.client.host if self._request.client else None - - def get_origin(self): - return self._request.headers.get("origin") - - def get_path(self): - return self._request.url.path diff --git a/dash/server_factories/flask_factory.py b/dash/server_factories/flask_factory.py deleted file mode 100644 index a488a070e1..0000000000 --- a/dash/server_factories/flask_factory.py +++ /dev/null @@ -1,273 +0,0 @@ -from contextvars import copy_context -import asyncio -import pkgutil -import sys -import mimetypes -import time -import flask -import inspect -from dash.fingerprint import check_fingerprint -from dash import _validate -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.server_factories import set_request_adapter -from .base_factory import BaseServerFactory - - -class FlaskServerFactory(BaseServerFactory): - def __call__(self, server, *args, **kwargs): - # Always WSGI - return server(*args, **kwargs) - - def create_app(self, name="__main__", config=None): - app = flask.Flask(name) - if config: - app.config.update(config) - return app - - def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder - ): - bp = flask.Blueprint( - blueprint_name, - __name__, - static_folder=assets_folder, - static_url_path=assets_url_path, - ) - app.register_blueprint(bp) - - def register_error_handlers(self, app): - @app.errorhandler(PreventUpdate) - def _handle_error(_): - return "", 204 - - @app.errorhandler(InvalidResourceError) - def _invalid_resources_handler(err): - return err.args[0], 404 - - def register_prune_error_handler(self, app, secret, get_traceback_func): - @app.errorhandler(Exception) - def _wrap_errors(error): - tb = get_traceback_func(secret, error) - return tb, 500 - - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - app.add_url_rule( - rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] - ) - - def before_request(self, app, func): - app.before_request(func) - - def after_request(self, app, func): - app.after_request(func) - - def run(self, app, host, port, debug, **kwargs): - app.run(host=host, port=port, debug=debug, **kwargs) - - def make_response(self, data, mimetype=None, content_type=None): - return flask.Response(data, mimetype=mimetype, content_type=content_type) - - def jsonify(self, obj): - return flask.jsonify(obj) - - def get_request_adapter(self): - return FlaskRequestAdapter - - def setup_catchall(self, dash_app): - def catchall(*args, **kwargs): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - return dash_app.index(*args, **kwargs) - - # pylint: disable=protected-access - dash_app._add_url("", catchall, methods=["GET"]) - - def setup_index(self, dash_app): - def index(*args, **kwargs): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - return dash_app.index(*args, **kwargs) - - # pylint: disable=protected-access - dash_app._add_url("", index, methods=["GET"]) - - def serve_component_suites(self, dash_app, package_name, fingerprinted_path): - path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) - _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) - extension = "." + path_in_pkg.split(".")[-1] - mimetype = mimetypes.types_map.get(extension, "application/octet-stream") - package = sys.modules[package_name] - dash_app.logger.debug( - "serving -- package: %s[%s] resource: %s => location: %s", - package_name, - package.__version__, - path_in_pkg, - package.__path__, - ) - data = pkgutil.get_data(package_name, path_in_pkg) - response = flask.Response(data, mimetype=mimetype) - if has_fingerprint: - response.cache_control.max_age = 31536000 # 1 year - else: - response.add_etag() - tag = response.get_etag()[0] - request_etag = flask.request.headers.get("If-None-Match") - if f'"{tag}"' == request_etag: - response = flask.Response(None, status=304) - return response - - def setup_component_suites(self, dash_app): - def serve(package_name, fingerprinted_path): - return self.serve_component_suites( - dash_app, package_name, fingerprinted_path - ) - - # pylint: disable=protected-access - dash_app._add_url( - "_dash-component-suites//", - serve, - ) - - # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=False): - def _dispatch(): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - body = flask.request.get_json() - # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) - ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) - response_data = ctx.run(partial_func) - if asyncio.iscoroutine(response_data): - raise Exception( - "You are trying to use a coroutine without dash[async]. " - "Please install the dependencies via `pip install dash[async]` and ensure " - "that `use_async=False` is not being passed to the app." - ) - g.dash_response.set_data(response_data) - return g.dash_response - - async def _dispatch_async(): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - body = flask.request.get_json() - # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) - ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) - response_data = ctx.run(partial_func) - if asyncio.iscoroutine(response_data): - response_data = await response_data - g.dash_response.set_data(response_data) - return g.dash_response - - if use_async: - return _dispatch_async - return _dispatch - - def _serve_default_favicon(self): - - return flask.Response( - pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" - ) - - def register_timing_hooks(self, app, _first_run): - def _before_request(): - flask.g.timing_information = { - "__dash_server": {"dur": time.time(), "desc": None} - } - - def _after_request(response): - timing_information = flask.g.get("timing_information", None) - if timing_information is None: - return response - dash_total = timing_information.get("__dash_server", None) - if dash_total is not None: - dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) - for name, info in timing_information.items(): - value = name - if info.get("desc") is not None: - value += f';desc="{info["desc"]}"' - if info.get("dur") is not None: - value += f";dur={info['dur']}" - response.headers.add("Server-Timing", value) - return response - - self.before_request(app, _before_request) - self.after_request(app, _after_request) - - def register_callback_api_routes(self, app, callback_api_paths): - """ - Register callback API endpoints on the Flask app. - Each key in callback_api_paths is a route, each value is a handler (sync or async). - The view function parses the JSON body and passes it to the handler. - """ - for path, handler in callback_api_paths.items(): - endpoint = f"dash_callback_api_{path}" - route = path if path.startswith("/") else f"/{path}" - methods = ["POST"] - - if inspect.iscoroutinefunction(handler): - async def view_func(*args, handler=handler, **kwargs): - data = flask.request.get_json() - result = await handler(**data) if data else await handler() - return flask.jsonify(result) - else: - def view_func(*args, handler=handler, **kwargs): - data = flask.request.get_json() - result = handler(**data) if data else handler() - return flask.jsonify(result) - - # Flask 2.x+ supports async views natively - app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) - - -class FlaskRequestAdapter: - @staticmethod - def get_args(): - return flask.request.args - - @staticmethod - def get_root(): - return flask.request.url_root - - @staticmethod - def get_json(): - return flask.request.get_json() - - @staticmethod - def is_json(): - return flask.request.is_json - - @staticmethod - def get_cookies(): - return flask.request.cookies - - @staticmethod - def get_headers(): - return flask.request.headers - - @staticmethod - def get_url(): - return flask.request.url - - @staticmethod - def get_full_path(): - return flask.request.full_path - - @staticmethod - def get_remote_addr(): - return flask.request.remote_addr - - @staticmethod - def get_origin(): - return getattr(flask.request, "origin", None) - - @staticmethod - def get_path(): - return flask.request.path From a4ca566d6810cce00ced20ea5d1b975c39cdc36a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:13:23 -0400 Subject: [PATCH 033/166] adding missing files --- dash/backend/__init__.py | 13 ++ dash/backend/base_server.py | 50 +++++ dash/backend/fastapi.py | 370 ++++++++++++++++++++++++++++++++++++ dash/backend/flask.py | 273 ++++++++++++++++++++++++++ dash/backend/registry.py | 22 +++ 5 files changed, 728 insertions(+) create mode 100644 dash/backend/__init__.py create mode 100644 dash/backend/base_server.py create mode 100644 dash/backend/fastapi.py create mode 100644 dash/backend/flask.py create mode 100644 dash/backend/registry.py diff --git a/dash/backend/__init__.py b/dash/backend/__init__.py new file mode 100644 index 0000000000..497f2dca2d --- /dev/null +++ b/dash/backend/__init__.py @@ -0,0 +1,13 @@ +# python +import contextvars +from .registry import * + +_request_adapter_var = contextvars.ContextVar("request_adapter") + + +def set_request_adapter(adapter): + _request_adapter_var.set(adapter) + + +def get_request_adapter(): + return _request_adapter_var.get() diff --git a/dash/backend/base_server.py b/dash/backend/base_server.py new file mode 100644 index 0000000000..8c902f4248 --- /dev/null +++ b/dash/backend/base_server.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class BaseDashServer(ABC): + def __call__(self, server, *args, **kwargs) -> Any: + # Default: WSGI + return server(*args, **kwargs) + + @abstractmethod + def create_app(self, name: str = "__main__", config=None) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def register_assets_blueprint( + self, app, blueprint_name: str, assets_url_path: str, assets_folder: str + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def register_error_handlers(self, app) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def add_url_rule(self, app, rule: str, view_func, endpoint=None, methods=None) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def before_request(self, app, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def after_request(self, app, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def run(self, app, host: str, port: int, debug: bool, **kwargs) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def make_response(self, data, mimetype=None, content_type=None) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def jsonify(self, obj) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def get_request_adapter(self) -> Any: # pragma: no cover - interface + pass diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py new file mode 100644 index 0000000000..b2feeb446f --- /dev/null +++ b/dash/backend/fastapi.py @@ -0,0 +1,370 @@ +import sys +import mimetypes +import hashlib +import inspect +import pkgutil +from contextvars import copy_context +import importlib.util +import time + +try: + import uvicorn + from fastapi import FastAPI, Request, Response + from fastapi.responses import JSONResponse, PlainTextResponse + from fastapi.staticfiles import StaticFiles + from starlette.responses import Response as StarletteResponse + from starlette.datastructures import MutableHeaders + from pydantic import create_model + from typing import Any, Optional +except ImportError: + uvicorn = None + FastAPI = None + Request = None + Response = None + JSONResponse = None + PlainTextResponse = None + StaticFiles = None + StarletteResponse = None + MutableHeaders = None + create_model = None + Any = None + Optional = None + +from dash.fingerprint import check_fingerprint +from dash import _validate +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.backend import set_request_adapter +from .base_server import BaseDashServer + + +class FastAPIDashServer(BaseDashServer): + def __init__(self): + self.config = {} + super().__init__() + + def __call__(self, server, *args, **kwargs): + # ASGI: (scope, receive, send) + if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: + return server(*args, **kwargs) + raise TypeError("FastAPI app must be called with (scope, receive, send)") + + def create_app(self, name="__main__", config=None): + app = FastAPI() + if config: + for key, value in config.items(): + setattr(app.state, key, value) + return app + + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): + try: + app.mount( + assets_url_path, + StaticFiles(directory=assets_folder), + name=blueprint_name, + ) + except RuntimeError: + # directory doesnt exist + pass + + def register_error_handlers(self, app): + @app.exception_handler(PreventUpdate) + async def _handle_error(_request, _exc): + return Response(status_code=204) + + @app.exception_handler(InvalidResourceError) + async def _invalid_resources_handler(_request, exc): + return Response(content=exc.args[0], status_code=404) + + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.exception_handler(Exception) + async def _wrap_errors(_error_request, error): + tb = get_traceback_func(secret, error) + return PlainTextResponse(tb, status_code=500) + + def _html_response_wrapper(self, view_func): + async def wrapped(*_args, **_kwargs): + # If view_func is a function, call it; if it's a string, use it directly + html = view_func() if callable(view_func) else view_func + return Response(content=html, media_type="text/html") + + return wrapped + + def setup_index(self, dash_app): + async def index(request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) + return Response(content=dash_app.index(), media_type="text/html") + + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) + + def setup_catchall(self, dash_app): + @dash_app.server.on_event("startup") + def _setup_catchall(): + dash_app.enable_dev_tools( + **self.config, first_run=False + ) # do this to make sure dev tools are enabled + + async def catchall(request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) + return Response(content=dash_app.index(), media_type="text/html") + + # pylint: disable=protected-access + dash_app._add_url("{path:path}", catchall, methods=["GET"]) + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False): + if rule == "": + rule = "/" + if isinstance(view_func, str): + # Wrap string or sync function to async FastAPI handler + view_func = self._html_response_wrapper(view_func) + app.add_api_route( + rule, + view_func, + methods=methods or ["GET"], + name=endpoint, + include_in_schema=include_in_schema, + ) + + def before_request(self, app, func): + # FastAPI does not have before_request, but we can use middleware + app.middleware("http")(self._make_before_middleware(func)) + + def after_request(self, app, func): + # FastAPI does not have after_request, but we can use middleware + app.middleware("http")(self._make_after_middleware(func)) + + def run(self, app, host, port, debug, **kwargs): + frame = inspect.stack()[2] + self.config = dict({"debug": debug} if debug else {}, **kwargs) + reload = debug + if reload: + # Dynamically determine the module name from the file path + file_path = frame.filename + module_name = importlib.util.spec_from_file_location("app", file_path).name + uvicorn.run( + f"{module_name}:app.server", + host=host, + port=port, + reload=reload, + **kwargs, + ) + else: + uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) + + def make_response(self, data, mimetype=None, content_type=None): + headers = {} + if mimetype: + headers["content-type"] = mimetype + if content_type: + headers["content-type"] = content_type + return Response(content=data, headers=headers) + + def jsonify(self, obj): + return JSONResponse(content=obj) + + def get_request_adapter(self): + return FastAPIRequestAdapter + + def _make_before_middleware(self, func): + async def middleware(request, call_next): + if func is not None: + if inspect.iscoroutinefunction(func): + await func() + else: + func() + response = await call_next(request) + return response + + return middleware + + def _make_after_middleware(self, func): + async def middleware(request, call_next): + response = await call_next(request) + if func is not None: + if inspect.iscoroutinefunction(func): + await func() + else: + func() + return response + + return middleware + + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path, request + ): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + headers = {} + if has_fingerprint: + headers["Cache-Control"] = "public, max-age=31536000" + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + etag = hashlib.md5(data).hexdigest() if data else "" + headers["ETag"] = etag + if request.headers.get("if-none-match") == etag: + return StarletteResponse(status_code=304) + return StarletteResponse(content=data, media_type=mimetype, headers=headers) + + def setup_component_suites(self, dash_app): + async def serve(request: Request, package_name: str, fingerprinted_path: str): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path, request + ) + + # pylint: disable=protected-access + dash_app._add_url( + "_dash-component-suites//", + serve, + ) + + # pylint: disable=unused-argument + def dispatch(self, app, dash_app, use_async=False): + async def _dispatch(request: Request): + adapter = FastAPIRequestAdapter() + set_request_adapter(adapter) + adapter.set_request(request) + # pylint: disable=protected-access + body = await request.json() + g = dash_app._initialize_context( + body, adapter + ) # pylint: disable=protected-access + func = dash_app._prepare_callback( + g, body + ) # pylint: disable=protected-access + args = dash_app._inputs_to_vals( + g.inputs_list + g.states_list + ) # pylint: disable=protected-access + ctx = copy_context() + partial_func = dash_app._execute_callback( + func, args, g.outputs_list, g + ) # pylint: disable=protected-access + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): + response_data = await response_data + # Instead of set_data, return a new Response + return Response(content=response_data, media_type="application/json") + + return _dispatch + + def _serve_default_favicon(self): + return Response( + content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" + ) + + def register_timing_hooks(self, app, first_run): + if not first_run: + return + + @app.middleware("http") + async def timing_middleware(request, call_next): + # Before request + request.state.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } + response = await call_next(request) + # After request + timing_information = getattr(request.state, "timing_information", None) + if timing_information is not None: + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + headers = MutableHeaders(response.headers) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + headers.append("Server-Timing", value) + return response + + def register_callback_api_routes(self, app, callback_api_paths): + """ + Register callback API endpoints on the FastAPI app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + Dynamically creates a Pydantic model for the handler's parameters and uses it as the body parameter. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + sig = inspect.signature(handler) + param_names = list(sig.parameters.keys()) + fields = {name: (Optional[Any], None) for name in param_names} + Model = create_model(f"Payload_{endpoint}", **fields) + + async def view_func(request: Request, body: Model): + kwargs = body.dict(exclude_unset=True) + if inspect.iscoroutinefunction(handler): + result = await handler(**kwargs) + else: + result = handler(**kwargs) + return JSONResponse(content=result) + + + app.add_api_route( + route, + view_func, + methods=methods, + name=endpoint, + include_in_schema=True, + ) + + +class FastAPIRequestAdapter: + def __init__(self): + self._request = None + + def set_request(self, request: Request): + self._request = request + + def get_root(self): + return str(self._request.base_url) + + def get_args(self): + return self._request.query_params + + async def get_json(self): + return await self._request.json() + + def is_json(self): + return self._request.headers.get("content-type", "").startswith( + "application/json" + ) + + def get_cookies(self, _request=None): + return self._request.cookies + + def get_headers(self): + return self._request.headers + + def get_full_path(self): + return str(self._request.url) + + def get_url(self): + return str(self._request.url) + + def get_remote_addr(self): + return self._request.client.host if self._request.client else None + + def get_origin(self): + return self._request.headers.get("origin") + + def get_path(self): + return self._request.url.path diff --git a/dash/backend/flask.py b/dash/backend/flask.py new file mode 100644 index 0000000000..2d7d01af32 --- /dev/null +++ b/dash/backend/flask.py @@ -0,0 +1,273 @@ +from contextvars import copy_context +import asyncio +import pkgutil +import sys +import mimetypes +import time +import flask +import inspect +from dash.fingerprint import check_fingerprint +from dash import _validate +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.backend import set_request_adapter +from .base_server import BaseDashServer + + +class FlaskDashServer(BaseDashServer): + def __call__(self, server, *args, **kwargs): + # Always WSGI + return server(*args, **kwargs) + + def create_app(self, name="__main__", config=None): + app = flask.Flask(name) + if config: + app.config.update(config) + return app + + def register_assets_blueprint( + self, app, blueprint_name, assets_url_path, assets_folder + ): + bp = flask.Blueprint( + blueprint_name, + __name__, + static_folder=assets_folder, + static_url_path=assets_url_path, + ) + app.register_blueprint(bp) + + def register_error_handlers(self, app): + @app.errorhandler(PreventUpdate) + def _handle_error(_): + return "", 204 + + @app.errorhandler(InvalidResourceError) + def _invalid_resources_handler(err): + return err.args[0], 404 + + def register_prune_error_handler(self, app, secret, get_traceback_func): + @app.errorhandler(Exception) + def _wrap_errors(error): + tb = get_traceback_func(secret, error) + return tb, 500 + + def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): + app.add_url_rule( + rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] + ) + + def before_request(self, app, func): + app.before_request(func) + + def after_request(self, app, func): + app.after_request(func) + + def run(self, app, host, port, debug, **kwargs): + app.run(host=host, port=port, debug=debug, **kwargs) + + def make_response(self, data, mimetype=None, content_type=None): + return flask.Response(data, mimetype=mimetype, content_type=content_type) + + def jsonify(self, obj): + return flask.jsonify(obj) + + def get_request_adapter(self): + return FlaskRequestAdapter + + def setup_catchall(self, dash_app): + def catchall(*args, **kwargs): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + return dash_app.index(*args, **kwargs) + + # pylint: disable=protected-access + dash_app._add_url("", catchall, methods=["GET"]) + + def setup_index(self, dash_app): + def index(*args, **kwargs): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + return dash_app.index(*args, **kwargs) + + # pylint: disable=protected-access + dash_app._add_url("", index, methods=["GET"]) + + def serve_component_suites(self, dash_app, package_name, fingerprinted_path): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) + _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) + extension = "." + path_in_pkg.split(".")[-1] + mimetype = mimetypes.types_map.get(extension, "application/octet-stream") + package = sys.modules[package_name] + dash_app.logger.debug( + "serving -- package: %s[%s] resource: %s => location: %s", + package_name, + package.__version__, + path_in_pkg, + package.__path__, + ) + data = pkgutil.get_data(package_name, path_in_pkg) + response = flask.Response(data, mimetype=mimetype) + if has_fingerprint: + response.cache_control.max_age = 31536000 # 1 year + else: + response.add_etag() + tag = response.get_etag()[0] + request_etag = flask.request.headers.get("If-None-Match") + if f'"{tag}"' == request_etag: + response = flask.Response(None, status=304) + return response + + def setup_component_suites(self, dash_app): + def serve(package_name, fingerprinted_path): + return self.serve_component_suites( + dash_app, package_name, fingerprinted_path + ) + + # pylint: disable=protected-access + dash_app._add_url( + "_dash-component-suites//", + serve, + ) + + # pylint: disable=unused-argument + def dispatch(self, app, dash_app, use_async=False): + def _dispatch(): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + body = flask.request.get_json() + # pylint: disable=protected-access + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + raise Exception( + "You are trying to use a coroutine without dash[async]. " + "Please install the dependencies via `pip install dash[async]` and ensure " + "that `use_async=False` is not being passed to the app." + ) + g.dash_response.set_data(response_data) + return g.dash_response + + async def _dispatch_async(): + adapter = FlaskRequestAdapter() + set_request_adapter(adapter) + body = flask.request.get_json() + # pylint: disable=protected-access + g = dash_app._initialize_context(body, adapter) + func = dash_app._prepare_callback(g, body) + args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + response_data = ctx.run(partial_func) + if asyncio.iscoroutine(response_data): + response_data = await response_data + g.dash_response.set_data(response_data) + return g.dash_response + + if use_async: + return _dispatch_async + return _dispatch + + def _serve_default_favicon(self): + + return flask.Response( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" + ) + + def register_timing_hooks(self, app, _first_run): + def _before_request(): + flask.g.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } + + def _after_request(response): + timing_information = flask.g.get("timing_information", None) + if timing_information is None: + return response + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + response.headers.add("Server-Timing", value) + return response + + self.before_request(app, _before_request) + self.after_request(app, _after_request) + + def register_callback_api_routes(self, app, callback_api_paths): + """ + Register callback API endpoints on the Flask app. + Each key in callback_api_paths is a route, each value is a handler (sync or async). + The view function parses the JSON body and passes it to the handler. + """ + for path, handler in callback_api_paths.items(): + endpoint = f"dash_callback_api_{path}" + route = path if path.startswith("/") else f"/{path}" + methods = ["POST"] + + if inspect.iscoroutinefunction(handler): + async def view_func(*args, handler=handler, **kwargs): + data = flask.request.get_json() + result = await handler(**data) if data else await handler() + return flask.jsonify(result) + else: + def view_func(*args, handler=handler, **kwargs): + data = flask.request.get_json() + result = handler(**data) if data else handler() + return flask.jsonify(result) + + # Flask 2.x+ supports async views natively + app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) + + +class FlaskRequestAdapter: + @staticmethod + def get_args(): + return flask.request.args + + @staticmethod + def get_root(): + return flask.request.url_root + + @staticmethod + def get_json(): + return flask.request.get_json() + + @staticmethod + def is_json(): + return flask.request.is_json + + @staticmethod + def get_cookies(): + return flask.request.cookies + + @staticmethod + def get_headers(): + return flask.request.headers + + @staticmethod + def get_url(): + return flask.request.url + + @staticmethod + def get_full_path(): + return flask.request.full_path + + @staticmethod + def get_remote_addr(): + return flask.request.remote_addr + + @staticmethod + def get_origin(): + return getattr(flask.request, "origin", None) + + @staticmethod + def get_path(): + return flask.request.path diff --git a/dash/backend/registry.py b/dash/backend/registry.py new file mode 100644 index 0000000000..1b80da879f --- /dev/null +++ b/dash/backend/registry.py @@ -0,0 +1,22 @@ +import importlib + +_backend_imports = { + 'flask': ('dash.backend.flask', 'FlaskDashServer'), + 'fastapi': ('dash.backend.fastapi', 'FastAPIDashServer'), + 'quart': ('dash.backend.quart', 'QuartDashServer'), +} + +def register_backend(name, module_path, class_name): + """Register a new backend by name.""" + _backend_imports[name.lower()] = (module_path, class_name) + +def get_backend(name): + try: + module_name, class_name = _backend_imports[name.lower()] + module = importlib.import_module(module_name) + return getattr(module, class_name) + except KeyError: + raise ValueError(f"Unknown backend: {name}") + except (ImportError, AttributeError) as e: + raise ImportError(f"Could not import backend '{name}': {e}") + From 708773f3d4f21cc1ef61fdc244010959c9c8567b Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:15:24 -0400 Subject: [PATCH 034/166] fixing issue with server not declared --- dash/dash.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 0e7cbb25fa..22f79873bb 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -540,8 +540,10 @@ def __init__( # pylint: disable=too-many-statements self.server = server else: # No server instance provided, create backend and let backend create server + if server is True and backend_cls is None: + backend_cls = FlaskDashServer self.backend = backend_cls() - self.server = server + self.server = self.backend.create_app(caller_name) # type: ignore base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix From b7bcebaf442e10455987dd72b02c98a1e680578f Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:16:50 -0400 Subject: [PATCH 035/166] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 22f79873bb..2000114067 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -526,7 +526,7 @@ def __init__( # pylint: disable=too-many-statements raise ValueError("Invalid backend argument") # Determine server and backend instance - if server is not None and server is not True and server is not False: + if server not in (None, True, False): # User provided a server instance (e.g., Flask, Quart, FastAPI) if _is_flask_instance(server): backend_cls = get_backend("flask") From 9873079800f773ef09581159312b7f0b48209f67 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:17:01 -0400 Subject: [PATCH 036/166] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 2000114067..af9ab03139 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -1768,7 +1768,7 @@ def display_content(path): For nested URLs, slashes are still included: `app.strip_relative_path('/page-1/sub-page-1/')` will return - `page-1/sub-page-1 + `page-1/sub-page-1` ``` """ return _get_paths.app_strip_relative_path( From 9f4d291689c05cada98f4ad9fe380aa718ae56ac Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:17:16 -0400 Subject: [PATCH 037/166] Update dash/backend/quart.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/backend/quart.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/dash/backend/quart.py b/dash/backend/quart.py index a2437811a4..5bb568fe72 100644 --- a/dash/backend/quart.py +++ b/dash/backend/quart.py @@ -207,17 +207,21 @@ def register_callback_api_routes(self, app, callback_api_paths): route = path if path.startswith("/") else f"/{path}" methods = ["POST"] - if inspect.iscoroutinefunction(handler): - async def view_func(*args, handler=handler, **kwargs): - data = await request.get_json() - result = await handler(**data) if data else await handler() - return jsonify(result) - else: - async def view_func(*args, handler=handler, **kwargs): - data = await request.get_json() - result = handler(**data) if data else handler() - return jsonify(result) - + def _make_view_func(handler): + if inspect.iscoroutinefunction(handler): + async def async_view_func(*args, **kwargs): + data = await request.get_json() + result = await handler(**data) if data else await handler() + return jsonify(result) + return async_view_func + else: + async def sync_view_func(*args, **kwargs): + data = await request.get_json() + result = handler(**data) if data else handler() + return jsonify(result) + return sync_view_func + + view_func = _make_view_func(handler) app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) def _serve_default_favicon(self): From da86e8666b731375bf251a0f76fd5eb8d360b6a9 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:17:52 -0400 Subject: [PATCH 038/166] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index af9ab03139..242f4cbd91 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -540,8 +540,6 @@ def __init__( # pylint: disable=too-many-statements self.server = server else: # No server instance provided, create backend and let backend create server - if server is True and backend_cls is None: - backend_cls = FlaskDashServer self.backend = backend_cls() self.server = self.backend.create_app(caller_name) # type: ignore From 4c60740a5f98e2e2612cbb8c0dbf22d324345255 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:22:44 -0400 Subject: [PATCH 039/166] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 242f4cbd91..0e7be84128 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -529,13 +529,33 @@ def __init__( # pylint: disable=too-many-statements if server not in (None, True, False): # User provided a server instance (e.g., Flask, Quart, FastAPI) if _is_flask_instance(server): - backend_cls = get_backend("flask") + inferred_backend = "flask" elif _is_quart_instance(server): - backend_cls = get_backend("quart") + inferred_backend = "quart" elif _is_fastapi_instance(server): - backend_cls = get_backend("fastapi") + inferred_backend = "fastapi" else: raise ValueError("Unsupported server type") + # Validate that backend matches server type if both are provided + if backend is not None: + if isinstance(backend, str): + requested_backend = backend + elif isinstance(backend, type): + # get_backend returns the backend class for a string + # So we compare the class names + requested_backend = get_backend(inferred_backend).__name__.lower() + backend_name = backend.__name__.lower() + if backend_name != requested_backend: + raise ValueError( + f"Conflict between provided backend '{backend_name}' and server type '{inferred_backend}'." + ) + else: + raise ValueError("Invalid backend argument") + if isinstance(backend, str) and backend.lower() != inferred_backend: + raise ValueError( + f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." + ) + backend_cls = get_backend(inferred_backend) self.backend = backend_cls() self.server = server else: From 84cb5e52de9e5046009ce4913f160cbaf435cac0 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:25:37 -0400 Subject: [PATCH 040/166] update for caller_name --- dash/dash.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dash/dash.py b/dash/dash.py index 0e7be84128..4aa18bd9c0 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -558,6 +558,9 @@ def __init__( # pylint: disable=too-many-statements backend_cls = get_backend(inferred_backend) self.backend = backend_cls() self.server = server + # Update caller_name from server's name attribute if available + if hasattr(server, "name"): + caller_name = server.name else: # No server instance provided, create backend and let backend create server self.backend = backend_cls() From 29cf8232684cd881034ad5986c8de51bc580a9e4 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:27:30 -0400 Subject: [PATCH 041/166] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 4aa18bd9c0..05249fe583 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -543,11 +543,10 @@ def __init__( # pylint: disable=too-many-statements elif isinstance(backend, type): # get_backend returns the backend class for a string # So we compare the class names - requested_backend = get_backend(inferred_backend).__name__.lower() - backend_name = backend.__name__.lower() - if backend_name != requested_backend: + expected_backend_cls = get_backend(inferred_backend) + if backend is not expected_backend_cls: raise ValueError( - f"Conflict between provided backend '{backend_name}' and server type '{inferred_backend}'." + f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." ) else: raise ValueError("Invalid backend argument") From 5d0f4dced2c362eb89c1211ca31a296aa69eda26 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:30:45 -0400 Subject: [PATCH 042/166] Update dash/dash.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 05249fe583..ce44c936ef 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -271,7 +271,7 @@ class Dash(ObsoleteChecker): :param backend: The backend to use for the Dash app. Can be a string (name of the backend) or a backend class. Default is None, which - selects the Flask backend. Currently, "flask" and "fastapi" backends + selects the Flask backend. Currently, "flask", "fastapi", and "quart" backends are supported. :type backend: string or type From 86f452873a4c8c9068296154578e7a05121d5f4d Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:33:10 -0400 Subject: [PATCH 043/166] adjustments for matching types --- dash/dash.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 05249fe583..db77f3f4fb 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -538,9 +538,7 @@ def __init__( # pylint: disable=too-many-statements raise ValueError("Unsupported server type") # Validate that backend matches server type if both are provided if backend is not None: - if isinstance(backend, str): - requested_backend = backend - elif isinstance(backend, type): + if isinstance(backend, type): # get_backend returns the backend class for a string # So we compare the class names expected_backend_cls = get_backend(inferred_backend) @@ -548,9 +546,9 @@ def __init__( # pylint: disable=too-many-statements raise ValueError( f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." ) - else: + elif not isinstance(backend, str): raise ValueError("Invalid backend argument") - if isinstance(backend, str) and backend.lower() != inferred_backend: + elif backend.lower() != inferred_backend: raise ValueError( f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." ) From 2a88385f46fa57758f2ba1023526db892cbcadf7 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:33:38 -0400 Subject: [PATCH 044/166] Update dash/backend/registry.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/backend/registry.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dash/backend/registry.py b/dash/backend/registry.py index 1b80da879f..4aac7142ef 100644 --- a/dash/backend/registry.py +++ b/dash/backend/registry.py @@ -17,6 +17,8 @@ def get_backend(name): return getattr(module, class_name) except KeyError: raise ValueError(f"Unknown backend: {name}") - except (ImportError, AttributeError) as e: - raise ImportError(f"Could not import backend '{name}': {e}") + except ImportError as e: + raise ImportError(f"Could not import module '{module_name}' for backend '{name}': {e}") + except AttributeError as e: + raise AttributeError(f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}") From bc51c0d0269cf0fba3d25e4f07a26cfad0f3bf06 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:35:54 -0400 Subject: [PATCH 045/166] Update dash/backend/registry.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/backend/registry.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dash/backend/registry.py b/dash/backend/registry.py index 4aac7142ef..fb9c99cc2d 100644 --- a/dash/backend/registry.py +++ b/dash/backend/registry.py @@ -15,10 +15,10 @@ def get_backend(name): module_name, class_name = _backend_imports[name.lower()] module = importlib.import_module(module_name) return getattr(module, class_name) - except KeyError: - raise ValueError(f"Unknown backend: {name}") + except KeyError as e: + raise ValueError(f"Unknown backend: {name}") from e except ImportError as e: - raise ImportError(f"Could not import module '{module_name}' for backend '{name}': {e}") + raise ImportError(f"Could not import module '{module_name}' for backend '{name}': {e}") from e except AttributeError as e: - raise AttributeError(f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}") + raise AttributeError(f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}") from e From 1b4d0d3f767ab0ca6ffd2829d5b75c03642fbcef Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 17:38:03 -0400 Subject: [PATCH 046/166] fixing another type check --- dash/dash.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index aa9e4c51ca..a95f969faa 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -542,7 +542,10 @@ def __init__( # pylint: disable=too-many-statements # get_backend returns the backend class for a string # So we compare the class names expected_backend_cls = get_backend(inferred_backend) - if backend is not expected_backend_cls: + if ( + backend.__module__ != expected_backend_cls.__module__ + or backend.__name__ != expected_backend_cls.__name__ + ): raise ValueError( f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." ) From f867f98fd791ec790d755f75eb0a6e5b8a986117 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 21:03:26 -0400 Subject: [PATCH 047/166] fixing for lint --- dash/backend/__init__.py | 4 +- dash/backend/base_server.py | 16 +++++-- dash/backend/fastapi.py | 10 ++-- dash/backend/flask.py | 9 +++- dash/backend/quart.py | 91 +++++++++++++++++++++++-------------- dash/backend/registry.py | 17 ++++--- dash/dash.py | 11 ++++- 7 files changed, 108 insertions(+), 50 deletions(-) diff --git a/dash/backend/__init__.py b/dash/backend/__init__.py index 497f2dca2d..eb1d47bc3f 100644 --- a/dash/backend/__init__.py +++ b/dash/backend/__init__.py @@ -1,6 +1,8 @@ # python import contextvars -from .registry import * +from .registry import get_backend # pylint: disable=unused-import + +__all__ = ["set_request_adapter", "get_request_adapter", "get_backend"] _request_adapter_var = contextvars.ContextVar("request_adapter") diff --git a/dash/backend/base_server.py b/dash/backend/base_server.py index 8c902f4248..4855f86ad6 100644 --- a/dash/backend/base_server.py +++ b/dash/backend/base_server.py @@ -8,7 +8,9 @@ def __call__(self, server, *args, **kwargs) -> Any: return server(*args, **kwargs) @abstractmethod - def create_app(self, name: str = "__main__", config=None) -> Any: # pragma: no cover - interface + def create_app( + self, name: str = "__main__", config=None + ) -> Any: # pragma: no cover - interface pass @abstractmethod @@ -22,7 +24,9 @@ def register_error_handlers(self, app) -> None: # pragma: no cover - interface pass @abstractmethod - def add_url_rule(self, app, rule: str, view_func, endpoint=None, methods=None) -> None: # pragma: no cover - interface + def add_url_rule( + self, app, rule: str, view_func, endpoint=None, methods=None + ) -> None: # pragma: no cover - interface pass @abstractmethod @@ -34,11 +38,15 @@ def after_request(self, app, func) -> None: # pragma: no cover - interface pass @abstractmethod - def run(self, app, host: str, port: int, debug: bool, **kwargs) -> None: # pragma: no cover - interface + def run( + self, app, host: str, port: int, debug: bool, **kwargs + ) -> None: # pragma: no cover - interface pass @abstractmethod - def make_response(self, data, mimetype=None, content_type=None) -> Any: # pragma: no cover - interface + def make_response( + self, data, mimetype=None, content_type=None + ) -> Any: # pragma: no cover - interface pass @abstractmethod diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py index b2feeb446f..d283e90346 100644 --- a/dash/backend/fastapi.py +++ b/dash/backend/fastapi.py @@ -117,7 +117,9 @@ async def catchall(request: Request): # pylint: disable=protected-access dash_app._add_url("{path:path}", catchall, methods=["GET"]) - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False): + def add_url_rule( + self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False + ): if rule == "": rule = "/" if isinstance(view_func, str): @@ -307,8 +309,11 @@ def register_callback_api_routes(self, app, callback_api_paths): sig = inspect.signature(handler) param_names = list(sig.parameters.keys()) fields = {name: (Optional[Any], None) for name in param_names} - Model = create_model(f"Payload_{endpoint}", **fields) + Model = create_model( + f"Payload_{endpoint}", **fields + ) # pylint: disable=cell-var-from-loop + # pylint: disable=cell-var-from-loop async def view_func(request: Request, body: Model): kwargs = body.dict(exclude_unset=True) if inspect.iscoroutinefunction(handler): @@ -317,7 +322,6 @@ async def view_func(request: Request, body: Model): result = handler(**kwargs) return JSONResponse(content=result) - app.add_api_route( route, view_func, diff --git a/dash/backend/flask.py b/dash/backend/flask.py index 2d7d01af32..b48225a3c5 100644 --- a/dash/backend/flask.py +++ b/dash/backend/flask.py @@ -4,8 +4,8 @@ import sys import mimetypes import time -import flask import inspect +import flask from dash.fingerprint import check_fingerprint from dash import _validate from dash.exceptions import PreventUpdate, InvalidResourceError @@ -213,18 +213,23 @@ def register_callback_api_routes(self, app, callback_api_paths): methods = ["POST"] if inspect.iscoroutinefunction(handler): + async def view_func(*args, handler=handler, **kwargs): data = flask.request.get_json() result = await handler(**data) if data else await handler() return flask.jsonify(result) + else: + def view_func(*args, handler=handler, **kwargs): data = flask.request.get_json() result = handler(**data) if data else handler() return flask.jsonify(result) # Flask 2.x+ supports async views natively - app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) + app.add_url_rule( + route, endpoint=endpoint, view_func=view_func, methods=methods + ) class FlaskRequestAdapter: diff --git a/dash/backend/quart.py b/dash/backend/quart.py index 5bb568fe72..c3d42dadee 100644 --- a/dash/backend/quart.py +++ b/dash/backend/quart.py @@ -1,15 +1,25 @@ -from .base_server import BaseDashServer -from quart import Quart, Request, Response, jsonify, request -from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.backend import set_request_adapter -from dash.fingerprint import check_fingerprint -from dash import _validate -from contextvars import copy_context import inspect import pkgutil import mimetypes import sys import time +from contextvars import copy_context + +try: + import quart + from quart import Quart, Response, jsonify, request, Blueprint +except ImportError: + quart = None + Quart = None + Response = None + jsonify = None + request = None + Blueprint = None +from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.backend import set_request_adapter +from dash.fingerprint import check_fingerprint +from dash import _validate +from .base_server import BaseDashServer class QuartDashServer(BaseDashServer): @@ -24,7 +34,7 @@ def __init__(self) -> None: super().__init__() def __call__(self, server, *args, **kwargs): - return super().__call__(server, *args, **kwargs) + return server(*args, **kwargs) def create_app(self, name="__main__", config=None): app = Quart(name) @@ -36,8 +46,6 @@ def create_app(self, name="__main__", config=None): def register_assets_blueprint( self, app, blueprint_name, assets_url_path, assets_folder ): - from quart import Blueprint - bp = Blueprint( blueprint_name, __name__, @@ -53,15 +61,15 @@ async def _wrap_errors(_error_request, error): return tb, 500 def register_timing_hooks(self, app, _first_run): # parity with Flask factory - from quart import g - @app.before_request async def _before_request(): # pragma: no cover - timing infra - g.timing_information = {"__dash_server": {"dur": time.time(), "desc": None}} + quart.g.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } @app.after_request async def _after_request(response): # pragma: no cover - timing infra - timing_information = getattr(g, "timing_information", None) + timing_information = getattr(quart.g, "timing_information", None) if timing_information is None: return response dash_total = timing_information.get("__dash_server", None) @@ -90,7 +98,7 @@ async def _invalid_resource(err): return err.args[0], 404 def _html_response_wrapper(self, view_func): - async def wrapped(*args, **kwargs): + async def wrapped(*_args, **_kwargs): html_val = view_func() if callable(view_func) else view_func if inspect.iscoroutine(html_val): # handle async function returning html html_val = await html_val @@ -105,21 +113,25 @@ def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): ) def setup_index(self, dash_app): - async def index(): + async def index(*args, **kwargs): adapter = QuartRequestAdapter() set_request_adapter(adapter) - adapter.set_request(request) - return Response(dash_app.index(), content_type="text/html") + adapter.set_request() + return Response(dash_app.index(*args, **kwargs), content_type="text/html") + # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) def setup_catchall(self, dash_app): - async def catchall(path): # noqa: ARG001 - path is unused but kept for route signature + async def catchall( + path, *args, **kwargs + ): # noqa: ARG001 - path is unused but kept for route signature, pylint: disable=unused-argument adapter = QuartRequestAdapter() set_request_adapter(adapter) - adapter.set_request(request) - return Response(dash_app.index(), content_type="text/html") + adapter.set_request() + return Response(dash_app.index(*args, **kwargs), content_type="text/html") + # pylint: disable=protected-access dash_app._add_url("", catchall, methods=["GET"]) def before_request(self, app, func): @@ -135,7 +147,7 @@ async def _after(response): return response def run(self, app, host, port, debug, **kwargs): - self.config = {'debug': debug, **kwargs} if debug else kwargs + self.config = {"debug": debug, **kwargs} if debug else kwargs app.run(host=host, port=port, debug=debug, **kwargs) def make_response(self, data, mimetype=None, content_type=None): @@ -147,7 +159,9 @@ def jsonify(self, obj): def get_request_adapter(self): return QuartRequestAdapter - def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req): # noqa: ARG002 unused req preserved for interface parity + def serve_component_suites( + self, dash_app, package_name, fingerprinted_path + ): # noqa: ARG002 unused req preserved for interface parity path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -170,24 +184,30 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req def setup_component_suites(self, dash_app): async def serve(package_name, fingerprinted_path): return self.serve_component_suites( - dash_app, package_name, fingerprinted_path, request + dash_app, package_name, fingerprinted_path ) + # pylint: disable=protected-access dash_app._add_url( "_dash-component-suites//", serve, ) + # pylint: disable=unused-argument def dispatch(self, app, dash_app, use_async=True): # Quart always async async def _dispatch(): adapter = QuartRequestAdapter() set_request_adapter(adapter) - adapter.set_request(request) + adapter.set_request() body = await request.get_json() + # pylint: disable=protected-access g = dash_app._initialize_context(body, adapter) + # pylint: disable=protected-access func = dash_app._prepare_callback(g, body) + # pylint: disable=protected-access args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) ctx = copy_context() + # pylint: disable=protected-access partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async @@ -209,20 +229,25 @@ def register_callback_api_routes(self, app, callback_api_paths): def _make_view_func(handler): if inspect.iscoroutinefunction(handler): + async def async_view_func(*args, **kwargs): data = await request.get_json() result = await handler(**data) if data else await handler() return jsonify(result) + return async_view_func - else: - async def sync_view_func(*args, **kwargs): - data = await request.get_json() - result = handler(**data) if data else handler() - return jsonify(result) - return sync_view_func + + async def sync_view_func(*args, **kwargs): + data = await request.get_json() + result = handler(**data) if data else handler() + return jsonify(result) + + return sync_view_func view_func = _make_view_func(handler) - app.add_url_rule(route, endpoint=endpoint, view_func=view_func, methods=methods) + app.add_url_rule( + route, endpoint=endpoint, view_func=view_func, methods=methods + ) def _serve_default_favicon(self): return Response( @@ -234,7 +259,7 @@ class QuartRequestAdapter: def __init__(self) -> None: self._request = None - def set_request(self, request: Request) -> None: + def set_request(self) -> None: self._request = request # Accessors (instance-based) diff --git a/dash/backend/registry.py b/dash/backend/registry.py index fb9c99cc2d..4aae9fafc5 100644 --- a/dash/backend/registry.py +++ b/dash/backend/registry.py @@ -1,15 +1,17 @@ import importlib _backend_imports = { - 'flask': ('dash.backend.flask', 'FlaskDashServer'), - 'fastapi': ('dash.backend.fastapi', 'FastAPIDashServer'), - 'quart': ('dash.backend.quart', 'QuartDashServer'), + "flask": ("dash.backend.flask", "FlaskDashServer"), + "fastapi": ("dash.backend.fastapi", "FastAPIDashServer"), + "quart": ("dash.backend.quart", "QuartDashServer"), } + def register_backend(name, module_path, class_name): """Register a new backend by name.""" _backend_imports[name.lower()] = (module_path, class_name) + def get_backend(name): try: module_name, class_name = _backend_imports[name.lower()] @@ -18,7 +20,10 @@ def get_backend(name): except KeyError as e: raise ValueError(f"Unknown backend: {name}") from e except ImportError as e: - raise ImportError(f"Could not import module '{module_name}' for backend '{name}': {e}") from e + raise ImportError( + f"Could not import module '{module_name}' for backend '{name}': {e}" + ) from e except AttributeError as e: - raise AttributeError(f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}") from e - + raise AttributeError( + f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}" + ) from e diff --git a/dash/dash.py b/dash/dash.py index a95f969faa..18c933f08c 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -155,23 +155,32 @@ except: # noqa: E722 page_container = None + def _is_flask_instance(obj): try: + # pylint: disable=import-outside-toplevel from flask import Flask + return isinstance(obj, Flask) except ImportError: return False + def _is_fastapi_instance(obj): try: + # pylint: disable=import-outside-toplevel from fastapi import FastAPI + return isinstance(obj, FastAPI) except ImportError: return False + def _is_quart_instance(obj): try: + # pylint: disable=import-outside-toplevel from quart import Quart + return isinstance(obj, Quart) except ImportError: return False @@ -453,7 +462,7 @@ class Dash(ObsoleteChecker): _layout: Any _extra_components: Any - def __init__( # pylint: disable=too-many-statements + def __init__( # pylint: disable=too-many-statements, too-many-branches self, name: Optional[str] = None, server: Union[bool, Callable[[], Any]] = True, From 0ed81ce67c38bb49056b9ea90564217349ba2825 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 12 Sep 2025 21:06:32 -0400 Subject: [PATCH 048/166] fixing failing test --- tests/integration/devtools/test_devtools_error_handling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/devtools/test_devtools_error_handling.py b/tests/integration/devtools/test_devtools_error_handling.py index b481ef2fad..005bf8c335 100644 --- a/tests/integration/devtools/test_devtools_error_handling.py +++ b/tests/integration/devtools/test_devtools_error_handling.py @@ -109,14 +109,14 @@ def test_dveh006_long_python_errors(dash_duo): assert "in bad_sub" not in error0 # dash and flask part of the traceback ARE included # since we set dev_tools_prune_errors=False - assert "factory.py" in error0 + assert "backend" in error0 and "flask.py" in error0 assert "self.wsgi_app" in error0 error1 = get_error_html(dash_duo, 1) assert "in update_output" in error1 assert "in bad_sub" in error1 assert "ZeroDivisionError" in error1 - assert "factory.py" in error1 + assert "backend" in error1 and "flask.py" in error1 assert "self.wsgi_app" in error1 From 6bd342a4feed571dcb87a3eeba1e96c06be35226 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:20:12 -0400 Subject: [PATCH 049/166] fixing issue with fastapi and component suites --- dash/backend/fastapi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py index d283e90346..56f2761a3d 100644 --- a/dash/backend/fastapi.py +++ b/dash/backend/fastapi.py @@ -231,7 +231,7 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): # pylint: disable=protected-access dash_app._add_url( - "_dash-component-suites//", + "_dash-component-suites/{package_name}/{fingerprinted_path:path}", serve, ) From b1c99537c08a1d71ee544d5defaf9f021b2895b8 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 13 Sep 2025 00:29:10 -0400 Subject: [PATCH 050/166] adjustments to fix issues with caller_name and init the app a couple of times --- dash/dash.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 18c933f08c..19f789e7c0 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -567,9 +567,6 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches backend_cls = get_backend(inferred_backend) self.backend = backend_cls() self.server = server - # Update caller_name from server's name attribute if available - if hasattr(server, "name"): - caller_name = server.name else: # No server instance provided, create backend and let backend create server self.backend = backend_cls() @@ -703,9 +700,6 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # tracks internally if a function already handled at least one request. self._got_first_request = {"pages": False, "setup_server": False} - if self.server is not None: - self.init_app() - self.logger.setLevel(logging.INFO) if self.__class__.__name__ == "JupyterDash": From bd40b56c9dd797d365d840945b69c749010d4ec8 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 13 Sep 2025 07:52:35 -0400 Subject: [PATCH 051/166] adjustments for failing tests --- dash/_pages.py | 4 ++-- dash/_utils.py | 5 +++++ dash/dash.py | 10 ++++++++++ tests/integration/multi_page/test_pages_layout.py | 3 ++- 4 files changed, 19 insertions(+), 3 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index 6c00e656c7..acb26e8791 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -390,8 +390,8 @@ def _path_to_page(path_id): def _page_meta_tags(app, request): - request_url = request.get_path() - start_page, path_variables = _path_to_page(request_url.strip("/")) + request_path = request.get_path() + start_page, path_variables = _path_to_page(request_path.strip("/")) image = start_page.get("image", "") if image: diff --git a/dash/_utils.py b/dash/_utils.py index f118e61538..ef6c63c281 100644 --- a/dash/_utils.py +++ b/dash/_utils.py @@ -104,6 +104,11 @@ def set_read_only(self, names, msg="Attribute is read-only"): else: object.__setattr__(self, "_read_only", new_read_only) + def unset_read_only(self, keys): + if hasattr(self, "_read_only"): + for key in keys: + self._read_only.pop(key, None) + def finalize(self, msg="Object is final: No new keys may be added."): """Prevent any new keys being set.""" object.__setattr__(self, "_final", msg) diff --git a/dash/dash.py b/dash/dash.py index 19f789e7c0..c84b4476df 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -565,6 +565,8 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." ) backend_cls = get_backend(inferred_backend) + if name is None: + caller_name = getattr(server, "name", caller_name) self.backend = backend_cls() self.server = server else: @@ -700,6 +702,9 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # tracks internally if a function already handled at least one request. self._got_first_request = {"pages": False, "setup_server": False} + if self.server is not None: + self.init_app() + self.logger.setLevel(logging.INFO) if self.__class__.__name__ == "JupyterDash": @@ -743,6 +748,11 @@ def _setup_hooks(self): def init_app(self, app: Optional[Any] = None, **kwargs) -> None: config = self.config + config.unset_read_only([ + "url_base_pathname", + "routes_pathname_prefix", + "requests_pathname_prefix", + ]) config.update(kwargs) config.set_read_only( [ diff --git a/tests/integration/multi_page/test_pages_layout.py b/tests/integration/multi_page/test_pages_layout.py index 48751021b9..a209ae4517 100644 --- a/tests/integration/multi_page/test_pages_layout.py +++ b/tests/integration/multi_page/test_pages_layout.py @@ -3,6 +3,7 @@ from dash import Dash, Input, State, dcc, html, Output from dash.dash import _ID_LOCATION from dash.exceptions import NoLayoutException +from dash.testing.wait import until def get_app(path1="/", path2="/layout2"): @@ -57,7 +58,7 @@ def test_pala001_layout(dash_duo, clear_pages_state): for page in dash.page_registry.values(): dash_duo.find_element("#" + page["id"]).click() dash_duo.wait_for_text_to_equal("#text_" + page["id"], "text for " + page["id"]) - assert dash_duo.driver.title == page["title"], "check that page title updates" + until(lambda: dash_duo.driver.title == page["title"], timeout=3) # test redirects dash_duo.wait_for_page(url=f"{dash_duo.server_url}/v2") From 4e50430bd5abf0772afcaeb82aef2b08b4881642 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 13 Sep 2025 08:18:59 -0400 Subject: [PATCH 052/166] format dash --- dash/dash.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index c84b4476df..747901bd9a 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -748,11 +748,13 @@ def _setup_hooks(self): def init_app(self, app: Optional[Any] = None, **kwargs) -> None: config = self.config - config.unset_read_only([ - "url_base_pathname", - "routes_pathname_prefix", - "requests_pathname_prefix", - ]) + config.unset_read_only( + [ + "url_base_pathname", + "routes_pathname_prefix", + "requests_pathname_prefix", + ] + ) config.update(kwargs) config.set_read_only( [ From 0d32e651e3d88cf9b2874422bfb5a6925d6d3518 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sun, 14 Sep 2025 09:01:33 -0400 Subject: [PATCH 053/166] removing `FlaskDashServer` from import and using `get_backend('flask')` instead --- dash/dash.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 747901bd9a..d9ac42bddf 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -64,7 +64,6 @@ from . import _validate from . import _watch from . import _get_app -from .backend.flask import FlaskDashServer from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group @@ -526,7 +525,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # Determine backend if backend is None: - backend_cls = FlaskDashServer + backend_cls = get_backend('flask') elif isinstance(backend, str): backend_cls = get_backend(backend) elif isinstance(backend, type): From 1b3f61ea5d924015f4f1959b6f7cff56ccc134d6 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sun, 14 Sep 2025 09:07:53 -0400 Subject: [PATCH 054/166] reverting change to callable(title) process --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index d9ac42bddf..1963884072 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -2474,7 +2474,7 @@ def update(pathname_, search_, **states): **{**(path_variables or {}), **query_parameters, **states} ) if callable(title): - title = title(**{**(path_variables or {})}) + title = title(**(path_variables or {})) return layout, {"title": title} From c6805b5b6ac70b05ef0a73ce697fcf77f8a2d753 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sun, 14 Sep 2025 17:11:11 -0400 Subject: [PATCH 055/166] fixing for lint --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 1963884072..18ad1c2367 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -525,7 +525,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # Determine backend if backend is None: - backend_cls = get_backend('flask') + backend_cls = get_backend("flask") elif isinstance(backend, str): backend_cls = get_backend(backend) elif isinstance(backend, type): From 8c7808962c3c7914ab82df2ea74d5691704a35c7 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:18:30 -0400 Subject: [PATCH 056/166] adding custom error handling per backend, tests and adjustments to the flow. Made endpoints for downloading the reqs --- .github/workflows/testing.yml | 103 +++++++++ dash/backend/fastapi.py | 184 ++++++++++++--- dash/backend/flask.py | 54 ++++- dash/backend/quart.py | 121 +++++++++- .../error/FrontEnd/FrontEndError.react.js | 44 ++-- dash/dash.py | 68 +----- dash/testing/application_runners.py | 20 +- package.json | 2 +- requirements/fastapi.txt | 2 + requirements/quart.txt | 1 + tests/backend_tests/__init__.py | 0 .../backend_tests/test_preconfig_backends.py | 211 ++++++++++++++++++ 12 files changed, 688 insertions(+), 122 deletions(-) create mode 100644 requirements/fastapi.txt create mode 100644 requirements/quart.txt create mode 100644 tests/backend_tests/__init__.py create mode 100644 tests/backend_tests/test_preconfig_backends.py diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 1fc0df1845..068fe777d1 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -271,6 +271,109 @@ jobs: cd bgtests pytest --headless --nopercyfinalize tests/async_tests -v -s + backend-tests: + name: Run Backend Callback Tests (Python ${{ matrix.python-version }}) + needs: [build, changes_filter] + if: | + (github.event_name == 'push' && (github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev')) || + needs.changes_filter.outputs.backend_tests_changed == 'true' + timeout-minutes: 30 + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.12"] + + services: + redis: + image: redis:6 + ports: + - 6379:6379 + options: >- + --health-cmd "redis-cli ping" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + env: + REDIS_URL: redis://localhost:6379 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '20' + cache: 'npm' + + - name: Install Node.js dependencies + run: npm ci + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + - name: Download built Dash packages + uses: actions/download-artifact@v4 + with: + name: dash-packages + path: packages/ + + - name: Install Dash packages + run: | + python -m pip install --upgrade pip wheel + python -m pip install "setuptools<78.0.0" + python -m pip install "selenium==4.32.0" + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[async,ci,testing,dev,celery,diskcache, fastapi, quart]"' \; + + - name: Install Google Chrome + run: | + sudo apt-get update + sudo apt-get install -y google-chrome-stable + + - name: Install ChromeDriver + run: | + echo "Determining Chrome version..." + CHROME_BROWSER_VERSION=$(google-chrome --version) + echo "Installed Chrome Browser version: $CHROME_BROWSER_VERSION" + CHROME_MAJOR_VERSION=$(echo "$CHROME_BROWSER_VERSION" | cut -f 3 -d ' ' | cut -f 1 -d '.') + echo "Detected Chrome Major version: $CHROME_MAJOR_VERSION" + if [ "$CHROME_MAJOR_VERSION" -ge 115 ]; then + echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using CfT endpoint..." + CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://googlechromelabs.github.io/chrome-for-testing/LATEST_RELEASE_${CHROME_MAJOR_VERSION}") + if [ -z "$CHROMEDRIVER_VERSION_STRING" ]; then + echo "Could not automatically find ChromeDriver version for Chrome $CHROME_MAJOR_VERSION via LATEST_RELEASE. Please check CfT endpoints." + exit 1 + fi + CHROMEDRIVER_URL="https://edgedl.me.gvt1.com/edgedl/chrome/chrome-for-testing/${CHROMEDRIVER_VERSION_STRING}/linux64/chromedriver-linux64.zip" + else + echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using older method..." + CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://chromedriver.storage.googleapis.com/LATEST_RELEASE_${CHROME_MAJOR_VERSION}") + CHROMEDRIVER_URL="https://chromedriver.storage.googleapis.com/${CHROMEDRIVER_VERSION_STRING}/chromedriver_linux64.zip" + fi + echo "Using ChromeDriver version string: $CHROMEDRIVER_VERSION_STRING" + echo "Downloading ChromeDriver from: $CHROMEDRIVER_URL" + wget -q -O chromedriver.zip "$CHROMEDRIVER_URL" + unzip -o chromedriver.zip -d /tmp/ + sudo mv /tmp/chromedriver-linux64/chromedriver /usr/local/bin/chromedriver || sudo mv /tmp/chromedriver /usr/local/bin/chromedriver + sudo chmod +x /usr/local/bin/chromedriver + echo "/usr/local/bin" >> $GITHUB_PATH + shell: bash + + - name: Build/Setup test components + run: npm run setup-tests.py + + - name: Run Backend Callback Tests + run: | + mkdir bgtests + cp -r tests bgtests/tests + cd bgtests + touch __init__.py + pytest --headless --nopercyfinalize tests/backend_tests -v -s + table-unit: name: Table Unit/Lint Tests (Python ${{ matrix.python-version }}) needs: [build, changes_filter] diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py index 56f2761a3d..0afcfabd07 100644 --- a/dash/backend/fastapi.py +++ b/dash/backend/fastapi.py @@ -6,6 +6,7 @@ from contextvars import copy_context import importlib.util import time +import traceback try: import uvicorn @@ -32,14 +33,28 @@ from dash.fingerprint import check_fingerprint from dash import _validate -from dash.exceptions import PreventUpdate, InvalidResourceError +from dash.exceptions import PreventUpdate, InvalidResourceError, InvalidCallbackReturnValue, BackgroundCallbackError from dash.backend import set_request_adapter from .base_server import BaseDashServer +import json +import os + +CONFIG_PATH = "dash_config.json" + +def save_config(config): + with open(CONFIG_PATH, "w") as f: + json.dump(config, f) + +def load_config(): + if os.path.exists(CONFIG_PATH): + with open(CONFIG_PATH, "r") as f: + return json.load(f) + return {} class FastAPIDashServer(BaseDashServer): def __init__(self): - self.config = {} + self.error_handling_mode = "prune" super().__init__() def __call__(self, server, *args, **kwargs): @@ -69,19 +84,120 @@ def register_assets_blueprint( pass def register_error_handlers(self, app): - @app.exception_handler(PreventUpdate) - async def _handle_error(_request, _exc): - return Response(status_code=204) + self.error_handling_mode = "prune" + # FastAPI uses exception handlers, but we will handle errors in middleware + pass + + def _get_traceback(self, secret, error: Exception): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if self.error_handling_mode == "prune": + if not callback_handled: + if 'callback invoked' in str(err) and '_callback.py' in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + + # Parse traceback lines to group by file + import re + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split('\n') + current_file = None + card_lines = [] + + for i, line in enumerate(lines[:-1]): # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = f'{match.group(1)} (line {match.group(2)}, in {match.group(3)})' + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + + cards_html = "" + for filename, card in file_cards: + cards_html += f""" +
+
{filename}
+
"""+ '\n'.join(card) + """
+
+ """ + + html = f""" + + + + {error_type}: {error_msg} // FastAPI Debugger + + + +
+

{error_type}

+
+

{error_type}: {error_msg}

+
+

Traceback (most recent call last)

+ {cards_html} +
{error_type}: {error_msg}
+
+

This is the Copy/Paste friendly version of the traceback.

+ +
+
+ The debugger caught an exception in your ASGI application. You can now + look at the traceback which led to the error. +
+ +
+ + + """ + return html - @app.exception_handler(InvalidResourceError) - async def _invalid_resources_handler(_request, exc): - return Response(content=exc.args[0], status_code=404) + def register_prune_error_handler(self, _app, _secret, prune_errors): + if prune_errors: + self.error_handling_mode = "prune" + else: + self.error_handling_mode = "raise" - def register_prune_error_handler(self, app, secret, get_traceback_func): - @app.exception_handler(Exception) - async def _wrap_errors(_error_request, error): - tb = get_traceback_func(secret, error) - return PlainTextResponse(tb, status_code=500) def _html_response_wrapper(self, view_func): async def wrapped(*_args, **_kwargs): @@ -104,9 +220,10 @@ async def index(request: Request): def setup_catchall(self, dash_app): @dash_app.server.on_event("startup") def _setup_catchall(): + config = load_config() dash_app.enable_dev_tools( - **self.config, first_run=False - ) # do this to make sure dev tools are enabled + **config, first_run=False + ) async def catchall(request: Request): adapter = FastAPIRequestAdapter() @@ -141,11 +258,15 @@ def after_request(self, app, func): # FastAPI does not have after_request, but we can use middleware app.middleware("http")(self._make_after_middleware(func)) - def run(self, app, host, port, debug, **kwargs): + def run(self, dash_app, app, host, port, debug, **kwargs): frame = inspect.stack()[2] - self.config = dict({"debug": debug} if debug else {}, **kwargs) - reload = debug - if reload: + config = dict({"debug": debug} if debug else {}, **{ + f'dev_tools_{k}': v for k, v in dash_app._dev_tools.items()}) + save_config(config) + if debug: + if kwargs.get('reload') is None: + kwargs['reload'] = True + if kwargs.get('reload'): # Dynamically determine the module name from the file path file_path = frame.filename module_name = importlib.util.spec_from_file_location("app", file_path).name @@ -153,11 +274,10 @@ def run(self, app, host, port, debug, **kwargs): f"{module_name}:app.server", host=host, port=port, - reload=reload, **kwargs, ) else: - uvicorn.run(app, host=host, port=port, reload=reload, **kwargs) + uvicorn.run(app, host=host, port=port, **kwargs) def make_response(self, data, mimetype=None, content_type=None): headers = {} @@ -175,13 +295,21 @@ def get_request_adapter(self): def _make_before_middleware(self, func): async def middleware(request, call_next): - if func is not None: - if inspect.iscoroutinefunction(func): - await func() - else: - func() - response = await call_next(request) - return response + try: + response = await call_next(request) + return response + except PreventUpdate: + # No content, nothing to update + return Response(status_code=204) + except Exception as e: + if self.error_handling_mode in ["raise", "prune"]: + # Prune the traceback to remove internal Dash calls + tb = self._get_traceback(None, e) + return Response(content=tb, media_type='text/html', status_code=500) + return JSONResponse( + status_code=500, + content={"error": "InternalServerError", "message": str(e.args[0])}, + ) return middleware diff --git a/dash/backend/flask.py b/dash/backend/flask.py index b48225a3c5..75526e6feb 100644 --- a/dash/backend/flask.py +++ b/dash/backend/flask.py @@ -11,6 +11,7 @@ from dash.exceptions import PreventUpdate, InvalidResourceError from dash.backend import set_request_adapter from .base_server import BaseDashServer +import traceback class FlaskDashServer(BaseDashServer): @@ -44,11 +45,52 @@ def _handle_error(_): def _invalid_resources_handler(err): return err.args[0], 404 - def register_prune_error_handler(self, app, secret, get_traceback_func): - @app.errorhandler(Exception) - def _wrap_errors(error): - tb = get_traceback_func(secret, error) - return tb, 500 + def _get_traceback(self, secret, error: Exception): + try: + from werkzeug.debug import tbtools + except ImportError: + tbtools = None + + def _get_skip(error): + from dash._callback import _invoke_callback, _async_invoke_callback + + tb = error.__traceback__ + skip = 1 + while tb.tb_next is not None: + skip += 1 + tb = tb.tb_next + if tb.tb_frame.f_code in [ + _invoke_callback.__code__, + _async_invoke_callback.__code__, + ]: + return skip + return skip + + def _do_skip(error): + from dash._callback import _invoke_callback, _async_invoke_callback + + tb = error.__traceback__ + while tb.tb_next is not None: + if tb.tb_frame.f_code in [ + _invoke_callback.__code__, + _async_invoke_callback.__code__, + ]: + return tb.tb_next + tb = tb.tb_next + return error.__traceback__ + + if hasattr(tbtools, "get_current_traceback"): + return tbtools.get_current_traceback(skip=_get_skip(error)).render_full() + if hasattr(tbtools, "DebugTraceback"): + return tbtools.DebugTraceback(error, skip=_get_skip(error)).render_debugger_html(True, secret, True) + return "".join(traceback.format_exception(type(error), error, _do_skip(error))) + + def register_prune_error_handler(self, app, secret, prune_errors): + if prune_errors: + @app.errorhandler(Exception) + def _wrap_errors(error): + tb = self._get_traceback(secret, error) + return tb, 500 def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): app.add_url_rule( @@ -61,7 +103,7 @@ def before_request(self, app, func): def after_request(self, app, func): app.after_request(func) - def run(self, app, host, port, debug, **kwargs): + def run(self, _dash_app, app, host, port, debug, **kwargs): app.run(host=host, port=port, debug=debug, **kwargs) def make_response(self, data, mimetype=None, content_type=None): diff --git a/dash/backend/quart.py b/dash/backend/quart.py index c3d42dadee..40f30108b2 100644 --- a/dash/backend/quart.py +++ b/dash/backend/quart.py @@ -4,6 +4,7 @@ import sys import time from contextvars import copy_context +import traceback try: import quart @@ -31,6 +32,7 @@ class QuartDashServer(BaseDashServer): def __init__(self) -> None: self.config = {} + self.error_handling_mode = "prune" super().__init__() def __call__(self, server, *args, **kwargs): @@ -54,11 +56,120 @@ def register_assets_blueprint( ) app.register_blueprint(bp) - def register_prune_error_handler(self, app, secret, get_traceback_func): + def _get_traceback(self, secret, error: Exception): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if self.error_handling_mode == "prune": + if not callback_handled: + if 'callback invoked' in str(err) and '_callback.py' in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + + # Parse traceback lines to group by file + import re + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split('\n') + current_file = None + card_lines = [] + + for i, line in enumerate(lines[:-1]): # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = f'{match.group(1)} (line {match.group(2)}, in {match.group(3)})' + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + + cards_html = "" + for filename, card in file_cards: + cards_html += f""" +
+
{filename}
+
""" + '\n'.join(card) + """
+
+ """ + + html = f""" + + + + {error_type}: {error_msg} // Quart Debugger + + + +
+

{error_type}

+
+

{error_type}: {error_msg}

+
+

Traceback (most recent call last)

+ {cards_html} +
{error_type}: {error_msg}
+
+

This is the Copy/Paste friendly version of the traceback.

+ +
+
+ The debugger caught an exception in your ASGI application. You can now + look at the traceback which led to the error. +
+ +
+ + + """ + return html + + def register_prune_error_handler(self, app, secret, prune_errors): + if prune_errors: + self.error_handling_mode = "prune" + else: + self.error_handling_mode = "raise" + @app.errorhandler(Exception) - async def _wrap_errors(_error_request, error): - tb = get_traceback_func(secret, error) - return tb, 500 + async def _wrap_errors(error): + tb = self._get_traceback(secret, error) + return Response(tb, status=500, content_type='text/html') def register_timing_hooks(self, app, _first_run): # parity with Flask factory @app.before_request @@ -146,7 +257,7 @@ async def _after(response): await result return response - def run(self, app, host, port, debug, **kwargs): + def run(self, _dash_app, app, host, port, debug, **kwargs): self.config = {"debug": debug, **kwargs} if debug else kwargs app.run(host=host, port=port, debug=debug, **kwargs) diff --git a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js index 176cb2c6f8..ab5430e7da 100644 --- a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js +++ b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js @@ -121,13 +121,17 @@ function BackendError({error, base}) { const MAX_MESSAGE_LENGTH = 40; /* eslint-disable no-inline-comments */ function UnconnectedErrorContent({error, base}) { + // Helper to detect full HTML document + const isFullHtmlDoc = typeof error.html === 'string' && + error.html.trim().toLowerCase().startsWith(' - {/* - * 40 is a rough heuristic - if longer than 40 then the - * message might overflow into ellipses in the title above & - * will need to be displayed in full in this error body - */} + {/* Frontend error message */} {typeof error.message !== 'string' || error.message.length < MAX_MESSAGE_LENGTH ? null : (
@@ -137,6 +141,7 @@ function UnconnectedErrorContent({error, base}) {
)} + {/* Frontend stack trace */} {typeof error.stack !== 'string' ? null : (
@@ -149,7 +154,6 @@ function UnconnectedErrorContent({error, base}) { browser's console.) - {error.stack.split('\n').map((line, i) => (

{line}

))} @@ -157,24 +161,30 @@ function UnconnectedErrorContent({error, base}) {
)} - {/* Backend Error */} - {typeof error.html !== 'string' ? null : error.html - .substring(0, '
- {/* Embed werkzeug debugger in an iframe to prevent - CSS leaking - werkzeug HTML includes a bunch - of CSS on base html elements like `` - */}
- ) : ( + ) : isHtmlFragment ? ( + // Backend error: HTML fragment +
+
+
+ ) : typeof error.html === 'string' ? ( + // Backend error: plain text
-
{error.html}
+
+
{error.html}
+
- )} + ) : null}
); } diff --git a/dash/dash.py b/dash/dash.py index 18ad1c2367..fa1aa45ea5 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -185,63 +185,6 @@ def _is_quart_instance(obj): return False -def _get_traceback(secret, error: Exception): - try: - # pylint: disable=import-outside-toplevel - from werkzeug.debug import tbtools - except ImportError: - tbtools = None - - def _get_skip(error): - from dash._callback import ( # pylint: disable=import-outside-toplevel - _invoke_callback, - _async_invoke_callback, - ) - - tb = error.__traceback__ - skip = 1 - while tb.tb_next is not None: - skip += 1 - tb = tb.tb_next - if tb.tb_frame.f_code in [ - _invoke_callback.__code__, - _async_invoke_callback.__code__, - ]: - return skip - - return skip - - def _do_skip(error): - from dash._callback import ( # pylint: disable=import-outside-toplevel - _invoke_callback, - _async_invoke_callback, - ) - - tb = error.__traceback__ - while tb.tb_next is not None: - if tb.tb_frame.f_code in [ - _invoke_callback.__code__, - _async_invoke_callback.__code__, - ]: - return tb.tb_next - tb = tb.tb_next - return error.__traceback__ - - # werkzeug<2.1.0 - if hasattr(tbtools, "get_current_traceback"): - return tbtools.get_current_traceback( # type: ignore - skip=_get_skip(error) - ).render_full() - - if hasattr(tbtools, "DebugTraceback"): - # pylint: disable=no-member - return tbtools.DebugTraceback( # type: ignore - error, skip=_get_skip(error) - ).render_debugger_html(True, secret, True) - - return "".join(traceback.format_exception(type(error), error, _do_skip(error))) - - # Singleton signal to not update an output, alternative to PreventUpdate no_update = _callback.NoUpdate() # pylint: disable=protected-access @@ -2058,11 +2001,10 @@ def enable_dev_tools( jupyter_dash.configure_callback_exception_handling( self, dev_tools.prune_errors ) - elif dev_tools.prune_errors: - secret = gen_salt(20) - self.backend.register_prune_error_handler( - self.server, secret, _get_traceback - ) + secret = gen_salt(20) + self.backend.register_prune_error_handler( + self.server, secret, dev_tools.prune_errors + ) if debug and dev_tools.ui: self.backend.register_timing_hooks(self.server, first_run) @@ -2350,7 +2292,7 @@ def verify_url_part(served_part, url_part, part_name): ) else: self.backend.run( - self.server, host=host, port=port, debug=debug, **flask_run_options + self, self.server, host=host, port=port, debug=debug, **flask_run_options ) def enable_pages(self) -> None: diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py index dc88afe844..df036aabfa 100644 --- a/dash/testing/application_runners.py +++ b/dash/testing/application_runners.py @@ -171,7 +171,15 @@ def run(): self.port = options["port"] try: - app.run(threaded=True, **options) + module = app.server.__class__.__module__ + # FastAPI support + if not module.startswith("flask"): + app.run( + **options + ) + # Dash/Flask fallback + else: + app.run(threaded=True, **options) except SystemExit: logger.info("Server stopped") except Exception as error: @@ -229,7 +237,15 @@ def target(): options = kwargs.copy() try: - app.run(threaded=True, **options) + module = app.server.__class__.__module__ + # FastAPI support + if not module.startswith("flask"): + app.run( + **options + ) + # Dash/Flask fallback + else: + app.run(threaded=True, **options) except SystemExit: logger.info("Server stopped") raise diff --git a/package.json b/package.json index e78e279c1b..b7416dbb34 100644 --- a/package.json +++ b/package.json @@ -44,7 +44,7 @@ "setup-tests.R": "run-s private::test.R.deploy-*", "citest.integration": "run-s setup-tests.py private::test.integration-*", "citest.unit": "run-s private::test.unit-**", - "test": "pytest && cd dash/dash-renderer && npm run test", + "test": "pytest --ignore=tests/backend_tests && cd dash/dash-renderer && npm run test", "first-build": "cd dash/dash-renderer && npm i && cd ../../ && cd components/dash-html-components && npm i && npm run extract && cd ../../ && npm run build" }, "devDependencies": { diff --git a/requirements/fastapi.txt b/requirements/fastapi.txt new file mode 100644 index 0000000000..97dc7cd8c1 --- /dev/null +++ b/requirements/fastapi.txt @@ -0,0 +1,2 @@ +fastapi +uvicorn diff --git a/requirements/quart.txt b/requirements/quart.txt new file mode 100644 index 0000000000..60af440c9c --- /dev/null +++ b/requirements/quart.txt @@ -0,0 +1 @@ +quart diff --git a/tests/backend_tests/__init__.py b/tests/backend_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py new file mode 100644 index 0000000000..4868406814 --- /dev/null +++ b/tests/backend_tests/test_preconfig_backends.py @@ -0,0 +1,211 @@ +import pytest +from dash import Dash, Input, Output, html, dcc + +@pytest.mark.parametrize( + "backend,fixture,input_value", + [ + ( + "fastapi", + "dash_duo", + "Hello FastAPI!" + ), + ( + "quart", + "dash_duo_mp", + "Hello Quart!" + ), + ] +) +def test_backend_basic_callback(request, backend, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + if backend == "fastapi": + from fastapi import FastAPI + server = FastAPI() + else: + import quart + server = quart.Quart(__name__) + app = Dash(__name__, server=server) + app.layout = html.Div([ + dcc.Input(id="input", value=input_value, type="text"), + html.Div(id="output") + ]) + + @app.callback(Output("output", "children"), Input("input", "value")) + def update_output(value): + return f"You typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"You typed: {input_value}") + dash_duo.find_element("#input").clear() + dash_duo.find_element("#input").send_keys(f"{backend.title()} Test") + dash_duo.wait_for_text_to_equal("#output", f"You typed: {backend.title()} Test") + assert dash_duo.get_logs() == [] + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs", + [ + ( + "fastapi", + "dash_duo", + {"debug": True, "reload": False, "dev_tools_ui": True}, + ), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + }, + ), + ] +) +def test_backend_error_handling(request, backend, fixture, start_server_kwargs): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div([ + html.Button(id="btn", children="Error", n_clicks=0), + html.Div(id="output") + ]) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + +def get_error_html(dash_duo, index): + # error is in an iframe so is annoying to read out - get it from the store + return dash_duo.driver.execute_script( + "return store.getState().error.backEnd[{}].error.html;".format(index) + ) + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs, error_msg", + [ + ( + "fastapi", + "dash_duo", + {"debug": True, "dev_tools_ui": True, "dev_tools_prune_errors": False, + "reload": False}, + "fastapi.py" + ), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + "dev_tools_prune_errors": False, + }, + "quart.py" + ), + ] +) +def test_backend_error_handling_no_prune(request, backend, fixture, start_server_kwargs, error_msg): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div([ + html.Button(id="btn", children="Error", n_clicks=0), + html.Div(id="output") + ]) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "backend" in error0 and error_msg in error0 + +@pytest.mark.parametrize( + "backend,fixture,start_server_kwargs, error_msg", + [ + ( + "fastapi", + "dash_duo", + {"debug": True, + "reload": False}, + "fastapi.py" + ), + ( + "quart", + "dash_duo_mp", + { + "debug": True, + "use_reloader": False, + "dev_tools_hot_reload": False, + }, + "quart.py" + ), + ] +) +def test_backend_error_handling_prune(request, backend, fixture, start_server_kwargs, error_msg): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div([ + html.Button(id="btn", children="Error", n_clicks=0), + html.Div(id="output") + ]) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "dash/backend" not in error0 and error_msg not in error0 + +@pytest.mark.parametrize( + "backend,fixture,input_value", + [ + ("fastapi", "dash_duo", "Background FastAPI!"), + ("quart", "dash_duo_mp", "Background Quart!"), + ] +) +def test_backend_background_callback(request, backend, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + import diskcache + cache = diskcache.Cache("./cache") + from dash.background_callback import DiskcacheManager + background_callback_manager = DiskcacheManager(cache) + + + app = Dash(__name__, backend=backend, background_callback_manager=background_callback_manager) + app.layout = html.Div([ + dcc.Input(id="input", value=input_value, type="text"), + html.Div(id="output") + ]) + + @app.callback(Output("output", "children"), Input("input", "value"), background=True) + def update_output_bg(value): + return f"Background typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}") + dash_duo.find_element("#input").clear() + dash_duo.find_element("#input").send_keys(f"{backend.title()} BG Test") + dash_duo.wait_for_text_to_equal("#output", f"Background typed: {backend.title()} BG Test") + assert dash_duo.get_logs() == [] From 5211f6fb43f335b5c99e37859f5f2f5ec2dbe729 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:20:56 -0400 Subject: [PATCH 057/166] adjusments for formatting --- dash/backend/fastapi.py | 46 ++++--- dash/backend/flask.py | 5 +- dash/backend/quart.py | 19 ++- .../error/FrontEnd/FrontEndError.react.js | 7 +- dash/dash.py | 7 +- dash/testing/application_runners.py | 8 +- .../backend_tests/test_preconfig_backends.py | 112 +++++++++--------- 7 files changed, 118 insertions(+), 86 deletions(-) diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py index 0afcfabd07..a76a5a47ec 100644 --- a/dash/backend/fastapi.py +++ b/dash/backend/fastapi.py @@ -33,7 +33,12 @@ from dash.fingerprint import check_fingerprint from dash import _validate -from dash.exceptions import PreventUpdate, InvalidResourceError, InvalidCallbackReturnValue, BackgroundCallbackError +from dash.exceptions import ( + PreventUpdate, + InvalidResourceError, + InvalidCallbackReturnValue, + BackgroundCallbackError, +) from dash.backend import set_request_adapter from .base_server import BaseDashServer @@ -42,16 +47,19 @@ CONFIG_PATH = "dash_config.json" + def save_config(config): with open(CONFIG_PATH, "w") as f: json.dump(config, f) + def load_config(): if os.path.exists(CONFIG_PATH): with open(CONFIG_PATH, "r") as f: return json.load(f) return {} + class FastAPIDashServer(BaseDashServer): def __init__(self): self.error_handling_mode = "prune" @@ -96,7 +104,7 @@ def _get_traceback(self, secret, error: Exception): for err in errors: if self.error_handling_mode == "prune": if not callback_handled: - if 'callback invoked' in str(err) and '_callback.py' in str(err): + if "callback invoked" in str(err) and "_callback.py" in str(err): callback_handled = True continue pass_errs.append(err) @@ -106,9 +114,10 @@ def _get_traceback(self, secret, error: Exception): # Parse traceback lines to group by file import re + file_cards = [] pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') - lines = formatted_tb.split('\n') + lines = formatted_tb.split("\n") current_file = None card_lines = [] @@ -117,7 +126,9 @@ def _get_traceback(self, secret, error: Exception): if match: if current_file and card_lines: file_cards.append((current_file, card_lines)) - current_file = f'{match.group(1)} (line {match.group(2)}, in {match.group(3)})' + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) card_lines = [line] elif current_file: card_lines.append(line) @@ -126,12 +137,16 @@ def _get_traceback(self, secret, error: Exception): cards_html = "" for filename, card in file_cards: - cards_html += f""" + cards_html += ( + f"""
{filename}
-
"""+ '\n'.join(card) + """
+
"""
+                + "\n".join(card)
+                + """
""" + ) html = f""" @@ -198,7 +213,6 @@ def register_prune_error_handler(self, _app, _secret, prune_errors): else: self.error_handling_mode = "raise" - def _html_response_wrapper(self, view_func): async def wrapped(*_args, **_kwargs): # If view_func is a function, call it; if it's a string, use it directly @@ -221,9 +235,7 @@ def setup_catchall(self, dash_app): @dash_app.server.on_event("startup") def _setup_catchall(): config = load_config() - dash_app.enable_dev_tools( - **config, first_run=False - ) + dash_app.enable_dev_tools(**config, first_run=False) async def catchall(request: Request): adapter = FastAPIRequestAdapter() @@ -260,13 +272,15 @@ def after_request(self, app, func): def run(self, dash_app, app, host, port, debug, **kwargs): frame = inspect.stack()[2] - config = dict({"debug": debug} if debug else {}, **{ - f'dev_tools_{k}': v for k, v in dash_app._dev_tools.items()}) + config = dict( + {"debug": debug} if debug else {}, + **{f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items()}, + ) save_config(config) if debug: - if kwargs.get('reload') is None: - kwargs['reload'] = True - if kwargs.get('reload'): + if kwargs.get("reload") is None: + kwargs["reload"] = True + if kwargs.get("reload"): # Dynamically determine the module name from the file path file_path = frame.filename module_name = importlib.util.spec_from_file_location("app", file_path).name @@ -305,7 +319,7 @@ async def middleware(request, call_next): if self.error_handling_mode in ["raise", "prune"]: # Prune the traceback to remove internal Dash calls tb = self._get_traceback(None, e) - return Response(content=tb, media_type='text/html', status_code=500) + return Response(content=tb, media_type="text/html", status_code=500) return JSONResponse( status_code=500, content={"error": "InternalServerError", "message": str(e.args[0])}, diff --git a/dash/backend/flask.py b/dash/backend/flask.py index 75526e6feb..542da93129 100644 --- a/dash/backend/flask.py +++ b/dash/backend/flask.py @@ -82,11 +82,14 @@ def _do_skip(error): if hasattr(tbtools, "get_current_traceback"): return tbtools.get_current_traceback(skip=_get_skip(error)).render_full() if hasattr(tbtools, "DebugTraceback"): - return tbtools.DebugTraceback(error, skip=_get_skip(error)).render_debugger_html(True, secret, True) + return tbtools.DebugTraceback( + error, skip=_get_skip(error) + ).render_debugger_html(True, secret, True) return "".join(traceback.format_exception(type(error), error, _do_skip(error))) def register_prune_error_handler(self, app, secret, prune_errors): if prune_errors: + @app.errorhandler(Exception) def _wrap_errors(error): tb = self._get_traceback(secret, error) diff --git a/dash/backend/quart.py b/dash/backend/quart.py index 40f30108b2..71a2053a61 100644 --- a/dash/backend/quart.py +++ b/dash/backend/quart.py @@ -64,7 +64,7 @@ def _get_traceback(self, secret, error: Exception): for err in errors: if self.error_handling_mode == "prune": if not callback_handled: - if 'callback invoked' in str(err) and '_callback.py' in str(err): + if "callback invoked" in str(err) and "_callback.py" in str(err): callback_handled = True continue pass_errs.append(err) @@ -74,9 +74,10 @@ def _get_traceback(self, secret, error: Exception): # Parse traceback lines to group by file import re + file_cards = [] pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') - lines = formatted_tb.split('\n') + lines = formatted_tb.split("\n") current_file = None card_lines = [] @@ -85,7 +86,9 @@ def _get_traceback(self, secret, error: Exception): if match: if current_file and card_lines: file_cards.append((current_file, card_lines)) - current_file = f'{match.group(1)} (line {match.group(2)}, in {match.group(3)})' + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) card_lines = [line] elif current_file: card_lines.append(line) @@ -94,12 +97,16 @@ def _get_traceback(self, secret, error: Exception): cards_html = "" for filename, card in file_cards: - cards_html += f""" + cards_html += ( + f"""
{filename}
-
""" + '\n'.join(card) + """
+
"""
+                + "\n".join(card)
+                + """
""" + ) html = f""" @@ -169,7 +176,7 @@ def register_prune_error_handler(self, app, secret, prune_errors): @app.errorhandler(Exception) async def _wrap_errors(error): tb = self._get_traceback(secret, error) - return Response(tb, status=500, content_type='text/html') + return Response(tb, status=500, content_type="text/html") def register_timing_hooks(self, app, _first_run): # parity with Flask factory @app.before_request diff --git a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js index ab5430e7da..db4c6ddd2b 100644 --- a/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js +++ b/dash/dash-renderer/src/components/error/FrontEnd/FrontEndError.react.js @@ -122,12 +122,13 @@ const MAX_MESSAGE_LENGTH = 40; /* eslint-disable no-inline-comments */ function UnconnectedErrorContent({error, base}) { // Helper to detect full HTML document - const isFullHtmlDoc = typeof error.html === 'string' && + const isFullHtmlDoc = + typeof error.html === 'string' && error.html.trim().toLowerCase().startsWith(' diff --git a/dash/dash.py b/dash/dash.py index fa1aa45ea5..994453f4a2 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -2292,7 +2292,12 @@ def verify_url_part(served_part, url_part, part_name): ) else: self.backend.run( - self, self.server, host=host, port=port, debug=debug, **flask_run_options + self, + self.server, + host=host, + port=port, + debug=debug, + **flask_run_options, ) def enable_pages(self) -> None: diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py index df036aabfa..2956f1a4c0 100644 --- a/dash/testing/application_runners.py +++ b/dash/testing/application_runners.py @@ -174,9 +174,7 @@ def run(): module = app.server.__class__.__module__ # FastAPI support if not module.startswith("flask"): - app.run( - **options - ) + app.run(**options) # Dash/Flask fallback else: app.run(threaded=True, **options) @@ -240,9 +238,7 @@ def target(): module = app.server.__class__.__module__ # FastAPI support if not module.startswith("flask"): - app.run( - **options - ) + app.run(**options) # Dash/Flask fallback else: app.run(threaded=True, **options) diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py index 4868406814..5fbd28dfd9 100644 --- a/tests/backend_tests/test_preconfig_backends.py +++ b/tests/backend_tests/test_preconfig_backends.py @@ -1,34 +1,28 @@ import pytest from dash import Dash, Input, Output, html, dcc + @pytest.mark.parametrize( "backend,fixture,input_value", [ - ( - "fastapi", - "dash_duo", - "Hello FastAPI!" - ), - ( - "quart", - "dash_duo_mp", - "Hello Quart!" - ), - ] + ("fastapi", "dash_duo", "Hello FastAPI!"), + ("quart", "dash_duo_mp", "Hello Quart!"), + ], ) def test_backend_basic_callback(request, backend, fixture, input_value): dash_duo = request.getfixturevalue(fixture) if backend == "fastapi": from fastapi import FastAPI + server = FastAPI() else: import quart + server = quart.Quart(__name__) app = Dash(__name__, server=server) - app.layout = html.Div([ - dcc.Input(id="input", value=input_value, type="text"), - html.Div(id="output") - ]) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) @app.callback(Output("output", "children"), Input("input", "value")) def update_output(value): @@ -41,6 +35,7 @@ def update_output(value): dash_duo.wait_for_text_to_equal("#output", f"You typed: {backend.title()} Test") assert dash_duo.get_logs() == [] + @pytest.mark.parametrize( "backend,fixture,start_server_kwargs", [ @@ -58,15 +53,14 @@ def update_output(value): "dev_tools_hot_reload": False, }, ), - ] + ], ) def test_backend_error_handling(request, backend, fixture, start_server_kwargs): dash_duo = request.getfixturevalue(fixture) app = Dash(__name__, backend=backend) - app.layout = html.Div([ - html.Button(id="btn", children="Error", n_clicks=0), - html.Div(id="output") - ]) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) @app.callback(Output("output", "children"), Input("btn", "n_clicks")) def error_callback(n): @@ -79,21 +73,27 @@ def error_callback(n): dash_duo.find_element("#btn").click() dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + def get_error_html(dash_duo, index): # error is in an iframe so is annoying to read out - get it from the store return dash_duo.driver.execute_script( "return store.getState().error.backEnd[{}].error.html;".format(index) ) + @pytest.mark.parametrize( "backend,fixture,start_server_kwargs, error_msg", [ ( "fastapi", "dash_duo", - {"debug": True, "dev_tools_ui": True, "dev_tools_prune_errors": False, - "reload": False}, - "fastapi.py" + { + "debug": True, + "dev_tools_ui": True, + "dev_tools_prune_errors": False, + "reload": False, + }, + "fastapi.py", ), ( "quart", @@ -104,17 +104,18 @@ def get_error_html(dash_duo, index): "dev_tools_hot_reload": False, "dev_tools_prune_errors": False, }, - "quart.py" + "quart.py", ), - ] + ], ) -def test_backend_error_handling_no_prune(request, backend, fixture, start_server_kwargs, error_msg): +def test_backend_error_handling_no_prune( + request, backend, fixture, start_server_kwargs, error_msg +): dash_duo = request.getfixturevalue(fixture) app = Dash(__name__, backend=backend) - app.layout = html.Div([ - html.Button(id="btn", children="Error", n_clicks=0), - html.Div(id="output") - ]) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) @app.callback(Output("output", "children"), Input("btn", "n_clicks")) def error_callback(n): @@ -132,16 +133,11 @@ def error_callback(n): assert "ZeroDivisionError" in error0 assert "backend" in error0 and error_msg in error0 + @pytest.mark.parametrize( "backend,fixture,start_server_kwargs, error_msg", [ - ( - "fastapi", - "dash_duo", - {"debug": True, - "reload": False}, - "fastapi.py" - ), + ("fastapi", "dash_duo", {"debug": True, "reload": False}, "fastapi.py"), ( "quart", "dash_duo_mp", @@ -150,17 +146,18 @@ def error_callback(n): "use_reloader": False, "dev_tools_hot_reload": False, }, - "quart.py" + "quart.py", ), - ] + ], ) -def test_backend_error_handling_prune(request, backend, fixture, start_server_kwargs, error_msg): +def test_backend_error_handling_prune( + request, backend, fixture, start_server_kwargs, error_msg +): dash_duo = request.getfixturevalue(fixture) app = Dash(__name__, backend=backend) - app.layout = html.Div([ - html.Button(id="btn", children="Error", n_clicks=0), - html.Div(id="output") - ]) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) @app.callback(Output("output", "children"), Input("btn", "n_clicks")) def error_callback(n): @@ -178,28 +175,35 @@ def error_callback(n): assert "ZeroDivisionError" in error0 assert "dash/backend" not in error0 and error_msg not in error0 + @pytest.mark.parametrize( "backend,fixture,input_value", [ ("fastapi", "dash_duo", "Background FastAPI!"), ("quart", "dash_duo_mp", "Background Quart!"), - ] + ], ) def test_backend_background_callback(request, backend, fixture, input_value): dash_duo = request.getfixturevalue(fixture) import diskcache + cache = diskcache.Cache("./cache") from dash.background_callback import DiskcacheManager - background_callback_manager = DiskcacheManager(cache) + background_callback_manager = DiskcacheManager(cache) - app = Dash(__name__, backend=backend, background_callback_manager=background_callback_manager) - app.layout = html.Div([ - dcc.Input(id="input", value=input_value, type="text"), - html.Div(id="output") - ]) + app = Dash( + __name__, + backend=backend, + background_callback_manager=background_callback_manager, + ) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) - @app.callback(Output("output", "children"), Input("input", "value"), background=True) + @app.callback( + Output("output", "children"), Input("input", "value"), background=True + ) def update_output_bg(value): return f"Background typed: {value}" @@ -207,5 +211,7 @@ def update_output_bg(value): dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}") dash_duo.find_element("#input").clear() dash_duo.find_element("#input").send_keys(f"{backend.title()} BG Test") - dash_duo.wait_for_text_to_equal("#output", f"Background typed: {backend.title()} BG Test") + dash_duo.wait_for_text_to_equal( + "#output", f"Background typed: {backend.title()} BG Test" + ) assert dash_duo.get_logs() == [] From 6a34208f92d20cf3a7283407c1ac68528d5c9d8a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:36:28 -0400 Subject: [PATCH 058/166] adjustment to retest backend --- .github/workflows/testing.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 068fe777d1..48bfe0c305 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -17,6 +17,7 @@ jobs: # This output will be 'true' if files in the 'table_related_paths' list changed, 'false' otherwise. table_paths_changed: ${{ steps.filter.outputs.table_related_paths }} background_cb_changed: ${{ steps.filter.outputs.background_paths }} + backend_cb_changed: ${{ steps.filter.outputs.backend_paths }} steps: - name: Checkout repository uses: actions/checkout@v4 @@ -37,6 +38,9 @@ jobs: - 'tests/background_callback/**' - 'tests/async_tests/**' - 'requirements/**' + backend_paths: + - 'dash/backend/**' + - 'tests/backend/**' build: name: Build Dash Package @@ -276,7 +280,7 @@ jobs: needs: [build, changes_filter] if: | (github.event_name == 'push' && (github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev')) || - needs.changes_filter.outputs.backend_tests_changed == 'true' + needs.changes_filter.outputs.backend_cb_changed == 'true' timeout-minutes: 30 runs-on: ubuntu-latest strategy: From 1a2b53124b11b16b014ed941d822377659d01a5a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:44:12 -0400 Subject: [PATCH 059/166] adding missing reqs association --- .github/workflows/testing.yml | 2 +- setup.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 48bfe0c305..be5caf4929 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -331,7 +331,7 @@ jobs: python -m pip install --upgrade pip wheel python -m pip install "setuptools<78.0.0" python -m pip install "selenium==4.32.0" - find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[async,ci,testing,dev,celery,diskcache, fastapi, quart]"' \; + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[async,ci,testing,dev,celery,diskcache,fastapi,quart]"' \; - name: Install Google Chrome run: | diff --git a/setup.py b/setup.py index 7ed781c20d..950bcbe14d 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,9 @@ def read_req_file(req_type): "testing": read_req_file("testing"), "celery": read_req_file("celery"), "diskcache": read_req_file("diskcache"), - "compress": read_req_file("compress") + "compress": read_req_file("compress"), + "fastapi": read_req_file("fastapi"), + "quart": read_req_file("quart"), }, entry_points={ "console_scripts": [ From 465e45e469324a25498f32fc5979ba190205f328 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 16 Sep 2025 16:59:36 -0400 Subject: [PATCH 060/166] fixing minor linting issues --- dash/backend/fastapi.py | 27 +++++++++++---------------- dash/backend/flask.py | 11 +++++------ dash/backend/quart.py | 7 +++---- dash/dash.py | 1 - 4 files changed, 19 insertions(+), 27 deletions(-) diff --git a/dash/backend/fastapi.py b/dash/backend/fastapi.py index a76a5a47ec..8c402cb187 100644 --- a/dash/backend/fastapi.py +++ b/dash/backend/fastapi.py @@ -7,11 +7,12 @@ import importlib.util import time import traceback +import re try: import uvicorn from fastapi import FastAPI, Request, Response - from fastapi.responses import JSONResponse, PlainTextResponse + from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from starlette.responses import Response as StarletteResponse from starlette.datastructures import MutableHeaders @@ -23,7 +24,6 @@ Request = None Response = None JSONResponse = None - PlainTextResponse = None StaticFiles = None StarletteResponse = None MutableHeaders = None @@ -31,20 +31,17 @@ Any = None Optional = None + +import json +import os from dash.fingerprint import check_fingerprint from dash import _validate from dash.exceptions import ( PreventUpdate, - InvalidResourceError, - InvalidCallbackReturnValue, - BackgroundCallbackError, ) from dash.backend import set_request_adapter from .base_server import BaseDashServer -import json -import os - CONFIG_PATH = "dash_config.json" @@ -93,10 +90,8 @@ def register_assets_blueprint( def register_error_handlers(self, app): self.error_handling_mode = "prune" - # FastAPI uses exception handlers, but we will handle errors in middleware - pass - def _get_traceback(self, secret, error: Exception): + def _get_traceback(self, _secret, error: Exception): tb = error.__traceback__ errors = traceback.format_exception(type(error), error, tb) pass_errs = [] @@ -113,15 +108,13 @@ def _get_traceback(self, secret, error: Exception): error_msg = str(error) # Parse traceback lines to group by file - import re - file_cards = [] pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') lines = formatted_tb.split("\n") current_file = None card_lines = [] - for i, line in enumerate(lines[:-1]): # Skip the last line (error message) + for line in lines[:-1]: # Skip the last line (error message) match = pattern.match(line) if match: if current_file and card_lines: @@ -274,7 +267,9 @@ def run(self, dash_app, app, host, port, debug, **kwargs): frame = inspect.stack()[2] config = dict( {"debug": debug} if debug else {}, - **{f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items()}, + **{ + f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items() + }, # pylint: disable=protected-access ) save_config(config) if debug: @@ -307,7 +302,7 @@ def jsonify(self, obj): def get_request_adapter(self): return FastAPIRequestAdapter - def _make_before_middleware(self, func): + def _make_before_middleware(self, _func): async def middleware(request, call_next): try: response = await call_next(request) diff --git a/dash/backend/flask.py b/dash/backend/flask.py index 542da93129..cf544ef5bc 100644 --- a/dash/backend/flask.py +++ b/dash/backend/flask.py @@ -5,13 +5,14 @@ import mimetypes import time import inspect +import traceback import flask from dash.fingerprint import check_fingerprint from dash import _validate +from dash._callback import _invoke_callback, _async_invoke_callback from dash.exceptions import PreventUpdate, InvalidResourceError from dash.backend import set_request_adapter from .base_server import BaseDashServer -import traceback class FlaskDashServer(BaseDashServer): @@ -47,13 +48,13 @@ def _invalid_resources_handler(err): def _get_traceback(self, secret, error: Exception): try: - from werkzeug.debug import tbtools + from werkzeug.debug import ( + tbtools, + ) # pylint: disable=import-outside-toplevel except ImportError: tbtools = None def _get_skip(error): - from dash._callback import _invoke_callback, _async_invoke_callback - tb = error.__traceback__ skip = 1 while tb.tb_next is not None: @@ -67,8 +68,6 @@ def _get_skip(error): return skip def _do_skip(error): - from dash._callback import _invoke_callback, _async_invoke_callback - tb = error.__traceback__ while tb.tb_next is not None: if tb.tb_frame.f_code in [ diff --git a/dash/backend/quart.py b/dash/backend/quart.py index 71a2053a61..830d7dd3b9 100644 --- a/dash/backend/quart.py +++ b/dash/backend/quart.py @@ -5,6 +5,7 @@ import time from contextvars import copy_context import traceback +import re try: import quart @@ -56,7 +57,7 @@ def register_assets_blueprint( ) app.register_blueprint(bp) - def _get_traceback(self, secret, error: Exception): + def _get_traceback(self, _secret, error: Exception): tb = error.__traceback__ errors = traceback.format_exception(type(error), error, tb) pass_errs = [] @@ -73,15 +74,13 @@ def _get_traceback(self, secret, error: Exception): error_msg = str(error) # Parse traceback lines to group by file - import re - file_cards = [] pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') lines = formatted_tb.split("\n") current_file = None card_lines = [] - for i, line in enumerate(lines[:-1]): # Skip the last line (error message) + for line in lines[:-1]: # Skip the last line (error message) match = pattern.match(line) if match: if current_file and card_lines: diff --git a/dash/dash.py b/dash/dash.py index 994453f4a2..6bba3aadfd 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -14,7 +14,6 @@ import mimetypes import hashlib import base64 -import traceback from urllib.parse import urlparse from typing import Any, Callable, Dict, Optional, Union, Sequence, Literal, List From c43a5835d78dff075e80e8df420f53aa9c37e18c Mon Sep 17 00:00:00 2001 From: chgiesse <83552131+chgiesse@users.noreply.github.com> Date: Wed, 17 Sep 2025 16:16:39 +0200 Subject: [PATCH 061/166] Add global Request Adapter (#6) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ∙ - remove contextvar from flask and quart only FastApi now relies on that ∙ - backend __init__ now holds the global request adapter and backend which get set on app initialisation ∙ request adapter and server can now be call from everywhere after the app initialised ∙ - added normal top level imports because the modules get matching loaded - but bad Import Error message when quart or equivilent are not installed ∙ - added _ as prefix to backends to avoid importing errors with their underlying ∙ - Can now move to remove unnecessary passing of the server object ∙ * Moved get_server_type to backends * ∙ moved async validation to validation ∙ replaced request.get_path with request.path ∙ * Moved custom backend check to _validation.py * Removed server injection of server methods - they use self.server now * removed use_async from dispatch server methods and use dash_app._use_async removed remaining set request process from flask * adding custom error handling per backend, tests and adjustments to the flow. Made endpoints for downloading the reqs * adjusments for formatting * adjustment to retest backend * Added Dash app as type to servers * adding missing reqs association * Addedd basic typing to servers * fixing minor linting issues * Fixed weird AI shit * Cleanup before heavy pull * Merged latest changes * f rebase * f rebase * Added Dash app as type to servers * Addedd basic typing to servers --------- Co-authored-by: Christian Giessel Co-authored-by: BSd3v <82055130+BSd3v@users.noreply.github.com> --- dash/_callback.py | 41 +-- dash/_pages.py | 32 ++- dash/_validate.py | 39 +++ dash/backend/__init__.py | 15 - dash/backend/base_server.py | 58 ---- dash/backend/registry.py | 29 -- dash/backends/__init__.py | 88 ++++++ .../fastapi.py => backends/_fastapi.py} | 264 +++++++++++------- dash/{backend/flask.py => backends/_flask.py} | 264 ++++++++++-------- dash/{backend/quart.py => backends/_quart.py} | 240 +++++++++------- dash/backends/base_server.py | 119 ++++++++ dash/dash.py | 210 +++++--------- dash_config.json | 1 + quart_app.py | 23 ++ 14 files changed, 831 insertions(+), 592 deletions(-) delete mode 100644 dash/backend/__init__.py delete mode 100644 dash/backend/base_server.py delete mode 100644 dash/backend/registry.py create mode 100644 dash/backends/__init__.py rename dash/{backend/fastapi.py => backends/_fastapi.py} (72%) rename dash/{backend/flask.py => backends/_flask.py} (55%) rename dash/{backend/quart.py => backends/_quart.py} (68%) create mode 100644 dash/backends/base_server.py create mode 100644 dash_config.json create mode 100644 quart_app.py diff --git a/dash/_callback.py b/dash/_callback.py index 6cc55b9162..4a714caeac 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,12 +1,8 @@ +from typing import Callable, Optional, Any, List, Tuple, Union +from functools import wraps import collections import hashlib -from functools import wraps - -from typing import Callable, Optional, Any, List, Tuple, Union - - import asyncio -from dash.backend import get_request_adapter from .dependencies import ( handle_callback_args, @@ -39,10 +35,11 @@ clean_property_name, ) -from . import _validate from .background_callback.managers import BaseBackgroundCallbackManager from ._callback_context import context_value from ._no_update import NoUpdate +from . import _validate +from . import backends async def _async_invoke_callback( @@ -176,7 +173,6 @@ def callback( Note that the endpoint will not appear in the list of registered callbacks in the Dash devtools. """ - background_spec = None config_prevent_initial_callbacks = _kwargs.pop( @@ -376,7 +372,8 @@ def _get_callback_manager( " and store results on redis.\n" ) - old_job = get_request_adapter().get_args().getlist("oldJob") + adapter = backends.request_adapter() + old_job = adapter.args.getlist("oldJob") if hasattr(adapter.args, "getlist") else [] if old_job: for job in old_job: @@ -390,6 +387,8 @@ def _setup_background_callback( ): """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) + if not callback_manager: + return to_json({"error": "No background callback manager configured"}) progress_outputs = background.get("progress") @@ -397,14 +396,11 @@ def _setup_background_callback( cache_key = callback_manager.build_cache_key( func, - # Inputs provided as dict is kwargs. func_args if func_args else func_kwargs, background.get("cache_args_to_ignore", []), None if cache_ignore_triggered else callback_ctx.get("triggered_inputs", []), ) - job_fn = callback_manager.func_registry.get(background_key) - ctx_value = AttributeDict(**context_value.get()) ctx_value.ignore_register_page = True ctx_value.pop("background_callback_manager") @@ -436,7 +432,8 @@ def _setup_background_callback( def _progress_background_callback(response, callback_manager, background): progress_outputs = background.get("progress") - cache_key = get_request_adapter().get_args().get("cacheKey") + adapter = backends.request_adapter() + cache_key = adapter.args.get("cacheKey") if progress_outputs: # Get the progress before the result as it would be erased after the results. @@ -453,8 +450,9 @@ def _update_background_callback( """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) - cache_key = get_request_adapter().get_args().get("cacheKey") - job_id = get_request_adapter().get_args().get("job") + adapter = backends.request_adapter() + cache_key = adapter.args.get("cacheKey") if adapter else None + job_id = adapter.args.get("job") if adapter else None _progress_background_callback(response, callback_manager, background) @@ -474,8 +472,9 @@ def _handle_rest_background_callback( multi, has_update=False, ): - cache_key = get_request_adapter().get_args().get("cacheKey") - job_id = get_request_adapter().get_args().get("job") + adapter = backends.request_adapter() + cache_key = adapter.args.get("cacheKey") if adapter else None + job_id = adapter.args.get("job") if adapter else None # Must get job_running after get_result since get_results terminates it. job_running = callback_manager.job_running(job_id) if not job_running and output_value is callback_manager.UNDEFINED: @@ -688,10 +687,11 @@ def add_context(*args, **kwargs): ) response: dict = {"multi": True} - jsonResponse = None + jsonResponse: Optional[str] = None try: if background is not None: - if not get_request_adapter().get_args().get("cacheKey"): + adapter = backends.request_adapter() + if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, background, @@ -762,7 +762,8 @@ async def async_add_context(*args, **kwargs): try: if background is not None: - if not get_request_adapter().get_args().get("cacheKey"): + adapter = backends.request_adapter() + if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, background, diff --git a/dash/_pages.py b/dash/_pages.py index acb26e8791..19a797bcf2 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -318,18 +318,22 @@ def register_page( ) page.update( supplied_title=title, - title=title - if title is not None - else CONFIG.title - if CONFIG.title != "Dash" - else page["name"], + title=( + title + if title is not None + else CONFIG.title + if CONFIG.title != "Dash" + else page["name"] + ), ) page.update( - description=description - if description - else CONFIG.description - if CONFIG.description - else "", + description=( + description + if description + else CONFIG.description + if CONFIG.description + else "" + ), order=order, supplied_order=order, supplied_layout=layout, @@ -390,15 +394,13 @@ def _path_to_page(path_id): def _page_meta_tags(app, request): - request_path = request.get_path() + request_path = request.path start_page, path_variables = _path_to_page(request_path.strip("/")) image = start_page.get("image", "") if image: image = app.get_asset_url(image) - assets_image_url = ( - "".join([request.get_root(), image.lstrip("/")]) if image else None - ) + assets_image_url = "".join([request.root, image.lstrip("/")]) if image else None supplied_image_url = start_page.get("image_url") image_url = supplied_image_url if supplied_image_url else assets_image_url @@ -413,7 +415,7 @@ def _page_meta_tags(app, request): return [ {"name": "description", "content": description}, {"property": "twitter:card", "content": "summary_large_image"}, - {"property": "twitter:url", "content": request.get_url()}, + {"property": "twitter:url", "content": request.url}, {"property": "twitter:title", "content": title}, {"property": "twitter:description", "content": description}, {"property": "twitter:image", "content": image_url or ""}, diff --git a/dash/_validate.py b/dash/_validate.py index dea19d64c2..76661cef6b 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -8,6 +8,7 @@ from ._grouping import grouping_len, map_grouping from ._no_update import NoUpdate from .development.base_component import Component +from . import backends from . import exceptions from ._utils import ( patch_collections_abc, @@ -585,3 +586,41 @@ def _valid(out): return _valid(output) + + +def check_async(use_async): + if use_async is None: + try: + import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa + + use_async = True + except ImportError: + pass + elif use_async: + try: + import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa + except ImportError as exc: + raise Exception( + "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" + ) from exc + + +def check_backend(backend, inferred_backend): + if backend is not None: + if isinstance(backend, type): + # get_backend returns the backend class for a string + # So we compare the class names + expected_backend_cls, _ = backends.get_backend(inferred_backend) + if ( + backend.__module__ != expected_backend_cls.__module__ + or backend.__name__ != expected_backend_cls.__name__ + ): + raise ValueError( + f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." + ) + elif not isinstance(backend, str): + raise ValueError("Invalid backend argument") + elif backend.lower() != inferred_backend: + raise ValueError( + f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." + ) diff --git a/dash/backend/__init__.py b/dash/backend/__init__.py deleted file mode 100644 index eb1d47bc3f..0000000000 --- a/dash/backend/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# python -import contextvars -from .registry import get_backend # pylint: disable=unused-import - -__all__ = ["set_request_adapter", "get_request_adapter", "get_backend"] - -_request_adapter_var = contextvars.ContextVar("request_adapter") - - -def set_request_adapter(adapter): - _request_adapter_var.set(adapter) - - -def get_request_adapter(): - return _request_adapter_var.get() diff --git a/dash/backend/base_server.py b/dash/backend/base_server.py deleted file mode 100644 index 4855f86ad6..0000000000 --- a/dash/backend/base_server.py +++ /dev/null @@ -1,58 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any - - -class BaseDashServer(ABC): - def __call__(self, server, *args, **kwargs) -> Any: - # Default: WSGI - return server(*args, **kwargs) - - @abstractmethod - def create_app( - self, name: str = "__main__", config=None - ) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def register_assets_blueprint( - self, app, blueprint_name: str, assets_url_path: str, assets_folder: str - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def register_error_handlers(self, app) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def add_url_rule( - self, app, rule: str, view_func, endpoint=None, methods=None - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def before_request(self, app, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def after_request(self, app, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def run( - self, app, host: str, port: int, debug: bool, **kwargs - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def make_response( - self, data, mimetype=None, content_type=None - ) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def jsonify(self, obj) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def get_request_adapter(self) -> Any: # pragma: no cover - interface - pass diff --git a/dash/backend/registry.py b/dash/backend/registry.py deleted file mode 100644 index 4aae9fafc5..0000000000 --- a/dash/backend/registry.py +++ /dev/null @@ -1,29 +0,0 @@ -import importlib - -_backend_imports = { - "flask": ("dash.backend.flask", "FlaskDashServer"), - "fastapi": ("dash.backend.fastapi", "FastAPIDashServer"), - "quart": ("dash.backend.quart", "QuartDashServer"), -} - - -def register_backend(name, module_path, class_name): - """Register a new backend by name.""" - _backend_imports[name.lower()] = (module_path, class_name) - - -def get_backend(name): - try: - module_name, class_name = _backend_imports[name.lower()] - module = importlib.import_module(module_name) - return getattr(module, class_name) - except KeyError as e: - raise ValueError(f"Unknown backend: {name}") from e - except ImportError as e: - raise ImportError( - f"Could not import module '{module_name}' for backend '{name}': {e}" - ) from e - except AttributeError as e: - raise AttributeError( - f"Module '{module_name}' does not have class '{class_name}' for backend '{name}': {e}" - ) from e diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py new file mode 100644 index 0000000000..940c8f18bd --- /dev/null +++ b/dash/backends/__init__.py @@ -0,0 +1,88 @@ +from .base_server import BaseDashServer, RequestAdapter + +from typing import Literal, Any +import importlib + + +request_adapter: RequestAdapter +backend: BaseDashServer + + +_backend_imports = { + "flask": ("dash.backends._flask", "FlaskDashServer", "FlaskRequestAdapter"), + "fastapi": ("dash.backends._fastapi", "FastAPIDashServer", "FastAPIRequestAdapter"), + "quart": ("dash.backends._quart", "QuartDashServer", "QuartRequestAdapter"), +} + + +request_adapter: RequestAdapter +backend: BaseDashServer + + +def get_backend( + name: Literal["flask", "fastapi", "quart"] | str +) -> tuple[BaseDashServer, RequestAdapter]: + module_name, server_class, request_class = _backend_imports[name.lower()] + try: + module = importlib.import_module(module_name) + server = getattr(module, server_class) + request_adapter = getattr(module, request_class) + return server, request_adapter + except KeyError as e: + raise ValueError(f"Unknown backend: {name}") from e + except ImportError as e: + raise ImportError( + f"Could not import module '{module_name}' for backend '{name}': {e}" + ) from e + except AttributeError as e: + raise AttributeError( + f"Module '{module_name}' does not have class '{server_class}' for backend '{name}': {e}" + ) from e + + +def _is_flask_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from flask import Flask + + return isinstance(obj, Flask) + except ImportError: + return False + + +def _is_fastapi_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from fastapi import FastAPI + + return isinstance(obj, FastAPI) + except ImportError: + return False + + +def _is_quart_instance(obj): + try: + # pylint: disable=import-outside-toplevel + from quart import Quart + + return isinstance(obj, Quart) + except ImportError: + return False + + +def get_server_type(server): + if _is_flask_instance(server): + return "flask" + if _is_quart_instance(server): + return "quart" + if _is_fastapi_instance(server): + return "fastapi" + raise ValueError("Invalid backend argument") + + +__all__ = [ + "get_backend", + "request_adapter", + "backend", + "get_server_type", +] diff --git a/dash/backend/fastapi.py b/dash/backends/_fastapi.py similarity index 72% rename from dash/backend/fastapi.py rename to dash/backends/_fastapi.py index 8c402cb187..f3f9f2df33 100644 --- a/dash/backend/fastapi.py +++ b/dash/backends/_fastapi.py @@ -1,46 +1,71 @@ +from __future__ import annotations + +from contextvars import copy_context, ContextVar +from typing import TYPE_CHECKING, Any, Callable, Dict import sys import mimetypes import hashlib import inspect import pkgutil -from contextvars import copy_context -import importlib.util import time import traceback -import re - -try: - import uvicorn - from fastapi import FastAPI, Request, Response - from fastapi.responses import JSONResponse - from fastapi.staticfiles import StaticFiles - from starlette.responses import Response as StarletteResponse - from starlette.datastructures import MutableHeaders - from pydantic import create_model - from typing import Any, Optional -except ImportError: - uvicorn = None - FastAPI = None - Request = None - Response = None - JSONResponse = None - StaticFiles = None - StarletteResponse = None - MutableHeaders = None - create_model = None - Any = None - Optional = None - - +from importlib.util import spec_from_file_location import json import os +import re + from dash.fingerprint import check_fingerprint from dash import _validate -from dash.exceptions import ( - PreventUpdate, -) -from dash.backend import set_request_adapter -from .base_server import BaseDashServer +from dash.exceptions import PreventUpdate +from .base_server import BaseDashServer, RequestAdapter + +from fastapi import FastAPI, Request, Response, Body +from fastapi.responses import JSONResponse +from fastapi.staticfiles import StaticFiles +from starlette.responses import Response as StarletteResponse +from starlette.datastructures import MutableHeaders +from starlette.types import ASGIApp, Scope, Receive, Send +import uvicorn + +if TYPE_CHECKING: # pragma: no cover - typing only + from dash.dash import Dash + + +_current_request_var = ContextVar("dash_current_request", default=None) + + +def set_current_request(req): + return _current_request_var.set(req) + + +def reset_current_request(token): + _current_request_var.reset(token) + + +def get_current_request() -> Request: + req = _current_request_var.get() + if req is None: + raise RuntimeError("No active request in context") + return req + + +class CurrentRequestMiddleware: + def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[name-defined] + # non-http/ws scopes pass through (lifespan etc.) + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + request = Request(scope, receive=receive) + token = set_current_request(request) + try: + await self.app(scope, receive, send) + finally: + reset_current_request(token) + CONFIG_PATH = "dash_config.json" @@ -58,28 +83,35 @@ def load_config(): class FastAPIDashServer(BaseDashServer): - def __init__(self): + + def __init__(self, server: FastAPI): + self.config = {} + self.server_type = "fastapi" + self.server: FastAPI = server self.error_handling_mode = "prune" super().__init__() - def __call__(self, server, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any): # ASGI: (scope, receive, send) if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: - return server(*args, **kwargs) + return self.server(*args, **kwargs) raise TypeError("FastAPI app must be called with (scope, receive, send)") - def create_app(self, name="__main__", config=None): + @staticmethod + def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): app = FastAPI() + app.add_middleware(CurrentRequestMiddleware) + if config: for key, value in config.items(): setattr(app.state, key, value) return app def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder + self, blueprint_name: str, assets_url_path: str, assets_folder: str ): try: - app.mount( + self.server.mount( assets_url_path, StaticFiles(directory=assets_folder), name=blueprint_name, @@ -88,7 +120,7 @@ def register_assets_blueprint( # directory doesnt exist pass - def register_error_handlers(self, app): + def register_error_handlers(self): self.error_handling_mode = "prune" def _get_traceback(self, _secret, error: Exception): @@ -200,13 +232,13 @@ def _get_traceback(self, _secret, error: Exception): """ return html - def register_prune_error_handler(self, _app, _secret, prune_errors): + def register_prune_error_handler(self, _secret, prune_errors): if prune_errors: self.error_handling_mode = "prune" else: self.error_handling_mode = "raise" - def _html_response_wrapper(self, view_func): + def _html_response_wrapper(self, view_func: Callable[..., Any] | str): async def wrapped(*_args, **_kwargs): # If view_func is a function, call it; if it's a string, use it directly html = view_func() if callable(view_func) else view_func @@ -214,40 +246,40 @@ async def wrapped(*_args, **_kwargs): return wrapped - def setup_index(self, dash_app): + def setup_index(self, dash_app: Dash): async def index(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) - def setup_catchall(self, dash_app): - @dash_app.server.on_event("startup") + def setup_catchall(self, dash_app: Dash): + @self.server.on_event("startup") def _setup_catchall(): - config = load_config() - dash_app.enable_dev_tools(**config, first_run=False) + dash_app.enable_dev_tools( + **self.config, first_run=False + ) # do this to make sure dev tools are enabled async def catchall(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access dash_app._add_url("{path:path}", catchall, methods=["GET"]) def add_url_rule( - self, app, rule, view_func, endpoint=None, methods=None, include_in_schema=False + self, + rule: str, + view_func: Callable[..., Any] | str, + endpoint: str | None = None, + methods: list[str] | None = None, + include_in_schema: bool = False, ): if rule == "": rule = "/" if isinstance(view_func, str): # Wrap string or sync function to async FastAPI handler view_func = self._html_response_wrapper(view_func) - app.add_api_route( + self.server.add_api_route( rule, view_func, methods=methods or ["GET"], @@ -255,15 +287,15 @@ def add_url_rule( include_in_schema=include_in_schema, ) - def before_request(self, app, func): + def before_request(self, func: Callable[[], Any] | None): # FastAPI does not have before_request, but we can use middleware - app.middleware("http")(self._make_before_middleware(func)) + self.server.middleware("http")(self._make_before_middleware(func)) - def after_request(self, app, func): + def after_request(self, func: Callable[[], Any] | None): # FastAPI does not have after_request, but we can use middleware - app.middleware("http")(self._make_after_middleware(func)) + self.server.middleware("http")(self._make_after_middleware(func)) - def run(self, dash_app, app, host, port, debug, **kwargs): + def run(self, dash_app: Dash, host, port, debug, **kwargs): frame = inspect.stack()[2] config = dict( {"debug": debug} if debug else {}, @@ -278,7 +310,8 @@ def run(self, dash_app, app, host, port, debug, **kwargs): if kwargs.get("reload"): # Dynamically determine the module name from the file path file_path = frame.filename - module_name = importlib.util.spec_from_file_location("app", file_path).name + spec = spec_from_file_location("app", file_path) + module_name = spec.name if spec and getattr(spec, "name", None) else "app" uvicorn.run( f"{module_name}:app.server", host=host, @@ -286,9 +319,14 @@ def run(self, dash_app, app, host, port, debug, **kwargs): **kwargs, ) else: - uvicorn.run(app, host=host, port=port, **kwargs) + uvicorn.run(self.server, host=host, port=port, **kwargs) - def make_response(self, data, mimetype=None, content_type=None): + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + ): headers = {} if mimetype: headers["content-type"] = mimetype @@ -296,13 +334,10 @@ def make_response(self, data, mimetype=None, content_type=None): headers["content-type"] = content_type return Response(content=data, headers=headers) - def jsonify(self, obj): + def jsonify(self, obj: Any): return JSONResponse(content=obj) - def get_request_adapter(self): - return FastAPIRequestAdapter - - def _make_before_middleware(self, _func): + def _make_before_middleware(self, func: Callable[[], Any] | None): async def middleware(request, call_next): try: response = await call_next(request) @@ -322,7 +357,7 @@ async def middleware(request, call_next): return middleware - def _make_after_middleware(self, func): + def _make_after_middleware(self, func: Callable[[], Any] | None): async def middleware(request, call_next): response = await call_next(request) if func is not None: @@ -335,8 +370,13 @@ async def middleware(request, call_next): return middleware def serve_component_suites( - self, dash_app, package_name, fingerprinted_path, request + self, + dash_app: Dash, + package_name: str, + fingerprinted_path: str, + request: Request, ): + path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -360,7 +400,7 @@ def serve_component_suites( return StarletteResponse(status_code=304) return StarletteResponse(content=data, media_type=mimetype, headers=headers) - def setup_component_suites(self, dash_app): + def setup_component_suites(self, dash_app: Dash): async def serve(request: Request, package_name: str, fingerprinted_path: str): return self.serve_component_suites( dash_app, package_name, fingerprinted_path, request @@ -373,16 +413,12 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): ) # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=False): + def dispatch(self, dash_app: Dash): + async def _dispatch(request: Request): - adapter = FastAPIRequestAdapter() - set_request_adapter(adapter) - adapter.set_request(request) # pylint: disable=protected-access body = await request.json() - g = dash_app._initialize_context( - body, adapter - ) # pylint: disable=protected-access + g = dash_app._initialize_context(body) # pylint: disable=protected-access func = dash_app._prepare_callback( g, body ) # pylint: disable=protected-access @@ -406,12 +442,12 @@ def _serve_default_favicon(self): content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" ) - def register_timing_hooks(self, app, first_run): + def register_timing_hooks(self, first_run: bool): if not first_run: return - @app.middleware("http") - async def timing_middleware(request, call_next): + @self.server.middleware("http") + async def timing_middleware(request: Request, call_next): # Before request request.state.timing_information = { "__dash_server": {"dur": time.time(), "desc": None} @@ -433,11 +469,11 @@ async def timing_middleware(request, call_next): headers.append("Server-Timing", value) return response - def register_callback_api_routes(self, app, callback_api_paths): + def register_callback_api_routes(self, callback_api_paths: Dict[str, Callable[..., Any]]): """ Register callback API endpoints on the FastAPI app. Each key in callback_api_paths is a route, each value is a handler (sync or async). - Dynamically creates a Pydantic model for the handler's parameters and uses it as the body parameter. + Accepts a JSON body (dict) and filters keys based on the handler's signature. """ for path, handler in callback_api_paths.items(): endpoint = f"dash_callback_api_{path}" @@ -445,21 +481,19 @@ def register_callback_api_routes(self, app, callback_api_paths): methods = ["POST"] sig = inspect.signature(handler) param_names = list(sig.parameters.keys()) - fields = {name: (Optional[Any], None) for name in param_names} - Model = create_model( - f"Payload_{endpoint}", **fields - ) # pylint: disable=cell-var-from-loop - - # pylint: disable=cell-var-from-loop - async def view_func(request: Request, body: Model): - kwargs = body.dict(exclude_unset=True) + + async def view_func(request: Request, body: dict = Body(...)): + # Only pass expected params; ignore extras + kwargs = { + k: v for k, v in body.items() if k in param_names and v is not None + } if inspect.iscoroutinefunction(handler): result = await handler(**kwargs) else: result = handler(**kwargs) return JSONResponse(content=result) - app.add_api_route( + self.server.add_api_route( route, view_func, methods=methods, @@ -468,44 +502,58 @@ async def view_func(request: Request, body: Model): ) -class FastAPIRequestAdapter: +class FastAPIRequestAdapter(RequestAdapter): + def __init__(self): - self._request = None + self._request: Request = get_current_request() + super().__init__() - def set_request(self, request: Request): - self._request = request + def __call__(self): + self._request = get_current_request() + return self - def get_root(self): + @property + def root(self): return str(self._request.base_url) - def get_args(self): + @property + def args(self): return self._request.query_params - async def get_json(self): - return await self._request.json() - + @property def is_json(self): return self._request.headers.get("content-type", "").startswith( "application/json" ) - def get_cookies(self, _request=None): + @property + def cookies(self): return self._request.cookies - def get_headers(self): + @property + def headers(self): return self._request.headers - def get_full_path(self): + @property + def full_path(self): return str(self._request.url) - def get_url(self): + @property + def url(self): return str(self._request.url) - def get_remote_addr(self): - return self._request.client.host if self._request.client else None + @property + def remote_addr(self): + client = getattr(self._request, "client", None) + return getattr(client, "host", None) - def get_origin(self): + @property + def origin(self): return self._request.headers.get("origin") - def get_path(self): + @property + def path(self): return self._request.url.path + + async def get_json(self): # async method retained + return await self._request.json() diff --git a/dash/backend/flask.py b/dash/backends/_flask.py similarity index 55% rename from dash/backend/flask.py rename to dash/backends/_flask.py index cf544ef5bc..5a1385d574 100644 --- a/dash/backend/flask.py +++ b/dash/backends/_flask.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from contextvars import copy_context +from typing import TYPE_CHECKING, Any, Callable, Dict import asyncio import pkgutil import sys @@ -6,43 +9,60 @@ import time import inspect import traceback -import flask +from flask import ( + Flask, + Blueprint, + Response, + request, + jsonify, + g as flask_g, +) + from dash.fingerprint import check_fingerprint from dash import _validate -from dash._callback import _invoke_callback, _async_invoke_callback from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.backend import set_request_adapter -from .base_server import BaseDashServer +from dash._callback import _invoke_callback, _async_invoke_callback +from .base_server import BaseDashServer, RequestAdapter + +if TYPE_CHECKING: # pragma: no cover - typing only + from dash import Dash class FlaskDashServer(BaseDashServer): - def __call__(self, server, *args, **kwargs): + + def __init__(self, server: Flask) -> None: + self.server: Flask = server + self.server_type = "flask" + super().__init__() + + def __call__(self, *args: Any, **kwargs: Any): # Always WSGI - return server(*args, **kwargs) + return self.server(*args, **kwargs) - def create_app(self, name="__main__", config=None): - app = flask.Flask(name) + @staticmethod + def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): + app = Flask(name) if config: app.config.update(config) return app def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder + self, blueprint_name: str, assets_url_path: str, assets_folder: str ): - bp = flask.Blueprint( + bp = Blueprint( blueprint_name, __name__, static_folder=assets_folder, static_url_path=assets_url_path, ) - app.register_blueprint(bp) + self.server.register_blueprint(bp) - def register_error_handlers(self, app): - @app.errorhandler(PreventUpdate) + def register_error_handlers(self): + @self.server.errorhandler(PreventUpdate) def _handle_error(_): return "", 204 - @app.errorhandler(InvalidResourceError) + @self.server.errorhandler(InvalidResourceError) def _invalid_resources_handler(err): return err.args[0], 404 @@ -86,56 +106,64 @@ def _do_skip(error): ).render_debugger_html(True, secret, True) return "".join(traceback.format_exception(type(error), error, _do_skip(error))) - def register_prune_error_handler(self, app, secret, prune_errors): + def register_prune_error_handler(self, secret, prune_errors): if prune_errors: - @app.errorhandler(Exception) + @self.server.errorhandler(Exception) def _wrap_errors(error): tb = self._get_traceback(secret, error) return tb, 500 - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - app.add_url_rule( + def add_url_rule( + self, + rule: str, + view_func: Callable[..., Any], + endpoint: str | None = None, + methods: list[str] | None = None, + ): + self.server.add_url_rule( rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) - def before_request(self, app, func): - app.before_request(func) - - def after_request(self, app, func): - app.after_request(func) + def before_request(self, func: Callable[[], Any]): + # Flask expects a callable; user responsibility not to pass None + self.server.before_request(func) - def run(self, _dash_app, app, host, port, debug, **kwargs): - app.run(host=host, port=port, debug=debug, **kwargs) + def after_request(self, func: Callable[[Any], Any]): + # Flask after_request expects a function(response) -> response + self.server.after_request(func) - def make_response(self, data, mimetype=None, content_type=None): - return flask.Response(data, mimetype=mimetype, content_type=content_type) + def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: Any): + self.server.run(host=host, port=port, debug=debug, **kwargs) - def jsonify(self, obj): - return flask.jsonify(obj) + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + ): + return Response(data, mimetype=mimetype, content_type=content_type) - def get_request_adapter(self): - return FlaskRequestAdapter + def jsonify(self, obj: Any): + return jsonify(obj) - def setup_catchall(self, dash_app): + def setup_catchall(self, dash_app: Dash): def catchall(*args, **kwargs): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) return dash_app.index(*args, **kwargs) # pylint: disable=protected-access dash_app._add_url("", catchall, methods=["GET"]) - def setup_index(self, dash_app): + def setup_index(self, dash_app: Dash): def index(*args, **kwargs): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) return dash_app.index(*args, **kwargs) # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) - def serve_component_suites(self, dash_app, package_name, fingerprinted_path): + def serve_component_suites( + self, dash_app: Dash, package_name: str, fingerprinted_path: str + ): path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) extension = "." + path_in_pkg.split(".")[-1] @@ -149,18 +177,18 @@ def serve_component_suites(self, dash_app, package_name, fingerprinted_path): package.__path__, ) data = pkgutil.get_data(package_name, path_in_pkg) - response = flask.Response(data, mimetype=mimetype) + response = Response(data, mimetype=mimetype) if has_fingerprint: response.cache_control.max_age = 31536000 # 1 year else: response.add_etag() tag = response.get_etag()[0] - request_etag = flask.request.headers.get("If-None-Match") + request_etag = request.headers.get("If-None-Match") if f'"{tag}"' == request_etag: - response = flask.Response(None, status=304) + response = Response(None, status=304) return response - def setup_component_suites(self, dash_app): + def setup_component_suites(self, dash_app: Dash): def serve(package_name, fingerprinted_path): return self.serve_component_suites( dash_app, package_name, fingerprinted_path @@ -173,17 +201,15 @@ def serve(package_name, fingerprinted_path): ) # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=False): + def dispatch(self, dash_app: Dash): def _dispatch(): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - body = flask.request.get_json() + body = request.get_json() # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + cb_ctx = dash_app._initialize_context(body) + func = dash_app._prepare_callback(cb_ctx, body) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + partial_func = dash_app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) response_data = ctx.run(partial_func) if asyncio.iscoroutine(response_data): raise Exception( @@ -191,43 +217,41 @@ def _dispatch(): "Please install the dependencies via `pip install dash[async]` and ensure " "that `use_async=False` is not being passed to the app." ) - g.dash_response.set_data(response_data) - return g.dash_response + cb_ctx.dash_response.set_data(response_data) + return cb_ctx.dash_response async def _dispatch_async(): - adapter = FlaskRequestAdapter() - set_request_adapter(adapter) - body = flask.request.get_json() + body = request.get_json() # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) - func = dash_app._prepare_callback(g, body) - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + cb_ctx = dash_app._initialize_context(body) + func = dash_app._prepare_callback(cb_ctx, body) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + partial_func = dash_app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) response_data = ctx.run(partial_func) if asyncio.iscoroutine(response_data): response_data = await response_data - g.dash_response.set_data(response_data) - return g.dash_response + cb_ctx.dash_response.set_data(response_data) + return cb_ctx.dash_response - if use_async: + if dash_app._use_async: return _dispatch_async return _dispatch def _serve_default_favicon(self): - - return flask.Response( + return Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) - def register_timing_hooks(self, app, _first_run): - def _before_request(): - flask.g.timing_information = { + def register_timing_hooks(self, _first_run: bool): + # Define timing hooks inside method scope and register them + def _before_request() -> None: + flask_g.timing_information = { # type: ignore[attr-defined] "__dash_server": {"dur": time.time(), "desc": None} } - def _after_request(response): - timing_information = flask.g.get("timing_information", None) + def _after_request(response: Response): # type: ignore[name-defined] + timing_information = flask_g.get("timing_information", None) # type: ignore[attr-defined] if timing_information is None: return response dash_total = timing_information.get("__dash_server", None) @@ -242,10 +266,10 @@ def _after_request(response): response.headers.add("Server-Timing", value) return response - self.before_request(app, _before_request) - self.after_request(app, _after_request) + self.before_request(_before_request) + self.after_request(_after_request) - def register_callback_api_routes(self, app, callback_api_paths): + def register_callback_api_routes(self, callback_api_paths: Dict[str, Callable[..., Any]]): """ Register callback API endpoints on the Flask app. Each key in callback_api_paths is a route, each value is a handler (sync or async). @@ -258,65 +282,79 @@ def register_callback_api_routes(self, app, callback_api_paths): if inspect.iscoroutinefunction(handler): - async def view_func(*args, handler=handler, **kwargs): - data = flask.request.get_json() + async def _async_view_func(*args, handler=handler, **kwargs): + data = request.get_json() result = await handler(**data) if data else await handler() - return flask.jsonify(result) + return jsonify(result) + view_func = _async_view_func else: - def view_func(*args, handler=handler, **kwargs): - data = flask.request.get_json() + def _sync_view_func(*args, handler=handler, **kwargs): + data = request.get_json() result = handler(**data) if data else handler() - return flask.jsonify(result) + return jsonify(result) + + view_func = _sync_view_func + + view_func = _sync_view_func # Flask 2.x+ supports async views natively - app.add_url_rule( + self.server.add_url_rule( route, endpoint=endpoint, view_func=view_func, methods=methods ) -class FlaskRequestAdapter: - @staticmethod - def get_args(): - return flask.request.args +class FlaskRequestAdapter(RequestAdapter): + """Flask implementation using property-based accessors.""" - @staticmethod - def get_root(): - return flask.request.url_root + def __init__(self) -> None: + # Store the request LocalProxy so we can reference it consistently + self._request = request + super().__init__() - @staticmethod - def get_json(): - return flask.request.get_json() + def __call__(self, *args: Any, **kwds: Any): + return self - @staticmethod - def is_json(): - return flask.request.is_json + @property + def args(self): + return self._request.args - @staticmethod - def get_cookies(): - return flask.request.cookies + @property + def root(self): + return self._request.url_root - @staticmethod - def get_headers(): - return flask.request.headers + def get_json(self): # kept as method + return self._request.get_json() - @staticmethod - def get_url(): - return flask.request.url + @property + def is_json(self): + return self._request.is_json - @staticmethod - def get_full_path(): - return flask.request.full_path + @property + def cookies(self): + return self._request.cookies - @staticmethod - def get_remote_addr(): - return flask.request.remote_addr + @property + def headers(self): + return self._request.headers - @staticmethod - def get_origin(): - return getattr(flask.request, "origin", None) + @property + def url(self): + return self._request.url - @staticmethod - def get_path(): - return flask.request.path + @property + def full_path(self): + return self._request.full_path + + @property + def remote_addr(self): + return self._request.remote_addr + + @property + def origin(self): + return getattr(self._request, "origin", None) + + @property + def path(self): + return self._request.path diff --git a/dash/backend/quart.py b/dash/backends/_quart.py similarity index 68% rename from dash/backend/quart.py rename to dash/backends/_quart.py index 830d7dd3b9..a462d07af6 100644 --- a/dash/backend/quart.py +++ b/dash/backends/_quart.py @@ -1,61 +1,68 @@ +from __future__ import annotations +from contextvars import copy_context +import typing as _t +import traceback +import mimetypes import inspect import pkgutil -import mimetypes -import sys import time -from contextvars import copy_context -import traceback +import sys import re -try: - import quart - from quart import Quart, Response, jsonify, request, Blueprint -except ImportError: - quart = None - Quart = None - Response = None - jsonify = None - request = None - Blueprint = None +# Attempt top-level Quart imports; allow absence if user not using quart backend +from quart import ( + Quart, + Response, + jsonify, + request, + Blueprint, + g, +) + +if _t.TYPE_CHECKING: + from dash import Dash + from dash.exceptions import PreventUpdate, InvalidResourceError -from dash.backend import set_request_adapter from dash.fingerprint import check_fingerprint from dash import _validate from .base_server import BaseDashServer class QuartDashServer(BaseDashServer): - """Quart implementation of the Dash server factory. - - All Quart/async specific imports are at the top-level (per user request) so - Quart must be installed when this module is imported. - """ - def __init__(self) -> None: + def __init__(self, server: Quart) -> None: + self.server_type = "quart" + self.server: Quart = server self.config = {} self.error_handling_mode = "prune" super().__init__() - def __call__(self, server, *args, **kwargs): - return server(*args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] + return self.server(*args, **kwargs) - def create_app(self, name="__main__", config=None): - app = Quart(name) + @staticmethod + def create_app(name: str = "__main__", config: _t.Optional[_t.Dict[str, _t.Any]] = None): + if Quart is None: + raise RuntimeError( + "Quart is not installed. Install with 'pip install quart' to use the quart backend." + ) + app = Quart(name) # type: ignore if config: for key, value in config.items(): app.config[key] = value return app def register_assets_blueprint( - self, app, blueprint_name, assets_url_path, assets_folder + self, blueprint_name: str, assets_url_path: str, assets_folder: str # type: ignore[name-defined] ): + bp = Blueprint( blueprint_name, __name__, static_folder=assets_folder, static_url_path=assets_url_path, ) - app.register_blueprint(bp) + self.server.register_blueprint(bp) def _get_traceback(self, _secret, error: Exception): tb = error.__traceback__ @@ -166,27 +173,30 @@ def _get_traceback(self, _secret, error: Exception): """ return html - def register_prune_error_handler(self, app, secret, prune_errors): + def register_prune_error_handler(self, secret, prune_errors): if prune_errors: self.error_handling_mode = "prune" else: self.error_handling_mode = "raise" - @app.errorhandler(Exception) + @self.server.errorhandler(Exception) async def _wrap_errors(error): tb = self._get_traceback(secret, error) return Response(tb, status=500, content_type="text/html") - def register_timing_hooks(self, app, _first_run): # parity with Flask factory - @app.before_request + def register_timing_hooks(self, _first_run: bool): # type: ignore[name-defined] parity with Flask factory + @self.server.before_request async def _before_request(): # pragma: no cover - timing infra - quart.g.timing_information = { - "__dash_server": {"dur": time.time(), "desc": None} - } + if g is not None: + g.timing_information = { # type: ignore[attr-defined] + "__dash_server": {"dur": time.time(), "desc": None} + } - @app.after_request + @self.server.after_request async def _after_request(response): # pragma: no cover - timing infra - timing_information = getattr(quart.g, "timing_information", None) + timing_information = ( + getattr(g, "timing_information", None) if g is not None else None + ) if timing_information is None: return response dash_total = timing_information.get("__dash_server", None) @@ -205,16 +215,17 @@ async def _after_request(response): # pragma: no cover - timing infra response.headers["Server-Timing"] = value return response - def register_error_handlers(self, app): - @app.errorhandler(PreventUpdate) + def register_error_handlers(self): # type: ignore[name-defined] + @self.server.errorhandler(PreventUpdate) async def _prevent_update(_): return "", 204 - @app.errorhandler(InvalidResourceError) + @self.server.errorhandler(InvalidResourceError) async def _invalid_resource(err): return err.args[0], 404 - def _html_response_wrapper(self, view_func): + def _html_response_wrapper(self, view_func: _t.Callable[..., _t.Any] | str): + async def wrapped(*_args, **_kwargs): html_val = view_func() if callable(view_func) else view_func if inspect.iscoroutine(html_val): # handle async function returning html @@ -224,38 +235,40 @@ async def wrapped(*_args, **_kwargs): return wrapped - def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None): - app.add_url_rule( + def add_url_rule( + self, + rule: str, + view_func: _t.Callable[..., _t.Any], + endpoint: str | None = None, + methods: list[str] | None = None, + ): + self.server.add_url_rule( rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"] ) - def setup_index(self, dash_app): + def setup_index(self, dash_app: Dash): # type: ignore[name-defined] + async def index(*args, **kwargs): - adapter = QuartRequestAdapter() - set_request_adapter(adapter) - adapter.set_request() - return Response(dash_app.index(*args, **kwargs), content_type="text/html") + return Response(dash_app.index(*args, **kwargs), content_type="text/html") # type: ignore[arg-type] # pylint: disable=protected-access dash_app._add_url("", index, methods=["GET"]) - def setup_catchall(self, dash_app): + def setup_catchall(self, dash_app: Dash): + async def catchall( - path, *args, **kwargs + path: str, *args, **kwargs ): # noqa: ARG001 - path is unused but kept for route signature, pylint: disable=unused-argument - adapter = QuartRequestAdapter() - set_request_adapter(adapter) - adapter.set_request() - return Response(dash_app.index(*args, **kwargs), content_type="text/html") + return Response(dash_app.index(*args, **kwargs), content_type="text/html") # type: ignore[arg-type] # pylint: disable=protected-access dash_app._add_url("", catchall, methods=["GET"]) - def before_request(self, app, func): - app.before_request(func) + def before_request(self, func: _t.Callable[[], _t.Any]): + self.server.before_request(func) - def after_request(self, app, func): - @app.after_request + def after_request(self, func: _t.Callable[[], _t.Any]): + @self.server.after_request async def _after(response): if func is not None: result = func() @@ -263,21 +276,25 @@ async def _after(response): await result return response - def run(self, _dash_app, app, host, port, debug, **kwargs): + def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): self.config = {"debug": debug, **kwargs} if debug else kwargs - app.run(host=host, port=port, debug=debug, **kwargs) + self.server.run(host=host, port=port, debug=debug, **kwargs) - def make_response(self, data, mimetype=None, content_type=None): + def make_response( + self, + data: str | bytes | bytearray, + mimetype: str | None = None, + content_type: str | None = None, + ): + if Response is None: + raise RuntimeError("Quart not installed; cannot generate Response") return Response(data, mimetype=mimetype, content_type=content_type) def jsonify(self, obj): return jsonify(obj) - def get_request_adapter(self): - return QuartRequestAdapter - def serve_component_suites( - self, dash_app, package_name, fingerprinted_path + self, dash_app: Dash, package_name: str, fingerprinted_path: str ): # noqa: ARG002 unused req preserved for interface parity path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path) _validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg) @@ -296,9 +313,11 @@ def serve_component_suites( if has_fingerprint: headers["Cache-Control"] = "public, max-age=31536000" + if Response is None: + raise RuntimeError("Quart not installed; cannot generate Response") return Response(data, content_type=mimetype, headers=headers) - def setup_component_suites(self, dash_app): + def setup_component_suites(self, dash_app: Dash): async def serve(package_name, fingerprinted_path): return self.serve_component_suites( dash_app, package_name, fingerprinted_path @@ -311,14 +330,13 @@ async def serve(package_name, fingerprinted_path): ) # pylint: disable=unused-argument - def dispatch(self, app, dash_app, use_async=True): # Quart always async + def dispatch(self, dash_app: Dash): # type: ignore[name-defined] Quart always async + async def _dispatch(): adapter = QuartRequestAdapter() - set_request_adapter(adapter) - adapter.set_request() - body = await request.get_json() + body = await adapter.get_json() # pylint: disable=protected-access - g = dash_app._initialize_context(body, adapter) + g = dash_app._initialize_context(body) # pylint: disable=protected-access func = dash_app._prepare_callback(g, body) # pylint: disable=protected-access @@ -329,11 +347,11 @@ async def _dispatch(): response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async response_data = await response_data - return Response(response_data, content_type="application/json") + return Response(response_data, content_type="application/json") # type: ignore[arg-type] return _dispatch - def register_callback_api_routes(self, app, callback_api_paths): + def register_callback_api_routes(self, callback_api_paths: _t.Dict[str, _t.Callable[..., _t.Any]]): """ Register callback API endpoints on the Quart app. Each key in callback_api_paths is a route, each value is a handler (sync or async). @@ -348,25 +366,33 @@ def _make_view_func(handler): if inspect.iscoroutinefunction(handler): async def async_view_func(*args, **kwargs): + if request is None: + raise RuntimeError( + "Quart not installed; request unavailable" + ) data = await request.get_json() result = await handler(**data) if data else await handler() - return jsonify(result) + return jsonify(result) # type: ignore[arg-type] return async_view_func async def sync_view_func(*args, **kwargs): + if request is None: + raise RuntimeError("Quart not installed; request unavailable") data = await request.get_json() result = handler(**data) if data else handler() - return jsonify(result) + return jsonify(result) # type: ignore[arg-type] return sync_view_func view_func = _make_view_func(handler) - app.add_url_rule( + self.server.add_url_rule( route, endpoint=endpoint, view_func=view_func, methods=methods ) def _serve_default_favicon(self): + if Response is None: + raise RuntimeError("Quart not installed; cannot generate Response") return Response( pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) @@ -374,41 +400,53 @@ def _serve_default_favicon(self): class QuartRequestAdapter: def __init__(self) -> None: - self._request = None + self._request = request # type: ignore[assignment] + if self._request is None: + raise RuntimeError("Quart not installed; cannot access request context") - def set_request(self) -> None: - self._request = request + @property + def request(self) -> _t.Any: + return self._request - # Accessors (instance-based) - def get_root(self): - return self._request.root_url + @property + def root(self): + return self.request.root_url - def get_args(self): - return self._request.args - - async def get_json(self): - return await self._request.get_json() + @property + def args(self): + return self.request.args + @property def is_json(self): - return self._request.is_json + return self.request.is_json + + @property + def cookies(self): + return self.request.cookies - def get_cookies(self): - return self._request.cookies + @property + def headers(self): + return self.request.headers - def get_headers(self): - return self._request.headers + @property + def full_path(self): + return self.request.full_path - def get_full_path(self): - return self._request.full_path + @property + def url(self): + return str(self.request.url) - def get_url(self): - return str(self._request.url) + @property + def remote_addr(self): + return self.request.remote_addr - def get_remote_addr(self): - return self._request.remote_addr + @property + def origin(self): + return self.request.headers.get("origin") - def get_origin(self): - return self._request.headers.get("origin") + @property + def path(self): + return self.request.path - def get_path(self): - return self._request.path + async def get_json(self): + return await self.request.get_json() diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py new file mode 100644 index 0000000000..1c47548ad0 --- /dev/null +++ b/dash/backends/base_server.py @@ -0,0 +1,119 @@ +from abc import ABC, abstractmethod +from typing import Any + + +class BaseDashServer(ABC): + server_type: str + server: Any + config: dict[str, Any] + + def __call__(self, *args, **kwargs) -> Any: + # Default: WSGI + return self.server(*args, **kwargs) + + @staticmethod + @abstractmethod + def create_app( + name: str = "__main__", config=None + ) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def register_assets_blueprint( + self, blueprint_name: str, assets_url_path: str, assets_folder: str + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def register_error_handlers(self) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def add_url_rule( + self, rule: str, view_func, endpoint=None, methods=None + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def before_request(self, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def after_request(self, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def run( + self, dash_app, host: str, port: int, debug: bool, **kwargs + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def make_response( + self, data, mimetype=None, content_type=None + ) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def jsonify(self, obj) -> Any: # pragma: no cover - interface + pass + + +class RequestAdapter(ABC): + def __call__(self) -> Any: + return self + + # Properties to be implemented in concrete adapters + @property # pragma: no cover - interface + @abstractmethod + def root(self) -> str: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def args(self): + raise NotImplementedError() + + @abstractmethod # kept as method (may be sync or async) + def get_json(self): # pragma: no cover - interface + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def is_json(self) -> bool: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def cookies(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def headers(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def full_path(self) -> str: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def url(self) -> str: + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def remote_addr(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def origin(self): + raise NotImplementedError() + + @property # pragma: no cover - interface + @abstractmethod + def path(self) -> str: + raise NotImplementedError() diff --git a/dash/dash.py b/dash/dash.py index 6bba3aadfd..1ed05657dc 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -25,7 +25,6 @@ from dash import dcc from dash import html from dash import dash_table - from .fingerprint import build_fingerprint from .resources import Scripts, Css from .dependencies import ( @@ -38,7 +37,7 @@ ProxyError, DuplicateCallback, ) -from .backend import get_request_adapter, get_backend +from .backends import get_backend from .version import __version__ from ._configs import get_combined_config, pathname_configs, pages_folder_config from ._utils import ( @@ -63,6 +62,7 @@ from . import _validate from . import _watch from . import _get_app +from . import backends from ._get_app import with_app_context, with_app_context_factory from ._grouping import map_grouping, grouping_len, update_args_group @@ -154,36 +154,6 @@ page_container = None -def _is_flask_instance(obj): - try: - # pylint: disable=import-outside-toplevel - from flask import Flask - - return isinstance(obj, Flask) - except ImportError: - return False - - -def _is_fastapi_instance(obj): - try: - # pylint: disable=import-outside-toplevel - from fastapi import FastAPI - - return isinstance(obj, FastAPI) - except ImportError: - return False - - -def _is_quart_instance(obj): - try: - # pylint: disable=import-outside-toplevel - from quart import Quart - - return isinstance(obj, Quart) - except ImportError: - return False - - # Singleton signal to not update an output, alternative to PreventUpdate no_update = _callback.NoUpdate() # pylint: disable=protected-access @@ -446,74 +416,41 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches **obsolete, ): - if use_async is None: - try: - import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa - - use_async = True - except ImportError: - pass - elif use_async: - try: - import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa - except ImportError as exc: - raise Exception( - "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" - ) from exc - + _validate.check_async(use_async) _validate.check_obsolete(obsolete) caller_name: str = name if name is not None else get_caller_name() # Determine backend if backend is None: - backend_cls = get_backend("flask") + backend_cls, request_cls = get_backend("flask") elif isinstance(backend, str): - backend_cls = get_backend(backend) + backend_cls, request_cls = get_backend(backend) elif isinstance(backend, type): backend_cls = backend + _, request_cls = get_backend(backend.server_type) else: raise ValueError("Invalid backend argument") # Determine server and backend instance if server not in (None, True, False): # User provided a server instance (e.g., Flask, Quart, FastAPI) - if _is_flask_instance(server): - inferred_backend = "flask" - elif _is_quart_instance(server): - inferred_backend = "quart" - elif _is_fastapi_instance(server): - inferred_backend = "fastapi" - else: - raise ValueError("Unsupported server type") - # Validate that backend matches server type if both are provided - if backend is not None: - if isinstance(backend, type): - # get_backend returns the backend class for a string - # So we compare the class names - expected_backend_cls = get_backend(inferred_backend) - if ( - backend.__module__ != expected_backend_cls.__module__ - or backend.__name__ != expected_backend_cls.__name__ - ): - raise ValueError( - f"Conflict between provided backend '{backend.__name__}' and server type '{inferred_backend}'." - ) - elif not isinstance(backend, str): - raise ValueError("Invalid backend argument") - elif backend.lower() != inferred_backend: - raise ValueError( - f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." - ) - backend_cls = get_backend(inferred_backend) + inferred_backend = backends.get_server_type(server) + _validate.check_backend(backend, inferred_backend) + backend_cls, request_cls = get_backend(inferred_backend) if name is None: caller_name = getattr(server, "name", caller_name) - self.backend = backend_cls() + + self.backend = backend_cls(server) self.server = server + backends.backend = self.backend # type: ignore + backends.request_adapter = request_cls else: # No server instance provided, create backend and let backend create server - self.backend = backend_cls() - self.server = self.backend.create_app(caller_name) # type: ignore + self.server = backend_cls.create_app(caller_name) # type: ignore + self.backend = backend_cls(self.server) + backends.backend = self.backend + backends.request_adapter = request_cls base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix @@ -710,7 +647,6 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: bp_prefix = config.routes_pathname_prefix.replace("/", "_").replace(".", "_") assets_blueprint_name = f"{bp_prefix}dash_assets" self.backend.register_assets_blueprint( - self.server, assets_blueprint_name, config.routes_pathname_prefix + self.config.assets_url_path.lstrip("/"), self.config.assets_folder, @@ -732,8 +668,9 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: raise ImportError( "To use the compress option, you need to install dash[compress]" ) from error - self.backend.register_error_handlers(self.server) - self.backend.before_request(self.server, self._setup_server) + + self.backend.register_error_handlers() + self.backend.before_request(self._setup_server) self._setup_routes() _get_app.APP = self self.enable_pages() @@ -742,7 +679,6 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> None: full_name = self.config.routes_pathname_prefix + name self.backend.add_url_rule( - self.server, full_name, view_func=view_func, endpoint=full_name, @@ -756,7 +692,7 @@ def _setup_routes(self): self._add_url("_dash-dependencies", self.dependencies) self._add_url( "_dash-update-component", - self.backend.dispatch(self.server, self, self._use_async), + self.backend.dispatch(self), ["POST"], ) self._add_url("_reload-hash", self.serve_reload_hash) @@ -803,7 +739,7 @@ def setup_apis(self): self.callback_api_paths[k] = _callback.GLOBAL_API_PATHS.pop(k) # Delegate to the server factory for route registration - self.backend.register_callback_api_routes(self.server, self.callback_api_paths) + self.backend.register_callback_api_routes(self.callback_api_paths) def _setup_plotlyjs(self): # pylint: disable=import-outside-toplevel @@ -1043,9 +979,11 @@ def _generate_css_dist_html(self): return "\n".join( [ - format_tag("link", link, opened=True) - if isinstance(link, dict) - else f'' + ( + format_tag("link", link, opened=True) + if isinstance(link, dict) + else f'' + ) for link in (external_links + links) ] ) @@ -1099,9 +1037,11 @@ def _generate_scripts_html(self) -> str: return "\n".join( [ - format_tag("script", src) - if isinstance(src, dict) - else f'' + ( + format_tag("script", src) + if isinstance(src, dict) + else f'' + ) for src in srcs ] + [f"" for src in self._inline_scripts] @@ -1139,11 +1079,8 @@ def index(self, *_args, **_kwargs): metas = self._generate_meta() renderer = self._generate_renderer() title = self.title - try: - request = get_request_adapter() - except LookupError: - # no request context - request = None + # Refactored: direct access to global request adapter + request = backends.request_adapter() if self.use_pages and self.config.include_pages_meta and request: metas = _page_meta_tags(self, request) + metas @@ -1357,8 +1294,9 @@ def _inputs_to_vals(self, inputs): return inputs_to_vals(inputs) # pylint: disable=R0915 - def _initialize_context(self, body, adapter): + def _initialize_context(self, body): """Initialize the global context for the request.""" + adapter = backends.request_adapter() g = AttributeDict({}) g.inputs_list = body.get("inputs", []) g.states_list = body.get("state", []) @@ -1372,12 +1310,12 @@ def _initialize_context(self, body, adapter): g.dash_response = self.backend.make_response( mimetype="application/json", data=None ) - g.cookies = dict(adapter.get_cookies()) - g.headers = dict(adapter.get_headers()) - g.args = adapter.get_args() - g.path = adapter.get_full_path() - g.remote = adapter.get_remote_addr() - g.origin = adapter.get_origin() + g.cookies = dict(adapter.cookies) + g.headers = dict(adapter.headers) + g.args = adapter.args + g.path = adapter.full_path + g.remote = adapter.remote_addr + g.origin = adapter.origin g.updated_props = {} return g @@ -1964,15 +1902,21 @@ def enable_dev_tools( packages[index] = dash_spec component_packages_dist = [ - dash_test_path # type: ignore[reportPossiblyUnboundVariable] - if isinstance(package, ModuleSpec) - else os.path.dirname(package.path) # type: ignore[reportAttributeAccessIssue] - if hasattr(package, "path") - else os.path.dirname( - package._path[0] # type: ignore[reportAttributeAccessIssue]; pylint: disable=protected-access - ) - if hasattr(package, "_path") - else package.filename # type: ignore[reportAttributeAccessIssue] + ( + dash_test_path # type: ignore[reportPossiblyUnboundVariable] + if isinstance(package, ModuleSpec) + else ( + os.path.dirname(package.path) # type: ignore[reportAttributeAccessIssue] + if hasattr(package, "path") + else ( + os.path.dirname( + package._path[0] # type: ignore[reportAttributeAccessIssue]; pylint: disable=protected-access + ) + if hasattr(package, "_path") + else package.filename + ) + ) + ) # type: ignore[reportAttributeAccessIssue] for package in packages ] @@ -2000,13 +1944,14 @@ def enable_dev_tools( jupyter_dash.configure_callback_exception_handling( self, dev_tools.prune_errors ) - secret = gen_salt(20) - self.backend.register_prune_error_handler( - self.server, secret, dev_tools.prune_errors - ) + elif dev_tools.prune_errors: + secret = gen_salt(20) + self.backend.register_prune_error_handler( + secret, dev_tools.prune_errors + ) if debug and dev_tools.ui: - self.backend.register_timing_hooks(self.server, first_run) + self.backend.register_timing_hooks(first_run) if ( debug @@ -2290,13 +2235,8 @@ def verify_url_part(served_part, url_part, part_name): server_url=jupyter_server_url, ) else: - self.backend.run( - self, - self.server, - host=host, - port=port, - debug=debug, - **flask_run_options, + backends.backend.run( + dash_app=self, host=host, port=port, debug=debug, **flask_run_options ) def enable_pages(self) -> None: @@ -2368,9 +2308,11 @@ async def update(pathname_, search_, **states): if not self.config.suppress_callback_exceptions: self.validation_layout = html.Div( [ - asyncio.run(execute_async_function(page["layout"])) - if callable(page["layout"]) - else page["layout"] + ( + asyncio.run(execute_async_function(page["layout"])) + if callable(page["layout"]) + else page["layout"] + ) for page in _pages.PAGE_REGISTRY.values() ] + [ @@ -2439,9 +2381,11 @@ def update(pathname_, search_, **states): ] self.validation_layout = html.Div( [ - page["layout"]() - if callable(page["layout"]) - else page["layout"] + ( + page["layout"]() + if callable(page["layout"]) + else page["layout"] + ) for page in _pages.PAGE_REGISTRY.values() ] + layout @@ -2460,7 +2404,7 @@ def update(pathname_, search_, **states): Input(_ID_STORE, "data"), ) - self.backend.before_request(self.server, router) + self.backend.before_request(router) def __call__(self, *args, **kwargs): - return self.backend.__call__(self.server, *args, **kwargs) + return self.backend.__call__(*args, **kwargs) diff --git a/dash_config.json b/dash_config.json new file mode 100644 index 0000000000..3afa0d11f1 --- /dev/null +++ b/dash_config.json @@ -0,0 +1 @@ +{"debug": true, "dev_tools_ui": true, "dev_tools_props_check": true, "dev_tools_serve_dev_bundles": true, "dev_tools_hot_reload": true, "dev_tools_silence_routes_logging": true, "dev_tools_prune_errors": true, "dev_tools_hot_reload_interval": 3.0, "dev_tools_hot_reload_watch_interval": 0.5, "dev_tools_hot_reload_max_retry": 8, "dev_tools_disable_version_check": false} \ No newline at end of file diff --git a/quart_app.py b/quart_app.py new file mode 100644 index 0000000000..54d40add56 --- /dev/null +++ b/quart_app.py @@ -0,0 +1,23 @@ +from dash import Dash, html, Input, Output +from dash import dcc +from dash import backends + +app = Dash(__name__, backend="quart") + +app.layout = html.Div( + [ + html.H2("Quart Server Factory Example"), + html.Div("Type below to see async callback update."), + dcc.Input(id="text", value="hello", autoComplete="off"), + html.Div(id="echo"), + ] +) + + +@app.callback(Output("echo", "children"), Input("text", "value")) +def update_echo(val): + return f"You typed: {val}" if val else "Type something" + + +if __name__ == "__main__": + app.run(debug=True) From c4795ed3b544964c259fe21cb81746911fb7e6aa Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:58:06 -0400 Subject: [PATCH 062/166] fixes for failing tests --- .gitignore | 1 + dash/backends/_fastapi.py | 5 +++-- dash/dash.py | 9 ++++----- dash_config.json | 2 +- tests/backend_tests/test_preconfig_backends.py | 12 ++++++------ 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index 89029448fe..06e855e2dc 100644 --- a/.gitignore +++ b/.gitignore @@ -93,3 +93,4 @@ packages/ !components/dash-core-components/tests/integration/upload/upload-assets/upft001.csv !components/dash-table/tests/assets/*.csv !components/dash-table/tests/selenium/assets/*.csv +dash_config.json diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index f3f9f2df33..57cf18b6ec 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -52,6 +52,7 @@ def get_current_request() -> Request: class CurrentRequestMiddleware: def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] self.app = app + print('loaded CurrentRequestMiddleware') async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[name-defined] # non-http/ws scopes pass through (lifespan etc.) @@ -100,7 +101,6 @@ def __call__(self, *args: Any, **kwargs: Any): @staticmethod def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): app = FastAPI() - app.add_middleware(CurrentRequestMiddleware) if config: for key, value in config.items(): @@ -257,7 +257,7 @@ def setup_catchall(self, dash_app: Dash): @self.server.on_event("startup") def _setup_catchall(): dash_app.enable_dev_tools( - **self.config, first_run=False + **load_config(), first_run=False ) # do this to make sure dev tools are enabled async def catchall(request: Request): @@ -289,6 +289,7 @@ def add_url_rule( def before_request(self, func: Callable[[], Any] | None): # FastAPI does not have before_request, but we can use middleware + self.server.add_middleware(CurrentRequestMiddleware) self.server.middleware("http")(self._make_before_middleware(func)) def after_request(self, func: Callable[[], Any] | None): diff --git a/dash/dash.py b/dash/dash.py index 1ed05657dc..3ab830e8a3 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -1944,11 +1944,10 @@ def enable_dev_tools( jupyter_dash.configure_callback_exception_handling( self, dev_tools.prune_errors ) - elif dev_tools.prune_errors: - secret = gen_salt(20) - self.backend.register_prune_error_handler( - secret, dev_tools.prune_errors - ) + secret = gen_salt(20) + self.backend.register_prune_error_handler( + secret, dev_tools.prune_errors + ) if debug and dev_tools.ui: self.backend.register_timing_hooks(first_run) diff --git a/dash_config.json b/dash_config.json index 3afa0d11f1..e4af4373cb 100644 --- a/dash_config.json +++ b/dash_config.json @@ -1 +1 @@ -{"debug": true, "dev_tools_ui": true, "dev_tools_props_check": true, "dev_tools_serve_dev_bundles": true, "dev_tools_hot_reload": true, "dev_tools_silence_routes_logging": true, "dev_tools_prune_errors": true, "dev_tools_hot_reload_interval": 3.0, "dev_tools_hot_reload_watch_interval": 0.5, "dev_tools_hot_reload_max_retry": 8, "dev_tools_disable_version_check": false} \ No newline at end of file +{"dev_tools_ui": false, "dev_tools_props_check": false, "dev_tools_serve_dev_bundles": false, "dev_tools_hot_reload": false, "dev_tools_silence_routes_logging": false, "dev_tools_prune_errors": false, "dev_tools_hot_reload_interval": 3.0, "dev_tools_hot_reload_watch_interval": 0.5, "dev_tools_hot_reload_max_retry": 8, "dev_tools_disable_version_check": true} \ No newline at end of file diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py index 5fbd28dfd9..4c4ccc7083 100644 --- a/tests/backend_tests/test_preconfig_backends.py +++ b/tests/backend_tests/test_preconfig_backends.py @@ -30,7 +30,7 @@ def update_output(value): dash_duo.start_server(app) dash_duo.wait_for_text_to_equal("#output", f"You typed: {input_value}") - dash_duo.find_element("#input").clear() + dash_duo.clear_input(dash_duo.find_element("#input")) dash_duo.find_element("#input").send_keys(f"{backend.title()} Test") dash_duo.wait_for_text_to_equal("#output", f"You typed: {backend.title()} Test") assert dash_duo.get_logs() == [] @@ -93,7 +93,7 @@ def get_error_html(dash_duo, index): "dev_tools_prune_errors": False, "reload": False, }, - "fastapi.py", + "_fastapi.py", ), ( "quart", @@ -104,7 +104,7 @@ def get_error_html(dash_duo, index): "dev_tools_hot_reload": False, "dev_tools_prune_errors": False, }, - "quart.py", + "_quart.py", ), ], ) @@ -131,7 +131,7 @@ def error_callback(n): error0 = get_error_html(dash_duo, 0) assert "in error_callback" in error0 assert "ZeroDivisionError" in error0 - assert "backend" in error0 and error_msg in error0 + assert "backends/" in error0 and error_msg in error0 @pytest.mark.parametrize( @@ -173,7 +173,7 @@ def error_callback(n): error0 = get_error_html(dash_duo, 0) assert "in error_callback" in error0 assert "ZeroDivisionError" in error0 - assert "dash/backend" not in error0 and error_msg not in error0 + assert "dash/backends/" not in error0 and error_msg not in error0 @pytest.mark.parametrize( @@ -209,7 +209,7 @@ def update_output_bg(value): dash_duo.start_server(app) dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}") - dash_duo.find_element("#input").clear() + dash_duo.clear_input(dash_duo.find_element("#input")) dash_duo.find_element("#input").send_keys(f"{backend.title()} BG Test") dash_duo.wait_for_text_to_equal( "#output", f"Background typed: {backend.title()} BG Test" From 567d0f8d592e4281047794100c80a3728aa6b128 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:05:33 -0400 Subject: [PATCH 063/166] fixing formatting --- dash/backends/_fastapi.py | 9 ++++----- dash/backends/_flask.py | 13 +++++++++---- dash/backends/_quart.py | 13 ++++++------- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 57cf18b6ec..540238f727 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -52,7 +52,7 @@ def get_current_request() -> Request: class CurrentRequestMiddleware: def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] self.app = app - print('loaded CurrentRequestMiddleware') + print("loaded CurrentRequestMiddleware") async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[name-defined] # non-http/ws scopes pass through (lifespan etc.) @@ -84,7 +84,6 @@ def load_config(): class FastAPIDashServer(BaseDashServer): - def __init__(self, server: FastAPI): self.config = {} self.server_type = "fastapi" @@ -415,7 +414,6 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): # pylint: disable=unused-argument def dispatch(self, dash_app: Dash): - async def _dispatch(request: Request): # pylint: disable=protected-access body = await request.json() @@ -470,7 +468,9 @@ async def timing_middleware(request: Request, call_next): headers.append("Server-Timing", value) return response - def register_callback_api_routes(self, callback_api_paths: Dict[str, Callable[..., Any]]): + def register_callback_api_routes( + self, callback_api_paths: Dict[str, Callable[..., Any]] + ): """ Register callback API endpoints on the FastAPI app. Each key in callback_api_paths is a route, each value is a handler (sync or async). @@ -504,7 +504,6 @@ async def view_func(request: Request, body: dict = Body(...)): class FastAPIRequestAdapter(RequestAdapter): - def __init__(self): self._request: Request = get_current_request() super().__init__() diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index 5a1385d574..138234a4bc 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -29,7 +29,6 @@ class FlaskDashServer(BaseDashServer): - def __init__(self, server: Flask) -> None: self.server: Flask = server self.server_type = "flask" @@ -209,7 +208,9 @@ def _dispatch(): func = dash_app._prepare_callback(cb_ctx, body) args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) response_data = ctx.run(partial_func) if asyncio.iscoroutine(response_data): raise Exception( @@ -227,7 +228,9 @@ async def _dispatch_async(): func = dash_app._prepare_callback(cb_ctx, body) args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) ctx = copy_context() - partial_func = dash_app._execute_callback(func, args, cb_ctx.outputs_list, cb_ctx) + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) response_data = ctx.run(partial_func) if asyncio.iscoroutine(response_data): response_data = await response_data @@ -269,7 +272,9 @@ def _after_request(response: Response): # type: ignore[name-defined] self.before_request(_before_request) self.after_request(_after_request) - def register_callback_api_routes(self, callback_api_paths: Dict[str, Callable[..., Any]]): + def register_callback_api_routes( + self, callback_api_paths: Dict[str, Callable[..., Any]] + ): """ Register callback API endpoints on the Flask app. Each key in callback_api_paths is a route, each value is a handler (sync or async). diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index a462d07af6..ff544c2c91 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -29,7 +29,6 @@ class QuartDashServer(BaseDashServer): - def __init__(self, server: Quart) -> None: self.server_type = "quart" self.server: Quart = server @@ -41,7 +40,9 @@ def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] return self.server(*args, **kwargs) @staticmethod - def create_app(name: str = "__main__", config: _t.Optional[_t.Dict[str, _t.Any]] = None): + def create_app( + name: str = "__main__", config: _t.Optional[_t.Dict[str, _t.Any]] = None + ): if Quart is None: raise RuntimeError( "Quart is not installed. Install with 'pip install quart' to use the quart backend." @@ -225,7 +226,6 @@ async def _invalid_resource(err): return err.args[0], 404 def _html_response_wrapper(self, view_func: _t.Callable[..., _t.Any] | str): - async def wrapped(*_args, **_kwargs): html_val = view_func() if callable(view_func) else view_func if inspect.iscoroutine(html_val): # handle async function returning html @@ -247,7 +247,6 @@ def add_url_rule( ) def setup_index(self, dash_app: Dash): # type: ignore[name-defined] - async def index(*args, **kwargs): return Response(dash_app.index(*args, **kwargs), content_type="text/html") # type: ignore[arg-type] @@ -255,7 +254,6 @@ async def index(*args, **kwargs): dash_app._add_url("", index, methods=["GET"]) def setup_catchall(self, dash_app: Dash): - async def catchall( path: str, *args, **kwargs ): # noqa: ARG001 - path is unused but kept for route signature, pylint: disable=unused-argument @@ -331,7 +329,6 @@ async def serve(package_name, fingerprinted_path): # pylint: disable=unused-argument def dispatch(self, dash_app: Dash): # type: ignore[name-defined] Quart always async - async def _dispatch(): adapter = QuartRequestAdapter() body = await adapter.get_json() @@ -351,7 +348,9 @@ async def _dispatch(): return _dispatch - def register_callback_api_routes(self, callback_api_paths: _t.Dict[str, _t.Callable[..., _t.Any]]): + def register_callback_api_routes( + self, callback_api_paths: _t.Dict[str, _t.Callable[..., _t.Any]] + ): """ Register callback API endpoints on the Quart app. Each key in callback_api_paths is a route, each value is a handler (sync or async). From a855c6db89e167ca01d87e79fc394e4bfd58c280 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:26:05 -0400 Subject: [PATCH 064/166] fixing issues --- dash/backends/__init__.py | 6 +----- dash/backends/_fastapi.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index 940c8f18bd..e4d4141bb8 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -15,12 +15,8 @@ } -request_adapter: RequestAdapter -backend: BaseDashServer - - def get_backend( - name: Literal["flask", "fastapi", "quart"] | str + name: str ) -> tuple[BaseDashServer, RequestAdapter]: module_name, server_class, request_class = _backend_imports[name.lower()] try: diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 540238f727..be2308d5f5 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -52,7 +52,6 @@ def get_current_request() -> Request: class CurrentRequestMiddleware: def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] self.app = app - print("loaded CurrentRequestMiddleware") async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[name-defined] # non-http/ws scopes pass through (lifespan etc.) From 79afb0bab2d7d058cc8770b6de65eb7dcab656dd Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 11:54:41 -0400 Subject: [PATCH 065/166] fixing async validation --- .github/workflows/testing.yml | 4 ++-- dash/_validate.py | 1 + dash/backends/__init__.py | 5 +---- dash/backends/_quart.py | 1 + dash/dash.py | 6 ++---- 5 files changed, 7 insertions(+), 10 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index be5caf4929..c47e188222 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -39,8 +39,8 @@ jobs: - 'tests/async_tests/**' - 'requirements/**' backend_paths: - - 'dash/backend/**' - - 'tests/backend/**' + - 'dash/backends/**' + - 'tests/backend_tests/**' build: name: Build Dash Package diff --git a/dash/_validate.py b/dash/_validate.py index 76661cef6b..d595cba0fc 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -603,6 +603,7 @@ def check_async(use_async): raise Exception( "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" ) from exc + return use_async or False def check_backend(backend, inferred_backend): diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index e4d4141bb8..b845abb1ad 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -1,6 +1,5 @@ from .base_server import BaseDashServer, RequestAdapter -from typing import Literal, Any import importlib @@ -15,9 +14,7 @@ } -def get_backend( - name: str -) -> tuple[BaseDashServer, RequestAdapter]: +def get_backend(name: str) -> tuple[BaseDashServer, RequestAdapter]: module_name, server_class, request_class = _backend_imports[name.lower()] try: module = importlib.import_module(module_name) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index ff544c2c91..c5759026d4 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -26,6 +26,7 @@ from dash.fingerprint import check_fingerprint from dash import _validate from .base_server import BaseDashServer +from typing import Any class QuartDashServer(BaseDashServer): diff --git a/dash/dash.py b/dash/dash.py index 3ab830e8a3..52dd219627 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -416,7 +416,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches **obsolete, ): - _validate.check_async(use_async) + use_async = _validate.check_async(use_async) _validate.check_obsolete(obsolete) caller_name: str = name if name is not None else get_caller_name() @@ -1945,9 +1945,7 @@ def enable_dev_tools( self, dev_tools.prune_errors ) secret = gen_salt(20) - self.backend.register_prune_error_handler( - secret, dev_tools.prune_errors - ) + self.backend.register_prune_error_handler(secret, dev_tools.prune_errors) if debug and dev_tools.ui: self.backend.register_timing_hooks(first_run) From 77e22a3ca21ddc4f5d1a0f50964204f7dc92fb46 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 12:29:37 -0400 Subject: [PATCH 066/166] adjustments for request_adapter --- dash/backends/__init__.py | 2 +- dash/backends/_fastapi.py | 1 + dash/backends/_flask.py | 1 + dash/backends/_quart.py | 1 + dash/backends/base_server.py | 1 + dash/dash.py | 11 +++++------ 6 files changed, 10 insertions(+), 7 deletions(-) diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index b845abb1ad..c8ac9321d0 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -19,7 +19,7 @@ def get_backend(name: str) -> tuple[BaseDashServer, RequestAdapter]: try: module = importlib.import_module(module_name) server = getattr(module, server_class) - request_adapter = getattr(module, request_class) + request_adapter = server.request_adapter # type: ignore return server, request_adapter except KeyError as e: raise ValueError(f"Unknown backend: {name}") from e diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index be2308d5f5..3f50f96f57 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -88,6 +88,7 @@ def __init__(self, server: FastAPI): self.server_type = "fastapi" self.server: FastAPI = server self.error_handling_mode = "prune" + self.request_adapter = FastAPIRequestAdapter super().__init__() def __call__(self, *args: Any, **kwargs: Any): diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index 138234a4bc..b4bab46ff2 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -32,6 +32,7 @@ class FlaskDashServer(BaseDashServer): def __init__(self, server: Flask) -> None: self.server: Flask = server self.server_type = "flask" + self.request_adapter = FlaskRequestAdapter super().__init__() def __call__(self, *args: Any, **kwargs: Any): diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index c5759026d4..8e509a08e0 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -35,6 +35,7 @@ def __init__(self, server: Quart) -> None: self.server: Quart = server self.config = {} self.error_handling_mode = "prune" + self.request_adapter = QuartRequestAdapter super().__init__() def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 1c47548ad0..cf2b62c2e7 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -6,6 +6,7 @@ class BaseDashServer(ABC): server_type: str server: Any config: dict[str, Any] + request_adapter: Any def __call__(self, *args, **kwargs) -> Any: # Default: WSGI diff --git a/dash/dash.py b/dash/dash.py index 52dd219627..2d71766b4e 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -423,12 +423,11 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # Determine backend if backend is None: - backend_cls, request_cls = get_backend("flask") + backend_cls = get_backend("flask") elif isinstance(backend, str): - backend_cls, request_cls = get_backend(backend) + backend_cls = get_backend(backend) elif isinstance(backend, type): backend_cls = backend - _, request_cls = get_backend(backend.server_type) else: raise ValueError("Invalid backend argument") @@ -437,20 +436,20 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # User provided a server instance (e.g., Flask, Quart, FastAPI) inferred_backend = backends.get_server_type(server) _validate.check_backend(backend, inferred_backend) - backend_cls, request_cls = get_backend(inferred_backend) + backend_cls = get_backend(inferred_backend) if name is None: caller_name = getattr(server, "name", caller_name) self.backend = backend_cls(server) self.server = server backends.backend = self.backend # type: ignore - backends.request_adapter = request_cls + backends.request_adapter = self.backend.request_adapter # type: ignore else: # No server instance provided, create backend and let backend create server self.server = backend_cls.create_app(caller_name) # type: ignore self.backend = backend_cls(self.server) backends.backend = self.backend - backends.request_adapter = request_cls + backends.request_adapter = self.backend.request_adapter # type: ignore base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix From f7331d3f7e97dc52ebc1289d71e95111d49b3bb8 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 12:31:35 -0400 Subject: [PATCH 067/166] adding test for custom dash server --- tests/backend_tests/test_custom_backend.py | 243 +++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 tests/backend_tests/test_custom_backend.py diff --git a/tests/backend_tests/test_custom_backend.py b/tests/backend_tests/test_custom_backend.py new file mode 100644 index 0000000000..aa590599b5 --- /dev/null +++ b/tests/backend_tests/test_custom_backend.py @@ -0,0 +1,243 @@ +import pytest +from dash import Dash, Input, Output, html, dcc +from fastapi import FastAPI +import traceback +import re +from dash.backends._fastapi import FastAPIDashServer + + +class CustomDashServer(FastAPIDashServer): + def _get_traceback(self, _secret, error: Exception): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if self.error_handling_mode == "prune": + if not callback_handled: + if "callback invoked" in str(err) and "_callback.py" in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + # Parse traceback lines to group by file + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split("\n") + current_file = None + card_lines = [] + for line in lines[:-1]: # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + cards_html = "" + for filename, card in file_cards: + cards_html += ( + f""" +
+
{filename}
+
"""
+                + "\n".join(card)
+                + """
+
+ """ + ) + html = f""" + + + + {error_type}: {error_msg} // Custom Debugger + + + +
+

{error_type}: {error_msg}

+ {cards_html} +
+ + + """ + return html + + +@pytest.mark.parametrize( + "fixture,input_value", + [ + ("dash_duo", "Hello CustomBackend!"), + ], +) +def test_custom_backend_basic_callback(request, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("input", "value")) + def update_output(value): + return f"You typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"You typed: {input_value}") + dash_duo.clear_input(dash_duo.find_element("#input")) + dash_duo.find_element("#input").send_keys("CustomBackend Test") + dash_duo.wait_for_text_to_equal("#output", "You typed: CustomBackend Test") + assert dash_duo.get_logs() == [] + + +@pytest.mark.parametrize( + "fixture,start_server_kwargs", + [ + ("dash_duo", {"debug": True, "reload": False, "dev_tools_ui": True}), + ], +) +def test_custom_backend_error_handling(request, fixture, start_server_kwargs): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + +def get_error_html(dash_duo, index): + # error is in an iframe so is annoying to read out - get it from the store + return dash_duo.driver.execute_script( + "return store.getState().error.backEnd[{}].error.html;".format(index) + ) + + +@pytest.mark.parametrize( + "fixture,start_server_kwargs", + [ + ( + "dash_duo", + { + "debug": True, + "dev_tools_ui": True, + "dev_tools_prune_errors": False, + "reload": False, + }, + ), + ], +) +def test_custom_backend_error_handling_no_prune(request, fixture, start_server_kwargs): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "Custom Debugger" in error0 + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "_callback.py" in error0 + + +@pytest.mark.parametrize( + "fixture,start_server_kwargs, error_msg", + [ + ("dash_duo", {"debug": True, "reload": False}, "custombackend.py"), + ], +) +def test_custom_backend_error_handling_prune( + request, fixture, start_server_kwargs, error_msg +): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=CustomDashServer) + app.layout = html.Div( + [html.Button(id="btn", children="Error", n_clicks=0), html.Div(id="output")] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def error_callback(n): + if n and n > 0: + return 1 / 0 # Intentional error + return "No error" + + dash_duo.start_server(app, **start_server_kwargs) + dash_duo.wait_for_text_to_equal("#output", "No error") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal(dash_duo.devtools_error_count_locator, "1") + + error0 = get_error_html(dash_duo, 0) + assert "Custom Debugger" in error0 + assert "in error_callback" in error0 + assert "ZeroDivisionError" in error0 + assert "_callback.py" not in error0 + + +@pytest.mark.parametrize( + "fixture,input_value", + [ + ("dash_duo", "Background CustomBackend!"), + ], +) +def test_custom_backend_background_callback(request, fixture, input_value): + dash_duo = request.getfixturevalue(fixture) + import diskcache + + cache = diskcache.Cache("./cache") + from dash.background_callback import DiskcacheManager + + background_callback_manager = DiskcacheManager(cache) + + app = Dash( + __name__, + backend=CustomDashServer, + background_callback_manager=background_callback_manager, + ) + app.layout = html.Div( + [dcc.Input(id="input", value=input_value, type="text"), html.Div(id="output")] + ) + + @app.callback( + Output("output", "children"), Input("input", "value"), background=True + ) + def update_output_bg(value): + return f"Background typed: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#output", f"Background typed: {input_value}") + dash_duo.clear_input(dash_duo.find_element("#input")) + dash_duo.find_element("#input").send_keys("CustomBackend BG Test") + dash_duo.wait_for_text_to_equal( + "#output", "Background typed: CustomBackend BG Test" + ) + assert dash_duo.get_logs() == [] From 8b58cf4e10b4a2f6c04bf988e5a752fc123757e8 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 12:50:15 -0400 Subject: [PATCH 068/166] fixing issue with `request_adapter` --- dash/backends/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index c8ac9321d0..c264af4824 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -19,8 +19,7 @@ def get_backend(name: str) -> tuple[BaseDashServer, RequestAdapter]: try: module = importlib.import_module(module_name) server = getattr(module, server_class) - request_adapter = server.request_adapter # type: ignore - return server, request_adapter + return server except KeyError as e: raise ValueError(f"Unknown backend: {name}") from e except ImportError as e: From b7d4af2bc744a9f58a31a94b005acb969bf7d6b9 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 13:30:00 -0400 Subject: [PATCH 069/166] adjusting error handling for fastapi --- dash/backends/_fastapi.py | 33 +++++++++++++++++++-------------- dash/backends/_quart.py | 4 +++- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 3f50f96f57..8fc08b7a89 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -70,24 +70,29 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # CONFIG_PATH = "dash_config.json" -def save_config(config): - with open(CONFIG_PATH, "w") as f: - json.dump(config, f) +# Internal config helpers (local to this file) +_CONFIG_PATH = "dash_config.json" +def _save_config(config): + with open(_CONFIG_PATH, "w") as f: + json.dump(config, f) -def load_config(): - if os.path.exists(CONFIG_PATH): - with open(CONFIG_PATH, "r") as f: - return json.load(f) +def _load_config(): + try: + if os.path.exists(_CONFIG_PATH): + with open(_CONFIG_PATH, "r") as f: + return json.load(f) + except Exception: + pass # ignore errors return {} class FastAPIDashServer(BaseDashServer): def __init__(self, server: FastAPI): - self.config = {} + _save_config({"debug": False}) # ensure config file exists self.server_type = "fastapi" self.server: FastAPI = server - self.error_handling_mode = "prune" + self.error_handling_mode = "ignore" self.request_adapter = FastAPIRequestAdapter super().__init__() @@ -120,7 +125,7 @@ def register_assets_blueprint( pass def register_error_handlers(self): - self.error_handling_mode = "prune" + self.error_handling_mode = "ignore" def _get_traceback(self, _secret, error: Exception): tb = error.__traceback__ @@ -256,7 +261,7 @@ def setup_catchall(self, dash_app: Dash): @self.server.on_event("startup") def _setup_catchall(): dash_app.enable_dev_tools( - **load_config(), first_run=False + **_load_config(), first_run=False ) # do this to make sure dev tools are enabled async def catchall(request: Request): @@ -298,12 +303,12 @@ def after_request(self, func: Callable[[], Any] | None): def run(self, dash_app: Dash, host, port, debug, **kwargs): frame = inspect.stack()[2] config = dict( - {"debug": debug} if debug else {}, + {"debug": debug} if debug else {"debug": False}, **{ f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items() }, # pylint: disable=protected-access ) - save_config(config) + _save_config(config) if debug: if kwargs.get("reload") is None: kwargs["reload"] = True @@ -352,7 +357,7 @@ async def middleware(request, call_next): return Response(content=tb, media_type="text/html", status_code=500) return JSONResponse( status_code=500, - content={"error": "InternalServerError", "message": str(e.args[0])}, + content={"error": "InternalServerError", "message": "An internal server error occurred."}, ) return middleware diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 8e509a08e0..eae8df9117 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -34,7 +34,7 @@ def __init__(self, server: Quart) -> None: self.server_type = "quart" self.server: Quart = server self.config = {} - self.error_handling_mode = "prune" + self.error_handling_mode = "ignore" self.request_adapter = QuartRequestAdapter super().__init__() @@ -184,6 +184,8 @@ def register_prune_error_handler(self, secret, prune_errors): @self.server.errorhandler(Exception) async def _wrap_errors(error): + if self.error_handling_mode == "ignore": + return Response("Internal server error.", status=500, content_type="text/plain") tb = self._get_traceback(secret, error) return Response(tb, status=500, content_type="text/html") From 4cf4686f4f62aa93210b907bed7c360e7e883f1c Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 14:02:45 -0400 Subject: [PATCH 070/166] adjustments for handling issues with `debug` for `fastapi` --- dash/backends/_fastapi.py | 22 ++++++++++++++-------- dash_config.json | 1 - 2 files changed, 14 insertions(+), 9 deletions(-) delete mode 100644 dash_config.json diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 8fc08b7a89..36249fbf8c 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -66,30 +66,32 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # finally: reset_current_request(token) - -CONFIG_PATH = "dash_config.json" - - # Internal config helpers (local to this file) -_CONFIG_PATH = "dash_config.json" +_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "dash_config.json") def _save_config(config): with open(_CONFIG_PATH, "w") as f: json.dump(config, f) def _load_config(): + resp = {"debug": False} try: if os.path.exists(_CONFIG_PATH): with open(_CONFIG_PATH, "r") as f: - return json.load(f) + resp = json.load(f) except Exception: pass # ignore errors - return {} + return resp + +def _remove_config(): + try: + os.remove(_CONFIG_PATH) + except FileNotFoundError: + pass class FastAPIDashServer(BaseDashServer): def __init__(self, server: FastAPI): - _save_config({"debug": False}) # ensure config file exists self.server_type = "fastapi" self.server: FastAPI = server self.error_handling_mode = "ignore" @@ -258,6 +260,10 @@ async def index(request: Request): dash_app._add_url("", index, methods=["GET"]) def setup_catchall(self, dash_app: Dash): + @self.server.on_event("shutdown") + def cleanup_config(): + _remove_config() + @self.server.on_event("startup") def _setup_catchall(): dash_app.enable_dev_tools( diff --git a/dash_config.json b/dash_config.json deleted file mode 100644 index e4af4373cb..0000000000 --- a/dash_config.json +++ /dev/null @@ -1 +0,0 @@ -{"dev_tools_ui": false, "dev_tools_props_check": false, "dev_tools_serve_dev_bundles": false, "dev_tools_hot_reload": false, "dev_tools_silence_routes_logging": false, "dev_tools_prune_errors": false, "dev_tools_hot_reload_interval": 3.0, "dev_tools_hot_reload_watch_interval": 0.5, "dev_tools_hot_reload_max_retry": 8, "dev_tools_disable_version_check": true} \ No newline at end of file From dfe0ac7f106dd1ea300c36adb3faf922a665b406 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Wed, 17 Sep 2025 16:17:50 -0400 Subject: [PATCH 071/166] fixing for lint --- dash/backends/__init__.py | 15 +- dash/backends/_fastapi.py | 226 +++++++-------------- dash/backends/_flask.py | 10 +- dash/backends/_quart.py | 158 +++----------- dash/backends/_utils.py | 108 ++++++++++ dash/testing/application_runners.py | 8 +- tests/backend_tests/test_custom_backend.py | 7 +- 7 files changed, 229 insertions(+), 303 deletions(-) create mode 100644 dash/backends/_utils.py diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index c264af4824..e8b007a50b 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -1,21 +1,19 @@ -from .base_server import BaseDashServer, RequestAdapter - import importlib +from .base_server import BaseDashServer -request_adapter: RequestAdapter backend: BaseDashServer _backend_imports = { - "flask": ("dash.backends._flask", "FlaskDashServer", "FlaskRequestAdapter"), - "fastapi": ("dash.backends._fastapi", "FastAPIDashServer", "FastAPIRequestAdapter"), - "quart": ("dash.backends._quart", "QuartDashServer", "QuartRequestAdapter"), + "flask": ("dash.backends._flask", "FlaskDashServer"), + "fastapi": ("dash.backends._fastapi", "FastAPIDashServer"), + "quart": ("dash.backends._quart", "QuartDashServer"), } -def get_backend(name: str) -> tuple[BaseDashServer, RequestAdapter]: - module_name, server_class, request_class = _backend_imports[name.lower()] +def get_backend(name: str) -> BaseDashServer: + module_name, server_class = _backend_imports[name.lower()] try: module = importlib.import_module(module_name) server = getattr(module, server_class) @@ -74,7 +72,6 @@ def get_server_type(server): __all__ = [ "get_backend", - "request_adapter", "backend", "get_server_type", ] diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 36249fbf8c..dc7805501b 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from contextvars import copy_context, ContextVar from typing import TYPE_CHECKING, Any, Callable, Dict import sys @@ -8,27 +9,41 @@ import inspect import pkgutil import time -import traceback from importlib.util import spec_from_file_location import json import os -import re + +try: + from fastapi import FastAPI, Request, Response, Body + from fastapi.responses import JSONResponse + from fastapi.staticfiles import StaticFiles + from starlette.responses import Response as StarletteResponse + from starlette.datastructures import MutableHeaders + from starlette.types import ASGIApp, Scope, Receive, Send + import uvicorn +except ImportError: + FastAPI = None + Request = None + Response = None + Body = None + JSONResponse = None + StaticFiles = None + StarletteResponse = None + MutableHeaders = None + ASGIApp = None + Scope = None + Receive = None + Send = None + uvicorn = None from dash.fingerprint import check_fingerprint from dash import _validate from dash.exceptions import PreventUpdate from .base_server import BaseDashServer, RequestAdapter - -from fastapi import FastAPI, Request, Response, Body -from fastapi.responses import JSONResponse -from fastapi.staticfiles import StaticFiles -from starlette.responses import Response as StarletteResponse -from starlette.datastructures import MutableHeaders -from starlette.types import ASGIApp, Scope, Receive, Send -import uvicorn +from ._utils import format_traceback_html if TYPE_CHECKING: # pragma: no cover - typing only - from dash.dash import Dash + from dash import Dash _current_request_var = ContextVar("dash_current_request", default=None) @@ -49,7 +64,7 @@ def get_current_request() -> Request: return req -class CurrentRequestMiddleware: +class CurrentRequestMiddleware: # pylint: disable=too-few-public-methods def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] self.app = app @@ -66,23 +81,27 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # finally: reset_current_request(token) + # Internal config helpers (local to this file) _CONFIG_PATH = os.path.join(os.path.dirname(__file__), "dash_config.json") + def _save_config(config): - with open(_CONFIG_PATH, "w") as f: + with open(_CONFIG_PATH, "w", encoding="utf-8") as f: json.dump(config, f) + def _load_config(): resp = {"debug": False} try: if os.path.exists(_CONFIG_PATH): - with open(_CONFIG_PATH, "r") as f: + with open(_CONFIG_PATH, "r", encoding="utf-8") as f: resp = json.load(f) - except Exception: + except (json.JSONDecodeError, OSError): pass # ignore errors return resp + def _remove_config(): try: os.remove(_CONFIG_PATH) @@ -130,113 +149,9 @@ def register_error_handlers(self): self.error_handling_mode = "ignore" def _get_traceback(self, _secret, error: Exception): - tb = error.__traceback__ - errors = traceback.format_exception(type(error), error, tb) - pass_errs = [] - callback_handled = False - for err in errors: - if self.error_handling_mode == "prune": - if not callback_handled: - if "callback invoked" in str(err) and "_callback.py" in str(err): - callback_handled = True - continue - pass_errs.append(err) - formatted_tb = "".join(pass_errs) - error_type = type(error).__name__ - error_msg = str(error) - - # Parse traceback lines to group by file - file_cards = [] - pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') - lines = formatted_tb.split("\n") - current_file = None - card_lines = [] - - for line in lines[:-1]: # Skip the last line (error message) - match = pattern.match(line) - if match: - if current_file and card_lines: - file_cards.append((current_file, card_lines)) - current_file = ( - f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" - ) - card_lines = [line] - elif current_file: - card_lines.append(line) - if current_file and card_lines: - file_cards.append((current_file, card_lines)) - - cards_html = "" - for filename, card in file_cards: - cards_html += ( - f""" -
-
{filename}
-
"""
-                + "\n".join(card)
-                + """
-
- """ - ) - - html = f""" - - - - {error_type}: {error_msg} // FastAPI Debugger - - - -
-

{error_type}

-
-

{error_type}: {error_msg}

-
-

Traceback (most recent call last)

- {cards_html} -
{error_type}: {error_msg}
-
-

This is the Copy/Paste friendly version of the traceback.

- -
-
- The debugger caught an exception in your ASGI application. You can now - look at the traceback which led to the error. -
- -
- - - """ - return html + return format_traceback_html( + error, self.error_handling_mode, "FastAPI Debugger", "FastAPI" + ) def register_prune_error_handler(self, _secret, prune_errors): if prune_errors: @@ -253,7 +168,7 @@ async def wrapped(*_args, **_kwargs): return wrapped def setup_index(self, dash_app: Dash): - async def index(request: Request): + async def index(_request: Request): return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access @@ -270,7 +185,7 @@ def _setup_catchall(): **_load_config(), first_run=False ) # do this to make sure dev tools are enabled - async def catchall(request: Request): + async def catchall(_request: Request): return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access @@ -308,11 +223,10 @@ def after_request(self, func: Callable[[], Any] | None): def run(self, dash_app: Dash, host, port, debug, **kwargs): frame = inspect.stack()[2] + dev_tools = dash_app._dev_tools # pylint: disable=protected-access config = dict( {"debug": debug} if debug else {"debug": False}, - **{ - f"dev_tools_{k}": v for k, v in dash_app._dev_tools.items() - }, # pylint: disable=protected-access + **{f"dev_tools_{k}": v for k, v in dev_tools.items()}, ) _save_config(config) if debug: @@ -348,7 +262,7 @@ def make_response( def jsonify(self, obj: Any): return JSONResponse(content=obj) - def _make_before_middleware(self, func: Callable[[], Any] | None): + def _make_before_middleware(self, _func: Callable[[], Any] | None): async def middleware(request, call_next): try: response = await call_next(request) @@ -356,14 +270,18 @@ async def middleware(request, call_next): except PreventUpdate: # No content, nothing to update return Response(status_code=204) - except Exception as e: + except (Exception) as e: # pylint: disable=broad-except + # Handle exceptions based on error_handling_mode if self.error_handling_mode in ["raise", "prune"]: # Prune the traceback to remove internal Dash calls tb = self._get_traceback(None, e) return Response(content=tb, media_type="text/html", status_code=500) return JSONResponse( status_code=500, - content={"error": "InternalServerError", "message": "An internal server error occurred."}, + content={ + "error": "InternalServerError", + "message": "An internal server error occurred.", + }, ) return middleware @@ -417,27 +335,25 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): dash_app, package_name, fingerprinted_path, request ) - # pylint: disable=protected-access - dash_app._add_url( - "_dash-component-suites/{package_name}/{fingerprinted_path:path}", - serve, - ) + name = "_dash-component-suites/{package_name}/{fingerprinted_path:path}" + dash_app._add_url(name, serve) # pylint: disable=protected-access - # pylint: disable=unused-argument def dispatch(self, dash_app: Dash): async def _dispatch(request: Request): # pylint: disable=protected-access body = await request.json() - g = dash_app._initialize_context(body) # pylint: disable=protected-access + cb_ctx = dash_app._initialize_context( + body + ) # pylint: disable=protected-access func = dash_app._prepare_callback( - g, body + cb_ctx, body ) # pylint: disable=protected-access args = dash_app._inputs_to_vals( - g.inputs_list + g.states_list + cb_ctx.inputs_list + cb_ctx.states_list ) # pylint: disable=protected-access ctx = copy_context() partial_func = dash_app._execute_callback( - func, args, g.outputs_list, g + func, args, cb_ctx.outputs_list, cb_ctx ) # pylint: disable=protected-access response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): @@ -494,20 +410,24 @@ def register_callback_api_routes( sig = inspect.signature(handler) param_names = list(sig.parameters.keys()) - async def view_func(request: Request, body: dict = Body(...)): - # Only pass expected params; ignore extras - kwargs = { - k: v for k, v in body.items() if k in param_names and v is not None - } - if inspect.iscoroutinefunction(handler): - result = await handler(**kwargs) - else: - result = handler(**kwargs) - return JSONResponse(content=result) + def make_view_func(handler, param_names): + async def view_func(_request: Request, body: dict = Body(...)): + kwargs = { + k: v + for k, v in body.items() + if k in param_names and v is not None + } + if inspect.iscoroutinefunction(handler): + result = await handler(**kwargs) + else: + result = handler(**kwargs) + return JSONResponse(content=result) + + return view_func self.server.add_api_route( route, - view_func, + make_view_func(handler, param_names), methods=methods, name=endpoint, include_in_schema=True, @@ -566,5 +486,5 @@ def origin(self): def path(self): return self._request.url.path - async def get_json(self): # async method retained - return await self._request.json() + def get_json(self): + return asyncio.run(self._request.json()) diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index b4bab46ff2..d9abe9c7ed 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -17,6 +17,7 @@ jsonify, g as flask_g, ) +from werkzeug.debug import tbtools from dash.fingerprint import check_fingerprint from dash import _validate @@ -67,13 +68,6 @@ def _invalid_resources_handler(err): return err.args[0], 404 def _get_traceback(self, secret, error: Exception): - try: - from werkzeug.debug import ( - tbtools, - ) # pylint: disable=import-outside-toplevel - except ImportError: - tbtools = None - def _get_skip(error): tb = error.__traceback__ skip = 1 @@ -238,7 +232,7 @@ async def _dispatch_async(): cb_ctx.dash_response.set_data(response_data) return cb_ctx.dash_response - if dash_app._use_async: + if dash_app._use_async: # pylint: disable=protected-access return _dispatch_async return _dispatch diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index eae8df9117..f417bc0d2e 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -1,32 +1,36 @@ from __future__ import annotations from contextvars import copy_context import typing as _t -import traceback import mimetypes import inspect import pkgutil import time import sys -import re +from typing import Any # Attempt top-level Quart imports; allow absence if user not using quart backend -from quart import ( - Quart, - Response, - jsonify, - request, - Blueprint, - g, -) - -if _t.TYPE_CHECKING: - from dash import Dash +try: + from quart import ( + Quart, + Response, + jsonify, + request, + Blueprint, + g, + ) +except ImportError: + Quart = None + Response = None + jsonify = None + request = None + Blueprint = None + g = None from dash.exceptions import PreventUpdate, InvalidResourceError from dash.fingerprint import check_fingerprint -from dash import _validate +from dash import _validate, Dash from .base_server import BaseDashServer -from typing import Any +from ._utils import format_traceback_html class QuartDashServer(BaseDashServer): @@ -68,113 +72,9 @@ def register_assets_blueprint( self.server.register_blueprint(bp) def _get_traceback(self, _secret, error: Exception): - tb = error.__traceback__ - errors = traceback.format_exception(type(error), error, tb) - pass_errs = [] - callback_handled = False - for err in errors: - if self.error_handling_mode == "prune": - if not callback_handled: - if "callback invoked" in str(err) and "_callback.py" in str(err): - callback_handled = True - continue - pass_errs.append(err) - formatted_tb = "".join(pass_errs) - error_type = type(error).__name__ - error_msg = str(error) - - # Parse traceback lines to group by file - file_cards = [] - pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') - lines = formatted_tb.split("\n") - current_file = None - card_lines = [] - - for line in lines[:-1]: # Skip the last line (error message) - match = pattern.match(line) - if match: - if current_file and card_lines: - file_cards.append((current_file, card_lines)) - current_file = ( - f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" - ) - card_lines = [line] - elif current_file: - card_lines.append(line) - if current_file and card_lines: - file_cards.append((current_file, card_lines)) - - cards_html = "" - for filename, card in file_cards: - cards_html += ( - f""" -
-
{filename}
-
"""
-                + "\n".join(card)
-                + """
-
- """ - ) - - html = f""" - - - - {error_type}: {error_msg} // Quart Debugger - - - -
-

{error_type}

-
-

{error_type}: {error_msg}

-
-

Traceback (most recent call last)

- {cards_html} -
{error_type}: {error_msg}
-
-

This is the Copy/Paste friendly version of the traceback.

- -
-
- The debugger caught an exception in your ASGI application. You can now - look at the traceback which led to the error. -
- -
- - - """ - return html + return format_traceback_html( + error, self.error_handling_mode, "Quart Debugger", "Quart" + ) def register_prune_error_handler(self, secret, prune_errors): if prune_errors: @@ -185,7 +85,9 @@ def register_prune_error_handler(self, secret, prune_errors): @self.server.errorhandler(Exception) async def _wrap_errors(error): if self.error_handling_mode == "ignore": - return Response("Internal server error.", status=500, content_type="text/plain") + return Response( + "Internal server error.", status=500, content_type="text/plain" + ) tb = self._get_traceback(secret, error) return Response(tb, status=500, content_type="text/html") @@ -337,14 +239,16 @@ async def _dispatch(): adapter = QuartRequestAdapter() body = await adapter.get_json() # pylint: disable=protected-access - g = dash_app._initialize_context(body) + cb_ctx = dash_app._initialize_context(body) # pylint: disable=protected-access - func = dash_app._prepare_callback(g, body) + func = dash_app._prepare_callback(cb_ctx, body) # pylint: disable=protected-access - args = dash_app._inputs_to_vals(g.inputs_list + g.states_list) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) ctx = copy_context() # pylint: disable=protected-access - partial_func = dash_app._execute_callback(func, args, g.outputs_list, g) + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async response_data = await response_data diff --git a/dash/backends/_utils.py b/dash/backends/_utils.py new file mode 100644 index 0000000000..0a5f4b0e76 --- /dev/null +++ b/dash/backends/_utils.py @@ -0,0 +1,108 @@ +import traceback +import re + + +def format_traceback_html(error, error_handling_mode, title, backend): + tb = error.__traceback__ + errors = traceback.format_exception(type(error), error, tb) + pass_errs = [] + callback_handled = False + for err in errors: + if error_handling_mode == "prune": + if not callback_handled: + if "callback invoked" in str(err) and "_callback.py" in str(err): + callback_handled = True + continue + pass_errs.append(err) + formatted_tb = "".join(pass_errs) + error_type = type(error).__name__ + error_msg = str(error) + # Parse traceback lines to group by file + file_cards = [] + pattern = re.compile(r' File "(.+)", line (\d+), in (\w+)') + lines = formatted_tb.split("\n") + current_file = None + card_lines = [] + for line in lines[:-1]: # Skip the last line (error message) + match = pattern.match(line) + if match: + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + current_file = ( + f"{match.group(1)} (line {match.group(2)}, in {match.group(3)})" + ) + card_lines = [line] + elif current_file: + card_lines.append(line) + if current_file and card_lines: + file_cards.append((current_file, card_lines)) + cards_html = "" + for filename, card in file_cards: + cards_html += ( + f""" +
+
{filename}
+
"""
+            + "\n".join(card)
+            + """
+
+ """ + ) + html = f""" + + + + {error_type}: {error_msg} // {title} + + + +
+

{error_type}

+
+

{error_type}: {error_msg}

+
+

Traceback (most recent call last)

+ {cards_html} +
{error_type}: {error_msg}
+
+

This is the Copy/Paste friendly version of the traceback.

+ +
+
+ The debugger caught an exception in your ASGI application. You can now + look at the traceback which led to the error. +
+
+ Brought to you by DON'T PANIC, your + friendly {backend} powered traceback interpreter. +
+
+ + + """ + return html diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py index 2956f1a4c0..6e6cc8b810 100644 --- a/dash/testing/application_runners.py +++ b/dash/testing/application_runners.py @@ -173,9 +173,9 @@ def run(): try: module = app.server.__class__.__module__ # FastAPI support - if not module.startswith("flask"): + if module.startswith("fastapi"): app.run(**options) - # Dash/Flask fallback + # Dash/Flask/Quart fallback else: app.run(threaded=True, **options) except SystemExit: @@ -237,9 +237,9 @@ def target(): try: module = app.server.__class__.__module__ # FastAPI support - if not module.startswith("flask"): + if module.startswith("fastapi"): app.run(**options) - # Dash/Flask fallback + # Dash/Flask/Quart fallback else: app.run(threaded=True, **options) except SystemExit: diff --git a/tests/backend_tests/test_custom_backend.py b/tests/backend_tests/test_custom_backend.py index aa590599b5..befff7734b 100644 --- a/tests/backend_tests/test_custom_backend.py +++ b/tests/backend_tests/test_custom_backend.py @@ -1,9 +1,12 @@ import pytest from dash import Dash, Input, Output, html, dcc -from fastapi import FastAPI import traceback import re -from dash.backends._fastapi import FastAPIDashServer + +try: + from dash.backends._fastapi import FastAPIDashServer +except ImportError: + FastAPIDashServer = None class CustomDashServer(FastAPIDashServer): From cd02cc5cbefefb7db331db625deee14897bf88b2 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 18 Sep 2025 14:05:28 -0400 Subject: [PATCH 072/166] adjustment for delayed config --- dash/dash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/dash.py b/dash/dash.py index 2d71766b4e..6b022108b6 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -579,7 +579,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches # tracks internally if a function already handled at least one request. self._got_first_request = {"pages": False, "setup_server": False} - if self.server is not None: + if server: self.init_app() self.logger.setLevel(logging.INFO) From 16b3c9e08743918b4e7af26cf43186955af1494f Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 18 Sep 2025 14:27:20 -0400 Subject: [PATCH 073/166] fix typing error --- dash/backends/base_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index cf2b62c2e7..2b11bc763b 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict class BaseDashServer(ABC): server_type: str server: Any - config: dict[str, Any] + config: Dict[str, Any] request_adapter: Any def __call__(self, *args, **kwargs) -> Any: From 493d1503630f3d87f5be95cd4a8dc641a84adad0 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Mon, 22 Sep 2025 08:30:16 -0400 Subject: [PATCH 074/166] fixes for pages --- dash/_pages.py | 13 +- dash/backends/_fastapi.py | 41 +++++-- dash/backends/_flask.py | 14 +++ dash/backends/_quart.py | 14 +++ dash/dash.py | 251 ++++++++++++++++++-------------------- 5 files changed, 181 insertions(+), 152 deletions(-) diff --git a/dash/_pages.py b/dash/_pages.py index 19a797bcf2..0a5f9d8c06 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -150,22 +150,13 @@ def _parse_path_variables(pathname, path_template): return dict(zip(var_names, variables)) -def _create_redirect_function(redirect_to): - def redirect(): - return flask.redirect(redirect_to, code=301) - - return redirect - - def _set_redirect(redirect_from, path): app = get_app() if redirect_from and len(redirect_from): for redirect in redirect_from: fullname = app.get_relative_path(redirect) - app.server.add_url_rule( - fullname, - fullname, - _create_redirect_function(app.get_relative_path(path)), + app.backend.add_redirect_rule( + app, fullname, app.get_relative_path(path) ) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index dc7805501b..61b2d65a8f 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -15,7 +15,7 @@ try: from fastapi import FastAPI, Request, Response, Body - from fastapi.responses import JSONResponse + from fastapi.responses import JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from starlette.responses import Response as StarletteResponse from starlette.datastructures import MutableHeaders @@ -27,6 +27,7 @@ Response = None Body = None JSONResponse = None + RedirectResponse = None StaticFiles = None StarletteResponse = None MutableHeaders = None @@ -115,6 +116,7 @@ def __init__(self, server: FastAPI): self.server: FastAPI = server self.error_handling_mode = "ignore" self.request_adapter = FastAPIRequestAdapter + self._before_request_funcs = [] super().__init__() def __call__(self, *args: Any, **kwargs: Any): @@ -213,9 +215,13 @@ def add_url_rule( ) def before_request(self, func: Callable[[], Any] | None): - # FastAPI does not have before_request, but we can use middleware - self.server.add_middleware(CurrentRequestMiddleware) - self.server.middleware("http")(self._make_before_middleware(func)) + if func is not None: + self._before_request_funcs.append(func) + # Only add the middleware once + if not hasattr(self, "_before_middleware_added"): + self.server.add_middleware(CurrentRequestMiddleware) + self.server.middleware("http")(self._make_before_middleware()) + self._before_middleware_added = True def after_request(self, func: Callable[[], Any] | None): # FastAPI does not have after_request, but we can use middleware @@ -262,18 +268,20 @@ def make_response( def jsonify(self, obj: Any): return JSONResponse(content=obj) - def _make_before_middleware(self, _func: Callable[[], Any] | None): + def _make_before_middleware(self): async def middleware(request, call_next): + for func in self._before_request_funcs: + if inspect.iscoroutinefunction(func): + await func() + else: + func() try: response = await call_next(request) return response except PreventUpdate: - # No content, nothing to update return Response(status_code=204) - except (Exception) as e: # pylint: disable=broad-except - # Handle exceptions based on error_handling_mode + except Exception as e: if self.error_handling_mode in ["raise", "prune"]: - # Prune the traceback to remove internal Dash calls tb = self._get_traceback(None, e) return Response(content=tb, media_type="text/html", status_code=500) return JSONResponse( @@ -338,6 +346,21 @@ async def serve(request: Request, package_name: str, fingerprinted_path: str): name = "_dash-component-suites/{package_name}/{fingerprinted_path:path}" dash_app._add_url(name, serve) # pylint: disable=protected-access + def _create_redirect_function(self, redirect_to): + def _redirect(): + return RedirectResponse(url=redirect_to, status_code=301) + + return _redirect + + def add_redirect_rule(self, app, fullname, path): + self.server.add_api_route( + fullname, + self._create_redirect_function(app.get_relative_path(path)), + methods=["GET"], + name=fullname, + include_in_schema=False, + ) + def dispatch(self, dash_app: Dash): async def _dispatch(request: Request): # pylint: disable=protected-access diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index d9abe9c7ed..2f7e08acf5 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -16,6 +16,7 @@ request, jsonify, g as flask_g, + redirect, ) from werkzeug.debug import tbtools @@ -194,6 +195,19 @@ def serve(package_name, fingerprinted_path): serve, ) + def _create_redirect_function(self, redirect_to): + def _redirect(): + return redirect(redirect_to, code=301) + + return _redirect + + def add_redirect_rule(self, app, fullname, path): + self.server.add_url_rule( + fullname, + fullname, + self._create_redirect_function(app.get_relative_path(path)), + ) + # pylint: disable=unused-argument def dispatch(self, dash_app: Dash): def _dispatch(): diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index f417bc0d2e..c08a165234 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -17,6 +17,7 @@ request, Blueprint, g, + redirect, ) except ImportError: Quart = None @@ -233,6 +234,19 @@ async def serve(package_name, fingerprinted_path): serve, ) + def _create_redirect_function(self, redirect_to): + def _redirect(): + return redirect(redirect_to, code=301) + + return _redirect + + def add_redirect_rule(self, app, fullname, path): + self.server.add_url_rule( + fullname, + fullname, + self._create_redirect_function(app.get_relative_path(path)), + ) + # pylint: disable=unused-argument def dispatch(self, dash_app: Dash): # type: ignore[name-defined] Quart always async async def _dispatch(): diff --git a/dash/dash.py b/dash/dash.py index 6b022108b6..53818cb5fb 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -2241,7 +2241,8 @@ def enable_pages(self) -> None: if self.pages_folder: _import_layouts_from_pages(self.config.pages_folder) - def router(): + # Async version + async def router_async(): if self._got_first_request["pages"]: return self._got_first_request["pages"] = True @@ -2250,157 +2251,143 @@ def router(): "pathname_": Input(_ID_LOCATION, "pathname"), "search_": Input(_ID_LOCATION, "search"), } - inputs.update(self.routing_callback_inputs) # type: ignore[reportCallIssue] + inputs.update(self.routing_callback_inputs) - if self._use_async: + @self.callback( + Output(_ID_CONTENT, "children"), + Output(_ID_STORE, "data"), + inputs=inputs, + prevent_initial_call=True, + ) + async def update(pathname_, search_, **states): + query_parameters = _parse_query_string(search_) + page, path_variables = _path_to_page(self.strip_relative_path(pathname_)) + if page == {}: + for module, page in _pages.PAGE_REGISTRY.items(): + if module.split(".")[-1] == "not_found_404": + layout = page["layout"] + title = page["title"] + break + else: + layout = html.H1("404 - Page not found") + title = self.title + else: + layout = page.get("layout", "") + title = page["title"] - @self.callback( - Output(_ID_CONTENT, "children"), - Output(_ID_STORE, "data"), - inputs=inputs, - prevent_initial_call=True, - ) - async def update(pathname_, search_, **states): - """ - Updates dash.page_container layout on page navigation. - Updates the stored page title which will trigger the clientside callback to update the app title - """ - - query_parameters = _parse_query_string(search_) - page, path_variables = _path_to_page( - self.strip_relative_path(pathname_) + if callable(layout): + layout = await execute_async_function( + layout, + **{**(path_variables or {}), **query_parameters, **states}, ) + if callable(title): + title = await execute_async_function( + title, **{**(path_variables or {})} + ) + return layout, {"title": title} - # get layout - if page == {}: - for module, page in _pages.PAGE_REGISTRY.items(): - if module.split(".")[-1] == "not_found_404": - layout = page["layout"] - title = page["title"] - break - else: - layout = html.H1("404 - Page not found") - title = self.title - else: - layout = page.get("layout", "") - title = page["title"] - - if callable(layout): - layout = await execute_async_function( - layout, - **{**(path_variables or {}), **query_parameters, **states}, - ) - if callable(title): - title = await execute_async_function( - title, **{**(path_variables or {})} - ) + _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) + _validate.validate_registry(_pages.PAGE_REGISTRY) - return layout, {"title": title} + if not self.config.suppress_callback_exceptions: + async def get_layouts(): + return [ + await execute_async_function(page["layout"]) + if callable(page["layout"]) else page["layout"] + for page in _pages.PAGE_REGISTRY.values() + ] + layouts = await get_layouts() + layouts += [ + self.layout() if callable(self.layout) else self.layout + ] + self.validation_layout = html.Div(layouts) + if _ID_CONTENT not in self.validation_layout: + raise Exception("`dash.page_container` not found in the layout") - _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) - _validate.validate_registry(_pages.PAGE_REGISTRY) + self.clientside_callback( + """ + function(data) { + document.title = data.title + } + """, + Output(_ID_DUMMY, "children"), + Input(_ID_STORE, "data"), + ) - # Set validation_layout - if not self.config.suppress_callback_exceptions: - self.validation_layout = html.Div( - [ - ( - asyncio.run(execute_async_function(page["layout"])) - if callable(page["layout"]) - else page["layout"] - ) - for page in _pages.PAGE_REGISTRY.values() - ] - + [ - # pylint: disable=not-callable - self.layout() - if callable(self.layout) - else self.layout - ] - ) - if _ID_CONTENT not in self.validation_layout: - raise Exception("`dash.page_container` not found in the layout") - else: + # Sync version + def router_sync(): + if self._got_first_request["pages"]: + return + self._got_first_request["pages"] = True - @self.callback( - Output(_ID_CONTENT, "children"), - Output(_ID_STORE, "data"), - inputs=inputs, - prevent_initial_call=True, - ) - def update(pathname_, search_, **states): - """ - Updates dash.page_container layout on page navigation. - Updates the stored page title which will trigger the clientside callback to update the app title - """ - - query_parameters = _parse_query_string(search_) - page, path_variables = _path_to_page( - self.strip_relative_path(pathname_) - ) + inputs = { + "pathname_": Input(_ID_LOCATION, "pathname"), + "search_": Input(_ID_LOCATION, "search"), + } + inputs.update(self.routing_callback_inputs) - # get layout - if page == {}: - for module, page in _pages.PAGE_REGISTRY.items(): - if module.split(".")[-1] == "not_found_404": - layout = page["layout"] - title = page["title"] - break - else: - layout = html.H1("404 - Page not found") - title = self.title + @self.callback( + Output(_ID_CONTENT, "children"), + Output(_ID_STORE, "data"), + inputs=inputs, + prevent_initial_call=True, + ) + def update(pathname_, search_, **states): + query_parameters = _parse_query_string(search_) + page, path_variables = _path_to_page(self.strip_relative_path(pathname_)) + if page == {}: + for module, page in _pages.PAGE_REGISTRY.items(): + if module.split(".")[-1] == "not_found_404": + layout = page["layout"] + title = page["title"] + break else: - layout = page.get("layout", "") - title = page["title"] + layout = html.H1("404 - Page not found") + title = self.title + else: + layout = page.get("layout", "") + title = page["title"] - if callable(layout): - layout = layout( - **{**(path_variables or {}), **query_parameters, **states} - ) - if callable(title): - title = title(**(path_variables or {})) - - return layout, {"title": title} - - _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) - _validate.validate_registry(_pages.PAGE_REGISTRY) - - # Set validation_layout - if not self.config.suppress_callback_exceptions: - layout = self.layout - if not isinstance(layout, list): - layout = [ - # pylint: disable=not-callable - self.layout() - if callable(self.layout) - else self.layout - ] - self.validation_layout = html.Div( - [ - ( - page["layout"]() - if callable(page["layout"]) - else page["layout"] - ) - for page in _pages.PAGE_REGISTRY.values() - ] - + layout - ) - if _ID_CONTENT not in self.validation_layout: - raise Exception("`dash.page_container` not found in the layout") + if callable(layout): + layout = layout( + **{**(path_variables or {}), **query_parameters, **states} + ) + if callable(title): + title = title(**(path_variables or {})) + return layout, {"title": title} + + _validate.check_for_duplicate_pathnames(_pages.PAGE_REGISTRY) + _validate.validate_registry(_pages.PAGE_REGISTRY) + + if not self.config.suppress_callback_exceptions: + layout = self.layout + if not isinstance(layout, list): + layout = [ + self.layout() if callable(self.layout) else self.layout + ] + self.validation_layout = html.Div( + [ + page["layout"]() if callable(page["layout"]) else page["layout"] + for page in _pages.PAGE_REGISTRY.values() + ] + layout + ) + if _ID_CONTENT not in self.validation_layout: + raise Exception("`dash.page_container` not found in the layout") - # Update the page title on page navigation self.clientside_callback( """ - function(data) {{ + function(data) { document.title = data.title - }} + } """, Output(_ID_DUMMY, "children"), Input(_ID_STORE, "data"), ) - self.backend.before_request(router) + if self._use_async: + self.backend.before_request(router_async) + else: + self.backend.before_request(router_sync) def __call__(self, *args, **kwargs): return self.backend.__call__(*args, **kwargs) From a10f86b46e04e369b1fd859266b177f574333b4a Mon Sep 17 00:00:00 2001 From: chgiesse <83552131+chgiesse@users.noreply.github.com> Date: Wed, 12 Nov 2025 12:14:29 +0100 Subject: [PATCH 075/166] Decouple flask (#7) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ∙ Added has_request_context to base_server ∙ removed flask specific import in _validate and use backends.backend.has_request_context now * ∙ added context to request adapter ∙ callback context uses request adapter context now ∙ * added get_root_path to dash _utils removed flask.helpers.get_root_path usage * moved compress to server implementations flask fully decoupled from dash * fixed compress in quart and fastapi * Fixed server.config in fastapi to use config file * Update dash/backends/_fastapi.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * removed unused flask import in pages * Update dash/backends/_quart.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * removed flask specific return type to remove global dependency --------- Co-authored-by: Christian Giessel Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- dash/_callback_context.py | 11 +-- dash/_configs.py | 5 +- dash/_hooks.py | 3 +- dash/_pages.py | 10 +-- dash/_utils.py | 59 ++++++++++++++++ dash/_validate.py | 3 +- dash/backends/_fastapi.py | 24 +++++++ dash/backends/_flask.py | 31 +++++++++ dash/backends/_quart.py | 44 ++++++++++-- dash/backends/base_server.py | 131 +++++++++++++++++++---------------- dash/dash.py | 21 +----- 11 files changed, 241 insertions(+), 101 deletions(-) diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 72b92e09e2..58c5c9bbd5 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -4,9 +4,8 @@ import contextvars import typing -import flask - from . import exceptions +from . import backends from ._utils import AttributeDict, stringify_id @@ -220,14 +219,15 @@ def record_timing(name, duration, description=None): :param description: A description of the resource. :type description: string or None """ - timing_information = getattr(flask.g, "timing_information", {}) + request = backends.backend.request_adapter() + timing_information = getattr(request.context, "timing_information", {}) if name in timing_information: raise KeyError(f'Duplicate resource name "{name}" found.') timing_information[name] = {"dur": round(duration * 1000), "desc": description} - setattr(flask.g, "timing_information", timing_information) + setattr(request.context, "timing_information", timing_information) @property @has_context @@ -250,7 +250,8 @@ def using_outputs_grouping(self): @property @has_context def timing_information(self): - return getattr(flask.g, "timing_information", {}) + request = backends.backend.request_adapter() + return getattr(request.context, "timing_information", {}) @has_context def set_props(self, component_id: typing.Union[str, dict], props: dict): diff --git a/dash/_configs.py b/dash/_configs.py index edbf7b50d1..c4ff8d59e6 100644 --- a/dash/_configs.py +++ b/dash/_configs.py @@ -1,6 +1,5 @@ +from ._utils import get_root_path import os -import flask - # noinspection PyCompatibility from . import exceptions from ._utils import AttributeDict @@ -127,7 +126,7 @@ def pages_folder_config(name, pages_folder, use_pages): if not pages_folder: return None is_custom_folder = str(pages_folder) != "pages" - pages_folder_path = os.path.join(flask.helpers.get_root_path(name), pages_folder) + pages_folder_path = os.path.join(get_root_path(name), pages_folder) if (use_pages or is_custom_folder) and not os.path.isdir(pages_folder_path): error_msg = f""" A folder called `{pages_folder}` does not exist. If a folder for pages is not diff --git a/dash/_hooks.py b/dash/_hooks.py index 98e5cf1ecd..4276c9f9f5 100644 --- a/dash/_hooks.py +++ b/dash/_hooks.py @@ -3,7 +3,6 @@ from importlib import metadata as _importlib_metadata import typing_extensions as _tx -import flask as _f from .exceptions import HookError from .resources import ResourceType @@ -125,7 +124,7 @@ def route( Add a route to the Dash server. """ - def wrap(func: _t.Callable[[], _f.Response]): + def wrap(func: _t.Callable[[], _t.Any]): _name = name or func.__name__ self.add_hook( "routes", diff --git a/dash/_pages.py b/dash/_pages.py index 0a5f9d8c06..be9d847309 100644 --- a/dash/_pages.py +++ b/dash/_pages.py @@ -9,13 +9,11 @@ from pathlib import Path from urllib.parse import parse_qs -import flask - from . import _validate from ._callback_context import context_value from ._get_app import get_app from ._get_paths import get_relative_path -from ._utils import AttributeDict +from ._utils import AttributeDict, get_root_path CONFIG = AttributeDict() PAGE_REGISTRY = collections.OrderedDict() @@ -98,7 +96,7 @@ def _path_to_module_name(path): def _infer_module_name(page_path): relative_path = page_path.split(CONFIG.pages_folder)[-1] module = _path_to_module_name(relative_path) - proj_root = flask.helpers.get_root_path(CONFIG.name) + proj_root = get_root_path(CONFIG.name) if CONFIG.pages_folder.startswith(proj_root): parent_path = CONFIG.pages_folder[len(proj_root) :] else: @@ -155,9 +153,7 @@ def _set_redirect(redirect_from, path): if redirect_from and len(redirect_from): for redirect in redirect_from: fullname = app.get_relative_path(redirect) - app.backend.add_redirect_rule( - app, fullname, app.get_relative_path(path) - ) + app.backend.add_redirect_rule(app, fullname, app.get_relative_path(path)) def register_page( diff --git a/dash/_utils.py b/dash/_utils.py index ef6c63c281..df090d252c 100644 --- a/dash/_utils.py +++ b/dash/_utils.py @@ -3,6 +3,7 @@ import sys import uuid import hashlib +import importlib from collections import abc import subprocess import logging @@ -12,6 +13,7 @@ import string import inspect import re +import os from html import escape from functools import wraps @@ -322,3 +324,60 @@ def pascal_case(name: Union[str, None]): return s[0].upper() + re.sub( r"[\-_\.]+([a-z])", lambda match: match.group(1).upper(), s[1:] ) + + +def get_root_path(import_name: str) -> str: + """Find the root path of a package, or the path that contains a + module. If it cannot be found, returns the current working + directory. + + Not to be confused with the value returned by :func:`find_package`. + + :meta private: + """ + # Module already imported and has a file attribute. Use that first. + mod = sys.modules.get(import_name) + + if mod is not None and hasattr(mod, "__file__") and mod.__file__ is not None: + return os.path.dirname(os.path.abspath(mod.__file__)) + + # Next attempt: check the loader. + try: + spec = importlib.util.find_spec(import_name) + + if spec is None: + raise ValueError + except (ImportError, ValueError): + loader = None + else: + loader = spec.loader + + # Loader does not exist or we're referring to an unloaded main + # module or a main module without path (interactive sessions), go + # with the current working directory. + if loader is None: + return os.getcwd() + + if hasattr(loader, "get_filename"): + filepath = loader.get_filename(import_name) # pyright: ignore + else: + # Fall back to imports. + __import__(import_name) + mod = sys.modules[import_name] + filepath = getattr(mod, "__file__", None) + + # If we don't have a file path it might be because it is a + # namespace package. In this case pick the root path from the + # first module that is contained in the package. + if filepath is None: + raise RuntimeError( + "No root path can be found for the provided module" + f" {import_name!r}. This can happen because the module" + " came from an import hook that does not provide file" + " name information or because it's a namespace package." + " In this case the root path needs to be explicitly" + " provided." + ) + + # filepath is import_name.py for a module, or __init__.py for a package. + return os.path.dirname(os.path.abspath(filepath)) # type: ignore[no-any-return] diff --git a/dash/_validate.py b/dash/_validate.py index d595cba0fc..8511b37150 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -3,7 +3,6 @@ import re from textwrap import dedent from keyword import iskeyword -import flask from ._grouping import grouping_len, map_grouping from ._no_update import NoUpdate @@ -511,7 +510,7 @@ def validate_use_pages(config): "`dash.register_page()` must be called after app instantiation" ) - if flask.has_request_context(): + if backends.backend.has_request_context(): raise exceptions.PageError( """ dash.register_page() can’t be called within a callback as it updates dash.page_registry, which is a global variable. diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 61b2d65a8f..3de76f7b42 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -227,6 +227,13 @@ def after_request(self, func: Callable[[], Any] | None): # FastAPI does not have after_request, but we can use middleware self.server.middleware("http")(self._make_after_middleware(func)) + def has_request_context(self) -> bool: + try: + get_current_request() + return True + except RuntimeError: + return False + def run(self, dash_app: Dash, host, port, debug, **kwargs): frame = inspect.stack()[2] dev_tools = dash_app._dev_tools # pylint: disable=protected-access @@ -456,6 +463,16 @@ async def view_func(_request: Request, body: dict = Body(...)): include_in_schema=True, ) + def enable_compression(self) -> None: + from fastapi.middleware.gzip import GZipMiddleware + + self.server.add_middleware(GZipMiddleware, minimum_size=500) + config = _load_config() + if "COMPRESS_ALGORITHM" not in config: + config["COMPRESS_ALGORITHM"] = ["gzip"] + + _save_config(config) + class FastAPIRequestAdapter(RequestAdapter): def __init__(self): @@ -466,6 +483,13 @@ def __call__(self): self._request = get_current_request() return self + @property + def context(self): + if self._request is None: + raise RuntimeError("No active request in context") + + return self._request.state + @property def root(self): return str(self._request.base_url) diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index 2f7e08acf5..75e16371f5 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextvars import copy_context +from importlib_metadata import version as _get_distribution_version from typing import TYPE_CHECKING, Any, Callable, Dict import asyncio import pkgutil @@ -16,6 +17,7 @@ request, jsonify, g as flask_g, + has_request_context, redirect, ) from werkzeug.debug import tbtools @@ -24,8 +26,10 @@ from dash import _validate from dash.exceptions import PreventUpdate, InvalidResourceError from dash._callback import _invoke_callback, _async_invoke_callback +from dash._utils import parse_version from .base_server import BaseDashServer, RequestAdapter + if TYPE_CHECKING: # pragma: no cover - typing only from dash import Dash @@ -128,6 +132,9 @@ def after_request(self, func: Callable[[Any], Any]): # Flask after_request expects a function(response) -> response self.server.after_request(func) + def has_request_context(self) -> bool: + return has_request_context() + def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: Any): self.server.run(host=host, port=port, debug=debug, **kwargs) @@ -318,6 +325,24 @@ def _sync_view_func(*args, handler=handler, **kwargs): route, endpoint=endpoint, view_func=view_func, methods=methods ) + def enable_compression(self) -> None: + try: + import flask_compress # pylint: disable=import-outside-toplevel + + Compress = flask_compress.Compress + Compress(self.server) + _flask_compress_version = parse_version( + _get_distribution_version("flask_compress") + ) + if not hasattr( + self.server.config, "COMPRESS_ALGORITHM" + ) and _flask_compress_version >= parse_version("1.6.0"): + self.server.config["COMPRESS_ALGORITHM"] = ["gzip"] + except ImportError as error: + raise ImportError( + "To use the compress option, you need to install dash[compress]" + ) from error + class FlaskRequestAdapter(RequestAdapter): """Flask implementation using property-based accessors.""" @@ -330,6 +355,12 @@ def __init__(self) -> None: def __call__(self, *args: Any, **kwds: Any): return self + @property + def context(self): + if not has_request_context(): + raise RuntimeError("No active request in context") + return flask_g + @property def args(self): return self._request.args diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index c08a165234..11af4bc169 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -1,4 +1,5 @@ from __future__ import annotations +from importlib_metadata import version as _get_distribution_version from contextvars import copy_context import typing as _t import mimetypes @@ -16,7 +17,8 @@ jsonify, request, Blueprint, - g, + g as quart_g, + has_request_context, redirect, ) except ImportError: @@ -25,10 +27,11 @@ jsonify = None request = None Blueprint = None - g = None + quart_g = None from dash.exceptions import PreventUpdate, InvalidResourceError from dash.fingerprint import check_fingerprint +from dash._utils import parse_version from dash import _validate, Dash from .base_server import BaseDashServer from ._utils import format_traceback_html @@ -95,15 +98,17 @@ async def _wrap_errors(error): def register_timing_hooks(self, _first_run: bool): # type: ignore[name-defined] parity with Flask factory @self.server.before_request async def _before_request(): # pragma: no cover - timing infra - if g is not None: - g.timing_information = { # type: ignore[attr-defined] + if quart_g is not None: + quart_g.timing_information = { # type: ignore[attr-defined] "__dash_server": {"dur": time.time(), "desc": None} } @self.server.after_request async def _after_request(response): # pragma: no cover - timing infra timing_information = ( - getattr(g, "timing_information", None) if g is not None else None + getattr(quart_g, "timing_information", None) + if quart_g is not None + else None ) if timing_information is None: return response @@ -181,6 +186,11 @@ async def _after(response): await result return response + def has_request_context(self) -> bool: + if has_request_context is None: + raise RuntimeError("Quart not installed; cannot check request context") + return has_request_context() + def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): self.config = {"debug": debug, **kwargs} if debug else kwargs self.server.run(host=host, port=port, debug=debug, **kwargs) @@ -318,6 +328,24 @@ def _serve_default_favicon(self): pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" ) + def enable_compression(self) -> None: + try: + import quart_compress # pylint: disable=import-outside-toplevel + + Compress = quart_compress.Compress + Compress(self.server) + _flask_compress_version = parse_version( + _get_distribution_version("quart_compress") + ) + if not hasattr( + self.server.config, "COMPRESS_ALGORITHM" + ) and _flask_compress_version >= parse_version("1.6.0"): + self.server.config["COMPRESS_ALGORITHM"] = ["gzip"] + except ImportError as error: + raise ImportError( + "To use the compress option, you need to install quart_compress." + ) from error + class QuartRequestAdapter: def __init__(self) -> None: @@ -325,6 +353,12 @@ def __init__(self) -> None: if self._request is None: raise RuntimeError("Quart not installed; cannot access request context") + @property + def context(self): + if not has_request_context(): + raise RuntimeError("No active request in context") + return quart_g + @property def request(self) -> _t.Any: return self._request diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 2b11bc763b..f571669bee 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -2,68 +2,15 @@ from typing import Any, Dict -class BaseDashServer(ABC): - server_type: str - server: Any - config: Dict[str, Any] - request_adapter: Any - - def __call__(self, *args, **kwargs) -> Any: - # Default: WSGI - return self.server(*args, **kwargs) - - @staticmethod - @abstractmethod - def create_app( - name: str = "__main__", config=None - ) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def register_assets_blueprint( - self, blueprint_name: str, assets_url_path: str, assets_folder: str - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def register_error_handlers(self) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def add_url_rule( - self, rule: str, view_func, endpoint=None, methods=None - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def before_request(self, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def after_request(self, func) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def run( - self, dash_app, host: str, port: int, debug: bool, **kwargs - ) -> None: # pragma: no cover - interface - pass - - @abstractmethod - def make_response( - self, data, mimetype=None, content_type=None - ) -> Any: # pragma: no cover - interface - pass - - @abstractmethod - def jsonify(self, obj) -> Any: # pragma: no cover - interface - pass - - class RequestAdapter(ABC): - def __call__(self) -> Any: + def __call__(self) -> "RequestAdapter": return self + @property + @abstractmethod + def context(self) -> Any: # pragma: no cover - interface + raise NotImplementedError() + # Properties to be implemented in concrete adapters @property # pragma: no cover - interface @abstractmethod @@ -118,3 +65,69 @@ def origin(self): @abstractmethod def path(self) -> str: raise NotImplementedError() + + +class BaseDashServer(ABC): + server_type: str + server: Any + config: Dict[str, Any] + request_adapter: RequestAdapter + + def __call__(self, *args, **kwargs) -> Any: + # Default: WSGI + return self.server(*args, **kwargs) + + @staticmethod + @abstractmethod + def create_app( + name: str = "__main__", config=None + ) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def register_assets_blueprint( + self, blueprint_name: str, assets_url_path: str, assets_folder: str + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def register_error_handlers(self) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def add_url_rule( + self, rule: str, view_func, endpoint=None, methods=None + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def before_request(self, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def after_request(self, func) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def has_request_context(self) -> bool: # pragma: no cover - interface + pass + + @abstractmethod + def run( + self, dash_app, host: str, port: int, debug: bool, **kwargs + ) -> None: # pragma: no cover - interface + pass + + @abstractmethod + def make_response( + self, data, mimetype=None, content_type=None + ) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def jsonify(self, obj) -> Any: # pragma: no cover - interface + pass + + @abstractmethod + def enable_compression(self) -> None: # pragma: no cover - interface + pass diff --git a/dash/dash.py b/dash/dash.py index acb557b977..f662e75efb 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -18,7 +18,6 @@ from typing import Any, Callable, Dict, Optional, Union, Sequence, Literal, List import asyncio -import flask from importlib_metadata import version as _get_distribution_version @@ -55,6 +54,7 @@ hooks_to_js_object, parse_version, get_caller_name, + get_root_path, ) from . import _callback from . import _get_paths @@ -463,7 +463,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self.config = AttributeDict( name=caller_name, assets_folder=os.path.join( - flask.helpers.get_root_path(caller_name), assets_folder + get_root_path(caller_name), assets_folder ), # type: ignore assets_url_path=assets_url_path, assets_ignore=assets_ignore, @@ -659,22 +659,7 @@ def init_app(self, app: Optional[Any] = None, **kwargs) -> None: self.config.assets_folder, ) if config.compress: - try: - import flask_compress # pylint: disable=import-outside-toplevel - - Compress = flask_compress.Compress - Compress(self.server) - _flask_compress_version = parse_version( - _get_distribution_version("flask_compress") - ) - if not hasattr( - self.server.config, "COMPRESS_ALGORITHM" - ) and _flask_compress_version >= parse_version("1.6.0"): - self.server.config["COMPRESS_ALGORITHM"] = ["gzip"] - except ImportError as error: - raise ImportError( - "To use the compress option, you need to install dash[compress]" - ) from error + self.backend.enable_compression() # type: ignore self.backend.register_error_handlers() self.backend.before_request(self._setup_server) From 3e22e930cc55ceda6a26a1c2e8705b2f7a4d70f2 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 24 Jan 2026 06:12:26 -0500 Subject: [PATCH 076/166] Update dash/backends/_utils.py Co-authored-by: Philippe Duval --- dash/backends/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/backends/_utils.py b/dash/backends/_utils.py index 0a5f4b0e76..1191d21038 100644 --- a/dash/backends/_utils.py +++ b/dash/backends/_utils.py @@ -94,7 +94,7 @@ def format_traceback_html(error, error_handling_mode, title, backend):
- The debugger caught an exception in your ASGI application. You can now + The debugger caught an exception in your Dash application. You can now look at the traceback which led to the error.
From f5dfc07896fb359ab366a669b2164602433fd24e Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 24 Jan 2026 06:25:26 -0500 Subject: [PATCH 077/166] updates for formatting --- dash/_callback.py | 3 +-- dash/_configs.py | 1 + dash/dash.py | 32 +++++++++++++++++++------------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/dash/_callback.py b/dash/_callback.py index 3f3904cf18..d203f72f5b 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -182,7 +182,6 @@ def callback( background_spec: Any = None - config_prevent_initial_callbacks = _kwargs.pop( "config_prevent_initial_callbacks", False ) @@ -694,7 +693,7 @@ def add_context(*args, **kwargs): args, kwargs, inputs_state_indices, has_output, insert_output ) - response: dict = {"multi": True} # type: ignore + response: dict = {"multi": True} # type: ignore jsonResponse: Optional[str] = None try: if background is not None: diff --git a/dash/_configs.py b/dash/_configs.py index c4ff8d59e6..225c604842 100644 --- a/dash/_configs.py +++ b/dash/_configs.py @@ -1,5 +1,6 @@ from ._utils import get_root_path import os + # noinspection PyCompatibility from . import exceptions from ._utils import AttributeDict diff --git a/dash/dash.py b/dash/dash.py index a65545f48d..7e1a972d09 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -211,6 +211,7 @@ def _do_skip(error): return "".join(traceback.format_exception(type(error), error, _do_skip(error))) + # Singleton signal to not update an output, alternative to PreventUpdate no_update = _callback.NoUpdate() # pylint: disable=protected-access @@ -723,7 +724,6 @@ def _handle_error(_): """Handle a halted callback and return an empty 204 response.""" return "", 204 - # To-Do add error handlers for these two scenarios # add handler for halted callbacks # self.backend.before_request(_handle_error) @@ -2395,10 +2395,13 @@ async def router_async(): Output(_ID_STORE, "data"), inputs=inputs, prevent_initial_call=True, - hidden=True,) + hidden=True, + ) async def update(pathname_, search_, **states): query_parameters = _parse_query_string(search_) - page, path_variables = _path_to_page(self.strip_relative_path(pathname_)) + page, path_variables = _path_to_page( + self.strip_relative_path(pathname_) + ) if page == {}: for module, page in _pages.PAGE_REGISTRY.items(): if module.split(".")[-1] == "not_found_404": @@ -2427,16 +2430,17 @@ async def update(pathname_, search_, **states): _validate.validate_registry(_pages.PAGE_REGISTRY) if not self.config.suppress_callback_exceptions: + async def get_layouts(): return [ await execute_async_function(page["layout"]) - if callable(page["layout"]) else page["layout"] + if callable(page["layout"]) + else page["layout"] for page in _pages.PAGE_REGISTRY.values() ] + layouts = await get_layouts() - layouts += [ - self.layout() if callable(self.layout) else self.layout - ] + layouts += [self.layout() if callable(self.layout) else self.layout] self.validation_layout = html.Div(layouts) if _ID_CONTENT not in self.validation_layout: raise Exception("`dash.page_container` not found in the layout") @@ -2468,10 +2472,13 @@ def router_sync(): Output(_ID_STORE, "data"), inputs=inputs, prevent_initial_call=True, - hidden=True,) + hidden=True, + ) def update(pathname_, search_, **states): query_parameters = _parse_query_string(search_) - page, path_variables = _path_to_page(self.strip_relative_path(pathname_)) + page, path_variables = _path_to_page( + self.strip_relative_path(pathname_) + ) if page == {}: for module, page in _pages.PAGE_REGISTRY.items(): if module.split(".")[-1] == "not_found_404": @@ -2499,14 +2506,13 @@ def update(pathname_, search_, **states): if not self.config.suppress_callback_exceptions: layout = self.layout if not isinstance(layout, list): - layout = [ - self.layout() if callable(self.layout) else self.layout - ] + layout = [self.layout() if callable(self.layout) else self.layout] self.validation_layout = html.Div( [ page["layout"]() if callable(page["layout"]) else page["layout"] for page in _pages.PAGE_REGISTRY.values() - ] + layout + ] + + layout ) if _ID_CONTENT not in self.validation_layout: raise Exception("`dash.page_container` not found in the layout") From 907aba294827cdd9851ad8ad8471e1d72c5900ac Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 24 Jan 2026 06:34:11 -0500 Subject: [PATCH 078/166] adjusting favicon to be consistent between all and using `make_response` instead of splitting it out --- dash/backends/_fastapi.py | 5 ----- dash/backends/_flask.py | 5 ----- dash/backends/_quart.py | 7 ------- dash/dash.py | 7 ++++++- 4 files changed, 6 insertions(+), 18 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 3de76f7b42..1240772b2d 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -393,11 +393,6 @@ async def _dispatch(request: Request): return _dispatch - def _serve_default_favicon(self): - return Response( - content=pkgutil.get_data("dash", "favicon.ico"), media_type="image/x-icon" - ) - def register_timing_hooks(self, first_run: bool): if not first_run: return diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index 75e16371f5..d33819fd62 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -257,11 +257,6 @@ async def _dispatch_async(): return _dispatch_async return _dispatch - def _serve_default_favicon(self): - return Response( - pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" - ) - def register_timing_hooks(self, _first_run: bool): # Define timing hooks inside method scope and register them def _before_request() -> None: diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 11af4bc169..6d61ee741b 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -321,13 +321,6 @@ async def sync_view_func(*args, **kwargs): route, endpoint=endpoint, view_func=view_func, methods=methods ) - def _serve_default_favicon(self): - if Response is None: - raise RuntimeError("Quart not installed; cannot generate Response") - return Response( - pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" - ) - def enable_compression(self) -> None: try: import quart_compress # pylint: disable=import-outside-toplevel diff --git a/dash/dash.py b/dash/dash.py index 7e1a972d09..181bdd0030 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -747,6 +747,11 @@ def _add_url(self, name: str, view_func: RouteCallable, methods=("GET",)) -> Non ) self.routes.append(full_name) + def _serve_default_favicon(self): + return self.backend.make_response( + pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon" + ) + def _setup_routes(self): self.backend.setup_component_suites(self) self._add_url("_dash-layout", self.serve_layout) @@ -759,7 +764,7 @@ def _setup_routes(self): self._add_url("_reload-hash", self.serve_reload_hash) self._add_url( "_favicon.ico", - self.backend._serve_default_favicon, # pylint: disable=protected-access + self._serve_default_favicon, # pylint: disable=protected-access ) if self.config.health_endpoint is not None: self._add_url(self.config.health_endpoint, self.serve_health) From 9852bf61fbbe40e35dee0d6bec933a9c3c8df474 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 24 Jan 2026 06:38:34 -0500 Subject: [PATCH 079/166] renaming `dispatch` to `serve_callback` --- dash/backends/_fastapi.py | 2 +- dash/backends/_flask.py | 2 +- dash/backends/_quart.py | 2 +- dash/dash.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 1240772b2d..33705e85b9 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -368,7 +368,7 @@ def add_redirect_rule(self, app, fullname, path): include_in_schema=False, ) - def dispatch(self, dash_app: Dash): + def serve_callback(self, dash_app: Dash): async def _dispatch(request: Request): # pylint: disable=protected-access body = await request.json() diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index d33819fd62..396c020e6b 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -216,7 +216,7 @@ def add_redirect_rule(self, app, fullname, path): ) # pylint: disable=unused-argument - def dispatch(self, dash_app: Dash): + def serve_callback(self, dash_app: Dash): def _dispatch(): body = request.get_json() # pylint: disable=protected-access diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 6d61ee741b..72c0caeecc 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -258,7 +258,7 @@ def add_redirect_rule(self, app, fullname, path): ) # pylint: disable=unused-argument - def dispatch(self, dash_app: Dash): # type: ignore[name-defined] Quart always async + def serve_callback(self, dash_app: Dash): # type: ignore[name-defined] Quart always async async def _dispatch(): adapter = QuartRequestAdapter() body = await adapter.get_json() diff --git a/dash/dash.py b/dash/dash.py index 181bdd0030..676803c6e1 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -758,7 +758,7 @@ def _setup_routes(self): self._add_url("_dash-dependencies", self.dependencies) self._add_url( "_dash-update-component", - self.backend.dispatch(self), + self.backend.serve_callback(self), ["POST"], ) self._add_url("_reload-hash", self.serve_reload_hash) From d961899475f06d4cde3ba5d885e9cc91aa26bbc3 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 24 Jan 2026 06:41:02 -0500 Subject: [PATCH 080/166] adding a config description for `debug=True` --- dash/backends/_fastapi.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 33705e85b9..62b8db8f6a 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -84,6 +84,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # # Internal config helpers (local to this file) +# This is used to persist dev tools config between reloads, since uvicorn runs a new process _CONFIG_PATH = os.path.join(os.path.dirname(__file__), "dash_config.json") From 8f038756d98c523246db1a14aaac738bce2e8b87 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 24 Jan 2026 07:08:13 -0500 Subject: [PATCH 081/166] fastapi requires `server` to be defined --- dash/backends/_fastapi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 62b8db8f6a..d88a54bed7 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -249,10 +249,10 @@ def run(self, dash_app: Dash, host, port, debug, **kwargs): if kwargs.get("reload"): # Dynamically determine the module name from the file path file_path = frame.filename - spec = spec_from_file_location("app", file_path) - module_name = spec.name if spec and getattr(spec, "name", None) else "app" + rel_path = os.path.relpath(file_path, os.getcwd()) + module_name = os.path.splitext(rel_path)[0].replace(os.sep, ".") uvicorn.run( - f"{module_name}:app.server", + f"{module_name}:server", host=host, port=port, **kwargs, From c87e1a767a47ffb14a13a0b947febc020410ab5a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 24 Jan 2026 07:21:35 -0500 Subject: [PATCH 082/166] changing import errors --- dash/backends/_fastapi.py | 17 +++-------------- dash/backends/_quart.py | 9 +++------ 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index d88a54bed7..aa5b75768e 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -22,20 +22,9 @@ from starlette.types import ASGIApp, Scope, Receive, Send import uvicorn except ImportError: - FastAPI = None - Request = None - Response = None - Body = None - JSONResponse = None - RedirectResponse = None - StaticFiles = None - StarletteResponse = None - MutableHeaders = None - ASGIApp = None - Scope = None - Receive = None - Send = None - uvicorn = None + raise ImportError( + "All dependencies not installed. Please install it with `dash[fastapi]` to use the FastAPI backend." + ) from None from dash.fingerprint import check_fingerprint from dash import _validate diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 72c0caeecc..911d228e21 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -22,12 +22,9 @@ redirect, ) except ImportError: - Quart = None - Response = None - jsonify = None - request = None - Blueprint = None - quart_g = None + raise ImportError( + "All dependencies not installed. Please install it with `dash[quart]` to use the Quart backend." + ) from None from dash.exceptions import PreventUpdate, InvalidResourceError from dash.fingerprint import check_fingerprint From cad3cbaab2c5ede4511e8e0ea300f1af562ce12c Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Sat, 24 Jan 2026 07:46:57 -0500 Subject: [PATCH 083/166] making the path name for the config file more dynamic in order to not replace --- dash/backends/_fastapi.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index aa5b75768e..eae11e8b16 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -71,31 +71,25 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # finally: reset_current_request(token) - -# Internal config helpers (local to this file) -# This is used to persist dev tools config between reloads, since uvicorn runs a new process -_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "dash_config.json") - - -def _save_config(config): - with open(_CONFIG_PATH, "w", encoding="utf-8") as f: +def _save_config(config_path, config): + with open(config_path, "w", encoding="utf-8") as f: json.dump(config, f) -def _load_config(): +def _load_config(config_path): resp = {"debug": False} try: - if os.path.exists(_CONFIG_PATH): - with open(_CONFIG_PATH, "r", encoding="utf-8") as f: + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: resp = json.load(f) except (json.JSONDecodeError, OSError): pass # ignore errors return resp -def _remove_config(): +def _remove_config(config_path): try: - os.remove(_CONFIG_PATH) + os.remove(config_path) except FileNotFoundError: pass @@ -107,6 +101,14 @@ def __init__(self, server: FastAPI): self.error_handling_mode = "ignore" self.request_adapter = FastAPIRequestAdapter self._before_request_funcs = [] + + fname = inspect.stack()[2] + file_path = fname.filename + rel_path = os.path.relpath(file_path, os.getcwd()) + module_name = os.path.splitext(rel_path)[0].replace(os.sep, ".") + # Internal config helpers (local to this file) + # This is used to persist dev tools config between reloads, since uvicorn runs a new process + self._CONFIG_PATH = os.path.join(os.path.dirname(file_path), f"_{module_name}_dash_config.json") super().__init__() def __call__(self, *args: Any, **kwargs: Any): @@ -169,12 +171,12 @@ async def index(_request: Request): def setup_catchall(self, dash_app: Dash): @self.server.on_event("shutdown") def cleanup_config(): - _remove_config() + _remove_config(self._CONFIG_PATH) @self.server.on_event("startup") def _setup_catchall(): dash_app.enable_dev_tools( - **_load_config(), first_run=False + **_load_config(self._CONFIG_PATH), first_run=False ) # do this to make sure dev tools are enabled async def catchall(_request: Request): @@ -231,7 +233,7 @@ def run(self, dash_app: Dash, host, port, debug, **kwargs): {"debug": debug} if debug else {"debug": False}, **{f"dev_tools_{k}": v for k, v in dev_tools.items()}, ) - _save_config(config) + _save_config(self._CONFIG_PATH, config) if debug: if kwargs.get("reload") is None: kwargs["reload"] = True From ee1e6116757ca1441be17946b557397f75e339b8 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Mon, 26 Jan 2026 10:11:41 -0500 Subject: [PATCH 084/166] Update dash/backends/_quart.py Co-authored-by: Philippe Duval --- dash/backends/_quart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 911d228e21..5abda7e30e 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -21,10 +21,10 @@ has_request_context, redirect, ) -except ImportError: +except ImportError as _err: raise ImportError( "All dependencies not installed. Please install it with `dash[quart]` to use the Quart backend." - ) from None + ) from _err from dash.exceptions import PreventUpdate, InvalidResourceError from dash.fingerprint import check_fingerprint From 476d1dd4bb3562c22dbaf733fd0b4311b7743be5 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Mon, 26 Jan 2026 11:31:57 -0500 Subject: [PATCH 085/166] adjustments for fastapi to use a controllable subprocess vs running with uvicorn, this allows for unloading the config file properly --- dash/backends/_fastapi.py | 66 ++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 25 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index eae11e8b16..c21cf67473 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -12,6 +12,7 @@ from importlib.util import spec_from_file_location import json import os +import subprocess try: from fastapi import FastAPI, Request, Response, Body @@ -71,6 +72,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # finally: reset_current_request(token) + def _save_config(config_path, config): with open(config_path, "w", encoding="utf-8") as f: json.dump(config, f) @@ -104,11 +106,16 @@ def __init__(self, server: FastAPI): fname = inspect.stack()[2] file_path = fname.filename - rel_path = os.path.relpath(file_path, os.getcwd()) - module_name = os.path.splitext(rel_path)[0].replace(os.sep, ".") - # Internal config helpers (local to this file) - # This is used to persist dev tools config between reloads, since uvicorn runs a new process - self._CONFIG_PATH = os.path.join(os.path.dirname(file_path), f"_{module_name}_dash_config.json") + + # Manually build the config directory path + home_dir = os.path.expanduser("~") + config_dir = os.path.join(home_dir, ".local", "share", "plotly-dash-configs") + os.makedirs(config_dir, exist_ok=True) + + # Hash the file path for a unique config filename + hash_digest = hashlib.sha256(file_path.encode("utf-8")).hexdigest() + config_filename = f"{hash_digest}.json" + self._CONFIG_PATH = os.path.join(config_dir, config_filename) super().__init__() def __call__(self, *args: Any, **kwargs: Any): @@ -117,6 +124,9 @@ def __call__(self, *args: Any, **kwargs: Any): return self.server(*args, **kwargs) raise TypeError("FastAPI app must be called with (scope, receive, send)") + def _cleanup_config(self): + _remove_config(self._CONFIG_PATH) + @staticmethod def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): app = FastAPI() @@ -169,10 +179,6 @@ async def index(_request: Request): dash_app._add_url("", index, methods=["GET"]) def setup_catchall(self, dash_app: Dash): - @self.server.on_event("shutdown") - def cleanup_config(): - _remove_config(self._CONFIG_PATH) - @self.server.on_event("startup") def _setup_catchall(): dash_app.enable_dev_tools( @@ -228,28 +234,38 @@ def has_request_context(self) -> bool: def run(self, dash_app: Dash, host, port, debug, **kwargs): frame = inspect.stack()[2] - dev_tools = dash_app._dev_tools # pylint: disable=protected-access + dev_tools = dash_app._dev_tools config = dict( {"debug": debug} if debug else {"debug": False}, **{f"dev_tools_{k}": v for k, v in dev_tools.items()}, ) _save_config(self._CONFIG_PATH, config) - if debug: - if kwargs.get("reload") is None: - kwargs["reload"] = True + if debug and kwargs.get("reload") is None: + kwargs["reload"] = True + + file_path = frame.filename + rel_path = os.path.relpath(file_path, os.getcwd()) + module_name = os.path.splitext(rel_path)[0].replace(os.sep, ".") + uvicorn_args = [ + sys.executable, + "-m", + "uvicorn", + f"{module_name}:server", + "--host", + str(host), + "--port", + str(port), + ] if kwargs.get("reload"): - # Dynamically determine the module name from the file path - file_path = frame.filename - rel_path = os.path.relpath(file_path, os.getcwd()) - module_name = os.path.splitext(rel_path)[0].replace(os.sep, ".") - uvicorn.run( - f"{module_name}:server", - host=host, - port=port, - **kwargs, - ) - else: - uvicorn.run(self.server, host=host, port=port, **kwargs) + uvicorn_args.append("--reload") + + # Add any other kwargs as CLI args if needed + + try: + proc = subprocess.Popen(uvicorn_args, env=os.environ.copy()) + proc.wait() + finally: + self._cleanup_config() def make_response( self, From 1d26d72c75900ce615a464337d508b630fb6cb47 Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 26 Jan 2026 11:44:39 -0500 Subject: [PATCH 086/166] fix dcc build --- components/dash-core-components/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/dash-core-components/package.json b/components/dash-core-components/package.json index 063fe6d661..fc4d8a967f 100644 --- a/components/dash-core-components/package.json +++ b/components/dash-core-components/package.json @@ -101,6 +101,6 @@ "react-dom": "16 - 19" }, "browserslist": [ - "last 9 years and not dead" + "last 10 years and not dead" ] } From b3a47bd58719887e09435ce841ae4b83f46f0b80 Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 5 Feb 2026 11:14:56 -0500 Subject: [PATCH 087/166] fix request_adapter calls --- dash/_callback.py | 22 ++++++++-------------- dash/_validate.py | 4 ++-- dash/backends/__init__.py | 4 ++-- dash/dash.py | 9 ++++----- 4 files changed, 16 insertions(+), 23 deletions(-) diff --git a/dash/_callback.py b/dash/_callback.py index 3684247110..6895623a74 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,15 +1,9 @@ -from typing import Callable, Optional, Any, List, Tuple, Union +from typing import Callable, Optional, Any, List, Tuple, Union, Dict from functools import wraps import collections import hashlib import inspect -from functools import wraps - -from typing import Callable, Optional, Any, List, Tuple, Union, Dict - -import asyncio - from .dependencies import ( handle_callback_args, handle_grouped_callback_args, @@ -363,7 +357,7 @@ def _initialize_context(args, kwargs, inputs_state_indices, has_output, insert_o def _get_callback_manager( kwargs: dict, background: dict -) -> Union[BaseBackgroundCallbackManager, None]: +) -> BaseBackgroundCallbackManager: """Set up the background callback and manage jobs.""" callback_manager = background.get( "manager", kwargs.get("background_callback_manager", None) @@ -379,7 +373,7 @@ def _get_callback_manager( " and store results on redis.\n" ) - adapter = backends.request_adapter() + adapter = backends.backend.request_adapter() old_job = adapter.args.getlist("oldJob") if hasattr(adapter.args, "getlist") else [] if old_job: @@ -439,7 +433,7 @@ def _setup_background_callback( def _progress_background_callback(response, callback_manager, background): progress_outputs = background.get("progress") - adapter = backends.request_adapter() + adapter = backends.backend.request_adapter() cache_key = adapter.args.get("cacheKey") if progress_outputs: @@ -457,7 +451,7 @@ def _update_background_callback( """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) - adapter = backends.request_adapter() + adapter = backends.backend.request_adapter() cache_key = adapter.args.get("cacheKey") if adapter else None job_id = adapter.args.get("job") if adapter else None @@ -479,7 +473,7 @@ def _handle_rest_background_callback( multi, has_update=False, ): - adapter = backends.request_adapter() + adapter = backends.backend.request_adapter() cache_key = adapter.args.get("cacheKey") if adapter else None job_id = adapter.args.get("job") if adapter else None # Must get job_running after get_result since get_results terminates it. @@ -697,7 +691,7 @@ def add_context(*args, **kwargs): jsonResponse: Optional[str] = None try: if background is not None: - adapter = backends.request_adapter() + adapter = backends.backend.request_adapter() if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, @@ -769,7 +763,7 @@ async def async_add_context(*args, **kwargs): try: if background is not None: - adapter = backends.request_adapter() + adapter = backends.backend.request_adapter() if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, diff --git a/dash/_validate.py b/dash/_validate.py index ceddb74030..bb76f896e1 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -590,14 +590,14 @@ def _valid(out): def check_async(use_async): if use_async is None: try: - import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa + import asgiref # type: ignore[import-not-found] # pylint: disable=unused-import, import-outside-toplevel # noqa use_async = True except ImportError: pass elif use_async: try: - import asgiref # pylint: disable=unused-import, import-outside-toplevel # noqa + import asgiref # type: ignore[import-not-found] # pylint: disable=unused-import, import-outside-toplevel # noqa except ImportError as exc: raise Exception( "You are trying to use dash[async] without having installed the requirements please install via: `pip install dash[async]`" diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index e8b007a50b..cde81b0fd4 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -43,7 +43,7 @@ def _is_flask_instance(obj): def _is_fastapi_instance(obj): try: # pylint: disable=import-outside-toplevel - from fastapi import FastAPI + from fastapi import FastAPI # type: ignore[import-not-found] return isinstance(obj, FastAPI) except ImportError: @@ -53,7 +53,7 @@ def _is_fastapi_instance(obj): def _is_quart_instance(obj): try: # pylint: disable=import-outside-toplevel - from quart import Quart + from quart import Quart # type: ignore[import-not-found] return isinstance(obj, Quart) except ImportError: diff --git a/dash/dash.py b/dash/dash.py index 7949540dfc..ce39a3f446 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -20,8 +20,6 @@ import traceback -from importlib_metadata import version as _get_distribution_version - from dash import dcc from dash import html from dash import dash_table @@ -53,7 +51,6 @@ convert_to_AttributeDict, gen_salt, hooks_to_js_object, - parse_version, get_caller_name, get_root_path, ) @@ -1182,7 +1179,7 @@ def index(self, *_args, **_kwargs): renderer = self._generate_renderer() title = self.title # Refactored: direct access to global request adapter - request = backends.request_adapter() + request = backends.backend.request_adapter() if self.use_pages and self.config.include_pages_meta and request: metas = _page_meta_tags(self, request) + metas @@ -1398,7 +1395,7 @@ def _inputs_to_vals(self, inputs): # pylint: disable=R0915 def _initialize_context(self, body): """Initialize the global context for the request.""" - adapter = backends.request_adapter() + adapter = backends.backend.request_adapter() g = AttributeDict({}) g.inputs_list = body.get("inputs", []) g.states_list = body.get("state", []) @@ -2445,6 +2442,7 @@ async def get_layouts(): ] layouts = await get_layouts() + # pylint: disable=not-callable layouts += [self.layout() if callable(self.layout) else self.layout] self.validation_layout = html.Div(layouts) if _ID_CONTENT not in self.validation_layout: @@ -2512,6 +2510,7 @@ def update(pathname_, search_, **states): if not self.config.suppress_callback_exceptions: layout = self.layout if not isinstance(layout, list): + # pylint: disable=not-callable layout = [self.layout() if callable(self.layout) else self.layout] self.validation_layout = html.Div( [ From ef9ff4fed3bee47e2cb203e60dc0b07034b04b43 Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 5 Feb 2026 11:24:56 -0500 Subject: [PATCH 088/166] import/lint fixes --- dash/_configs.py | 3 ++- dash/backends/_fastapi.py | 8 +++----- dash/backends/_flask.py | 9 ++++++--- dash/backends/_quart.py | 7 +++++-- dash/backends/base_server.py | 4 ++-- 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/dash/_configs.py b/dash/_configs.py index 225c604842..107b8308f5 100644 --- a/dash/_configs.py +++ b/dash/_configs.py @@ -1,6 +1,7 @@ -from ._utils import get_root_path import os +from ._utils import get_root_path + # noinspection PyCompatibility from . import exceptions from ._utils import AttributeDict diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index c21cf67473..8ebf38ed96 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -9,7 +9,6 @@ import inspect import pkgutil import time -from importlib.util import spec_from_file_location import json import os import subprocess @@ -21,11 +20,10 @@ from starlette.responses import Response as StarletteResponse from starlette.datastructures import MutableHeaders from starlette.types import ASGIApp, Scope, Receive, Send - import uvicorn -except ImportError: +except ImportError as _err: raise ImportError( "All dependencies not installed. Please install it with `dash[fastapi]` to use the FastAPI backend." - ) from None + ) from _err from dash.fingerprint import check_fingerprint from dash import _validate @@ -234,7 +232,7 @@ def has_request_context(self) -> bool: def run(self, dash_app: Dash, host, port, debug, **kwargs): frame = inspect.stack()[2] - dev_tools = dash_app._dev_tools + dev_tools = dash_app._dev_tools # pylint-disable=W0212 config = dict( {"debug": debug} if debug else {"debug": False}, **{f"dev_tools_{k}": v for k, v in dev_tools.items()}, diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index 396c020e6b..f1494a8335 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -1,8 +1,5 @@ from __future__ import annotations -from contextvars import copy_context -from importlib_metadata import version as _get_distribution_version -from typing import TYPE_CHECKING, Any, Callable, Dict import asyncio import pkgutil import sys @@ -10,6 +7,12 @@ import time import inspect import traceback + +from contextvars import copy_context +from typing import TYPE_CHECKING, Any, Callable, Dict + +from importlib_metadata import version as _get_distribution_version + from flask import ( Flask, Blueprint, diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 5abda7e30e..3a9df59578 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -1,14 +1,17 @@ from __future__ import annotations -from importlib_metadata import version as _get_distribution_version -from contextvars import copy_context + import typing as _t import mimetypes import inspect import pkgutil import time import sys + +from contextvars import copy_context from typing import Any +from importlib_metadata import version as _get_distribution_version + # Attempt top-level Quart imports; allow absence if user not using quart backend try: from quart import ( diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index f571669bee..fc149b3847 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Type class RequestAdapter(ABC): @@ -71,7 +71,7 @@ class BaseDashServer(ABC): server_type: str server: Any config: Dict[str, Any] - request_adapter: RequestAdapter + request_adapter: Type[RequestAdapter] def __call__(self, *args, **kwargs) -> Any: # Default: WSGI From 25e1f453ac4d798872262c88e125c69ba7f1f1a1 Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 5 Feb 2026 13:29:34 -0500 Subject: [PATCH 089/166] fix linting --- dash/_callback.py | 2 +- dash/backends/__init__.py | 4 +++- dash/backends/_fastapi.py | 23 ++++++++++++++--------- dash/backends/_flask.py | 5 ++--- dash/backends/_quart.py | 12 ++++++------ dash/backends/base_server.py | 18 +++++++++++++++--- dash/dash.py | 1 + 7 files changed, 42 insertions(+), 23 deletions(-) diff --git a/dash/_callback.py b/dash/_callback.py index 6895623a74..0ce4ee59a4 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -276,7 +276,7 @@ def insert_callback( no_output=False, optional=False, hidden=None, -): +) -> str: if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index cde81b0fd4..3a12e7939a 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -1,4 +1,6 @@ import importlib +from typing import Type + from .base_server import BaseDashServer @@ -12,7 +14,7 @@ } -def get_backend(name: str) -> BaseDashServer: +def get_backend(name: str) -> Type[BaseDashServer]: module_name, server_class = _backend_imports[name.lower()] try: module = importlib.import_module(module_name) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 8ebf38ed96..dfeb26e13c 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -94,13 +94,14 @@ def _remove_config(config_path): pass -class FastAPIDashServer(BaseDashServer): +class FastAPIDashServer(BaseDashServer[FastAPI]): def __init__(self, server: FastAPI): + super().__init__(server) self.server_type = "fastapi" - self.server: FastAPI = server self.error_handling_mode = "ignore" self.request_adapter = FastAPIRequestAdapter self._before_request_funcs = [] + self._before_middleware_added = False fname = inspect.stack()[2] file_path = fname.filename @@ -114,7 +115,6 @@ def __init__(self, server: FastAPI): hash_digest = hashlib.sha256(file_path.encode("utf-8")).hexdigest() config_filename = f"{hash_digest}.json" self._CONFIG_PATH = os.path.join(config_dir, config_filename) - super().__init__() def __call__(self, *args: Any, **kwargs: Any): # ASGI: (scope, receive, send) @@ -232,7 +232,7 @@ def has_request_context(self) -> bool: def run(self, dash_app: Dash, host, port, debug, **kwargs): frame = inspect.stack()[2] - dev_tools = dash_app._dev_tools # pylint-disable=W0212 + dev_tools = dash_app._dev_tools # pylint: disable=W0212 config = dict( {"debug": debug} if debug else {"debug": False}, **{f"dev_tools_{k}": v for k, v in dev_tools.items()}, @@ -260,6 +260,7 @@ def run(self, dash_app: Dash, host, port, debug, **kwargs): # Add any other kwargs as CLI args if needed try: + # pylint: disable=R1732 proc = subprocess.Popen(uvicorn_args, env=os.environ.copy()) proc.wait() finally: @@ -293,7 +294,7 @@ async def middleware(request, call_next): return response except PreventUpdate: return Response(status_code=204) - except Exception as e: + except Exception as e: # pylint: disable=W0718 if self.error_handling_mode in ["raise", "prune"]: tb = self._get_traceback(None, e) return Response(content=tb, media_type="text/html", status_code=500) @@ -465,14 +466,18 @@ async def view_func(_request: Request, body: dict = Body(...)): ) def enable_compression(self) -> None: - from fastapi.middleware.gzip import GZipMiddleware + from fastapi.middleware.gzip import ( + GZipMiddleware, + ) # pylint: disable=import-outside-toplevel self.server.add_middleware(GZipMiddleware, minimum_size=500) - config = _load_config() + config = _load_config(self._CONFIG_PATH) if "COMPRESS_ALGORITHM" not in config: - config["COMPRESS_ALGORITHM"] = ["gzip"] + config["COMPRESS_ALGORITHM"] = [ + "gzip" + ] # pylint: disable=no-value-for-parameter - _save_config(config) + _save_config(self._CONFIG_PATH, config) class FastAPIRequestAdapter(RequestAdapter): diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index f1494a8335..ef31d81557 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -37,12 +37,11 @@ from dash import Dash -class FlaskDashServer(BaseDashServer): +class FlaskDashServer(BaseDashServer[Flask]): def __init__(self, server: Flask) -> None: - self.server: Flask = server + super().__init__(server) self.server_type = "flask" self.request_adapter = FlaskRequestAdapter - super().__init__() def __call__(self, *args: Any, **kwargs: Any): # Always WSGI diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 3a9df59578..2f7a07203d 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -33,18 +33,17 @@ from dash.fingerprint import check_fingerprint from dash._utils import parse_version from dash import _validate, Dash -from .base_server import BaseDashServer +from .base_server import BaseDashServer, RequestAdapter from ._utils import format_traceback_html -class QuartDashServer(BaseDashServer): +class QuartDashServer(BaseDashServer[Quart]): def __init__(self, server: Quart) -> None: + super().__init__(server) self.server_type = "quart" - self.server: Quart = server self.config = {} self.error_handling_mode = "ignore" self.request_adapter = QuartRequestAdapter - super().__init__() def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] return self.server(*args, **kwargs) @@ -340,7 +339,7 @@ def enable_compression(self) -> None: ) from error -class QuartRequestAdapter: +class QuartRequestAdapter(RequestAdapter): def __init__(self) -> None: self._request = request # type: ignore[assignment] if self._request is None: @@ -396,5 +395,6 @@ def origin(self): def path(self): return self.request.path - async def get_json(self): + async def get_json(self): # pylint: disable=W0236 + # TODO consider using a sync wraper return await self.request.get_json() diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index fc149b3847..da0c920a37 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -1,5 +1,13 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Type +from typing import Any, Dict, Type, TypeVar, Generic, Protocol + + +class _ServerCallable(Protocol): # pylint: disable=too-few-public-methods + def __call__(self, *args: Any, **kwds: Any) -> Any: + raise NotImplementedError + + +ServerType = TypeVar("ServerType", bound=_ServerCallable) class RequestAdapter(ABC): @@ -67,12 +75,16 @@ def path(self) -> str: raise NotImplementedError() -class BaseDashServer(ABC): +class BaseDashServer(ABC, Generic[ServerType]): server_type: str - server: Any + server: ServerType config: Dict[str, Any] request_adapter: Type[RequestAdapter] + def __init__(self, server: ServerType) -> None: + super().__init__() + self.server = server + def __call__(self, *args, **kwargs) -> Any: # Default: WSGI return self.server(*args, **kwargs) diff --git a/dash/dash.py b/dash/dash.py index ce39a3f446..75cd57aa37 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -2301,6 +2301,7 @@ def run( host = host or "127.0.0.1" else: host = host or os.getenv("HOST", "127.0.0.1") + assert host port = port or os.getenv("PORT", "8050") proxy = proxy or os.getenv("DASH_PROXY") From 710c9d0c8c662565085a27dac1709f3aa4326d62 Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 5 Feb 2026 13:56:54 -0500 Subject: [PATCH 090/166] add missing backend abstract method --- dash/backends/_fastapi.py | 3 ++- dash/backends/base_server.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index dfeb26e13c..4d8d5ed5f3 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -466,9 +466,10 @@ async def view_func(_request: Request, body: dict = Body(...)): ) def enable_compression(self) -> None: + # pylint: disable=import-outside-toplevel from fastapi.middleware.gzip import ( GZipMiddleware, - ) # pylint: disable=import-outside-toplevel + ) self.server.add_middleware(GZipMiddleware, minimum_size=500) config = _load_config(self._CONFIG_PATH) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index da0c920a37..8a7b25ad08 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -143,3 +143,11 @@ def jsonify(self, obj) -> Any: # pragma: no cover - interface @abstractmethod def enable_compression(self) -> None: # pragma: no cover - interface pass + + @abstractmethod + def register_prune_error_handler(self, secret: str, prune_errors: bool) -> None: + pass + + @abstractmethod + def register_timing_hooks(self, first_run: bool) -> None: + pass From 86272974f5dbb45b246331eccdcdc409e69584ac Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 5 Feb 2026 15:28:37 -0500 Subject: [PATCH 091/166] lint fix --- dash/backends/_fastapi.py | 3 ++- dash/backends/_quart.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 4d8d5ed5f3..75247c760f 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -126,6 +126,7 @@ def _cleanup_config(self): _remove_config(self._CONFIG_PATH) @staticmethod + # pylint: disable=W0613 def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): app = FastAPI() @@ -466,7 +467,7 @@ async def view_func(_request: Request, body: dict = Body(...)): ) def enable_compression(self) -> None: - # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel,import-error from fastapi.middleware.gzip import ( GZipMiddleware, ) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 2f7a07203d..9bd927d80c 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -190,6 +190,7 @@ def has_request_context(self) -> bool: raise RuntimeError("Quart not installed; cannot check request context") return has_request_context() + # pylint: disable=W0613 def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): self.config = {"debug": debug, **kwargs} if debug else kwargs self.server.run(host=host, port=port, debug=debug, **kwargs) From c74278b0f9b463283a1df5c1397dd92cb7e30209 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 6 Feb 2026 10:14:46 -0500 Subject: [PATCH 092/166] fix fastapi request middleware --- dash/backends/_fastapi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 75247c760f..142a90bb0a 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -215,7 +215,7 @@ def before_request(self, func: Callable[[], Any] | None): if func is not None: self._before_request_funcs.append(func) # Only add the middleware once - if not hasattr(self, "_before_middleware_added"): + if not self._before_middleware_added: self.server.add_middleware(CurrentRequestMiddleware) self.server.middleware("http")(self._make_before_middleware()) self._before_middleware_added = True From 62946094edc1572dce3d57ed35ad9d64643ae61f Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 16 Feb 2026 13:20:48 -0500 Subject: [PATCH 093/166] fix backend tests --- dash/backends/_fastapi.py | 84 ++++++++++++++++++++++++++++----------- 1 file changed, 61 insertions(+), 23 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 142a90bb0a..ae5faf6cdd 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -12,6 +12,7 @@ import json import os import subprocess +import threading try: from fastapi import FastAPI, Request, Response, Body @@ -20,6 +21,7 @@ from starlette.responses import Response as StarletteResponse from starlette.datastructures import MutableHeaders from starlette.types import ASGIApp, Scope, Receive, Send + import uvicorn except ImportError as _err: raise ImportError( "All dependencies not installed. Please install it with `dash[fastapi]` to use the FastAPI backend." @@ -242,30 +244,66 @@ def run(self, dash_app: Dash, host, port, debug, **kwargs): if debug and kwargs.get("reload") is None: kwargs["reload"] = True - file_path = frame.filename - rel_path = os.path.relpath(file_path, os.getcwd()) - module_name = os.path.splitext(rel_path)[0].replace(os.sep, ".") - uvicorn_args = [ - sys.executable, - "-m", - "uvicorn", - f"{module_name}:server", - "--host", - str(host), - "--port", - str(port), - ] - if kwargs.get("reload"): - uvicorn_args.append("--reload") - - # Add any other kwargs as CLI args if needed + # Check if we're running in a thread (e.g., from testing framework) + # If so, run uvicorn directly instead of spawning a subprocess + is_threaded = threading.current_thread() != threading.main_thread() - try: - # pylint: disable=R1732 - proc = subprocess.Popen(uvicorn_args, env=os.environ.copy()) - proc.wait() - finally: - self._cleanup_config() + if is_threaded: + # Running in a thread (testing context) - use uvicorn.run directly + # This allows the testing framework to control the server lifecycle + if kwargs.get("reload"): + kwargs["reload"] = True + try: + uvicorn.run(self.server, host=host, port=port, **kwargs) + finally: + self._cleanup_config() + else: + # Running in main thread (normal context) - use subprocess + file_path = frame.filename + rel_path = os.path.relpath(file_path, os.getcwd()) + + # Check if the file is outside the current working directory + if rel_path.startswith(".."): + # File is outside cwd, try to find the module name from sys.modules + module_name = None + for mod_name, mod in sys.modules.items(): + if hasattr(mod, "__file__") and mod.__file__: + if os.path.abspath(mod.__file__) == os.path.abspath(file_path): + module_name = mod_name + break + + # If we still can't find it, raise an error + if not module_name: + raise RuntimeError( + f"Cannot determine module name for {file_path}. " + "The file is outside the current working directory and not found in sys.modules. " + "Please ensure the FastAPI app is being run from a file within the current working directory." + ) + else: + # File is within cwd, use relative path + module_name = os.path.splitext(rel_path)[0].replace(os.sep, ".") + + uvicorn_args = [ + sys.executable, + "-m", + "uvicorn", + f"{module_name}:server", + "--host", + str(host), + "--port", + str(port), + ] + if kwargs.get("reload"): + uvicorn_args.append("--reload") + + # Add any other kwargs as CLI args if needed + + try: + # pylint: disable=R1732 + proc = subprocess.Popen(uvicorn_args, env=os.environ.copy()) + proc.wait() + finally: + self._cleanup_config() def make_response( self, From 58ad02fc6ebc3130011fa59228639ca681aff59d Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 17 Feb 2026 09:24:05 -0500 Subject: [PATCH 094/166] fix status code make_response --- dash/backends/_fastapi.py | 3 ++- dash/backends/_flask.py | 5 ++++- dash/backends/_quart.py | 5 ++++- dash/backends/base_server.py | 6 +++++- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index ae5faf6cdd..a27b338da4 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -310,13 +310,14 @@ def make_response( data: str | bytes | bytearray, mimetype: str | None = None, content_type: str | None = None, + status: int | None = None, ): headers = {} if mimetype: headers["content-type"] = mimetype if content_type: headers["content-type"] = content_type - return Response(content=data, headers=headers) + return Response(content=data, headers=headers, status_code=status or 200) def jsonify(self, obj: Any): return JSONResponse(content=obj) diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index ef31d81557..af18961b55 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -145,8 +145,11 @@ def make_response( data: str | bytes | bytearray, mimetype: str | None = None, content_type: str | None = None, + status: int | None = None, ): - return Response(data, mimetype=mimetype, content_type=content_type) + return Response( + data, mimetype=mimetype, content_type=content_type, status=status + ) def jsonify(self, obj: Any): return jsonify(obj) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 9bd927d80c..0bc0afcce5 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -200,10 +200,13 @@ def make_response( data: str | bytes | bytearray, mimetype: str | None = None, content_type: str | None = None, + status=None, ): if Response is None: raise RuntimeError("Quart not installed; cannot generate Response") - return Response(data, mimetype=mimetype, content_type=content_type) + return Response( + data, mimetype=mimetype, content_type=content_type, status=status + ) def jsonify(self, obj): return jsonify(obj) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 8a7b25ad08..49c8596ce7 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -132,7 +132,11 @@ def run( @abstractmethod def make_response( - self, data, mimetype=None, content_type=None + self, + data, + mimetype=None, + content_type=None, + status=None, ) -> Any: # pragma: no cover - interface pass From 11166eed0d1241f6c254622ac87aab3b3b78a725 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 17 Feb 2026 10:52:14 -0500 Subject: [PATCH 095/166] fix test number input --- .../tests/integration/input/test_number_input.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/components/dash-core-components/tests/integration/input/test_number_input.py b/components/dash-core-components/tests/integration/input/test_number_input.py index 01deb79287..f16eb2c7b5 100644 --- a/components/dash-core-components/tests/integration/input/test_number_input.py +++ b/components/dash-core-components/tests/integration/input/test_number_input.py @@ -245,7 +245,7 @@ def test_inni010_valid_numbers(dash_dcc, ninput_app): (str(sys.float_info.max), float), (str(sys.float_info.min), float), ): - elem = dash_dcc.find_element("#input_false") + elem = dash_dcc.wait_for_element("#input_false") elem.send_keys(num) assert dash_dcc.wait_for_text_to_equal( "#div_false", str(op(num)) From 6d56c395376c028866b1b9f1db6e90c9fbb5e17e Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 17 Feb 2026 13:59:11 -0500 Subject: [PATCH 096/166] wait for input false first --- .../tests/integration/input/test_number_input.py | 1 + 1 file changed, 1 insertion(+) diff --git a/components/dash-core-components/tests/integration/input/test_number_input.py b/components/dash-core-components/tests/integration/input/test_number_input.py index f16eb2c7b5..e2ec647c69 100644 --- a/components/dash-core-components/tests/integration/input/test_number_input.py +++ b/components/dash-core-components/tests/integration/input/test_number_input.py @@ -238,6 +238,7 @@ def update_output(val): def test_inni010_valid_numbers(dash_dcc, ninput_app): dash_dcc.start_server(ninput_app) + elem = dash_dcc.wait_for_element("#input_false") for num, op in ( ("1.0", lambda x: int(float(x))), # limitation of js/json ("10e10", lambda x: int(float(x))), From 3209a63ef8e3986e0e49c4182672a38df220839d Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 17 Feb 2026 15:34:31 -0500 Subject: [PATCH 097/166] fix fastapi run to auto resolve instance --- dash/backends/_fastapi.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index a27b338da4..00c351cb38 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -119,10 +119,8 @@ def __init__(self, server: FastAPI): self._CONFIG_PATH = os.path.join(config_dir, config_filename) def __call__(self, *args: Any, **kwargs: Any): - # ASGI: (scope, receive, send) - if len(args) == 3 and isinstance(args[0], dict) and "type" in args[0]: - return self.server(*args, **kwargs) - raise TypeError("FastAPI app must be called with (scope, receive, send)") + # ASGI: pass through to FastAPI + return self.server(*args, **kwargs) def _cleanup_config(self): _remove_config(self._CONFIG_PATH) @@ -283,11 +281,33 @@ def run(self, dash_app: Dash, host, port, debug, **kwargs): # File is within cwd, use relative path module_name = os.path.splitext(rel_path)[0].replace(os.sep, ".") + # Find the Dash app variable name by inspecting the calling frame + dash_var_name = None + calling_frame = frame.frame + for var_name, var_value in calling_frame.f_locals.items(): + if var_value is dash_app: + dash_var_name = var_name + break + + # If not found in locals, check globals + if not dash_var_name: + for var_name, var_value in calling_frame.f_globals.items(): + if var_value is dash_app: + dash_var_name = var_name + break + + # Construct the app path - use .server to access the FastAPI instance + if dash_var_name: + app_path = f"{module_name}:{dash_var_name}.server" + else: + # Fallback to looking for 'server' variable (old behavior) + app_path = f"{module_name}:server" + uvicorn_args = [ sys.executable, "-m", "uvicorn", - f"{module_name}:server", + app_path, "--host", str(host), "--port", From 2e17d1132d243c1f89a23e68c70544eec3ba49cc Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 17 Feb 2026 15:53:47 -0500 Subject: [PATCH 098/166] Disable route logging for quart & fastapi --- dash/backends/_quart.py | 13 +++++++++ dash/dash.py | 27 ++++++++++++++----- .../backend_tests/test_preconfig_backends.py | 25 +++++++++++++++++ 3 files changed, 58 insertions(+), 7 deletions(-) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 0bc0afcce5..56f3253bac 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -7,6 +7,7 @@ import time import sys +from logging.config import dictConfig from contextvars import copy_context from typing import Any @@ -193,6 +194,18 @@ def has_request_context(self) -> bool: # pylint: disable=W0613 def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): self.config = {"debug": debug, **kwargs} if debug else kwargs + if dash_app._dev_tools.silence_routes_logging: + dictConfig( + { + "version": 1, + "loggers": { + "quart.app": { + "level": "ERROR", + }, + }, + } + ) + self.server.run(host=host, port=port, debug=debug, **kwargs) def make_response( diff --git a/dash/dash.py b/dash/dash.py index 75cd57aa37..96ccf4dde6 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -1944,9 +1944,10 @@ def enable_dev_tools( # pylint: disable=too-many-branches env: ``DASH_HOT_RELOAD_MAX_RETRY`` :type dev_tools_hot_reload_max_retry: int - :param dev_tools_silence_routes_logging: Silence the `werkzeug` logger, - will remove all routes logging. Enabled with debugging by default - because hot reload hash checks generate a lot of requests. + :param dev_tools_silence_routes_logging: Silence the route logging for the + web server (werkzeug for Flask, hypercorn for Quart, uvicorn for FastAPI). + Enabled with debugging by default because hot reload hash checks generate + a lot of requests. env: ``DASH_SILENCE_ROUTES_LOGGING`` :type dev_tools_silence_routes_logging: bool @@ -1981,7 +1982,18 @@ def enable_dev_tools( # pylint: disable=too-many-branches ) if dev_tools.silence_routes_logging: - logging.getLogger("werkzeug").setLevel(logging.ERROR) + # Silence route logging based on backend type + backend_type = getattr(self.backend, "server_type", "flask") + if backend_type == "flask": + logging.getLogger("werkzeug").setLevel(logging.ERROR) + elif backend_type == "quart": + # Quart uses hypercorn as its ASGI server + logging.getLogger("hypercorn.access").setLevel(logging.ERROR) + logging.getLogger("hypercorn.error").setLevel(logging.ERROR) + elif backend_type == "fastapi": + # FastAPI uses uvicorn as its ASGI server + logging.getLogger("uvicorn.access").setLevel(logging.ERROR) + logging.getLogger("uvicorn.error").setLevel(logging.ERROR) if dev_tools.hot_reload: _reload = self._hot_reload @@ -2236,9 +2248,10 @@ def run( env: ``DASH_HOT_RELOAD_MAX_RETRY`` :type dev_tools_hot_reload_max_retry: int - :param dev_tools_silence_routes_logging: Silence the `werkzeug` logger, - will remove all routes logging. Enabled with debugging by default - because hot reload hash checks generate a lot of requests. + :param dev_tools_silence_routes_logging: Silence the route logging for the + web server (werkzeug for Flask, hypercorn for Quart, uvicorn for FastAPI). + Enabled with debugging by default because hot reload hash checks generate + a lot of requests. env: ``DASH_SILENCE_ROUTES_LOGGING`` :type dev_tools_silence_routes_logging: bool diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py index 4c4ccc7083..3193305ee2 100644 --- a/tests/backend_tests/test_preconfig_backends.py +++ b/tests/backend_tests/test_preconfig_backends.py @@ -1,3 +1,4 @@ +import logging import pytest from dash import Dash, Input, Output, html, dcc @@ -215,3 +216,27 @@ def update_output_bg(value): "#output", f"Background typed: {backend.title()} BG Test" ) assert dash_duo.get_logs() == [] + + +@pytest.mark.parametrize( + "backend,expected_loggers", + [ + ("flask", ["werkzeug"]), + ("quart", ["hypercorn.access", "hypercorn.error"]), + ("fastapi", ["uvicorn.access", "uvicorn.error"]), + ], +) +def test_silence_routes_logging(backend, expected_loggers): + """Test that route logging is silenced for all backends when dev_tools_silence_routes_logging is enabled.""" + app = Dash(__name__, backend=backend) + app.layout = html.Div([html.Div(id="output", children="Test")]) + + # Enable dev tools with silence_routes_logging + app.enable_dev_tools(debug=True, dev_tools_silence_routes_logging=True) + + # Check that the expected loggers have been set to ERROR level + for logger_name in expected_loggers: + logger = logging.getLogger(logger_name) + assert ( + logger.level == logging.ERROR + ), f"Logger {logger_name} should be set to ERROR level for {backend} backend" From 2a2f45c6323cbd85e2ef9bc862b1b7954f799385 Mon Sep 17 00:00:00 2001 From: philippe Date: Wed, 18 Feb 2026 09:02:22 -0500 Subject: [PATCH 099/166] fix test ini --- .../tests/integration/input/conftest.py | 8 ++++---- dash/backends/_fastapi.py | 2 +- dash/backends/_quart.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/components/dash-core-components/tests/integration/input/conftest.py b/components/dash-core-components/tests/integration/input/conftest.py index c03087db1a..f612cfe462 100644 --- a/components/dash-core-components/tests/integration/input/conftest.py +++ b/components/dash-core-components/tests/integration/input/conftest.py @@ -2,7 +2,7 @@ from dash import Dash, Input, Output, dcc, html -@pytest.fixture(scope="module") +@pytest.fixture def ninput_app(): app = Dash(__name__) app.layout = html.Div( @@ -35,7 +35,7 @@ def render(fval, tval): yield app -@pytest.fixture(scope="module") +@pytest.fixture def input_range_app(): app = Dash(__name__) app.layout = html.Div( @@ -59,7 +59,7 @@ def range_out(val): yield app -@pytest.fixture(scope="module") +@pytest.fixture def debounce_text_app(): app = Dash(__name__) app.layout = html.Div( @@ -89,7 +89,7 @@ def render(slow_val, fast_val): yield app -@pytest.fixture(scope="module") +@pytest.fixture def debounce_number_app(): app = Dash(__name__) app.layout = html.Div( diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 00c351cb38..f19e628a73 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -231,7 +231,7 @@ def has_request_context(self) -> bool: except RuntimeError: return False - def run(self, dash_app: Dash, host, port, debug, **kwargs): + def run(self, dash_app: Dash, host, port, debug, **kwargs): # pylint: disable=R0912 frame = inspect.stack()[2] dev_tools = dash_app._dev_tools # pylint: disable=W0212 config = dict( diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 56f3253bac..92f67c2205 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -194,6 +194,7 @@ def has_request_context(self) -> bool: # pylint: disable=W0613 def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): self.config = {"debug": debug, **kwargs} if debug else kwargs + # pylint: disable=protected-access if dash_app._dev_tools.silence_routes_logging: dictConfig( { From 31493def7e3db283f38d0591eb815b56503721da Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 19 Feb 2026 10:29:58 -0500 Subject: [PATCH 100/166] fastapi config from env, refactor middlewares --- dash/backends/_fastapi.py | 302 ++++++++++++++++++----------------- dash/backends/base_server.py | 29 +++- dash/dash.py | 1 + 3 files changed, 184 insertions(+), 148 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index f19e628a73..f781300bdf 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -2,6 +2,7 @@ import asyncio from contextvars import copy_context, ContextVar +import json from typing import TYPE_CHECKING, Any, Callable, Dict import sys import mimetypes @@ -9,7 +10,6 @@ import inspect import pkgutil import time -import json import os import subprocess import threading @@ -55,45 +55,122 @@ def get_current_request() -> Request: return req -class CurrentRequestMiddleware: # pylint: disable=too-few-public-methods - def __init__(self, app: ASGIApp) -> None: # type: ignore[name-defined] +_ENV_CONFIG = "_DASH_FASTAPI_CONFIG" + + +class DashMiddleware: # pylint: disable=too-few-public-methods + """Consolidated middleware for all Dash/FastAPI integration needs.""" + + def __init__( + self, + app: ASGIApp, + dash_app: Dash, + before_request_funcs: list, + after_request_func: Callable | None = None, + enable_timing: bool = False, + error_handling_mode: str = "ignore", + get_traceback_func: Callable | None = None, + ) -> None: self.app = app + self.dash_app = dash_app + self.before_request_funcs = before_request_funcs + self.after_request_func = after_request_func + self.enable_timing = enable_timing + self.error_handling_mode = error_handling_mode + self.get_traceback_func = get_traceback_func + self._dev_tools_initialized = False + + async def _initialize_dev_tools(self) -> None: + """Initialize dev tools from environment config on first run.""" + if not self._dev_tools_initialized: + config = json.loads(os.getenv(_ENV_CONFIG, "{}")) + if config: + self.dash_app.enable_dev_tools(**config, first_run=False) + self._dev_tools_initialized = True + + def _setup_timing(self, request: Request) -> None: + """Set up timing information for the request.""" + if self.enable_timing: + request.state.timing_information = { + "__dash_server": {"dur": time.time(), "desc": None} + } - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # type: ignore[name-defined] - # non-http/ws scopes pass through (lifespan etc.) + async def _run_before_hooks(self) -> None: + """Run all before-request hooks.""" + for func in self.before_request_funcs: + if inspect.iscoroutinefunction(func): + await func() + else: + func() + + async def _run_after_hooks(self) -> None: + """Run after-request hook if configured.""" + if self.after_request_func is not None: + if inspect.iscoroutinefunction(self.after_request_func): + await self.after_request_func() + else: + self.after_request_func() + + def _finalize_timing(self, request: Request) -> dict | None: + """Calculate final timing information and return headers to add.""" + if not self.enable_timing or not hasattr(request.state, "timing_information"): + return None + + timing_information = request.state.timing_information + dash_total = timing_information.get("__dash_server", None) + if dash_total is not None: + dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) + + return timing_information + + async def _handle_error( + self, error: Exception, scope: Scope, receive: Receive, send: Send + ) -> None: + """Handle exceptions during request processing.""" + if isinstance(error, PreventUpdate): + response = Response(status_code=204) + elif self.error_handling_mode in ["raise", "prune"] and self.get_traceback_func: + tb = self.get_traceback_func(None, error) + response = Response(content=tb, media_type="text/html", status_code=500) + else: + response = JSONResponse( + status_code=500, + content={ + "error": "InternalServerError", + "message": "An internal server error occurred.", + }, + ) + await response(scope, receive, send) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + # Handle lifespan events (startup/shutdown) + if scope["type"] == "lifespan": + await self._initialize_dev_tools() + await self.app(scope, receive, send) + return + + # Non-HTTP/WebSocket scopes pass through if scope["type"] not in ("http", "websocket"): await self.app(scope, receive, send) return + # HTTP/WebSocket request handling request = Request(scope, receive=receive) token = set_current_request(request) - try: - await self.app(scope, receive, send) - finally: - reset_current_request(token) - - -def _save_config(config_path, config): - with open(config_path, "w", encoding="utf-8") as f: - json.dump(config, f) + try: + self._setup_timing(request) + await self._run_before_hooks() -def _load_config(config_path): - resp = {"debug": False} - try: - if os.path.exists(config_path): - with open(config_path, "r", encoding="utf-8") as f: - resp = json.load(f) - except (json.JSONDecodeError, OSError): - pass # ignore errors - return resp + await self.app(scope, receive, send) + await self._run_after_hooks() + self._finalize_timing(request) -def _remove_config(config_path): - try: - os.remove(config_path) - except FileNotFoundError: - pass + except Exception as e: # pylint: disable=W0718 + await self._handle_error(e, scope, receive, send) + finally: + reset_current_request(token) class FastAPIDashServer(BaseDashServer[FastAPI]): @@ -103,28 +180,13 @@ def __init__(self, server: FastAPI): self.error_handling_mode = "ignore" self.request_adapter = FastAPIRequestAdapter self._before_request_funcs = [] - self._before_middleware_added = False - - fname = inspect.stack()[2] - file_path = fname.filename - - # Manually build the config directory path - home_dir = os.path.expanduser("~") - config_dir = os.path.join(home_dir, ".local", "share", "plotly-dash-configs") - os.makedirs(config_dir, exist_ok=True) - - # Hash the file path for a unique config filename - hash_digest = hashlib.sha256(file_path.encode("utf-8")).hexdigest() - config_filename = f"{hash_digest}.json" - self._CONFIG_PATH = os.path.join(config_dir, config_filename) + self._after_request_func = None + self._enable_timing = False def __call__(self, *args: Any, **kwargs: Any): # ASGI: pass through to FastAPI return self.server(*args, **kwargs) - def _cleanup_config(self): - _remove_config(self._CONFIG_PATH) - @staticmethod # pylint: disable=W0613 def create_app(name: str = "__main__", config: Dict[str, Any] | None = None): @@ -178,17 +240,11 @@ async def index(_request: Request): dash_app._add_url("", index, methods=["GET"]) def setup_catchall(self, dash_app: Dash): - @self.server.on_event("startup") - def _setup_catchall(): - dash_app.enable_dev_tools( - **_load_config(self._CONFIG_PATH), first_run=False - ) # do this to make sure dev tools are enabled - - async def catchall(_request: Request): - return Response(content=dash_app.index(), media_type="text/html") + async def catchall(_request: Request): + return Response(content=dash_app.index(), media_type="text/html") - # pylint: disable=protected-access - dash_app._add_url("{path:path}", catchall, methods=["GET"]) + # pylint: disable=protected-access + dash_app._add_url("{path:path}", catchall, methods=["GET"]) def add_url_rule( self, @@ -214,15 +270,9 @@ def add_url_rule( def before_request(self, func: Callable[[], Any] | None): if func is not None: self._before_request_funcs.append(func) - # Only add the middleware once - if not self._before_middleware_added: - self.server.add_middleware(CurrentRequestMiddleware) - self.server.middleware("http")(self._make_before_middleware()) - self._before_middleware_added = True def after_request(self, func: Callable[[], Any] | None): - # FastAPI does not have after_request, but we can use middleware - self.server.middleware("http")(self._make_after_middleware(func)) + self._after_request_func = func def has_request_context(self) -> bool: try: @@ -233,12 +283,6 @@ def has_request_context(self) -> bool: def run(self, dash_app: Dash, host, port, debug, **kwargs): # pylint: disable=R0912 frame = inspect.stack()[2] - dev_tools = dash_app._dev_tools # pylint: disable=W0212 - config = dict( - {"debug": debug} if debug else {"debug": False}, - **{f"dev_tools_{k}": v for k, v in dev_tools.items()}, - ) - _save_config(self._CONFIG_PATH, config) if debug and kwargs.get("reload") is None: kwargs["reload"] = True @@ -251,10 +295,7 @@ def run(self, dash_app: Dash, host, port, debug, **kwargs): # pylint: disable=R # This allows the testing framework to control the server lifecycle if kwargs.get("reload"): kwargs["reload"] = True - try: - uvicorn.run(self.server, host=host, port=port, **kwargs) - finally: - self._cleanup_config() + uvicorn.run(self.server, host=host, port=port, **kwargs) else: # Running in main thread (normal context) - use subprocess file_path = frame.filename @@ -316,14 +357,19 @@ def run(self, dash_app: Dash, host, port, debug, **kwargs): # pylint: disable=R if kwargs.get("reload"): uvicorn_args.append("--reload") + dev_tools = dash_app._dev_tools # pylint: disable=W0212 + config = dict( + {"debug": debug} if debug else {"debug": False}, + **{f"dev_tools_{k}": v for k, v in dev_tools.items()}, + ) + env = os.environ.copy() + env[_ENV_CONFIG] = json.dumps(config) + # Add any other kwargs as CLI args if needed - try: - # pylint: disable=R1732 - proc = subprocess.Popen(uvicorn_args, env=os.environ.copy()) - proc.wait() - finally: - self._cleanup_config() + # pylint: disable=R1732 + proc = subprocess.Popen(uvicorn_args, env=env) + proc.wait() def make_response( self, @@ -342,44 +388,6 @@ def make_response( def jsonify(self, obj: Any): return JSONResponse(content=obj) - def _make_before_middleware(self): - async def middleware(request, call_next): - for func in self._before_request_funcs: - if inspect.iscoroutinefunction(func): - await func() - else: - func() - try: - response = await call_next(request) - return response - except PreventUpdate: - return Response(status_code=204) - except Exception as e: # pylint: disable=W0718 - if self.error_handling_mode in ["raise", "prune"]: - tb = self._get_traceback(None, e) - return Response(content=tb, media_type="text/html", status_code=500) - return JSONResponse( - status_code=500, - content={ - "error": "InternalServerError", - "message": "An internal server error occurred.", - }, - ) - - return middleware - - def _make_after_middleware(self, func: Callable[[], Any] | None): - async def middleware(request, call_next): - response = await call_next(request) - if func is not None: - if inspect.iscoroutinefunction(func): - await func() - else: - func() - return response - - return middleware - def serve_component_suites( self, dash_app: Dash, @@ -461,31 +469,8 @@ async def _dispatch(request: Request): return _dispatch def register_timing_hooks(self, first_run: bool): - if not first_run: - return - - @self.server.middleware("http") - async def timing_middleware(request: Request, call_next): - # Before request - request.state.timing_information = { - "__dash_server": {"dur": time.time(), "desc": None} - } - response = await call_next(request) - # After request - timing_information = getattr(request.state, "timing_information", None) - if timing_information is not None: - dash_total = timing_information.get("__dash_server", None) - if dash_total is not None: - dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000) - headers = MutableHeaders(response.headers) - for name, info in timing_information.items(): - value = name - if info.get("desc") is not None: - value += f';desc="{info["desc"]}"' - if info.get("dur") is not None: - value += f";dur={info['dur']}" - headers.append("Server-Timing", value) - return response + if first_run: + self._enable_timing = True def register_callback_api_routes( self, callback_api_paths: Dict[str, Callable[..., Any]] @@ -532,13 +517,36 @@ def enable_compression(self) -> None: ) self.server.add_middleware(GZipMiddleware, minimum_size=500) - config = _load_config(self._CONFIG_PATH) - if "COMPRESS_ALGORITHM" not in config: - config["COMPRESS_ALGORITHM"] = [ - "gzip" - ] # pylint: disable=no-value-for-parameter - _save_config(self._CONFIG_PATH, config) + def setup_backend(self, dash_app: Dash): + # Add consolidated middleware for all Dash functionality + self.server.add_middleware( + DashMiddleware, + dash_app=dash_app, + before_request_funcs=self._before_request_funcs, + after_request_func=self._after_request_func, + enable_timing=self._enable_timing, + error_handling_mode=self.error_handling_mode, + get_traceback_func=self._get_traceback, + ) + + # Add timing middleware separately if enabled (needs to modify response headers) + if self._enable_timing: + + @self.server.middleware("http") + async def timing_headers_middleware(request: Request, call_next): + response = await call_next(request) + timing_information = getattr(request.state, "timing_information", None) + if timing_information is not None: + headers = MutableHeaders(response.headers) + for name, info in timing_information.items(): + value = name + if info.get("desc") is not None: + value += f';desc="{info["desc"]}"' + if info.get("dur") is not None: + value += f";dur={info['dur']}" + headers.append("Server-Timing", value) + return response class FastAPIRequestAdapter(RequestAdapter): diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 49c8596ce7..9f12791afa 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -1,5 +1,9 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Type, TypeVar, Generic, Protocol +from typing import Any, Dict, Type, TypeVar, Generic, Protocol, TYPE_CHECKING + + +if TYPE_CHECKING: + import dash class _ServerCallable(Protocol): # pylint: disable=too-few-public-methods @@ -155,3 +159,26 @@ def register_prune_error_handler(self, secret: str, prune_errors: bool) -> None: @abstractmethod def register_timing_hooks(self, first_run: bool) -> None: pass + + @abstractmethod + def register_callback_api_routes(self, callback_api_paths): + pass + + @abstractmethod + def setup_component_suites(self, dash_app: "dash.Dash") -> str: + pass + + @abstractmethod + def serve_callback(self, dash_app: "dash.Dash"): + pass + + @abstractmethod + def setup_index(self, dash_app: "dash.Dash"): + pass + + @abstractmethod + def setup_catchall(self, dash_app: "dash.Dash"): + pass + + def setup_backend(self, dash_app: "dash.Dash"): + """Override to provide any other required setup""" diff --git a/dash/dash.py b/dash/dash.py index 96ccf4dde6..036ca3bd3b 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -729,6 +729,7 @@ def _handle_error(_): self.backend.register_error_handlers() self.backend.before_request(self._setup_server) + self.backend.setup_backend(self) self._setup_routes() _get_app.APP = self self.enable_pages() From f8c89836a53332ccb00f748d230201f51dfaa796 Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 19 Feb 2026 10:49:48 -0500 Subject: [PATCH 101/166] add docstring to base server --- dash/backends/base_server.py | 167 ++++++++++++++++++++++++++++++++++- 1 file changed, 166 insertions(+), 1 deletion(-) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 9f12791afa..d88d8422f7 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -1,3 +1,8 @@ +"""Base server abstractions for Dash backend implementations. + +This module provides abstract base classes and protocols that define the interface +for different web server backends (Flask, Quart, FastAPI, etc.) to integrate with Dash. +""" from abc import ABC, abstractmethod from typing import Any, Dict, Type, TypeVar, Generic, Protocol, TYPE_CHECKING @@ -7,6 +12,10 @@ class _ServerCallable(Protocol): # pylint: disable=too-few-public-methods + """Protocol for callable server instances. + + Defines the interface for server objects that can be called as WSGI/ASGI applications. + """ def __call__(self, *args: Any, **kwds: Any) -> Any: raise NotImplementedError @@ -15,81 +24,123 @@ def __call__(self, *args: Any, **kwds: Any) -> Any: class RequestAdapter(ABC): + """Abstract adapter for normalizing HTTP request objects across different server backends. + + This adapter provides a unified interface for accessing request data regardless of + the underlying web framework (Flask, Quart, FastAPI, etc.). Concrete implementations + wrap framework-specific request objects and expose their data through these properties. + """ def __call__(self) -> "RequestAdapter": return self @property @abstractmethod def context(self) -> Any: # pragma: no cover - interface + """Get the framework-specific request context object.""" raise NotImplementedError() # Properties to be implemented in concrete adapters @property # pragma: no cover - interface @abstractmethod def root(self) -> str: + """Get the application root path.""" raise NotImplementedError() @property # pragma: no cover - interface @abstractmethod def args(self): + """Get the request query string arguments.""" raise NotImplementedError() @abstractmethod # kept as method (may be sync or async) def get_json(self): # pragma: no cover - interface + """Get the parsed JSON body of the request. + + May be synchronous or asynchronous depending on the backend. + """ raise NotImplementedError() @property # pragma: no cover - interface @abstractmethod def is_json(self) -> bool: + """Check if the request has a JSON content type.""" raise NotImplementedError() @property # pragma: no cover - interface @abstractmethod def cookies(self): + """Get the request cookies.""" raise NotImplementedError() @property # pragma: no cover - interface @abstractmethod def headers(self): + """Get the request headers.""" raise NotImplementedError() @property # pragma: no cover - interface @abstractmethod def full_path(self) -> str: + """Get the full request path including query string.""" raise NotImplementedError() @property # pragma: no cover - interface @abstractmethod def url(self) -> str: + """Get the full request URL.""" raise NotImplementedError() @property # pragma: no cover - interface @abstractmethod def remote_addr(self): + """Get the remote client IP address.""" raise NotImplementedError() @property # pragma: no cover - interface @abstractmethod def origin(self): + """Get the Origin header value.""" raise NotImplementedError() @property # pragma: no cover - interface @abstractmethod def path(self) -> str: + """Get the request path without query string.""" raise NotImplementedError() class BaseDashServer(ABC, Generic[ServerType]): + """Abstract base class for Dash server backend implementations. + + This class defines the interface that all server backends must implement to + work with Dash. Concrete implementations exist for Flask, Quart, FastAPI, and + other web frameworks. + + Attributes: + server_type: String identifier for the server backend (e.g., 'flask', 'quart') + server: The underlying server instance + config: Configuration dictionary for the server + request_adapter: RequestAdapter class for normalizing requests + """ server_type: str server: ServerType config: Dict[str, Any] request_adapter: Type[RequestAdapter] def __init__(self, server: ServerType) -> None: + """Initialize the server wrapper. + + Args: + server: The underlying server instance to wrap + """ super().__init__() self.server = server def __call__(self, *args, **kwargs) -> Any: + """Make the server wrapper callable as a WSGI/ASGI application. + + Delegates to the underlying server instance. + """ # Default: WSGI return self.server(*args, **kwargs) @@ -98,40 +149,89 @@ def __call__(self, *args, **kwargs) -> Any: def create_app( name: str = "__main__", config=None ) -> Any: # pragma: no cover - interface + """Create a new server application instance. + + Args: + name: Application name, defaults to '__main__' + config: Configuration dictionary or object + + Returns: + The server application instance + """ pass @abstractmethod def register_assets_blueprint( self, blueprint_name: str, assets_url_path: str, assets_folder: str ) -> None: # pragma: no cover - interface + """Register a blueprint/router for serving static assets. + + Args: + blueprint_name: Name for the assets blueprint + assets_url_path: URL path prefix for assets + assets_folder: Filesystem path to the assets folder + """ pass @abstractmethod def register_error_handlers(self) -> None: # pragma: no cover - interface + """Register error handlers for common HTTP errors.""" pass @abstractmethod def add_url_rule( self, rule: str, view_func, endpoint=None, methods=None ) -> None: # pragma: no cover - interface + """Add a URL routing rule. + + Args: + rule: URL pattern/route + view_func: View function to handle the route + endpoint: Optional endpoint name + methods: Optional list of HTTP methods (e.g., ['GET', 'POST']) + """ pass @abstractmethod def before_request(self, func) -> None: # pragma: no cover - interface + """Register a function to run before each request. + + Args: + func: Function to execute before request handling + """ pass @abstractmethod def after_request(self, func) -> None: # pragma: no cover - interface + """Register a function to run after each request. + + Args: + func: Function to execute after request handling + """ pass @abstractmethod def has_request_context(self) -> bool: # pragma: no cover - interface + """Check if currently executing within a request context. + + Returns: + True if in request context, False otherwise + """ pass @abstractmethod def run( self, dash_app, host: str, port: int, debug: bool, **kwargs ) -> None: # pragma: no cover - interface + """Start the development server. + + Args: + dash_app: The Dash application instance + host: Hostname to bind to + port: Port number to bind to + debug: Enable debug mode + **kwargs: Additional server-specific arguments + """ pass @abstractmethod @@ -142,43 +242,108 @@ def make_response( content_type=None, status=None, ) -> Any: # pragma: no cover - interface + """Create an HTTP response object. + + Args: + data: Response body data + mimetype: MIME type of the response + content_type: Content-Type header value + status: HTTP status code + + Returns: + Server-specific response object + """ pass @abstractmethod def jsonify(self, obj) -> Any: # pragma: no cover - interface + """Convert an object to a JSON response. + + Args: + obj: Object to serialize to JSON + + Returns: + JSON response object + """ pass @abstractmethod def enable_compression(self) -> None: # pragma: no cover - interface + """Enable HTTP compression for responses.""" pass @abstractmethod def register_prune_error_handler(self, secret: str, prune_errors: bool) -> None: + """Register handler for pruning error stack traces. + + Args: + secret: Secret key for error handling + prune_errors: Whether to prune stack traces in errors + """ pass @abstractmethod def register_timing_hooks(self, first_run: bool) -> None: + """Register hooks for timing request/response cycles. + + Args: + first_run: Whether this is the first run of the application + """ pass @abstractmethod def register_callback_api_routes(self, callback_api_paths): + """Register routes for Dash callback API endpoints. + + Args: + callback_api_paths: Paths for callback API endpoints + """ pass @abstractmethod def setup_component_suites(self, dash_app: "dash.Dash") -> str: + """Set up routes for serving component JavaScript bundles. + + Args: + dash_app: The Dash application instance + + Returns: + Base path for component suites + """ pass @abstractmethod def serve_callback(self, dash_app: "dash.Dash"): + """Set up the callback handling endpoint. + + Args: + dash_app: The Dash application instance + """ pass @abstractmethod def setup_index(self, dash_app: "dash.Dash"): + """Set up the index/root route for serving the main application. + + Args: + dash_app: The Dash application instance + """ pass @abstractmethod def setup_catchall(self, dash_app: "dash.Dash"): + """Set up the catchall route for client-side routing. + + Args: + dash_app: The Dash application instance + """ pass def setup_backend(self, dash_app: "dash.Dash"): - """Override to provide any other required setup""" + """Perform any additional backend-specific setup. + + Override this method in concrete implementations to provide custom setup logic. + + Args: + dash_app: The Dash application instance + """ From 6e616f5a102bcc661931278d195957021c84875c Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 19 Feb 2026 11:17:37 -0500 Subject: [PATCH 102/166] fix fastapi traceback debugger --- dash/backends/_fastapi.py | 13 +++++-------- dash/backends/base_server.py | 21 +++------------------ 2 files changed, 8 insertions(+), 26 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index f781300bdf..6476ca7a00 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -65,19 +65,17 @@ def __init__( self, app: ASGIApp, dash_app: Dash, + dash_server: FastAPIDashServer, before_request_funcs: list, after_request_func: Callable | None = None, enable_timing: bool = False, - error_handling_mode: str = "ignore", - get_traceback_func: Callable | None = None, ) -> None: self.app = app self.dash_app = dash_app + self.dash_server = dash_server self.before_request_funcs = before_request_funcs self.after_request_func = after_request_func self.enable_timing = enable_timing - self.error_handling_mode = error_handling_mode - self.get_traceback_func = get_traceback_func self._dev_tools_initialized = False async def _initialize_dev_tools(self) -> None: @@ -129,8 +127,8 @@ async def _handle_error( """Handle exceptions during request processing.""" if isinstance(error, PreventUpdate): response = Response(status_code=204) - elif self.error_handling_mode in ["raise", "prune"] and self.get_traceback_func: - tb = self.get_traceback_func(None, error) + elif self.dash_server.error_handling_mode in ["raise", "prune"]: + tb = self.dash_server._get_traceback(None, error) # pylint: disable=W0212 response = Response(content=tb, media_type="text/html", status_code=500) else: response = JSONResponse( @@ -523,11 +521,10 @@ def setup_backend(self, dash_app: Dash): self.server.add_middleware( DashMiddleware, dash_app=dash_app, + dash_server=self, before_request_funcs=self._before_request_funcs, after_request_func=self._after_request_func, enable_timing=self._enable_timing, - error_handling_mode=self.error_handling_mode, - get_traceback_func=self._get_traceback, ) # Add timing middleware separately if enabled (needs to modify response headers) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index d88d8422f7..5606da8824 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -16,6 +16,7 @@ class _ServerCallable(Protocol): # pylint: disable=too-few-public-methods Defines the interface for server objects that can be called as WSGI/ASGI applications. """ + def __call__(self, *args: Any, **kwds: Any) -> Any: raise NotImplementedError @@ -30,6 +31,7 @@ class RequestAdapter(ABC): the underlying web framework (Flask, Quart, FastAPI, etc.). Concrete implementations wrap framework-specific request objects and expose their data through these properties. """ + def __call__(self) -> "RequestAdapter": return self @@ -122,6 +124,7 @@ class BaseDashServer(ABC, Generic[ServerType]): config: Configuration dictionary for the server request_adapter: RequestAdapter class for normalizing requests """ + server_type: str server: ServerType config: Dict[str, Any] @@ -158,7 +161,6 @@ def create_app( Returns: The server application instance """ - pass @abstractmethod def register_assets_blueprint( @@ -171,12 +173,10 @@ def register_assets_blueprint( assets_url_path: URL path prefix for assets assets_folder: Filesystem path to the assets folder """ - pass @abstractmethod def register_error_handlers(self) -> None: # pragma: no cover - interface """Register error handlers for common HTTP errors.""" - pass @abstractmethod def add_url_rule( @@ -190,7 +190,6 @@ def add_url_rule( endpoint: Optional endpoint name methods: Optional list of HTTP methods (e.g., ['GET', 'POST']) """ - pass @abstractmethod def before_request(self, func) -> None: # pragma: no cover - interface @@ -199,7 +198,6 @@ def before_request(self, func) -> None: # pragma: no cover - interface Args: func: Function to execute before request handling """ - pass @abstractmethod def after_request(self, func) -> None: # pragma: no cover - interface @@ -208,7 +206,6 @@ def after_request(self, func) -> None: # pragma: no cover - interface Args: func: Function to execute after request handling """ - pass @abstractmethod def has_request_context(self) -> bool: # pragma: no cover - interface @@ -217,7 +214,6 @@ def has_request_context(self) -> bool: # pragma: no cover - interface Returns: True if in request context, False otherwise """ - pass @abstractmethod def run( @@ -232,7 +228,6 @@ def run( debug: Enable debug mode **kwargs: Additional server-specific arguments """ - pass @abstractmethod def make_response( @@ -253,7 +248,6 @@ def make_response( Returns: Server-specific response object """ - pass @abstractmethod def jsonify(self, obj) -> Any: # pragma: no cover - interface @@ -265,12 +259,10 @@ def jsonify(self, obj) -> Any: # pragma: no cover - interface Returns: JSON response object """ - pass @abstractmethod def enable_compression(self) -> None: # pragma: no cover - interface """Enable HTTP compression for responses.""" - pass @abstractmethod def register_prune_error_handler(self, secret: str, prune_errors: bool) -> None: @@ -280,7 +272,6 @@ def register_prune_error_handler(self, secret: str, prune_errors: bool) -> None: secret: Secret key for error handling prune_errors: Whether to prune stack traces in errors """ - pass @abstractmethod def register_timing_hooks(self, first_run: bool) -> None: @@ -289,7 +280,6 @@ def register_timing_hooks(self, first_run: bool) -> None: Args: first_run: Whether this is the first run of the application """ - pass @abstractmethod def register_callback_api_routes(self, callback_api_paths): @@ -298,7 +288,6 @@ def register_callback_api_routes(self, callback_api_paths): Args: callback_api_paths: Paths for callback API endpoints """ - pass @abstractmethod def setup_component_suites(self, dash_app: "dash.Dash") -> str: @@ -310,7 +299,6 @@ def setup_component_suites(self, dash_app: "dash.Dash") -> str: Returns: Base path for component suites """ - pass @abstractmethod def serve_callback(self, dash_app: "dash.Dash"): @@ -319,7 +307,6 @@ def serve_callback(self, dash_app: "dash.Dash"): Args: dash_app: The Dash application instance """ - pass @abstractmethod def setup_index(self, dash_app: "dash.Dash"): @@ -328,7 +315,6 @@ def setup_index(self, dash_app: "dash.Dash"): Args: dash_app: The Dash application instance """ - pass @abstractmethod def setup_catchall(self, dash_app: "dash.Dash"): @@ -337,7 +323,6 @@ def setup_catchall(self, dash_app: "dash.Dash"): Args: dash_app: The Dash application instance """ - pass def setup_backend(self, dash_app: "dash.Dash"): """Perform any additional backend-specific setup. From a263af2b90fd58658461b9becf774a870d43ba8f Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 19 Feb 2026 14:32:11 -0500 Subject: [PATCH 103/166] allowing a common method for cookies and headers on a new response_adapter --- dash/backends/_fastapi.py | 26 ++++++++++++++++++++--- dash/backends/_flask.py | 33 ++++++++++++++++++++++++----- dash/backends/_quart.py | 27 ++++++++++++++++++++++-- dash/backends/base_server.py | 41 ++++++++++++++++++++++++++++++++++++ dash/dash.py | 4 +--- 5 files changed, 118 insertions(+), 13 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 6476ca7a00..ed652cab0f 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -30,13 +30,33 @@ from dash.fingerprint import check_fingerprint from dash import _validate from dash.exceptions import PreventUpdate -from .base_server import BaseDashServer, RequestAdapter +from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter from ._utils import format_traceback_html if TYPE_CHECKING: # pragma: no cover - typing only from dash import Dash +class FastAPIResponseAdapter(ResponseAdapter): + """ + A custom Response class that wraps FastAPI's JSONResponse + and provides a set_response() method for compatibility with Dash's callback system. + """ + + def set_response(self, **kwargs): + """ + Set the response data. This method provides compatibility with Flask's Response.set_data(). + """ + data = kwargs.get("data") + if isinstance(data, (str, bytes, bytearray)): + resp = Response(content=data, headers=self._headers) + else: + resp = JSONResponse(content=data, headers=self._headers) + if self._cookies: + for key, (value, cookie_kwargs) in self._cookies.items(): + resp.set_cookie(key, value, **cookie_kwargs) + return resp + _current_request_var = ContextVar("dash_current_request", default=None) @@ -177,6 +197,7 @@ def __init__(self, server: FastAPI): self.server_type = "fastapi" self.error_handling_mode = "ignore" self.request_adapter = FastAPIRequestAdapter + self.response_adapter = FastAPIResponseAdapter self._before_request_funcs = [] self._after_request_func = None self._enable_timing = False @@ -461,8 +482,7 @@ async def _dispatch(request: Request): response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): response_data = await response_data - # Instead of set_data, return a new Response - return Response(content=response_data, media_type="application/json") + return cb_ctx.dash_response.set_response(data=response_data) return _dispatch diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index af18961b55..b9630c02fe 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -30,18 +30,43 @@ from dash.exceptions import PreventUpdate, InvalidResourceError from dash._callback import _invoke_callback, _async_invoke_callback from dash._utils import parse_version -from .base_server import BaseDashServer, RequestAdapter +from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter if TYPE_CHECKING: # pragma: no cover - typing only from dash import Dash +class FlaskResponseAdapter(ResponseAdapter): + """ + A custom Response class that wraps Flask's Response + and provides a set_response() method for compatibility with Dash's callback system. + """ + + def __init__(self): + self._flask_response = Response(content_type='application/json') + super().__init__() + + @property + def callback_response(self) -> Response: + return self._flask_response + + def set_cookie(self, key, value='', **kwargs): + self._flask_response.set_cookie(key, value, **kwargs) + + def set_header(self, key, value): + self._flask_response.headers.add(key, value) + + def set_response(self, **kwargs): + self._flask_response.set_data(kwargs.get('data','')) + return self._flask_response + class FlaskDashServer(BaseDashServer[Flask]): def __init__(self, server: Flask) -> None: super().__init__(server) self.server_type = "flask" self.request_adapter = FlaskRequestAdapter + self.response_adapter = FlaskResponseAdapter def __call__(self, *args: Any, **kwargs: Any): # Always WSGI @@ -239,8 +264,7 @@ def _dispatch(): "Please install the dependencies via `pip install dash[async]` and ensure " "that `use_async=False` is not being passed to the app." ) - cb_ctx.dash_response.set_data(response_data) - return cb_ctx.dash_response + return cb_ctx.dash_response.set_response(data=response_data) async def _dispatch_async(): body = request.get_json() @@ -255,8 +279,7 @@ async def _dispatch_async(): response_data = ctx.run(partial_func) if asyncio.iscoroutine(response_data): response_data = await response_data - cb_ctx.dash_response.set_data(response_data) - return cb_ctx.dash_response + return cb_ctx.dash_response.set_response(data=response_data) if dash_app._use_async: # pylint: disable=protected-access return _dispatch_async diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 92f67c2205..d793e19574 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -34,9 +34,31 @@ from dash.fingerprint import check_fingerprint from dash._utils import parse_version from dash import _validate, Dash -from .base_server import BaseDashServer, RequestAdapter +from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter from ._utils import format_traceback_html +class QuartResponseAdapter(ResponseAdapter): + """ + A custom Response class that wraps Quart's Response + and provides a set_response() method for compatibility with Dash's callback system. + """ + def __init__(self): + self._quart_response = Response(content_type='application/json') + super().__init__() + + @property + def callback_response(self) -> Response: + return self._quart_response + + def set_cookie(self, key, value='', **kwargs): + self._quart_response.set_cookie(key, value, **kwargs) + + def set_header(self, key, value): + self._quart_response.headers.add(key, value) + + def set_response(self, **kwargs): + self._quart_response.set_data(kwargs.get('data','')) + return self._quart_response class QuartDashServer(BaseDashServer[Quart]): def __init__(self, server: Quart) -> None: @@ -45,6 +67,7 @@ def __init__(self, server: Quart) -> None: self.config = {} self.error_handling_mode = "ignore" self.request_adapter = QuartRequestAdapter + self.response_adapter = QuartResponseAdapter def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] return self.server(*args, **kwargs) @@ -293,7 +316,7 @@ async def _dispatch(): response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async response_data = await response_data - return Response(response_data, content_type="application/json") # type: ignore[arg-type] + return cb_ctx.dash_response.set_response(data=response_data) # type: ignore[arg-type] return _dispatch diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 5606da8824..82e62299b4 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -110,6 +110,46 @@ def path(self) -> str: """Get the request path without query string.""" raise NotImplementedError() +class ResponseAdapter: + """Adapter for server response objects to allow setting data.""" + + def __init__(self): + # Accept a pre-made response object + self._headers = {} + self._cookies = {} + + @property + def callback_response(self): + """Get the response object to be returned from a callback.""" + # This method should be overridden in concrete implementations to return the appropriate response object + raise NotImplementedError() + + def set_cookie(self, key, value='', **kwargs): + """Set a cookie in the response (like Flask's set_cookie).""" + # Store as a tuple: (value, kwargs) + self._cookies[key] = (value, kwargs) + + def set_header(self, key, value): + """Add a header to the response (like Flask's headers.add).""" + # Allow multiple values per header key + if key in self._headers: + if isinstance(self._headers[key], list): + self._headers[key].append(value) + else: + self._headers[key] = [self._headers[key], value] + else: + self._headers[key] = value + + def set_response(self, **kwargs): + """Set the response data if supported by the response object.""" + raise NotImplementedError() + + + @property + def response(self): + """Get the underlying response object.""" + return self._response + class BaseDashServer(ABC, Generic[ServerType]): """Abstract base class for Dash server backend implementations. @@ -129,6 +169,7 @@ class BaseDashServer(ABC, Generic[ServerType]): server: ServerType config: Dict[str, Any] request_adapter: Type[RequestAdapter] + response_adapter: Type[ResponseAdapter] def __init__(self, server: ServerType) -> None: """Initialize the server wrapper. diff --git a/dash/dash.py b/dash/dash.py index 036ca3bd3b..71f7f33cb4 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -1407,9 +1407,7 @@ def _initialize_context(self, body): {"prop_id": x, "value": g.input_values.get(x)} for x in body.get("changedPropIds", []) ] - g.dash_response = self.backend.make_response( - mimetype="application/json", data=None - ) + g.dash_response = self.backend.response_adapter() g.cookies = dict(adapter.cookies) g.headers = dict(adapter.headers) g.args = adapter.args From b78853a9ad25b5f9a36f8991a4456eb9329e6fd1 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 19 Feb 2026 14:53:06 -0500 Subject: [PATCH 104/166] fix for lint --- dash/backends/_fastapi.py | 1 + dash/backends/_flask.py | 7 ++++--- dash/backends/_quart.py | 11 +++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index ed652cab0f..cca9a11d8e 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -57,6 +57,7 @@ def set_response(self, **kwargs): resp.set_cookie(key, value, **cookie_kwargs) return resp + _current_request_var = ContextVar("dash_current_request", default=None) diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index b9630c02fe..930b7a3a4f 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -36,6 +36,7 @@ if TYPE_CHECKING: # pragma: no cover - typing only from dash import Dash + class FlaskResponseAdapter(ResponseAdapter): """ A custom Response class that wraps Flask's Response @@ -43,21 +44,21 @@ class FlaskResponseAdapter(ResponseAdapter): """ def __init__(self): - self._flask_response = Response(content_type='application/json') + self._flask_response = Response(content_type="application/json") super().__init__() @property def callback_response(self) -> Response: return self._flask_response - def set_cookie(self, key, value='', **kwargs): + def set_cookie(self, key, value="", **kwargs): self._flask_response.set_cookie(key, value, **kwargs) def set_header(self, key, value): self._flask_response.headers.add(key, value) def set_response(self, **kwargs): - self._flask_response.set_data(kwargs.get('data','')) + self._flask_response.set_data(kwargs.get("data", "")) return self._flask_response diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index d793e19574..cf197e2d60 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -37,29 +37,32 @@ from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter from ._utils import format_traceback_html + class QuartResponseAdapter(ResponseAdapter): """ A custom Response class that wraps Quart's Response and provides a set_response() method for compatibility with Dash's callback system. """ + def __init__(self): - self._quart_response = Response(content_type='application/json') + self._quart_response = Response(content_type="application/json") super().__init__() @property def callback_response(self) -> Response: return self._quart_response - def set_cookie(self, key, value='', **kwargs): + def set_cookie(self, key, value="", **kwargs): self._quart_response.set_cookie(key, value, **kwargs) def set_header(self, key, value): self._quart_response.headers.add(key, value) def set_response(self, **kwargs): - self._quart_response.set_data(kwargs.get('data','')) + self._quart_response.set_data(kwargs.get("data", "")) return self._quart_response + class QuartDashServer(BaseDashServer[Quart]): def __init__(self, server: Quart) -> None: super().__init__(server) @@ -316,7 +319,7 @@ async def _dispatch(): response_data = ctx.run(partial_func) if inspect.iscoroutine(response_data): # if user callback is async response_data = await response_data - return cb_ctx.dash_response.set_response(data=response_data) # type: ignore[arg-type] + return cb_ctx.dash_response.set_response(data=response_data) # type: ignore[arg-type] return _dispatch From 7d7321c244e880dc37dbfa6a0baa9606dd8af97d Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 19 Feb 2026 15:43:48 -0500 Subject: [PATCH 105/166] fix for lint --- dash/backends/base_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 82e62299b4..09367f285d 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -110,6 +110,7 @@ def path(self) -> str: """Get the request path without query string.""" raise NotImplementedError() + class ResponseAdapter: """Adapter for server response objects to allow setting data.""" @@ -124,7 +125,7 @@ def callback_response(self): # This method should be overridden in concrete implementations to return the appropriate response object raise NotImplementedError() - def set_cookie(self, key, value='', **kwargs): + def set_cookie(self, key, value="", **kwargs): """Set a cookie in the response (like Flask's set_cookie).""" # Store as a tuple: (value, kwargs) self._cookies[key] = (value, kwargs) @@ -144,7 +145,6 @@ def set_response(self, **kwargs): """Set the response data if supported by the response object.""" raise NotImplementedError() - @property def response(self): """Get the underlying response object.""" From dd0d8e21577504f7fc94a9c0c2f38b38327da6e2 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 19 Feb 2026 16:23:20 -0500 Subject: [PATCH 106/166] fix for lint --- dash/backends/base_server.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 09367f285d..32c708a55a 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -145,11 +145,6 @@ def set_response(self, **kwargs): """Set the response data if supported by the response object.""" raise NotImplementedError() - @property - def response(self): - """Get the underlying response object.""" - return self._response - class BaseDashServer(ABC, Generic[ServerType]): """Abstract base class for Dash server backend implementations. From 2bd80f497882a0aaf0ee1302fb5256a53a3f8fdf Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 19 Feb 2026 16:45:25 -0500 Subject: [PATCH 107/166] fix for lint --- dash/backends/_fastapi.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index cca9a11d8e..782ebab7d4 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -43,6 +43,14 @@ class FastAPIResponseAdapter(ResponseAdapter): and provides a set_response() method for compatibility with Dash's callback system. """ + @property + def callback_response(self): + """Get the response object to be returned from a callback.""" + print( + "Cannot access callback_response directly on FastAPIResponseAdapter. Use set_response() to create a response with data." + ) + raise NotImplementedError() + def set_response(self, **kwargs): """ Set the response data. This method provides compatibility with Flask's Response.set_data(). From 20fb94f5bf118bfb4e8d5853f8a4f074f2d459f1 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Thu, 19 Feb 2026 17:37:17 -0500 Subject: [PATCH 108/166] adding tests for cookies and headers --- .../backend_tests/test_preconfig_backends.py | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py index 3193305ee2..e1c0dac0d2 100644 --- a/tests/backend_tests/test_preconfig_backends.py +++ b/tests/backend_tests/test_preconfig_backends.py @@ -1,6 +1,35 @@ import logging import pytest -from dash import Dash, Input, Output, html, dcc +from dash import Dash, Input, Output, html, dcc, ctx + + +@pytest.mark.parametrize( + "backend,fixture", + [ + ("flask", "dash_duo"), + ("fastapi", "dash_duo"), + ("quart", "dash_duo_mp"), + ], +) +def test_set_cookie_and_header(request, backend, fixture): + dash_duo = request.getfixturevalue(fixture) + app = Dash(__name__, backend=backend) + app.layout = html.Div([html.Button("Set", id="btn"), html.Div(id="output")]) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def set_cookie_and_header(n): + if ctx.response: + ctx.response.set_cookie("mycookie", "cookieval") + ctx.response.set_header("X-My-Header", "HeaderVal") + return f"Clicked {n}" if n else "Not clicked" + + dash_duo.start_server(app) + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Check cookie + cookies = dash_duo.driver.get_cookies() + assert any(c["name"] == "mycookie" and c["value"] == "cookieval" for c in cookies) @pytest.mark.parametrize( From 7f220aa07edd4e0629a90398fd9bc62ede838059 Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 20 Feb 2026 05:36:52 -0500 Subject: [PATCH 109/166] allowing for header test --- tests/backend_tests/test_preconfig_backends.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py index e1c0dac0d2..0bcd65cfe1 100644 --- a/tests/backend_tests/test_preconfig_backends.py +++ b/tests/backend_tests/test_preconfig_backends.py @@ -24,6 +24,21 @@ def set_cookie_and_header(n): return f"Clicked {n}" if n else "Not clicked" dash_duo.start_server(app) + dash_duo.driver.execute_script( + """ + window._lastResponseHeaders = null; + const origFetch = window.fetch; + window.fetch = async function() { + const response = await origFetch.apply(this, arguments); + response.clone().headers.forEach((v, k) => { + if (!window._lastResponseHeaders) window._lastResponseHeaders = {}; + window._lastResponseHeaders[k] = v; + }); + return response; + }; + """ + ) + dash_duo.find_element("#btn").click() dash_duo.wait_for_text_to_equal("#output", "Clicked 1") @@ -31,6 +46,9 @@ def set_cookie_and_header(n): cookies = dash_duo.driver.get_cookies() assert any(c["name"] == "mycookie" and c["value"] == "cookieval" for c in cookies) + headers = dash_duo.driver.execute_script("return window._lastResponseHeaders;") + assert headers and headers["x-my-header"] == "HeaderVal" + @pytest.mark.parametrize( "backend,fixture,input_value", From a9d55c1c3f2222404760c288114e03c371aab69a Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Fri, 20 Feb 2026 10:08:44 -0500 Subject: [PATCH 110/166] adjustments for append and set methods of headers, and adjusting the test --- dash/backends/_fastapi.py | 11 +++++++++-- dash/backends/_flask.py | 5 ++++- dash/backends/_quart.py | 5 ++++- dash/backends/base_server.py | 6 +++++- tests/backend_tests/test_preconfig_backends.py | 6 +++++- 5 files changed, 27 insertions(+), 6 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 782ebab7d4..0b27f43692 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -57,9 +57,16 @@ def set_response(self, **kwargs): """ data = kwargs.get("data") if isinstance(data, (str, bytes, bytearray)): - resp = Response(content=data, headers=self._headers) + resp = Response(content=data) else: - resp = JSONResponse(content=data, headers=self._headers) + resp = JSONResponse(content=data) + if self._headers: + for key, value in self._headers.items(): + if isinstance(value, list): + for v in value: + resp.headers.append(key, v) + else: + resp.headers[key] = value if self._cookies: for key, (value, cookie_kwargs) in self._cookies.items(): resp.set_cookie(key, value, **cookie_kwargs) diff --git a/dash/backends/_flask.py b/dash/backends/_flask.py index 930b7a3a4f..00c8730d8a 100644 --- a/dash/backends/_flask.py +++ b/dash/backends/_flask.py @@ -54,9 +54,12 @@ def callback_response(self) -> Response: def set_cookie(self, key, value="", **kwargs): self._flask_response.set_cookie(key, value, **kwargs) - def set_header(self, key, value): + def append_header(self, key, value): self._flask_response.headers.add(key, value) + def set_header(self, key, value): + self._flask_response.headers.set(key, value) + def set_response(self, **kwargs): self._flask_response.set_data(kwargs.get("data", "")) return self._flask_response diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index cf197e2d60..ddf31ff2f4 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -55,9 +55,12 @@ def callback_response(self) -> Response: def set_cookie(self, key, value="", **kwargs): self._quart_response.set_cookie(key, value, **kwargs) - def set_header(self, key, value): + def append_header(self, key, value): self._quart_response.headers.add(key, value) + def set_header(self, key, value): + self._quart_response.headers.set(key, value) + def set_response(self, **kwargs): self._quart_response.set_data(kwargs.get("data", "")) return self._quart_response diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 32c708a55a..f7211f44a6 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -130,7 +130,7 @@ def set_cookie(self, key, value="", **kwargs): # Store as a tuple: (value, kwargs) self._cookies[key] = (value, kwargs) - def set_header(self, key, value): + def append_header(self, key, value): """Add a header to the response (like Flask's headers.add).""" # Allow multiple values per header key if key in self._headers: @@ -141,6 +141,10 @@ def set_header(self, key, value): else: self._headers[key] = value + def set_header(self, key, value): + """Set a header to the response.""" + self._headers[key] = [value] + def set_response(self, **kwargs): """Set the response data if supported by the response object.""" raise NotImplementedError() diff --git a/tests/backend_tests/test_preconfig_backends.py b/tests/backend_tests/test_preconfig_backends.py index 0bcd65cfe1..eec832070a 100644 --- a/tests/backend_tests/test_preconfig_backends.py +++ b/tests/backend_tests/test_preconfig_backends.py @@ -21,6 +21,9 @@ def set_cookie_and_header(n): if ctx.response: ctx.response.set_cookie("mycookie", "cookieval") ctx.response.set_header("X-My-Header", "HeaderVal") + ctx.response.append_header("X-My-Header", "HeaderVal2") + ctx.response.append_header("X-My-Header2", "HeaderVal3") + ctx.response.set_header("X-My-Header2", "HeaderVal4") return f"Clicked {n}" if n else "Not clicked" dash_duo.start_server(app) @@ -47,7 +50,8 @@ def set_cookie_and_header(n): assert any(c["name"] == "mycookie" and c["value"] == "cookieval" for c in cookies) headers = dash_duo.driver.execute_script("return window._lastResponseHeaders;") - assert headers and headers["x-my-header"] == "HeaderVal" + assert headers and headers["x-my-header"] == "HeaderVal, HeaderVal2" + assert headers and headers["x-my-header2"] == "HeaderVal4" @pytest.mark.parametrize( From c820fee6a87cbb7ac7dd4009c6024ea07a198784 Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 23 Feb 2026 12:27:59 -0500 Subject: [PATCH 111/166] version 4.1.0rc0 --- CHANGELOG.md | 10 ++++++++++ dash/version.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cb22d4d02..5209f0dd11 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,16 @@ All notable changes to `dash` will be documented in this file. This project adheres to [Semantic Versioning](https://semver.org/). +# [4.1.0rc0] - 2026-02-23 + +## Added + +- Add support for multiple backend implementation beside flask such as fastapi and quart (both included). + - Add `app = Dash(backend="flask" | "fastapi" | "quart" | CustomBackendImpl)` parameter to automatically setup + - An existing `Fastapi`, `Quart` or `Flask` instance can also be given as `app = Dash(server=Fastapi())` to automatically setup a dash app on the server. + - Install fastapi dependencies with `pip install dash[fastapi]` or quart with `pip install dash[quart]`, flask is still included by default. + - Custom backend implementation can be added as a subclass of `dash.backends.base_server.BaseDashServer` and response/request adapters. + ## [4.0.0] - 2026-02-03 ## Added diff --git a/dash/version.py b/dash/version.py index ce1305bf4e..498a21e35c 100644 --- a/dash/version.py +++ b/dash/version.py @@ -1 +1 @@ -__version__ = "4.0.0" +__version__ = "4.1.0rc0" From e3939d8402ecb6534a87b490faa21f0c8c86eb35 Mon Sep 17 00:00:00 2001 From: Liam Connors Date: Mon, 23 Feb 2026 13:01:01 -0500 Subject: [PATCH 112/166] Update dash.py --- dash/dash.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 71f7f33cb4..e543bda31c 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -237,12 +237,11 @@ class Dash(ObsoleteChecker): best value to use. Default ``'__main__'``, env: ``DASH_APP_NAME`` :type name: string - :param server: Sets the Flask server for your app. There are three options: - ``True`` (default): Dash will create a new server + :param server: Sets the server for your app. There are three options: + ``True`` (default): Dash will create a new server using the specified backend ``False``: The server will be added later via ``app.init_app(server)`` - where ``server`` is a ``flask.Flask`` instance. - ``flask.Flask``: use this pre-existing Flask server. - :type server: boolean or flask.Flask + A server instance: Use a pre-existing server (Flask, Quart, or FastAPI) + :type server: boolean or server instance :param backend: The backend to use for the Dash app. Can be a string (name of the backend) or a backend class. Default is None, which From 5481f09e3987bb01ea5c248ee85a037ed16417d7 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 24 Feb 2026 11:46:54 -0500 Subject: [PATCH 113/166] remove global backends.backend --- dash/_callback.py | 15 +++++++-------- dash/_callback_context.py | 6 +++--- dash/_get_app.py | 4 +++- dash/_validate.py | 19 ++++++++++++------- dash/backends/__init__.py | 4 ---- dash/dash.py | 11 ++++------- dash/exceptions.py | 4 ++++ 7 files changed, 33 insertions(+), 30 deletions(-) diff --git a/dash/_callback.py b/dash/_callback.py index 0ce4ee59a4..37a53d7ec5 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -20,7 +20,7 @@ BackgroundCallbackError, ImportedInsideCallbackError, ) - +from ._get_app import get_app from ._grouping import ( flatten_grouping, make_grouping_by_index, @@ -39,7 +39,6 @@ from ._callback_context import context_value from ._no_update import NoUpdate from . import _validate -from . import backends async def _async_invoke_callback( @@ -373,7 +372,7 @@ def _get_callback_manager( " and store results on redis.\n" ) - adapter = backends.backend.request_adapter() + adapter = get_app().backend.request_adapter() old_job = adapter.args.getlist("oldJob") if hasattr(adapter.args, "getlist") else [] if old_job: @@ -433,7 +432,7 @@ def _setup_background_callback( def _progress_background_callback(response, callback_manager, background): progress_outputs = background.get("progress") - adapter = backends.backend.request_adapter() + adapter = get_app().backend.request_adapter() cache_key = adapter.args.get("cacheKey") if progress_outputs: @@ -451,7 +450,7 @@ def _update_background_callback( """Set up the background callback and manage jobs.""" callback_manager = _get_callback_manager(kwargs, background) - adapter = backends.backend.request_adapter() + adapter = get_app().backend.request_adapter() cache_key = adapter.args.get("cacheKey") if adapter else None job_id = adapter.args.get("job") if adapter else None @@ -473,7 +472,7 @@ def _handle_rest_background_callback( multi, has_update=False, ): - adapter = backends.backend.request_adapter() + adapter = get_app().backend.request_adapter() cache_key = adapter.args.get("cacheKey") if adapter else None job_id = adapter.args.get("job") if adapter else None # Must get job_running after get_result since get_results terminates it. @@ -691,7 +690,7 @@ def add_context(*args, **kwargs): jsonResponse: Optional[str] = None try: if background is not None: - adapter = backends.backend.request_adapter() + adapter = get_app().backend.request_adapter() if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, @@ -763,7 +762,7 @@ async def async_add_context(*args, **kwargs): try: if background is not None: - adapter = backends.backend.request_adapter() + adapter = get_app().backend.request_adapter() if not (adapter and adapter.args.get("cacheKey")): return _setup_background_callback( kwargs, diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 10cfb20055..646db990ab 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -5,7 +5,7 @@ import typing from . import exceptions -from . import backends +from ._get_app import get_app from ._utils import AttributeDict, stringify_id @@ -221,7 +221,7 @@ def record_timing(name, duration, description=None): :param description: A description of the resource. :type description: string or None """ - request = backends.backend.request_adapter() + request = get_app().backend.request_adapter() timing_information = getattr(request.context, "timing_information", {}) if name in timing_information: @@ -252,7 +252,7 @@ def using_outputs_grouping(self): @property @has_context def timing_information(self): - request = backends.backend.request_adapter() + request = get_app().backend.request_adapter() return getattr(request.context, "timing_information", {}) @has_context diff --git a/dash/_get_app.py b/dash/_get_app.py index a64a7450cc..ab0b897f81 100644 --- a/dash/_get_app.py +++ b/dash/_get_app.py @@ -4,6 +4,8 @@ from textwrap import dedent from typing import Any, Optional +from dash.exceptions import AppNotFoundError + APP: Optional[Any] = None app_context: ContextVar[Any] = ContextVar("dash_app_context") @@ -55,7 +57,7 @@ def get_app(): pass if APP is None: - raise Exception( + raise AppNotFoundError( dedent( """ App object is not yet defined. `app = dash.Dash()` needs to be run diff --git a/dash/_validate.py b/dash/_validate.py index bb76f896e1..fb5689f850 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -9,6 +9,7 @@ from .development.base_component import Component from . import backends from . import exceptions +from ._get_app import get_app from ._utils import ( patch_collections_abc, stringify_id, @@ -510,13 +511,17 @@ def validate_use_pages(config): "`dash.register_page()` must be called after app instantiation" ) - if backends.backend.has_request_context(): - raise exceptions.PageError( - """ - dash.register_page() can’t be called within a callback as it updates dash.page_registry, which is a global variable. - For more details, see https://dash.plotly.com/sharing-data-between-callbacks#why-global-variables-will-break-your-app - """ - ) + try: + if get_app().backend.has_request_context(): + raise exceptions.PageError( + """ + dash.register_page() can’t be called within a callback as it updates dash.page_registry, which is a global variable. + For more details, see https://dash.plotly.com/sharing-data-between-callbacks#why-global-variables-will-break-your-app + """ + ) + except exceptions.AppNotFoundError: + # If the app is not found we can add pages since before instantiation. + pass def validate_module_name(module): diff --git a/dash/backends/__init__.py b/dash/backends/__init__.py index 3a12e7939a..585d34a65c 100644 --- a/dash/backends/__init__.py +++ b/dash/backends/__init__.py @@ -4,9 +4,6 @@ from .base_server import BaseDashServer -backend: BaseDashServer - - _backend_imports = { "flask": ("dash.backends._flask", "FlaskDashServer"), "fastapi": ("dash.backends._fastapi", "FastAPIDashServer"), @@ -74,6 +71,5 @@ def get_server_type(server): __all__ = [ "get_backend", - "backend", "get_server_type", ] diff --git a/dash/dash.py b/dash/dash.py index 71f7f33cb4..a46f3af99a 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -502,14 +502,10 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self.backend = backend_cls(server) self.server = server - backends.backend = self.backend # type: ignore - backends.request_adapter = self.backend.request_adapter # type: ignore else: # No server instance provided, create backend and let backend create server self.server = backend_cls.create_app(caller_name) # type: ignore self.backend = backend_cls(self.server) - backends.backend = self.backend - backends.request_adapter = self.backend.request_adapter # type: ignore base_prefix, routes_prefix, requests_prefix = pathname_configs( url_base_pathname, routes_pathname_prefix, requests_pathname_prefix @@ -642,6 +638,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._got_first_request = {"pages": False, "setup_server": False} if server: + print(f"init app from server {server}") self.init_app() self.logger.setLevel(logging.INFO) @@ -1180,7 +1177,7 @@ def index(self, *_args, **_kwargs): renderer = self._generate_renderer() title = self.title # Refactored: direct access to global request adapter - request = backends.backend.request_adapter() + request = self.backend.request_adapter() if self.use_pages and self.config.include_pages_meta and request: metas = _page_meta_tags(self, request) + metas @@ -1396,7 +1393,7 @@ def _inputs_to_vals(self, inputs): # pylint: disable=R0915 def _initialize_context(self, body): """Initialize the global context for the request.""" - adapter = backends.backend.request_adapter() + adapter = self.backend.request_adapter() g = AttributeDict({}) g.inputs_list = body.get("inputs", []) g.states_list = body.get("state", []) @@ -2383,7 +2380,7 @@ def verify_url_part(served_part, url_part, part_name): server_url=jupyter_server_url, ) else: - backends.backend.run( + self.backend.run( dash_app=self, host=host, port=port, debug=debug, **flask_run_options ) diff --git a/dash/exceptions.py b/dash/exceptions.py index 00bd2c1553..019f0d2726 100644 --- a/dash/exceptions.py +++ b/dash/exceptions.py @@ -109,3 +109,7 @@ class ImportedInsideCallbackError(DashException): class HookError(DashException): pass + + +class AppNotFoundError(DashException): + pass From d95ddc3e36e0ccac83cb64b30c433e4848745bdb Mon Sep 17 00:00:00 2001 From: BSd3v <82055130+BSd3v@users.noreply.github.com> Date: Tue, 17 Mar 2026 11:16:23 -0400 Subject: [PATCH 114/166] fixing issue where FastAPI would not allow new paths to be added --- dash/backends/_fastapi.py | 51 ++++++++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 0b27f43692..867e9a25b0 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -28,10 +28,11 @@ ) from _err from dash.fingerprint import check_fingerprint -from dash import _validate +from dash import _validate, get_app from dash.exceptions import PreventUpdate from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter from ._utils import format_traceback_html +import traceback if TYPE_CHECKING: # pragma: no cover - typing only from dash import Dash @@ -122,8 +123,12 @@ async def _initialize_dev_tools(self) -> None: self.dash_app.enable_dev_tools(**config, first_run=False) self._dev_tools_initialized = True - def _setup_timing(self, request: Request) -> None: + async def _setup_timing(self, request: Request) -> None: """Set up timing information for the request.""" + try: + request.state.json_body = await request.json() if request.headers.get("content-type", "").startswith("application/json") else None + except: + request.state.json_body = None if self.enable_timing: request.state.timing_information = { "__dash_server": {"dur": time.time(), "desc": None} @@ -179,6 +184,12 @@ async def _handle_error( async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # Handle lifespan events (startup/shutdown) if scope["type"] == "lifespan": + try: + dash_app = get_app() + dash_app.backend._setup_catchall() + except: + print("Error during catch-all setup:") + print(traceback.format_exc()) await self._initialize_dev_tools() await self.app(scope, receive, send) return @@ -193,7 +204,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: token = set_current_request(request) try: - self._setup_timing(request) + await self._setup_timing(request) await self._run_before_hooks() await self.app(scope, receive, send) @@ -275,11 +286,24 @@ async def index(_request: Request): dash_app._add_url("", index, methods=["GET"]) def setup_catchall(self, dash_app: Dash): - async def catchall(_request: Request): - return Response(content=dash_app.index(), media_type="text/html") + '''This is needed to ensure that all routes are handled by FastAPI + and passed through the middleware, which is necessary for features like authentication + and timing to work correctly on all routes. FastAPI will match this catch-all route + for any path that isn't matched by a more specific route, allowing the middleware to + process the request and then return the appropriate response (e.g., 404 if no Dash route matches).''' - # pylint: disable=protected-access - dash_app._add_url("{path:path}", catchall, methods=["GET"]) + + def _setup_catchall(self): + try: + print("Setting up catch-all route for unmatched paths") + dash_app = get_app() + async def catchall(_request: Request): + return Response(content=dash_app.index(), media_type="text/html") + + # pylint: disable=protected-access + self.add_url_rule("{path:path}", catchall, methods=["GET"]) + except: + print(traceback.format_exc()) def add_url_rule( self, @@ -289,6 +313,7 @@ def add_url_rule( methods: list[str] | None = None, include_in_schema: bool = False, ): + print(f"Adding URL rule: {rule} -> {view_func} (endpoint: {endpoint}, methods: {methods})") if rule == "": rule = "/" if isinstance(view_func, str): @@ -481,7 +506,7 @@ def add_redirect_rule(self, app, fullname, path): def serve_callback(self, dash_app: Dash): async def _dispatch(request: Request): # pylint: disable=protected-access - body = await request.json() + body = self.request_adapter().get_json() cb_ctx = dash_app._initialize_context( body ) # pylint: disable=protected-access @@ -641,5 +666,13 @@ def origin(self): def path(self): return self._request.url.path + async def _get_json(self, request: Request=None): + req = self._request + if not hasattr(req.state, "json_body"): + req.state.json_body = await request.json() + return req.state.json_body + def get_json(self): - return asyncio.run(self._request.json()) + if not hasattr(self, "_request") or self._request is None: + self._request = get_current_request() + return self._request.state.json_body From 98ba64e896722d8a844851bbeeb1d77852fc4d9b Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 7 Apr 2026 10:55:08 -0400 Subject: [PATCH 115/166] fix lint --- dash/backends/_fastapi.py | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 867e9a25b0..4e6266146f 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio from contextvars import copy_context, ContextVar import json from typing import TYPE_CHECKING, Any, Callable, Dict @@ -13,6 +12,7 @@ import os import subprocess import threading +import traceback try: from fastapi import FastAPI, Request, Response, Body @@ -32,7 +32,6 @@ from dash.exceptions import PreventUpdate from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter from ._utils import format_traceback_html -import traceback if TYPE_CHECKING: # pragma: no cover - typing only from dash import Dash @@ -126,8 +125,14 @@ async def _initialize_dev_tools(self) -> None: async def _setup_timing(self, request: Request) -> None: """Set up timing information for the request.""" try: - request.state.json_body = await request.json() if request.headers.get("content-type", "").startswith("application/json") else None - except: + request.state.json_body = ( + await request.json() + if request.headers.get("content-type", "").startswith( + "application/json" + ) + else None + ) + except Exception: # pylint: disable=broad-exception-caught request.state.json_body = None if self.enable_timing: request.state.timing_information = { @@ -187,9 +192,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: try: dash_app = get_app() dash_app.backend._setup_catchall() - except: - print("Error during catch-all setup:") - print(traceback.format_exc()) + except Exception: # pylint: disable=broad-exception-caught + traceback.print_exc() await self._initialize_dev_tools() await self.app(scope, receive, send) return @@ -286,24 +290,24 @@ async def index(_request: Request): dash_app._add_url("", index, methods=["GET"]) def setup_catchall(self, dash_app: Dash): - '''This is needed to ensure that all routes are handled by FastAPI + """This is needed to ensure that all routes are handled by FastAPI and passed through the middleware, which is necessary for features like authentication and timing to work correctly on all routes. FastAPI will match this catch-all route for any path that isn't matched by a more specific route, allowing the middleware to - process the request and then return the appropriate response (e.g., 404 if no Dash route matches).''' - + process the request and then return the appropriate response (e.g., 404 if no Dash route matches).""" def _setup_catchall(self): try: - print("Setting up catch-all route for unmatched paths") + print("Setting up catch-all route for unmatched paths", file=sys.stderr) dash_app = get_app() + async def catchall(_request: Request): return Response(content=dash_app.index(), media_type="text/html") # pylint: disable=protected-access self.add_url_rule("{path:path}", catchall, methods=["GET"]) - except: - print(traceback.format_exc()) + except Exception: # pylint: disable=broad-exception-caught + traceback.print_exc() def add_url_rule( self, @@ -313,7 +317,10 @@ def add_url_rule( methods: list[str] | None = None, include_in_schema: bool = False, ): - print(f"Adding URL rule: {rule} -> {view_func} (endpoint: {endpoint}, methods: {methods})") + print( + f"Adding URL rule: {rule} -> {view_func} (endpoint: {endpoint}, methods: {methods})", + file=sys.stderr, + ) if rule == "": rule = "/" if isinstance(view_func, str): @@ -504,7 +511,7 @@ def add_redirect_rule(self, app, fullname, path): ) def serve_callback(self, dash_app: Dash): - async def _dispatch(request: Request): + async def _dispatch(request: Request): # pylint: disable=unused-argument # pylint: disable=protected-access body = self.request_adapter().get_json() cb_ctx = dash_app._initialize_context( @@ -666,7 +673,7 @@ def origin(self): def path(self): return self._request.url.path - async def _get_json(self, request: Request=None): + async def _get_json(self, request: Request = None): req = self._request if not hasattr(req.state, "json_body"): req.state.json_body = await request.json() From c9e23cc38abc96fd0bfa3f8a3d3a42045ec372d9 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 10 Apr 2026 11:35:42 -0400 Subject: [PATCH 116/166] fix fastapi websocket --- dash/backends/_fastapi.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 4e6266146f..edb5bade27 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -188,6 +188,7 @@ async def _handle_error( async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: # Handle lifespan events (startup/shutdown) + if scope["type"] == "lifespan": try: dash_app = get_app() @@ -199,7 +200,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return # Non-HTTP/WebSocket scopes pass through - if scope["type"] not in ("http", "websocket"): + if scope["type"] != "http": await self.app(scope, receive, send) return From 0af47f130417c0d3310ea545ef43dc1ae1c9f8e9 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 10 Apr 2026 11:38:19 -0400 Subject: [PATCH 117/166] remove leftover prints --- dash/backends/_fastapi.py | 5 ----- dash/dash.py | 1 - 2 files changed, 6 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index edb5bade27..4e5bdad621 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -299,7 +299,6 @@ def setup_catchall(self, dash_app: Dash): def _setup_catchall(self): try: - print("Setting up catch-all route for unmatched paths", file=sys.stderr) dash_app = get_app() async def catchall(_request: Request): @@ -318,10 +317,6 @@ def add_url_rule( methods: list[str] | None = None, include_in_schema: bool = False, ): - print( - f"Adding URL rule: {rule} -> {view_func} (endpoint: {endpoint}, methods: {methods})", - file=sys.stderr, - ) if rule == "": rule = "/" if isinstance(view_func, str): diff --git a/dash/dash.py b/dash/dash.py index d17c3c4e4e..ea53edf341 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -637,7 +637,6 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._got_first_request = {"pages": False, "setup_server": False} if server: - print(f"init app from server {server}") self.init_app() self.logger.setLevel(logging.INFO) From cac28e46e7e55365451c964f0816705a063e3ad6 Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 13 Apr 2026 11:00:43 -0400 Subject: [PATCH 118/166] version 4.2.0rc0 --- dash/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/version.py b/dash/version.py index 7039708762..25b76de3c3 100644 --- a/dash/version.py +++ b/dash/version.py @@ -1 +1 @@ -__version__ = "4.1.0" +__version__ = "4.2.0rc0" From 4ff49b58de95ca7a0ee83113f4b1b674ae599686 Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 13 Apr 2026 11:06:23 -0400 Subject: [PATCH 119/166] update changelog --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 49782ccf04..3709fbd2e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,12 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#3609](https://github.com/plotly/dash/pull/3609) Add backward compat alias for _Wildcard - [#3672](https://github.com/plotly/dash/pull/3672) Improve browser performance when app contains a large number of pattern matching callback callbacks. Exposes an api endpoint to fetch the latest computeGraph call. +# [4.2.0rc0] - 2026-04-13 + +## Fixed + +- Fix websocket used in the same FastAPI server. Fix [#3636](https://github.com/plotly/dash/issues/3636) +- Fix FastAPI url paths order. Fix [3667](https://github.com/plotly/dash/issues/3667) # [4.1.0rc0] - 2026-02-23 From 312c6e3e5f4eaf02f797ced820bbe89e3fab5398 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 14 Apr 2026 09:30:13 -0400 Subject: [PATCH 120/166] update ai docs to include backends --- .ai/ARCHITECTURE.md | 147 ++++++++++++++++++++++++++++++++++++++++++-- .ai/COMMANDS.md | 13 ++++ 2 files changed, 154 insertions(+), 6 deletions(-) diff --git a/.ai/ARCHITECTURE.md b/.ai/ARCHITECTURE.md index ca6816ccf9..afff17394a 100644 --- a/.ai/ARCHITECTURE.md +++ b/.ai/ARCHITECTURE.md @@ -2,7 +2,9 @@ ## Python Backend Framework -- **`dash/dash.py`** - Main `Dash` application class (~2000 lines). Orchestrates Flask server, layout management, callback registration, routing, and asset serving. Key methods: `layout` property, `callback()`, `clientside_callback()`, `run()`. +- **`dash/dash.py`** - Main `Dash` application class (~2000 lines). Orchestrates the server backend, layout management, callback registration, routing, and asset serving. Key methods: `layout` property, `callback()`, `clientside_callback()`, `run()`. + +- **`dash/backends/`** - Server backend implementations. See [Server Backends](#server-backends) section for details. - **`dash/_callback.py`** - Callback registration and execution. Contains `callback()` decorator (usable as `@dash.callback` without app instance), `clientside_callback()`, and `register_callback()` which inserts callbacks into the callback map. @@ -101,6 +103,127 @@ Use dict IDs with wildcards (`MATCH`, `ALL`, `ALLSMALLER`) to target dynamically - `/_dash-component-suites//` - Serves component JS/CSS assets - `/assets/` - Serves static assets from app's assets folder +## Server Backends + +Dash supports multiple web server backends. The backend abstraction is in `dash/backends/`. + +### Available Backends + +| Backend | Type | Install | Use Case | +|---------|------|---------|----------| +| **Flask** (default) | WSGI (sync) | `pip install dash` | Standard deployments, simplicity | +| **Quart** | ASGI (async) | `pip install dash[quart]` | Async callbacks, WebSocket support | +| **FastAPI** | ASGI (async) | `pip install dash[fastapi]` | OpenAPI docs, async, modern Python | + +### Usage + +**Default (Flask):** +```python +from dash import Dash +app = Dash(__name__) +``` + +**With existing server instance:** +```python +from flask import Flask +from dash import Dash + +server = Flask(__name__) +app = Dash(__name__, server=server) +``` + +**Quart backend:** +```python +from quart import Quart +from dash import Dash + +server = Quart(__name__) +app = Dash(__name__, server=server) +``` + +**FastAPI backend:** +```python +from fastapi import FastAPI +from dash import Dash + +server = FastAPI() +app = Dash(__name__, server=server) + +# Run with: uvicorn module:app.server --reload +``` + +### Architecture + +The backend system uses an abstract interface: + +- **`BaseDashServer`** (`dash/backends/base_server.py`) - Abstract base class defining the server interface. All backends implement this. + +- **`RequestAdapter`** - Normalizes HTTP request objects across frameworks. Provides unified access to `args`, `cookies`, `headers`, `get_json()`, etc. + +- **`ResponseAdapter`** - Normalizes response creation. Handles `set_cookie()`, `set_header()`, `set_response()`. + +- **`get_backend(name)`** - Factory function to get backend class by name (`"flask"`, `"quart"`, `"fastapi"`). + +- **`get_server_type(server)`** - Auto-detects backend from a server instance. + +### Backend Implementations + +**Flask** (`dash/backends/_flask.py`): +- `FlaskDashServer` - Wraps Flask app +- `FlaskRequestAdapter` - Uses `flask.request` proxy +- `FlaskResponseAdapter` - Uses `flask.Response` +- Compression via `flask-compress` + +**Quart** (`dash/backends/_quart.py`): +- `QuartDashServer` - Wraps Quart app (async Flask API) +- `QuartRequestAdapter` - Uses `quart.request` proxy +- `QuartResponseAdapter` - Uses `quart.Response` +- All route handlers are `async def` +- Compression via `quart-compress` + +**FastAPI** (`dash/backends/_fastapi.py`): +- `FastAPIDashServer` - Wraps FastAPI app +- `FastAPIRequestAdapter` - Uses context variable for current request +- `FastAPIResponseAdapter` - Uses Starlette responses +- `DashMiddleware` - Consolidated ASGI middleware for request handling +- Runs with uvicorn, supports hot reload +- Built-in GZip compression + +### Key Interface Methods + +All backends implement: + +```python +class BaseDashServer(ABC): + def create_app(name, config) -> server # Create new server + def add_url_rule(rule, view_func, ...) # Register routes + def before_request(func) # Request hooks + def after_request(func) # Response hooks + def run(dash_app, host, port, debug) # Start dev server + def make_response(data, mimetype, status) # Create response + def jsonify(obj) # JSON response + def setup_index(dash_app) # Register / route + def serve_callback(dash_app) # Callback endpoint + def setup_component_suites(dash_app) # JS/CSS serving +``` + +### Accessing the Backend + +```python +app = Dash(__name__) + +# Get the underlying server +app.server # Flask/Quart/FastAPI instance + +# Get the backend wrapper +app.backend # BaseDashServer subclass instance +app.backend.server_type # "flask", "quart", or "fastapi" + +# Access request in callbacks +from dash import dash +dash.get_app().backend.request_adapter() # RequestAdapter instance +``` + ## Frontend (dash-renderer) **`dash/dash-renderer/src/`** contains the TypeScript/React frontend. See [RENDERER.md](RENDERER.md) for detailed documentation on: @@ -503,14 +626,14 @@ app.run( ### How It Works 1. `app.run()` detects Jupyter environment via `get_ipython()` -2. Flask server starts in background daemon thread +2. Server starts in background daemon thread 3. Jupyter comm protocol negotiates proxy configuration 4. App displays according to selected mode ``` app.run() in notebook ↓ -Detect Jupyter → Start Flask in background thread +Detect Jupyter → Start server in background thread ↓ Comm request → Extension responds with base_url ↓ @@ -572,8 +695,8 @@ Special handling for Colab: ### Dash() Constructor Parameters **Basic Setup:** -- `name` - Flask app name (default: infers from `__name__`) -- `server` - Flask instance or `True` to create new (default: `True`) +- `name` - Application name (default: infers from `__name__`) +- `server` - Server instance (Flask, Quart, or FastAPI) or `True` to create Flask (default: `True`) - `title` - Browser tab title (default: `"Dash"`) - `update_title` - Title during callbacks (default: `"Updating..."`) @@ -682,6 +805,7 @@ Dash supports `async def` callbacks for non-blocking execution. ### Setup +**With Flask backend:** ```bash pip install dash[async] ``` @@ -692,6 +816,16 @@ Async is auto-enabled when `asgiref` is detected. Or explicitly: app = Dash(__name__, use_async=True) ``` +**With Quart or FastAPI backend:** Async is native - no extra dependencies needed. + +```python +from fastapi import FastAPI +from dash import Dash + +server = FastAPI() +app = Dash(__name__, server=server) # Async works automatically +``` + ### Usage ```python @@ -708,7 +842,8 @@ async def async_update(value): - Regular async callbacks are **non-blocking** - multiple can run concurrently - Background callbacks also support `async def` - Jupyter uses `nest_asyncio` for event loop compatibility -- Without `dash[async]`, coroutines raise an error +- With Flask backend: requires `dash[async]`, coroutines raise error without it +- With Quart/FastAPI backends: async is native, no extra setup needed ### Async with Background Callbacks diff --git a/.ai/COMMANDS.md b/.ai/COMMANDS.md index 4c155d87b7..9671a6c6cb 100644 --- a/.ai/COMMANDS.md +++ b/.ai/COMMANDS.md @@ -14,6 +14,19 @@ pip install -e .[ci,dev,testing,celery,diskcache] npm ci ``` +### Optional Backend Dependencies + +```bash +# For Quart backend (ASGI async) +pip install dash[quart] + +# For FastAPI backend (ASGI async) +pip install dash[fastapi] + +# For async callbacks with Flask +pip install dash[async] +``` + ## Building ```bash From 2148fbbe1a66574ccaf99824e2944d96cca4cb2c Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 14 Apr 2026 11:11:05 -0400 Subject: [PATCH 121/166] Add websocket callbacks --- @plotly/dash-websocket-worker/README.md | 3 + @plotly/dash-websocket-worker/package.json | 29 ++ .../src/MessageRouter.ts | 186 ++++++++++++ .../src/WebSocketManager.ts | 257 ++++++++++++++++ @plotly/dash-websocket-worker/src/index.ts | 18 ++ @plotly/dash-websocket-worker/src/types.ts | 150 ++++++++++ @plotly/dash-websocket-worker/src/worker.ts | 129 ++++++++ @plotly/dash-websocket-worker/tsconfig.json | 20 ++ .../dash-websocket-worker/webpack.config.js | 25 ++ dash/_callback.py | 1 + dash/_callback_context.py | 52 +++- dash/_dash_renderer.py | 9 +- dash/backends/_fastapi.py | 224 +++++++++++++- dash/backends/base_server.py | 40 +++ dash/dash-renderer/init.template | 5 + dash/dash-renderer/package.json | 2 +- dash/dash-renderer/src/AppProvider.react.tsx | 32 +- dash/dash-renderer/src/actions/callbacks.ts | 171 ++++++++++- dash/dash-renderer/src/config.ts | 5 + .../src/observers/websocketObserver.ts | 148 ++++++++++ dash/dash-renderer/src/utils/rendererId.ts | 23 ++ dash/dash-renderer/src/utils/workerClient.ts | 276 ++++++++++++++++++ dash/dash-renderer/webpack.base.config.js | 29 +- dash/dash.py | 42 +++ wsapp.py | 106 +++++++ 25 files changed, 1965 insertions(+), 17 deletions(-) create mode 100644 @plotly/dash-websocket-worker/README.md create mode 100644 @plotly/dash-websocket-worker/package.json create mode 100644 @plotly/dash-websocket-worker/src/MessageRouter.ts create mode 100644 @plotly/dash-websocket-worker/src/WebSocketManager.ts create mode 100644 @plotly/dash-websocket-worker/src/index.ts create mode 100644 @plotly/dash-websocket-worker/src/types.ts create mode 100644 @plotly/dash-websocket-worker/src/worker.ts create mode 100644 @plotly/dash-websocket-worker/tsconfig.json create mode 100644 @plotly/dash-websocket-worker/webpack.config.js create mode 100644 dash/dash-renderer/src/observers/websocketObserver.ts create mode 100644 dash/dash-renderer/src/utils/rendererId.ts create mode 100644 dash/dash-renderer/src/utils/workerClient.ts create mode 100644 wsapp.py diff --git a/@plotly/dash-websocket-worker/README.md b/@plotly/dash-websocket-worker/README.md new file mode 100644 index 0000000000..64e37a1987 --- /dev/null +++ b/@plotly/dash-websocket-worker/README.md @@ -0,0 +1,3 @@ +# Dash websocket worker + +Worker for websocket based callbacks. diff --git a/@plotly/dash-websocket-worker/package.json b/@plotly/dash-websocket-worker/package.json new file mode 100644 index 0000000000..619a842380 --- /dev/null +++ b/@plotly/dash-websocket-worker/package.json @@ -0,0 +1,29 @@ +{ + "name": "@plotly/dash-websocket-worker", + "version": "1.0.0", + "description": "SharedWorker for WebSocket-based Dash callbacks", + "main": "dist/index.js", + "types": "dist/index.d.ts", + "scripts": { + "build": "webpack --mode production", + "build:dev": "webpack --mode development", + "watch": "webpack --mode development --watch", + "clean": "rm -rf dist" + }, + "files": [ + "dist" + ], + "keywords": [ + "dash", + "websocket", + "sharedworker" + ], + "author": "Plotly", + "license": "MIT", + "devDependencies": { + "typescript": "^5.0.0", + "webpack": "^5.0.0", + "webpack-cli": "^5.0.0", + "ts-loader": "^9.0.0" + } +} diff --git a/@plotly/dash-websocket-worker/src/MessageRouter.ts b/@plotly/dash-websocket-worker/src/MessageRouter.ts new file mode 100644 index 0000000000..1082c3e6c1 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/MessageRouter.ts @@ -0,0 +1,186 @@ +import { + WorkerMessageType, + WorkerMessage, + CallbackRequestMessage, + GetPropsResponseMessage, + SetPropsMessage, + GetPropsRequestMessage, + CallbackResponseMessage +} from './types'; + +/** + * Routes messages between renderers (via MessagePorts) and the WebSocket server. + */ +export class MessageRouter { + /** Map of renderer IDs to their MessagePorts */ + private renderers: Map = new Map(); + + /** Callback to send messages to the WebSocket server */ + public sendToServer: ((message: unknown) => void) | null = null; + + /** + * Register a renderer with its MessagePort. + * @param rendererId Unique identifier for the renderer + * @param port The MessagePort for communication + */ + public registerRenderer(rendererId: string, port: MessagePort): void { + this.renderers.set(rendererId, port); + } + + /** + * Unregister a renderer. + * @param rendererId The renderer to unregister + */ + public unregisterRenderer(rendererId: string): void { + this.renderers.delete(rendererId); + } + + /** + * Get the number of connected renderers. + */ + public get rendererCount(): number { + return this.renderers.size; + } + + /** + * Handle a message from a renderer. + * @param rendererId The ID of the renderer that sent the message + * @param message The message from the renderer + */ + public handleRendererMessage(rendererId: string, message: WorkerMessage): void { + switch (message.type) { + case WorkerMessageType.CALLBACK_REQUEST: + this.forwardCallbackRequest(rendererId, message as CallbackRequestMessage); + break; + + case WorkerMessageType.GET_PROPS_RESPONSE: + this.forwardGetPropsResponse(rendererId, message as GetPropsResponseMessage); + break; + + default: + console.warn(`Unknown message type from renderer: ${message.type}`); + } + } + + /** + * Handle a message from the WebSocket server. + * @param message The message from the server + */ + public handleServerMessage(message: unknown): void { + const msg = message as WorkerMessage; + const rendererId = msg.rendererId; + + switch (msg.type) { + case WorkerMessageType.CALLBACK_RESPONSE: + this.forwardToRenderer(rendererId, msg as CallbackResponseMessage); + break; + + case WorkerMessageType.SET_PROPS: + this.forwardSetProps(rendererId, msg as SetPropsMessage); + break; + + case WorkerMessageType.GET_PROPS_REQUEST: + this.forwardGetPropsRequest(rendererId, msg as GetPropsRequestMessage); + break; + + case WorkerMessageType.ERROR: + this.forwardToRenderer(rendererId, msg); + break; + + default: + console.warn(`Unknown message type from server: ${msg.type}`); + } + } + + /** + * Send a message to all connected renderers. + * @param message The message to broadcast + */ + public broadcastToRenderers(message: WorkerMessage): void { + for (const [, port] of this.renderers) { + port.postMessage(message); + } + } + + /** + * Send a connected notification to a specific renderer. + * @param rendererId The renderer to notify + */ + public notifyConnected(rendererId: string): void { + const port = this.renderers.get(rendererId); + if (port) { + port.postMessage({ + type: WorkerMessageType.CONNECTED, + rendererId + }); + } + } + + /** + * Send a disconnected notification to all renderers. + * @param reason Optional reason for disconnection + */ + public notifyDisconnected(reason?: string): void { + this.broadcastToRenderers({ + type: WorkerMessageType.DISCONNECTED, + rendererId: '', + payload: { reason } + }); + } + + /** + * Send an error notification to a specific renderer. + * @param rendererId The renderer to notify + * @param message Error message + * @param code Optional error code + */ + public notifyError(rendererId: string, message: string, code?: string): void { + const port = this.renderers.get(rendererId); + if (port) { + port.postMessage({ + type: WorkerMessageType.ERROR, + rendererId, + payload: { message, code } + }); + } + } + + private forwardCallbackRequest(rendererId: string, message: CallbackRequestMessage): void { + if (this.sendToServer) { + this.sendToServer({ + type: WorkerMessageType.CALLBACK_REQUEST, + rendererId, + requestId: message.requestId, + payload: message.payload + }); + } + } + + private forwardGetPropsResponse(rendererId: string, message: GetPropsResponseMessage): void { + if (this.sendToServer) { + this.sendToServer({ + type: WorkerMessageType.GET_PROPS_RESPONSE, + rendererId, + requestId: message.requestId, + payload: message.payload + }); + } + } + + private forwardToRenderer(rendererId: string, message: WorkerMessage): void { + const port = this.renderers.get(rendererId); + if (port) { + port.postMessage(message); + } else { + console.warn(`Renderer ${rendererId} not found for message`); + } + } + + private forwardSetProps(rendererId: string, message: SetPropsMessage): void { + this.forwardToRenderer(rendererId, message); + } + + private forwardGetPropsRequest(rendererId: string, message: GetPropsRequestMessage): void { + this.forwardToRenderer(rendererId, message); + } +} diff --git a/@plotly/dash-websocket-worker/src/WebSocketManager.ts b/@plotly/dash-websocket-worker/src/WebSocketManager.ts new file mode 100644 index 0000000000..e0353e7e94 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/WebSocketManager.ts @@ -0,0 +1,257 @@ +/** + * Configuration options for WebSocket connection. + */ +interface WebSocketConfig { + /** Maximum number of reconnection attempts */ + maxRetries: number; + /** Initial delay between reconnection attempts (ms) */ + initialRetryDelay: number; + /** Maximum delay between reconnection attempts (ms) */ + maxRetryDelay: number; + /** Heartbeat interval (ms) */ + heartbeatInterval: number; + /** Heartbeat timeout (ms) */ + heartbeatTimeout: number; +} + +const DEFAULT_CONFIG: WebSocketConfig = { + maxRetries: 10, + initialRetryDelay: 1000, + maxRetryDelay: 30000, + heartbeatInterval: 30000, + heartbeatTimeout: 10000 +}; + +/** + * Manages WebSocket connection with automatic reconnection and heartbeat. + */ +export class WebSocketManager { + private ws: WebSocket | null = null; + private serverUrl: string | null = null; + private config: WebSocketConfig; + private retryCount = 0; + private retryTimeout: ReturnType | null = null; + private heartbeatInterval: ReturnType | null = null; + private heartbeatTimeout: ReturnType | null = null; + private messageQueue: string[] = []; + private isConnecting = false; + + /** Callback when connection is established */ + public onOpen: (() => void) | null = null; + /** Callback when connection is closed */ + public onClose: ((reason?: string) => void) | null = null; + /** Callback when a message is received */ + public onMessage: ((data: unknown) => void) | null = null; + /** Callback when an error occurs */ + public onError: ((error: Error) => void) | null = null; + + constructor(config: Partial = {}) { + this.config = { ...DEFAULT_CONFIG, ...config }; + } + + /** + * Connect to the WebSocket server. + * @param serverUrl The WebSocket server URL + */ + public connect(serverUrl: string): void { + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + // Already connected + return; + } + + if (this.isConnecting) { + // Connection in progress + return; + } + + this.serverUrl = serverUrl; + this.isConnecting = true; + this.createConnection(); + } + + /** + * Disconnect from the WebSocket server. + */ + public disconnect(): void { + this.cleanup(); + if (this.ws) { + this.ws.close(1000, 'Client disconnect'); + this.ws = null; + } + this.serverUrl = null; + this.retryCount = 0; + } + + /** + * Send a message through the WebSocket connection. + * If not connected, queues the message for later delivery. + * @param message The message to send + */ + public send(message: unknown): void { + const data = JSON.stringify(message); + + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.send(data); + } else { + // Queue message for when connection is established + this.messageQueue.push(data); + } + } + + /** + * Check if the WebSocket is currently connected. + */ + public get isConnected(): boolean { + return this.ws !== null && this.ws.readyState === WebSocket.OPEN; + } + + private createConnection(): void { + if (!this.serverUrl) { + return; + } + + try { + this.ws = new WebSocket(this.serverUrl); + this.ws.onopen = this.handleOpen.bind(this); + this.ws.onclose = this.handleClose.bind(this); + this.ws.onmessage = this.handleMessage.bind(this); + this.ws.onerror = this.handleError.bind(this); + } catch (error) { + this.isConnecting = false; + this.scheduleReconnect(); + } + } + + private handleOpen(): void { + this.isConnecting = false; + this.retryCount = 0; + + // Flush queued messages + while (this.messageQueue.length > 0) { + const message = this.messageQueue.shift(); + if (message && this.ws) { + this.ws.send(message); + } + } + + // Start heartbeat + this.startHeartbeat(); + + if (this.onOpen) { + this.onOpen(); + } + } + + private handleClose(event: CloseEvent): void { + this.isConnecting = false; + this.cleanup(); + + const reason = event.reason || 'Connection closed'; + + if (this.onClose) { + this.onClose(reason); + } + + // Only reconnect if we haven't explicitly disconnected + if (this.serverUrl && event.code !== 1000) { + this.scheduleReconnect(); + } + } + + private handleMessage(event: MessageEvent): void { + try { + const data = JSON.parse(event.data); + + // Handle heartbeat acknowledgment + if (data.type === 'heartbeat_ack') { + this.clearHeartbeatTimeout(); + return; + } + + if (this.onMessage) { + this.onMessage(data); + } + } catch (error) { + if (this.onError) { + this.onError(new Error('Failed to parse message')); + } + } + } + + private handleError(): void { + this.isConnecting = false; + // WebSocket error events don't contain useful information + // The close event will follow with more details + } + + private scheduleReconnect(): void { + if (this.retryTimeout) { + clearTimeout(this.retryTimeout); + } + + if (this.retryCount >= this.config.maxRetries) { + if (this.onError) { + this.onError(new Error('Max reconnection attempts reached')); + } + return; + } + + // Exponential backoff with jitter + const delay = Math.min( + this.config.initialRetryDelay * Math.pow(2, this.retryCount) + + Math.random() * 1000, + this.config.maxRetryDelay + ); + + this.retryCount++; + + this.retryTimeout = setTimeout(() => { + this.createConnection(); + }, delay); + } + + private startHeartbeat(): void { + this.stopHeartbeat(); + + this.heartbeatInterval = setInterval(() => { + if (this.ws && this.ws.readyState === WebSocket.OPEN) { + this.ws.send(JSON.stringify({ type: 'heartbeat' })); + this.setHeartbeatTimeout(); + } + }, this.config.heartbeatInterval); + } + + private stopHeartbeat(): void { + if (this.heartbeatInterval) { + clearInterval(this.heartbeatInterval); + this.heartbeatInterval = null; + } + this.clearHeartbeatTimeout(); + } + + private setHeartbeatTimeout(): void { + this.clearHeartbeatTimeout(); + + this.heartbeatTimeout = setTimeout(() => { + // Heartbeat timeout - connection may be dead + if (this.ws) { + this.ws.close(4000, 'Heartbeat timeout'); + } + }, this.config.heartbeatTimeout); + } + + private clearHeartbeatTimeout(): void { + if (this.heartbeatTimeout) { + clearTimeout(this.heartbeatTimeout); + this.heartbeatTimeout = null; + } + } + + private cleanup(): void { + this.stopHeartbeat(); + if (this.retryTimeout) { + clearTimeout(this.retryTimeout); + this.retryTimeout = null; + } + } +} diff --git a/@plotly/dash-websocket-worker/src/index.ts b/@plotly/dash-websocket-worker/src/index.ts new file mode 100644 index 0000000000..e21b382d41 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/index.ts @@ -0,0 +1,18 @@ +/** + * Dash WebSocket Worker Package + * + * Provides a SharedWorker for WebSocket-based Dash callbacks. + */ + +export * from './types'; + +/** + * Get the URL for the WebSocket worker script. + * This should be used to instantiate the SharedWorker. + * + * @param baseUrl Base URL where the worker script is served + * @returns Full URL to the worker script + */ +export function getWorkerUrl(baseUrl: string): string { + return `${baseUrl}/dash-ws-worker.js`; +} diff --git a/@plotly/dash-websocket-worker/src/types.ts b/@plotly/dash-websocket-worker/src/types.ts new file mode 100644 index 0000000000..36fadf03a0 --- /dev/null +++ b/@plotly/dash-websocket-worker/src/types.ts @@ -0,0 +1,150 @@ +/** + * Message types for communication between renderer and worker. + */ +export enum WorkerMessageType { + // Renderer -> Worker + CONNECT = 'connect', + DISCONNECT = 'disconnect', + CALLBACK_REQUEST = 'callback_request', + GET_PROPS_RESPONSE = 'get_props_response', + + // Worker -> Renderer + CONNECTED = 'connected', + DISCONNECTED = 'disconnected', + CALLBACK_RESPONSE = 'callback_response', + SET_PROPS = 'set_props', + GET_PROPS_REQUEST = 'get_props_request', + ERROR = 'error' +} + +/** + * Base message structure for worker communication. + */ +export interface WorkerMessage { + type: WorkerMessageType; + rendererId: string; + requestId?: string; + payload?: unknown; +} + +/** + * Message from renderer to worker requesting connection. + */ +export interface ConnectMessage extends WorkerMessage { + type: WorkerMessageType.CONNECT; + payload: { + serverUrl: string; + }; +} + +/** + * Message from renderer to worker requesting disconnect. + */ +export interface DisconnectMessage extends WorkerMessage { + type: WorkerMessageType.DISCONNECT; +} + +/** + * Callback request payload structure. + */ +export interface CallbackPayload { + output: string; + outputs: unknown[]; + inputs: unknown[]; + state?: unknown[]; + changedPropIds: string[]; + parsedChangedPropsIds?: string[]; +} + +/** + * Message from renderer to worker with callback request. + */ +export interface CallbackRequestMessage extends WorkerMessage { + type: WorkerMessageType.CALLBACK_REQUEST; + payload: CallbackPayload; +} + +/** + * Message from worker to renderer with callback response. + */ +export interface CallbackResponseMessage extends WorkerMessage { + type: WorkerMessageType.CALLBACK_RESPONSE; + payload: { + status: 'ok' | 'prevent_update' | 'error'; + data?: Record; + message?: string; + }; +} + +/** + * Message from worker to renderer to set component props. + */ +export interface SetPropsMessage extends WorkerMessage { + type: WorkerMessageType.SET_PROPS; + payload: { + componentId: string; + props: Record; + }; +} + +/** + * Message from worker to renderer requesting prop values. + */ +export interface GetPropsRequestMessage extends WorkerMessage { + type: WorkerMessageType.GET_PROPS_REQUEST; + payload: { + componentId: string; + properties: string[]; + }; +} + +/** + * Message from renderer to worker with prop values. + */ +export interface GetPropsResponseMessage extends WorkerMessage { + type: WorkerMessageType.GET_PROPS_RESPONSE; + payload: Record; +} + +/** + * Error message from worker to renderer. + */ +export interface ErrorMessage extends WorkerMessage { + type: WorkerMessageType.ERROR; + payload: { + message: string; + code?: string; + }; +} + +/** + * Connected confirmation message from worker to renderer. + */ +export interface ConnectedMessage extends WorkerMessage { + type: WorkerMessageType.CONNECTED; +} + +/** + * Disconnected notification message from worker to renderer. + */ +export interface DisconnectedMessage extends WorkerMessage { + type: WorkerMessageType.DISCONNECTED; + payload?: { + reason?: string; + }; +} + +/** + * Union type of all possible worker messages. + */ +export type AnyWorkerMessage = + | ConnectMessage + | DisconnectMessage + | CallbackRequestMessage + | CallbackResponseMessage + | SetPropsMessage + | GetPropsRequestMessage + | GetPropsResponseMessage + | ErrorMessage + | ConnectedMessage + | DisconnectedMessage; diff --git a/@plotly/dash-websocket-worker/src/worker.ts b/@plotly/dash-websocket-worker/src/worker.ts new file mode 100644 index 0000000000..ff84b4fa0f --- /dev/null +++ b/@plotly/dash-websocket-worker/src/worker.ts @@ -0,0 +1,129 @@ +/** + * Dash WebSocket Worker + * + * A SharedWorker that maintains a single WebSocket connection to the Dash server + * and routes messages between multiple renderer instances (browser tabs). + */ + +import { WebSocketManager } from './WebSocketManager'; +import { MessageRouter } from './MessageRouter'; +import { + WorkerMessageType, + WorkerMessage, + ConnectMessage +} from './types'; + +// SharedWorker global scope +declare const self: SharedWorkerGlobalScope; + +/** WebSocket connection manager */ +const wsManager = new WebSocketManager(); + +/** Message router for renderers */ +const router = new MessageRouter(); + +/** Current server URL */ +let serverUrl: string | null = null; + +/** + * Set up WebSocket manager callbacks. + */ +wsManager.onOpen = () => { + console.log('[DashWSWorker] WebSocket connected'); + // Notify all renderers that connection is established + for (const rendererId of getRendererIds()) { + router.notifyConnected(rendererId); + } +}; + +wsManager.onClose = (reason?: string) => { + console.log(`[DashWSWorker] WebSocket closed: ${reason}`); + router.notifyDisconnected(reason); +}; + +wsManager.onMessage = (data: unknown) => { + router.handleServerMessage(data); +}; + +wsManager.onError = (error: Error) => { + console.error('[DashWSWorker] WebSocket error:', error.message); +}; + +/** + * Set up router to send messages to WebSocket. + */ +router.sendToServer = (message: unknown) => { + wsManager.send(message); +}; + +// Track renderer IDs separately for iteration +const rendererIds = new Set(); + +/** + * Get all registered renderer IDs. + */ +function getRendererIds(): string[] { + return Array.from(rendererIds); +} + +/** + * Handle new connection from a renderer (browser tab). + */ +self.onconnect = (event: MessageEvent) => { + const port = event.ports[0]; + + port.onmessage = (e: MessageEvent) => { + const message = e.data as WorkerMessage; + + switch (message.type) { + case WorkerMessageType.CONNECT: { + const connectMsg = message as ConnectMessage; + const rendererId = connectMsg.rendererId; + const newServerUrl = connectMsg.payload.serverUrl; + + // Register the renderer + router.registerRenderer(rendererId, port); + rendererIds.add(rendererId); + + console.log(`[DashWSWorker] Renderer ${rendererId} connected`); + + // Connect to server if not already connected + if (!wsManager.isConnected) { + if (serverUrl !== newServerUrl) { + serverUrl = newServerUrl; + } + wsManager.connect(serverUrl); + } else { + // Already connected, notify the renderer + router.notifyConnected(rendererId); + } + break; + } + + case WorkerMessageType.DISCONNECT: { + const rendererId = message.rendererId; + router.unregisterRenderer(rendererId); + rendererIds.delete(rendererId); + + console.log(`[DashWSWorker] Renderer ${rendererId} disconnected`); + + // If no more renderers, disconnect from server + if (router.rendererCount === 0) { + wsManager.disconnect(); + serverUrl = null; + console.log('[DashWSWorker] All renderers disconnected, closing WebSocket'); + } + break; + } + + default: + // Forward other messages through the router + router.handleRendererMessage(message.rendererId, message); + } + }; + + port.start(); +}; + +// Log worker startup +console.log('[DashWSWorker] SharedWorker initialized'); diff --git a/@plotly/dash-websocket-worker/tsconfig.json b/@plotly/dash-websocket-worker/tsconfig.json new file mode 100644 index 0000000000..0254db7f91 --- /dev/null +++ b/@plotly/dash-websocket-worker/tsconfig.json @@ -0,0 +1,20 @@ +{ + "compilerOptions": { + "target": "ES2020", + "module": "ESNext", + "lib": ["ES2020", "WebWorker"], + "declaration": true, + "declarationMap": true, + "sourceMap": true, + "outDir": "./dist", + "rootDir": "./src", + "strict": true, + "moduleResolution": "node", + "esModuleInterop": true, + "skipLibCheck": true, + "forceConsistentCasingInFileNames": true, + "resolveJsonModule": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist"] +} diff --git a/@plotly/dash-websocket-worker/webpack.config.js b/@plotly/dash-websocket-worker/webpack.config.js new file mode 100644 index 0000000000..efe7b59e89 --- /dev/null +++ b/@plotly/dash-websocket-worker/webpack.config.js @@ -0,0 +1,25 @@ +const path = require('path'); + +// This config is for standalone development/testing of the worker. +// The production build is handled by dash-renderer's webpack config. +module.exports = { + entry: './src/worker.ts', + output: { + filename: 'dash-ws-worker.js', + path: path.resolve(__dirname, 'dist'), + clean: true + }, + resolve: { + extensions: ['.ts', '.js'] + }, + module: { + rules: [ + { + test: /\.ts$/, + use: 'ts-loader', + exclude: /node_modules/ + } + ] + }, + target: 'webworker' +}; diff --git a/dash/_callback.py b/dash/_callback.py index 37a53d7ec5..ff1072efd3 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -77,6 +77,7 @@ def callback( api_endpoint: Optional[str] = None, optional: Optional[bool] = False, hidden: Optional[bool] = None, + _websocket: Optional[bool] = False, # Reserved for future use **_kwargs, ) -> Callable[..., Any]: """ diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 646db990ab..4f296bde66 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -1,9 +1,12 @@ +import asyncio import functools import warnings import json import contextvars import typing +from dash.backends.base_server import DashWebsocketCallback + from . import exceptions from ._get_app import get_app from ._utils import AttributeDict, stringify_id @@ -323,6 +326,32 @@ def custom_data(self): """ return _get_from_context("custom_data", {}) + @property + @has_context + def get_websocket(self) -> typing.Optional[DashWebsocketCallback]: + """Get WebSocket interface if running in WebSocket context. + + Returns the DashWebsocketCallback instance if the callback is being + executed via WebSocket, otherwise returns None. + + Raises: + RuntimeError: If websocket_callbacks is requested but the backend + doesn't support WebSocket. + """ + ws = _get_from_context("dash_websocket", None) + if ws is None: + app = get_app() + if ( + hasattr(app, "_websocket_callbacks") + and app._websocket_callbacks # pylint: disable=protected-access + and not app.backend.websocket_capability + ): + raise RuntimeError( + f"WebSocket callbacks requested but backend " + f"'{app.backend.server_type}' doesn't support them." + ) + return ws + callback_context = CallbackContext() @@ -330,5 +359,26 @@ def custom_data(self): def set_props(component_id: typing.Union[str, dict], props: dict): """ Set the props for a component not included in the callback outputs. + + If running in a WebSocket context, props are streamed immediately to the + client. Otherwise, props are batched and sent with the callback response. """ - callback_context.set_props(component_id, props) + ws = _get_from_context("dash_websocket", None) + if ws is not None: + # Stream immediately via WebSocket + _id = stringify_id(component_id) + + async def _send_props(): + for prop_name, value in props.items(): + await ws.set_prop(_id, prop_name, value) + + # If we're in an async context, schedule the coroutine + try: + asyncio.get_running_loop() + asyncio.ensure_future(_send_props()) + except RuntimeError: + # No running event loop - run synchronously + asyncio.run(_send_props()) + else: + # Batch for response (existing behavior) + callback_context.set_props(component_id, props) diff --git a/dash/_dash_renderer.py b/dash/_dash_renderer.py index ee507ddb71..5574131d10 100644 --- a/dash/_dash_renderer.py +++ b/dash/_dash_renderer.py @@ -1,7 +1,7 @@ import os from typing import Any, List, Dict -__version__ = "3.0.0" +__version__ = "3.1.0" _available_react_versions = {"18.3.1", "18.2.0", "16.14.0"} _available_reactdom_versions = {"18.3.1", "18.2.0", "16.14.0"} @@ -65,7 +65,7 @@ def _set_react_version(v_react, v_reactdom=None): { "relative_package_path": "dash-renderer/build/dash_renderer.min.js", "dev_package_path": "dash-renderer/build/dash_renderer.dev.js", - "external_url": "https://unpkg.com/dash-renderer@3.0.0" + "external_url": "https://unpkg.com/dash-renderer@3.1.0" "/build/dash_renderer.min.js", "namespace": "dash", }, @@ -75,4 +75,9 @@ def _set_react_version(v_react, v_reactdom=None): "namespace": "dash", "dynamic": True, }, + { + "relative_package_path": "dash-renderer/build/dash-ws-worker.js", + "namespace": "dash", + "dynamic": True, + }, ] diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 4e5bdad621..f1e9dc838e 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -1,7 +1,9 @@ from __future__ import annotations from contextvars import copy_context, ContextVar +import asyncio import json +import uuid from typing import TYPE_CHECKING, Any, Callable, Dict import sys import mimetypes @@ -21,6 +23,7 @@ from starlette.responses import Response as StarletteResponse from starlette.datastructures import MutableHeaders from starlette.types import ASGIApp, Scope, Receive, Send + from starlette.websockets import WebSocket, WebSocketDisconnect import uvicorn except ImportError as _err: raise ImportError( @@ -30,7 +33,12 @@ from dash.fingerprint import check_fingerprint from dash import _validate, get_app from dash.exceptions import PreventUpdate -from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter +from .base_server import ( + BaseDashServer, + RequestAdapter, + ResponseAdapter, + DashWebsocketCallback, +) from ._utils import format_traceback_html if TYPE_CHECKING: # pragma: no cover - typing only @@ -73,6 +81,77 @@ def set_response(self, **kwargs): return resp +class FastAPIWebsocketCallback(DashWebsocketCallback): + """WebSocket callback implementation for FastAPI backend. + + Provides real-time bidirectional communication for callback execution. + """ + + def __init__( + self, websocket: WebSocket, pending_get_props: Dict[str, asyncio.Future] + ): + """Initialize the WebSocket callback interface. + + Args: + websocket: The WebSocket connection + pending_get_props: Dict to track pending get_props requests + """ + self._websocket = websocket + self._pending_get_props = pending_get_props + + async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: + """Send immediate prop update to the client via WebSocket. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to update + value: The new value to set + """ + await self._websocket.send_json( + { + "type": "set_props", + "payload": {"componentId": component_id, "props": {prop_name: value}}, + } + ) + + async def get_prop(self, component_id: str, prop_name: str) -> Any: + """Request current prop value from the client. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to retrieve + + Returns: + The current value of the property from the client's state + """ + request_id = str(uuid.uuid4()) + + # Create a future to wait for the response + future: asyncio.Future = asyncio.get_event_loop().create_future() + self._pending_get_props[request_id] = future + + # Send the request + await self._websocket.send_json( + { + "type": "get_props_request", + "requestId": request_id, + "payload": {"componentId": component_id, "properties": [prop_name]}, + } + ) + + # Wait for the response with timeout + try: + result = await asyncio.wait_for(future, timeout=30.0) + if result and prop_name in result: + return result[prop_name] + return None + except asyncio.TimeoutError as exc: + self._pending_get_props.pop(request_id, None) + raise TimeoutError( + f"Timeout waiting for get_prop response for {component_id}.{prop_name}" + ) from exc + + _current_request_var = ContextVar("dash_current_request", default=None) @@ -224,6 +303,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class FastAPIDashServer(BaseDashServer[FastAPI]): + websocket_capability: bool = True + def __init__(self, server: FastAPI): super().__init__(server) self.server_type = "fastapi" @@ -609,6 +690,147 @@ async def timing_headers_middleware(request: Request, call_next): headers.append("Server-Timing", value) return response + def serve_websocket_callback(self, dash_app: "Dash"): + """Set up the WebSocket endpoint for callback handling. + + Args: + dash_app: The Dash application instance + """ + # pylint: disable=too-many-statements + ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback" + + async def websocket_handler(websocket: WebSocket): + await websocket.accept() + + # Track pending get_props requests + pending_get_props: Dict[str, asyncio.Future] = {} + + try: + while True: + message = await websocket.receive_json() + msg_type = message.get("type") + renderer_id = message.get("rendererId") + + if msg_type == "callback_request": + response = await self._execute_ws_callback( + dash_app, websocket, message, pending_get_props + ) + await websocket.send_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": message.get("requestId"), + "payload": response, + } + ) + + elif msg_type == "get_props_response": + # Handle response for pending get_props request + request_id = message.get("requestId") + if request_id in pending_get_props: + future = pending_get_props.pop(request_id) + if not future.done(): + future.set_result(message.get("payload")) + + elif msg_type == "heartbeat": + await websocket.send_json({"type": "heartbeat_ack"}) + + except WebSocketDisconnect: + pass # Clean disconnect + finally: + # Cancel any pending futures + for future in pending_get_props.values(): + if not future.done(): + future.cancel() + + self.server.add_api_websocket_route(ws_path, websocket_handler) + + async def _execute_ws_callback( + self, + dash_app: "Dash", + websocket: WebSocket, + message: dict, + pending_get_props: Dict[str, asyncio.Future], + ) -> dict: + """Execute callback from WebSocket message. + + Args: + dash_app: The Dash application instance + websocket: The WebSocket connection + message: The callback request message + pending_get_props: Dict to track pending get_props requests + + Returns: + Response dict with status and data + """ + payload = message.get("payload", {}) + + # Create WebSocket callback context + cb_ctx = self._create_ws_context( + dash_app, websocket, payload, pending_get_props + ) + + try: + # Reuse existing callback machinery + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, payload) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) + # pylint: enable=protected-access + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): + response_data = await response_data + + return {"status": "ok", "data": json.loads(response_data)} + + except PreventUpdate: + return {"status": "prevent_update"} + except Exception as e: # pylint: disable=broad-exception-caught + traceback.print_exc() + return {"status": "error", "message": str(e)} + + def _create_ws_context( + self, + _dash_app: "Dash", # pylint: disable=unused-argument + websocket: WebSocket, + payload: dict, + pending_get_props: Dict[str, asyncio.Future], + ): + """Create callback context from WebSocket message. + + Args: + _dash_app: The Dash application instance (unused, kept for API consistency) + websocket: The WebSocket connection + payload: The callback payload + pending_get_props: Dict to track pending get_props requests + + Returns: + AttributeDict with callback context + """ + # pylint: disable=import-outside-toplevel + from dash._utils import AttributeDict, inputs_to_dict + + g = AttributeDict({}) + g.inputs_list = payload.get("inputs", []) + g.states_list = payload.get("state", []) + g.outputs_list = payload.get("outputs", []) + g.input_values = inputs_to_dict(g.inputs_list) + g.state_values = inputs_to_dict(g.states_list) + g.triggered_inputs = [ + {"prop_id": x, "value": g.input_values.get(x)} + for x in payload.get("changedPropIds", []) + ] + g.dash_response = FastAPIResponseAdapter() + g.updated_props = {} + + # Add WebSocket callback interface + g.dash_websocket = FastAPIWebsocketCallback(websocket, pending_get_props) + + return g + class FastAPIRequestAdapter(RequestAdapter): def __init__(self): diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index f7211f44a6..94e00d1bfc 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -169,6 +169,7 @@ class BaseDashServer(ABC, Generic[ServerType]): config: Dict[str, Any] request_adapter: Type[RequestAdapter] response_adapter: Type[ResponseAdapter] + websocket_capability: bool = False def __init__(self, server: ServerType) -> None: """Initialize the server wrapper. @@ -372,3 +373,42 @@ def setup_backend(self, dash_app: "dash.Dash"): Args: dash_app: The Dash application instance """ + + def serve_websocket_callback(self, dash_app: "dash.Dash"): + """Set up the WebSocket endpoint for callback handling. + + Override this method in backends that support WebSocket callbacks. + + Args: + dash_app: The Dash application instance + """ + + +class DashWebsocketCallback(ABC): + """Abstract interface for WebSocket-based callback communication. + + Provides methods for real-time bidirectional communication between + the server and renderer during callback execution. + """ + + @abstractmethod + async def get_prop(self, component_id: str, prop_name: str) -> Any: + """Request current prop value from the client. + + Args: + component_id: The component ID (string or stringified dict for pattern matching) + prop_name: The property name to retrieve + + Returns: + The current value of the property from the client's state + """ + + @abstractmethod + async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: + """Send immediate prop update to the client via WebSocket. + + Args: + component_id: The component ID (string or stringified dict for pattern matching) + prop_name: The property name to update + value: The new value to set + """ diff --git a/dash/dash-renderer/init.template b/dash/dash-renderer/init.template index 463cfa02aa..a6b84d3d70 100644 --- a/dash/dash-renderer/init.template +++ b/dash/dash-renderer/init.template @@ -75,4 +75,9 @@ _js_dist = [ "namespace": "dash", "dynamic": True, }, + { + "relative_package_path": "dash-renderer/build/dash-ws-worker.js", + "namespace": "dash", + "dynamic": True, + }, ] diff --git a/dash/dash-renderer/package.json b/dash/dash-renderer/package.json index f92d22cfc5..a404fa2425 100644 --- a/dash/dash-renderer/package.json +++ b/dash/dash-renderer/package.json @@ -13,7 +13,7 @@ "build:dev": "webpack", "build:local": "renderer build local", "build": "renderer build && npm run prepublishOnly", - "postbuild": "es-check es2015 ../deps/*.js build/*.js", + "postbuild": "es-check es2015 ../deps/*.js build/dash_renderer.*.js", "test": "karma start karma.conf.js --single-run", "format": "run-s private::format.*", "lint": "run-s private::lint.* --continue-on-error" diff --git a/dash/dash-renderer/src/AppProvider.react.tsx b/dash/dash-renderer/src/AppProvider.react.tsx index 343789ca43..2a6b95240c 100644 --- a/dash/dash-renderer/src/AppProvider.react.tsx +++ b/dash/dash-renderer/src/AppProvider.react.tsx @@ -1,9 +1,14 @@ import PropTypes from 'prop-types'; -import React, {useState} from 'react'; +import React, {useState, useEffect} from 'react'; import {Provider} from 'react-redux'; import Store from './store'; import AppContainer from './AppContainer.react'; +import getConfigFromDOM from './config'; +import { + initializeWebSocket, + disconnectWebSocket +} from './observers/websocketObserver'; const AppProvider = ({ hooks = { @@ -16,6 +21,31 @@ const AppProvider = ({ } }: any) => { const [{store}] = useState(() => new Store()); + + // Initialize WebSocket connection if enabled + useEffect(() => { + const config = getConfigFromDOM(); + if (config.websocket?.enabled) { + // Add fetch config for consistency + const fullConfig = { + ...config, + fetch: { + credentials: 'same-origin', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json' + } + } + }; + initializeWebSocket(store, fullConfig); + } + + // Cleanup on unmount + return () => { + disconnectWebSocket(); + }; + }, [store]); + return ( diff --git a/dash/dash-renderer/src/actions/callbacks.ts b/dash/dash-renderer/src/actions/callbacks.ts index 37aab3f194..2b487e72c9 100644 --- a/dash/dash-renderer/src/actions/callbacks.ts +++ b/dash/dash-renderer/src/actions/callbacks.ts @@ -52,6 +52,7 @@ import {parsePMCId} from './patternMatching'; import {replacePMC} from './patternMatching'; import {loaded, loading} from './loading'; import {getComponentLayout} from '../wrapper/wrapping'; +import {getWorkerClient, isWebSocketEnabled} from '../utils/workerClient'; export const addBlockedCallbacks = createAction( CallbackActionType.AddBlocked @@ -685,6 +686,137 @@ function handleServerside( }); } +/** + * Handle serverside callback via WebSocket connection. + * + * Uses the SharedWorker to send the callback request through the persistent + * WebSocket connection instead of HTTP POST. + */ +async function handleWebsocketCallback( + dispatch: any, + hooks: any, + config: any, + payload: ICallbackPayload, + running: any +): Promise { + if (hooks.request_pre) { + hooks.request_pre(payload); + } + + const requestTime = Date.now(); + let runningOff: any; + + if (running) { + dispatch(sideUpdate(running.running, payload)); + runningOff = running.runningOff; + } + + const workerClient = getWorkerClient(); + + try { + const response = await workerClient.sendCallback(payload); + + // Handle running off state + if (runningOff) { + dispatch(sideUpdate(runningOff, payload)); + } + + if (response.status === 'prevent_update') { + // Record timing for profiling + if (config.ui) { + const totalTime = Date.now() - requestTime; + dispatch( + updateResourceUsage({ + id: payload.output, + usage: { + __dash_server: totalTime, + __dash_client: totalTime, + __dash_upload: 0, + __dash_download: 0 + }, + status: STATUS.PREVENT_UPDATE, + result: {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + return {}; + } + + if (response.status === 'error') { + throw new Error(response.message || 'Callback error'); + } + + // Extract the callback data - structure is {multi: boolean, response: {...}} + const callbackData = response.data as CallbackResponseData; + + // Handle sideUpdate if present + if (callbackData?.sideUpdate) { + dispatch(sideUpdate(callbackData.sideUpdate, payload)); + } + + // Extract the actual outputs from the response + // Format is similar to HTTP path's finishLine function + let result: CallbackResponse; + const {multi, response: callbackResponse} = callbackData || {}; + + if (hooks.request_post) { + hooks.request_post(payload, callbackResponse); + } + + if (multi) { + result = callbackResponse as CallbackResponse; + } else { + // Single output - convert to the expected format + const {output} = payload; + const id = output.substr(0, output.lastIndexOf('.')); + result = {[id]: (callbackResponse as CallbackResponse)?.props}; + } + + // Record timing for profiling + if (config.ui) { + const totalTime = Date.now() - requestTime; + dispatch( + updateResourceUsage({ + id: payload.output, + usage: { + __dash_server: totalTime, + __dash_client: totalTime, + __dash_upload: 0, + __dash_download: 0 + }, + status: STATUS.OK, + result: result || {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + + return result || {}; + } catch (error) { + // Handle running off state on error + if (runningOff) { + dispatch(sideUpdate(runningOff, payload)); + } + + if (config.ui) { + dispatch( + updateResourceUsage({ + id: payload.output, + status: STATUS.NO_RESPONSE, + result: {}, + inputs: payload.inputs, + state: payload.state + }) + ); + } + + throw error; + } +} + function inputsToDict(inputs_list: any) { // Ported directly from _utils.py, inputs_to_dict // takes an array of inputs (some inputs may be an array) @@ -890,18 +1022,37 @@ export function executeCallback( } ); + // Use WebSocket for callbacks when enabled (but not for background callbacks) + const useWebSocket = isWebSocketEnabled(config) && !background; + for (let retry = 0; retry <= MAX_AUTH_RETRIES; retry++) { try { - let data = await handleServerside( - dispatch, - hooks, - newConfig, - payload, - background, - additionalArgs.length ? additionalArgs : undefined, - getState, - cb.callback.running - ); + let data: CallbackResponse; + + if (useWebSocket) { + // Use WebSocket path for real-time callbacks + data = await handleWebsocketCallback( + dispatch, + hooks, + newConfig, + payload, + cb.callback.running + ); + } else { + // Use traditional HTTP path + data = await handleServerside( + dispatch, + hooks, + newConfig, + payload, + background, + additionalArgs.length + ? additionalArgs + : undefined, + getState, + cb.callback.running + ); + } if (newHeaders) { dispatch(addHttpHeaders(newHeaders)); diff --git a/dash/dash-renderer/src/config.ts b/dash/dash-renderer/src/config.ts index ac18678364..caf49e348d 100644 --- a/dash/dash-renderer/src/config.ts +++ b/dash/dash-renderer/src/config.ts @@ -22,6 +22,11 @@ export type DashConfig = { serve_locally?: boolean; plotlyjs_url?: string; validate_callbacks: boolean; + websocket?: { + enabled: boolean; + url: string; + worker_url: string; + }; }; export default function getConfigFromDOM(): DashConfig { diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts new file mode 100644 index 0000000000..1faa7c5a34 --- /dev/null +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -0,0 +1,148 @@ +/** + * Observer for handling incoming WebSocket messages (SET_PROPS, GET_PROPS_REQUEST). + */ + +/* eslint-disable no-console */ + +import {Store} from 'redux'; +import {path} from 'ramda'; + +import {IStoreState} from '../store'; +import {updateProps, notifyObservers} from '../actions'; +import {getPath} from '../actions/paths'; +import { + getWorkerClient, + SetPropsPayload, + GetPropsRequestPayload +} from '../utils/workerClient'; +import {DashConfig} from '../config'; + +/** + * Initialize the WebSocket observer. + * + * Sets up handlers for: + * - SET_PROPS: Update component props when received from server + * - GET_PROPS_REQUEST: Send current prop values back to server + * + * @param store Redux store + * @param config Dash configuration + */ +export async function initializeWebSocket( + store: Store, + config: DashConfig +): Promise { + if (!config.websocket?.enabled) { + return; + } + + // Check if SharedWorker is supported + if (typeof SharedWorker === 'undefined') { + console.warn( + 'SharedWorker not supported in this browser. ' + + 'WebSocket callbacks will fall back to HTTP.' + ); + return; + } + + const workerClient = getWorkerClient(); + + // Handle SET_PROPS messages + workerClient.onSetProps = (payload: SetPropsPayload) => { + const {componentId, props} = payload; + const state = store.getState(); + const componentPath = getPath(state.paths, componentId); + + if (!componentPath) { + console.warn( + `SET_PROPS: Component ${componentId} not found in layout` + ); + return; + } + + // Update the component props + store.dispatch( + updateProps({ + props, + itempath: componentPath, + renderType: 'websocket' + }) as any + ); + + // Notify observers + store.dispatch(notifyObservers({id: componentId, props}) as any); + }; + + // Handle GET_PROPS_REQUEST messages + workerClient.onGetPropsRequest = ( + requestId: string, + payload: GetPropsRequestPayload + ) => { + const {componentId, properties} = payload; + const state = store.getState(); + const componentPath = getPath(state.paths, componentId); + + const result: Record = {}; + + if (componentPath) { + const componentProps = path( + [...componentPath, 'props'], + state.layout + ) as Record | undefined; + + if (componentProps) { + for (const propName of properties) { + result[propName] = componentProps[propName]; + } + } + } + + // Send the response + workerClient.sendGetPropsResponse(requestId, result); + }; + + // Handle connection events + workerClient.onConnected = () => { + console.log('[Dash] WebSocket connected'); + }; + + workerClient.onDisconnected = (reason?: string) => { + console.log(`[Dash] WebSocket disconnected: ${reason}`); + }; + + workerClient.onError = (message: string, code?: string) => { + console.error(`[Dash] WebSocket error: ${message}`, code); + }; + + // Connect to the worker + const wsUrl = buildWebSocketUrl(config); + + try { + await workerClient.connect(config.websocket.worker_url, wsUrl); + } catch (error) { + console.error('[Dash] Failed to connect to WebSocket worker:', error); + } +} + +/** + * Build the WebSocket URL from config. + */ +function buildWebSocketUrl(config: DashConfig): string { + if (!config.websocket?.url) { + throw new Error('WebSocket URL not configured'); + } + + // Convert HTTP(S) URL to WS(S) + const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const host = window.location.host; + + // The config.websocket.url is a path like "/_dash-ws-callback" + return `${wsProtocol}//${host}${config.websocket.url}`; +} + +/** + * Disconnect from the WebSocket. + */ +export function disconnectWebSocket(): void { + const workerClient = getWorkerClient(); + workerClient.disconnect(); +} diff --git a/dash/dash-renderer/src/utils/rendererId.ts b/dash/dash-renderer/src/utils/rendererId.ts new file mode 100644 index 0000000000..8168d1576d --- /dev/null +++ b/dash/dash-renderer/src/utils/rendererId.ts @@ -0,0 +1,23 @@ +/** + * Generate or retrieve a unique renderer ID for this browser tab/session. + * + * The ID is stored in sessionStorage to persist across page reloads + * but remain unique per tab. + */ +export function getRendererId(): string { + const key = '__dash_renderer_id'; + let id = sessionStorage.getItem(key); + + if (!id) { + // Generate a unique ID + if (typeof crypto !== 'undefined' && crypto.randomUUID) { + id = crypto.randomUUID(); + } else { + // Fallback for older browsers + id = `${Date.now()}-${Math.random().toString(36).slice(2)}`; + } + sessionStorage.setItem(key, id); + } + + return id; +} diff --git a/dash/dash-renderer/src/utils/workerClient.ts b/dash/dash-renderer/src/utils/workerClient.ts new file mode 100644 index 0000000000..b38fc5a68a --- /dev/null +++ b/dash/dash-renderer/src/utils/workerClient.ts @@ -0,0 +1,276 @@ +/** + * Client for communicating with the Dash WebSocket SharedWorker. + */ + +import {getRendererId} from './rendererId'; + +/** Message types for worker communication */ +export enum WorkerMessageType { + CONNECT = 'connect', + DISCONNECT = 'disconnect', + CALLBACK_REQUEST = 'callback_request', + GET_PROPS_RESPONSE = 'get_props_response', + CONNECTED = 'connected', + DISCONNECTED = 'disconnected', + CALLBACK_RESPONSE = 'callback_response', + SET_PROPS = 'set_props', + GET_PROPS_REQUEST = 'get_props_request', + ERROR = 'error' +} + +/** Callback response structure */ +export interface CallbackResponse { + status: 'ok' | 'prevent_update' | 'error'; + data?: Record; + message?: string; +} + +/** Set props message payload */ +export interface SetPropsPayload { + componentId: string; + props: Record; +} + +/** Get props request payload */ +export interface GetPropsRequestPayload { + componentId: string; + properties: string[]; +} + +/** Pending callback request */ +interface PendingRequest { + resolve: (value: CallbackResponse) => void; + reject: (error: Error) => void; +} + +/** + * Client for the Dash WebSocket SharedWorker. + */ +class WorkerClient { + private worker: SharedWorker | null = null; + private rendererId: string; + private pendingCallbacks: Map = new Map(); + private requestCounter = 0; + private isConnected = false; + private connectionPromise: Promise | null = null; + private connectionResolve: (() => void) | null = null; + + /** Callback when SET_PROPS message is received */ + public onSetProps: ((payload: SetPropsPayload) => void) | null = null; + + /** Callback when GET_PROPS_REQUEST message is received */ + public onGetPropsRequest: + | ((requestId: string, payload: GetPropsRequestPayload) => void) + | null = null; + + /** Callback when connection is established */ + public onConnected: (() => void) | null = null; + + /** Callback when connection is lost */ + public onDisconnected: ((reason?: string) => void) | null = null; + + /** Callback when an error occurs */ + public onError: ((message: string, code?: string) => void) | null = null; + + constructor() { + this.rendererId = getRendererId(); + } + + /** + * Initialize the worker connection. + * @param workerUrl URL to the SharedWorker script + * @param serverUrl WebSocket server URL + */ + public async connect(workerUrl: string, serverUrl: string): Promise { + if (this.worker) { + // Already connected + return; + } + + // Create the SharedWorker + this.worker = new SharedWorker(workerUrl, { + name: 'dash-ws-worker' + }); + + // Set up message handling + this.worker.port.onmessage = this.handleMessage.bind(this); + + // Create promise for connection + this.connectionPromise = new Promise(resolve => { + this.connectionResolve = resolve; + }); + + // Start the port + this.worker.port.start(); + + // Send connect message + this.worker.port.postMessage({ + type: WorkerMessageType.CONNECT, + rendererId: this.rendererId, + payload: { + serverUrl + } + }); + + // Wait for connection + await this.connectionPromise; + } + + /** + * Disconnect from the worker. + */ + public disconnect(): void { + if (this.worker) { + this.worker.port.postMessage({ + type: WorkerMessageType.DISCONNECT, + rendererId: this.rendererId + }); + this.worker.port.close(); + this.worker = null; + } + this.isConnected = false; + this.connectionPromise = null; + this.connectionResolve = null; + + // Reject any pending callbacks + for (const [, pending] of this.pendingCallbacks) { + pending.reject(new Error('Worker disconnected')); + } + this.pendingCallbacks.clear(); + } + + /** + * Send a callback request to the server via the worker. + * @param payload The callback payload + * @returns Promise that resolves with the callback response + */ + public async sendCallback(payload: unknown): Promise { + // Wait for connection if one is in progress + if (this.connectionPromise && !this.isConnected) { + await this.connectionPromise; + } + + if (!this.worker || !this.isConnected) { + throw new Error('Worker not connected'); + } + + const requestId = `${this.rendererId}-${++this.requestCounter}`; + + return new Promise((resolve, reject) => { + this.pendingCallbacks.set(requestId, {resolve, reject}); + + this.worker!.port.postMessage({ + type: WorkerMessageType.CALLBACK_REQUEST, + rendererId: this.rendererId, + requestId, + payload + }); + }); + } + + /** + * Send a get_props response back to the server. + * @param requestId The request ID from the get_props request + * @param props The property values + */ + public sendGetPropsResponse( + requestId: string, + props: Record + ): void { + if (!this.worker || !this.isConnected) { + return; + } + + this.worker.port.postMessage({ + type: WorkerMessageType.GET_PROPS_RESPONSE, + rendererId: this.rendererId, + requestId, + payload: props + }); + } + + /** + * Check if the worker is connected. + */ + public get connected(): boolean { + return this.isConnected; + } + + private handleMessage(event: MessageEvent): void { + const message = event.data; + + switch (message.type) { + case WorkerMessageType.CONNECTED: + this.isConnected = true; + if (this.connectionResolve) { + this.connectionResolve(); + this.connectionResolve = null; + } + if (this.onConnected) { + this.onConnected(); + } + break; + + case WorkerMessageType.DISCONNECTED: + this.isConnected = false; + if (this.onDisconnected) { + this.onDisconnected(message.payload?.reason); + } + break; + + case WorkerMessageType.CALLBACK_RESPONSE: { + const requestId = message.requestId; + const pending = this.pendingCallbacks.get(requestId); + if (pending) { + this.pendingCallbacks.delete(requestId); + pending.resolve(message.payload); + } + break; + } + + case WorkerMessageType.SET_PROPS: + if (this.onSetProps) { + this.onSetProps(message.payload); + } + break; + + case WorkerMessageType.GET_PROPS_REQUEST: + if (this.onGetPropsRequest) { + this.onGetPropsRequest(message.requestId, message.payload); + } + break; + + case WorkerMessageType.ERROR: + if (this.onError) { + this.onError( + message.payload?.message || 'Unknown error', + message.payload?.code + ); + } + break; + } + } +} + +// Singleton instance +let workerClientInstance: WorkerClient | null = null; + +/** + * Get the singleton WorkerClient instance. + */ +export function getWorkerClient(): WorkerClient { + if (!workerClientInstance) { + workerClientInstance = new WorkerClient(); + } + return workerClientInstance; +} + +/** + * Check if WebSocket callbacks are enabled and supported. + * @param config The Dash config + */ +export function isWebSocketEnabled(config: { + websocket?: {enabled: boolean}; +}): boolean { + return !!(config.websocket?.enabled && typeof SharedWorker !== 'undefined'); +} diff --git a/dash/dash-renderer/webpack.base.config.js b/dash/dash-renderer/webpack.base.config.js index ed95239f7d..e8a9d14596 100644 --- a/dash/dash-renderer/webpack.base.config.js +++ b/dash/dash-renderer/webpack.base.config.js @@ -72,6 +72,31 @@ const rendererOptions = { ...defaults }; +// WebSocket Worker configuration +const workerOptions = { + mode: 'production', + entry: { + 'dash-ws-worker': '../../@plotly/dash-websocket-worker/src/worker.ts', + }, + output: { + path: path.resolve(__dirname, "build"), + filename: '[name].js', + }, + target: 'webworker', + module: { + rules: [ + { + test: /\.ts$/, + exclude: /node_modules/, + use: ['ts-loader'], + }, + ] + }, + resolve: { + extensions: ['.ts', '.js'] + } +}; + module.exports = options => [ R.mergeAll([ options, @@ -109,5 +134,7 @@ module.exports = options => [ ] ), } - ]) + ]), + // WebSocket Worker build + workerOptions ]; diff --git a/dash/dash.py b/dash/dash.py index ea53edf341..ca8520f981 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -472,6 +472,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches on_error: Optional[Callable[[Exception], Any]] = None, use_async: Optional[bool] = None, health_endpoint: Optional[str] = None, + websocket_callbacks: Optional[bool] = False, **obsolete, ): @@ -619,6 +620,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._assets_files: list = [] self._background_manager = background_callback_manager + self._websocket_callbacks = websocket_callbacks self.logger = logging.getLogger(__name__) @@ -761,6 +763,11 @@ def _setup_routes(self): ) if self.config.health_endpoint is not None: self._add_url(self.config.health_endpoint, self.serve_health) + + # Set up WebSocket callback route if enabled and supported + if self._websocket_callbacks and self.backend.websocket_capability: + self.backend.serve_websocket_callback(self) + self.backend.setup_index(self) self.backend.setup_catchall(self) @@ -940,6 +947,14 @@ def _config(self): custom_dev_tools.append({**hook_dev_tools, "props": props}) config["dev_tools"] = custom_dev_tools + # Add websocket config if enabled and backend supports it + if self._websocket_callbacks and self.backend.websocket_capability: + config["websocket"] = { + "enabled": True, + "url": self.config.requests_pathname_prefix + "_dash-ws-callback", + "worker_url": self._get_worker_url(), + } + return config def serve_reload_hash(self): @@ -967,6 +982,33 @@ def serve_health(self): """ return self.backend.make_response("OK", status=200, mimetype="text/plain") + def _get_worker_url(self) -> str: + """Get the URL for the WebSocket worker script. + + Returns: + The fingerprinted URL for the worker script served via component suites. + """ + relative_path = "dash-renderer/build/dash-ws-worker.js" + namespace = "dash" + + # Register the path so it can be served + self.registered_paths[namespace].add(relative_path) + + # Build fingerprinted URL (same pattern as _collect_and_register_resources) + module_path = os.path.join( + os.path.dirname(sys.modules[namespace].__file__), # type: ignore + relative_path, + ) + + # Use a fallback if the file doesn't exist yet (during development) + try: + modified = int(os.stat(module_path).st_mtime) + except FileNotFoundError: + modified = 0 + + fingerprint = build_fingerprint(relative_path, __version__, modified) + return f"{self.config.requests_pathname_prefix}_dash-component-suites/{namespace}/{fingerprint}" + def get_dist(self, libraries: Sequence[str]) -> list: dists = [] for dist_type in ("_js_dist", "_css_dist"): diff --git a/wsapp.py b/wsapp.py new file mode 100644 index 0000000000..98b2db2f38 --- /dev/null +++ b/wsapp.py @@ -0,0 +1,106 @@ +""" +Test app for WebSocket-based callbacks. + +Run with: + python wsapp.py + +Then open http://127.0.0.1:8050 in your browser. +""" + +from dash import Dash, html, dcc, callback, Output, Input, ctx +import time + +# Create app with FastAPI backend and WebSocket callbacks enabled +app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, +) + +app.layout = html.Div([ + html.H1("WebSocket Callbacks Test"), + + html.Div([ + html.H3("Basic Callback Test"), + html.Button("Click me", id="btn-1", n_clicks=0), + html.Div(id="output-1"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("Input Test"), + dcc.Input(id="input-1", type="text", placeholder="Type something..."), + html.Div(id="output-2"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("Slider Test"), + dcc.Slider(id="slider-1", min=0, max=100, value=50), + html.Div(id="output-3"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("set_props Test"), + html.Button("Update via set_props", id="btn-2", n_clicks=0), + html.Div(id="output-4", children="Initial content"), + html.Div(id="output-5", children="Will be updated by set_props"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("WebSocket Context Test"), + html.Button("Check WebSocket Context", id="btn-3", n_clicks=0), + html.Div(id="output-6"), + ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), + + html.Div(id="config-display", style={"marginTop": "20px", "fontSize": "12px", "color": "#666"}), +]) + + +@callback(Output("output-1", "children"), Input("btn-1", "n_clicks")) +def update_output_1(n_clicks): + return f"Button clicked {n_clicks} times" + + +@callback(Output("output-2", "children"), Input("input-1", "value")) +def update_output_2(value): + return f"You typed: {value}" + + +@callback(Output("output-3", "children"), Input("slider-1", "value")) +def update_output_3(value): + return f"Slider value: {value}" + + +@callback(Output("output-4", "children"), Input("btn-2", "n_clicks")) +def update_with_set_props(n_clicks): + if n_clicks > 0: + # Use set_props to update another component + from dash._callback_context import set_props + set_props("output-5", {"children": f"Updated via set_props at click {n_clicks}"}) + return f"set_props button clicked {n_clicks} times" + + +@callback(Output("output-6", "children"), Input("btn-3", "n_clicks")) +def check_websocket_context(n_clicks): + if n_clicks > 0: + ws = ctx.get_websocket + if ws is not None: + return f"WebSocket context is available! (click {n_clicks})" + else: + return f"WebSocket context is None (click {n_clicks}) - may be using HTTP fallback" + return "Click to check WebSocket context" + + +@callback(Output("config-display", "children"), Input("btn-1", "n_clicks")) +def show_config(n_clicks): + config = app._config() + ws_config = config.get("websocket", {}) + if ws_config: + return f"WebSocket enabled: {ws_config.get('enabled')}, URL: {ws_config.get('url')}" + return "WebSocket not configured" + + +if __name__ == "__main__": + print("Starting WebSocket callbacks test app...") + print(f"WebSocket callbacks enabled: {app._websocket_callbacks}") + print(f"Backend websocket capability: {app.backend.websocket_capability}") + app.run(debug=True, port=8050) From ad3693e2a9990c2003c12ae7e746ab612f183ea5 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 14 Apr 2026 12:44:29 -0400 Subject: [PATCH 122/166] websocket flag for individual callbacks --- dash/_callback.py | 6 +++++- dash/dash-renderer/src/actions/callbacks.ts | 17 ++++++++++++++--- .../src/observers/websocketObserver.ts | 11 +++++++++-- dash/dash-renderer/src/types/callbacks.ts | 1 + dash/dash-renderer/src/utils/workerClient.ts | 16 +++++++++++++++- dash/dash.py | 12 +++++++----- 6 files changed, 51 insertions(+), 12 deletions(-) diff --git a/dash/_callback.py b/dash/_callback.py index ff1072efd3..b27e9a18b9 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -77,7 +77,7 @@ def callback( api_endpoint: Optional[str] = None, optional: Optional[bool] = False, hidden: Optional[bool] = None, - _websocket: Optional[bool] = False, # Reserved for future use + websocket: Optional[bool] = False, **_kwargs, ) -> Callable[..., Any]: """ @@ -229,6 +229,7 @@ def callback( api_endpoint=api_endpoint, optional=optional, hidden=hidden, + websocket=websocket, ) @@ -276,6 +277,7 @@ def insert_callback( no_output=False, optional=False, hidden=None, + websocket=False, ) -> str: if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks @@ -301,6 +303,7 @@ def insert_callback( "no_output": no_output, "optional": optional, "hidden": hidden, + "websocket": websocket, } if running: callback_spec["running"] = running @@ -653,6 +656,7 @@ def register_callback( no_output=not has_output, optional=_kwargs.get("optional", False), hidden=_kwargs.get("hidden", None), + websocket=_kwargs.get("websocket", False), ) # pylint: disable=too-many-locals diff --git a/dash/dash-renderer/src/actions/callbacks.ts b/dash/dash-renderer/src/actions/callbacks.ts index 2b487e72c9..8ca781b50b 100644 --- a/dash/dash-renderer/src/actions/callbacks.ts +++ b/dash/dash-renderer/src/actions/callbacks.ts @@ -52,7 +52,11 @@ import {parsePMCId} from './patternMatching'; import {replacePMC} from './patternMatching'; import {loaded, loading} from './loading'; import {getComponentLayout} from '../wrapper/wrapping'; -import {getWorkerClient, isWebSocketEnabled} from '../utils/workerClient'; +import { + getWorkerClient, + isWebSocketEnabled, + isWebSocketAvailable +} from '../utils/workerClient'; export const addBlockedCallbacks = createAction( CallbackActionType.AddBlocked @@ -1022,8 +1026,15 @@ export function executeCallback( } ); - // Use WebSocket for callbacks when enabled (but not for background callbacks) - const useWebSocket = isWebSocketEnabled(config) && !background; + // Use WebSocket for callbacks when: + // 1. Global WebSocket is enabled, OR + // 2. Per-callback websocket flag is set (and WebSocket is available) + // (but never for background callbacks) + const useWebSocket = + !background && + (isWebSocketEnabled(config) || + (cb.callback.websocket && + isWebSocketAvailable(config))); for (let retry = 0; retry <= MAX_AUTH_RETRIES; retry++) { try { diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts index 1faa7c5a34..b86be84b4c 100644 --- a/dash/dash-renderer/src/observers/websocketObserver.ts +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -31,7 +31,13 @@ export async function initializeWebSocket( store: Store, config: DashConfig ): Promise { - if (!config.websocket?.enabled) { + // Initialize WebSocket if: + // 1. Global websocket is enabled, OR + // 2. WebSocket config is available (for per-callback websocket=True) + const wsAvailable = !!( + config.websocket?.url && config.websocket?.worker_url + ); + if (!wsAvailable) { return; } @@ -117,7 +123,8 @@ export async function initializeWebSocket( const wsUrl = buildWebSocketUrl(config); try { - await workerClient.connect(config.websocket.worker_url, wsUrl); + // config.websocket is guaranteed to exist due to wsAvailable check above + await workerClient.connect(config.websocket!.worker_url, wsUrl); } catch (error) { console.error('[Dash] Failed to connect to WebSocket worker:', error); } diff --git a/dash/dash-renderer/src/types/callbacks.ts b/dash/dash-renderer/src/types/callbacks.ts index f1e1dc382c..38a5d7d82f 100644 --- a/dash/dash-renderer/src/types/callbacks.ts +++ b/dash/dash-renderer/src/types/callbacks.ts @@ -15,6 +15,7 @@ export interface ICallbackDefinition { dynamic_creator?: boolean; running: any; no_output?: boolean; + websocket?: boolean; } export interface ICallbackProperty { diff --git a/dash/dash-renderer/src/utils/workerClient.ts b/dash/dash-renderer/src/utils/workerClient.ts index b38fc5a68a..a76dc96b50 100644 --- a/dash/dash-renderer/src/utils/workerClient.ts +++ b/dash/dash-renderer/src/utils/workerClient.ts @@ -266,7 +266,7 @@ export function getWorkerClient(): WorkerClient { } /** - * Check if WebSocket callbacks are enabled and supported. + * Check if WebSocket callbacks are globally enabled and supported. * @param config The Dash config */ export function isWebSocketEnabled(config: { @@ -274,3 +274,17 @@ export function isWebSocketEnabled(config: { }): boolean { return !!(config.websocket?.enabled && typeof SharedWorker !== 'undefined'); } + +/** + * Check if WebSocket infrastructure is available (for per-callback websocket). + * @param config The Dash config + */ +export function isWebSocketAvailable(config: { + websocket?: {enabled?: boolean; url?: string; worker_url?: string}; +}): boolean { + return !!( + config.websocket?.url && + config.websocket?.worker_url && + typeof SharedWorker !== 'undefined' + ); +} diff --git a/dash/dash.py b/dash/dash.py index ca8520f981..a4a9ba3597 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -764,8 +764,9 @@ def _setup_routes(self): if self.config.health_endpoint is not None: self._add_url(self.config.health_endpoint, self.serve_health) - # Set up WebSocket callback route if enabled and supported - if self._websocket_callbacks and self.backend.websocket_capability: + # Set up WebSocket callback route if backend supports it + # This enables both global websocket_callbacks and per-callback websocket=True + if self.backend.websocket_capability: self.backend.serve_websocket_callback(self) self.backend.setup_index(self) @@ -947,10 +948,11 @@ def _config(self): custom_dev_tools.append({**hook_dev_tools, "props": props}) config["dev_tools"] = custom_dev_tools - # Add websocket config if enabled and backend supports it - if self._websocket_callbacks and self.backend.websocket_capability: + # Add websocket config if backend supports it + # This enables both global websocket_callbacks and per-callback websocket=True + if self.backend.websocket_capability: config["websocket"] = { - "enabled": True, + "enabled": bool(self._websocket_callbacks), "url": self.config.requests_pathname_prefix + "_dash-ws-callback", "worker_url": self._get_worker_url(), } From 1a20893fa04be654b305480aa9d375b88da7622f Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 14 Apr 2026 13:08:15 -0400 Subject: [PATCH 123/166] ensure connected for ws callbacks --- dash/dash-renderer/src/actions/callbacks.ts | 3 ++ dash/dash-renderer/src/utils/workerClient.ts | 36 ++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/dash/dash-renderer/src/actions/callbacks.ts b/dash/dash-renderer/src/actions/callbacks.ts index 8ca781b50b..d2e4a18d1c 100644 --- a/dash/dash-renderer/src/actions/callbacks.ts +++ b/dash/dash-renderer/src/actions/callbacks.ts @@ -718,6 +718,9 @@ async function handleWebsocketCallback( const workerClient = getWorkerClient(); try { + // Ensure WebSocket connection is established + await workerClient.ensureConnected(config); + const response = await workerClient.sendCallback(payload); // Handle running off state diff --git a/dash/dash-renderer/src/utils/workerClient.ts b/dash/dash-renderer/src/utils/workerClient.ts index a76dc96b50..d0f61d022a 100644 --- a/dash/dash-renderer/src/utils/workerClient.ts +++ b/dash/dash-renderer/src/utils/workerClient.ts @@ -139,6 +139,42 @@ class WorkerClient { this.pendingCallbacks.clear(); } + /** + * Ensure the worker is connected, initiating connection if needed. + * @param config The Dash config with websocket settings + */ + public async ensureConnected(config: { + websocket?: {url?: string; worker_url?: string}; + }): Promise { + // Already connected + if (this.isConnected) { + return; + } + + // Connection in progress, wait for it + if (this.connectionPromise) { + await this.connectionPromise; + return; + } + + // Need to initiate connection + if (!config.websocket?.url || !config.websocket?.worker_url) { + throw new Error('WebSocket config not available'); + } + + if (typeof SharedWorker === 'undefined') { + throw new Error('SharedWorker not supported'); + } + + // Build WebSocket URL + const wsProtocol = + window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const host = window.location.host; + const wsUrl = `${wsProtocol}//${host}${config.websocket.url}`; + + await this.connect(config.websocket.worker_url, wsUrl); + } + /** * Send a callback request to the server via the worker. * @param payload The callback payload From aff6deed4e8eaee8cd1b7550d050dac57668d6e3 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 14 Apr 2026 13:20:46 -0400 Subject: [PATCH 124/166] websocket origin validation --- dash/backends/_fastapi.py | 28 ++++++++++++++++++++++++++++ dash/dash.py | 2 ++ 2 files changed, 30 insertions(+) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index f1e9dc838e..167cae66b3 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -15,6 +15,7 @@ import subprocess import threading import traceback +from urllib.parse import urlparse try: from fastapi import FastAPI, Request, Response, Body @@ -699,7 +700,34 @@ def serve_websocket_callback(self, dash_app: "Dash"): # pylint: disable=too-many-statements ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback" + # Get allowed origins from dash app config + allowed_origins = getattr( + dash_app, "_allowed_websocket_origins", [] + ) # pylint: disable=protected-access + + def validate_origin(origin: str | None, host: str | None) -> str | None: + """Validate WebSocket origin. Returns error message or None if valid.""" + if not origin: + return "Origin header required" + if origin in allowed_origins: + return None # Explicitly allowed + if not host: + return "Origin not allowed" + # Check same-origin + origin_host = urlparse(origin).netloc + if origin_host != host: + return "Origin not allowed" + return None + async def websocket_handler(websocket: WebSocket): + # Validate Origin header to prevent Cross-Site WebSocket Hijacking + origin = websocket.headers.get("origin") + host = websocket.headers.get("host") + error = validate_origin(origin, host) + if error: + await websocket.close(code=4003, reason=error) + return + await websocket.accept() # Track pending get_props requests diff --git a/dash/dash.py b/dash/dash.py index a4a9ba3597..2762d14a9b 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -473,6 +473,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches use_async: Optional[bool] = None, health_endpoint: Optional[str] = None, websocket_callbacks: Optional[bool] = False, + allowed_websocket_origins: Optional[List[str]] = None, **obsolete, ): @@ -621,6 +622,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._background_manager = background_callback_manager self._websocket_callbacks = websocket_callbacks + self._allowed_websocket_origins = allowed_websocket_origins or [] self.logger = logging.getLogger(__name__) From a4f10fd40cf35ffecbd885b5d88c5efdbed6beb0 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 14 Apr 2026 13:34:20 -0400 Subject: [PATCH 125/166] add close method to DashWebsocket --- dash/backends/_fastapi.py | 9 +++++++++ dash/backends/base_server.py | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 167cae66b3..257e8cb23d 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -152,6 +152,15 @@ async def get_prop(self, component_id: str, prop_name: str) -> Any: f"Timeout waiting for get_prop response for {component_id}.{prop_name}" ) from exc + async def close(self, code: int = 1000, reason: str = "Connection closed") -> None: + """Close the WebSocket connection. + + Args: + code: WebSocket close code (default 1000 for normal closure) + reason: Human-readable reason for closing + """ + await self._websocket.close(code=code, reason=reason) + _current_request_var = ContextVar("dash_current_request", default=None) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 94e00d1bfc..283a3414e2 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -412,3 +412,15 @@ async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: prop_name: The property name to update value: The new value to set """ + + @abstractmethod + async def close(self, code: int = 1000, reason: str = "Connection closed") -> None: + """Close the WebSocket connection. + + Allows developers to forcibly disconnect a client, e.g., on suspicious + activity, session revocation, or policy violation. + + Args: + code: WebSocket close code (default 1000 for normal closure) + reason: Human-readable reason for closing + """ From 206c9d81d0b75a302d44c09eaac9f64abbcb21b4 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 14 Apr 2026 14:32:10 -0400 Subject: [PATCH 126/166] add websocket hooks --- dash/_hooks.py | 56 +++++++++++++++++++++++++++++++++++++++ dash/backends/_fastapi.py | 50 ++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/dash/_hooks.py b/dash/_hooks.py index 1631b40ddc..f260b1fcb0 100644 --- a/dash/_hooks.py +++ b/dash/_hooks.py @@ -49,6 +49,8 @@ def __init__(self) -> None: "index": [], "custom_data": [], "dev_tools": [], + "websocket_connect": [], + "websocket_message": [], } self._js_dist: _t.List[_t.Any] = [] self._css_dist: _t.List[_t.Any] = [] @@ -244,6 +246,60 @@ def devtool( } ) + def websocket_connect(self, priority: _t.Optional[int] = None, final: bool = False): + """ + Register a WebSocket connection validation hook. + + The hook receives the WebSocket object and should return: + - True (or any truthy value): Allow the connection + - False: Reject with default code (4001) and reason + - tuple (code, reason): Reject with custom close code and reason + + Hooks can be sync or async. + + Example: + @hooks.websocket_connect() + async def validate_session(websocket): + session_id = websocket.cookies.get("session_id") + if not session_id: + return (4001, "No session cookie") + if not await is_valid_session(session_id): + return (4002, "Invalid session") + return True + """ + + def decorator(func: _t.Callable): + self.add_hook("websocket_connect", func, priority=priority, final=final) + return func + + return decorator + + def websocket_message(self, priority: _t.Optional[int] = None, final: bool = False): + """ + Register a WebSocket message validation hook. + + The hook receives the WebSocket object and message dict, and should return: + - True (or any truthy value): Allow the message + - False: Disconnect with default code (4001) and reason + - tuple (code, reason): Disconnect with custom close code and reason + + Hooks can be sync or async. + + Example: + @hooks.websocket_message() + async def validate_session(websocket, message): + session_id = websocket.cookies.get("session_id") + if not await is_session_active(session_id): + return (4002, "Session expired") + return True + """ + + def decorator(func: _t.Callable): + self.add_hook("websocket_message", func, priority=priority, final=final) + return func + + return decorator + hooks = _Hooks() diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 257e8cb23d..1e2186c384 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -700,6 +700,33 @@ async def timing_headers_middleware(request: Request, call_next): headers.append("Server-Timing", value) return response + async def _run_ws_hooks( + self, hooks, websocket: "WebSocket", *args, default_reason: str = "Rejected" + ) -> tuple | None: + """Run WebSocket hooks and return rejection tuple or None if all pass. + + Args: + hooks: List of hooks to run + websocket: The WebSocket connection + *args: Additional arguments to pass to hooks + default_reason: Default reason if hook returns False + + Returns: + None if all hooks pass, or (code, reason) tuple for rejection + """ + for hook in hooks: + try: + result = hook(websocket, *args) + if inspect.iscoroutine(result): + result = await result + if result is False: + return (4001, default_reason) + if isinstance(result, tuple) and len(result) == 2: + return result + except Exception: # pylint: disable=broad-exception-caught + return (4001, "Authentication error") + return None + def serve_websocket_callback(self, dash_app: "Dash"): """Set up the WebSocket endpoint for callback handling. @@ -737,6 +764,17 @@ async def websocket_handler(websocket: WebSocket): await websocket.close(code=4003, reason=error) return + # Call websocket_connect hooks (before accept) + # pylint: disable=protected-access + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_connect"), + websocket, + default_reason="Connection rejected", + ) + if rejection: + await websocket.close(code=rejection[0], reason=rejection[1]) + return + await websocket.accept() # Track pending get_props requests @@ -745,6 +783,18 @@ async def websocket_handler(websocket: WebSocket): try: while True: message = await websocket.receive_json() + + # Call websocket_message hooks + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_message"), + websocket, + message, + default_reason="Message rejected", + ) + if rejection: + await websocket.close(code=rejection[0], reason=rejection[1]) + return + msg_type = message.get("type") renderer_id = message.get("rendererId") From fed2311dc9c85118ae74edef2fd46d399cd299be Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 16 Apr 2026 11:03:15 -0400 Subject: [PATCH 127/166] Add websocket inactivity disconnect --- .../src/MessageRouter.ts | 45 +++++++++---- .../src/WebSocketManager.ts | 66 +++++++++++++++---- @plotly/dash-websocket-worker/src/types.ts | 1 + @plotly/dash-websocket-worker/src/worker.ts | 8 ++- dash/dash-renderer/src/config.ts | 1 + .../src/observers/websocketObserver.ts | 6 +- dash/dash-renderer/src/utils/workerClient.ts | 33 ++++++++-- dash/dash.py | 3 + 8 files changed, 130 insertions(+), 33 deletions(-) diff --git a/@plotly/dash-websocket-worker/src/MessageRouter.ts b/@plotly/dash-websocket-worker/src/MessageRouter.ts index 1082c3e6c1..68a9f4bfc2 100644 --- a/@plotly/dash-websocket-worker/src/MessageRouter.ts +++ b/@plotly/dash-websocket-worker/src/MessageRouter.ts @@ -97,8 +97,14 @@ export class MessageRouter { * @param message The message to broadcast */ public broadcastToRenderers(message: WorkerMessage): void { - for (const [, port] of this.renderers) { - port.postMessage(message); + for (const [rendererId, port] of this.renderers) { + try { + port.postMessage(message); + } catch (error) { + // Port may be closed if tab was closed + console.warn(`Failed to send to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } } } @@ -109,10 +115,15 @@ export class MessageRouter { public notifyConnected(rendererId: string): void { const port = this.renderers.get(rendererId); if (port) { - port.postMessage({ - type: WorkerMessageType.CONNECTED, - rendererId - }); + try { + port.postMessage({ + type: WorkerMessageType.CONNECTED, + rendererId + }); + } catch (error) { + console.warn(`Failed to notify renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } } } @@ -137,11 +148,16 @@ export class MessageRouter { public notifyError(rendererId: string, message: string, code?: string): void { const port = this.renderers.get(rendererId); if (port) { - port.postMessage({ - type: WorkerMessageType.ERROR, - rendererId, - payload: { message, code } - }); + try { + port.postMessage({ + type: WorkerMessageType.ERROR, + rendererId, + payload: { message, code } + }); + } catch (error) { + console.warn(`Failed to send error to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } } } @@ -170,7 +186,12 @@ export class MessageRouter { private forwardToRenderer(rendererId: string, message: WorkerMessage): void { const port = this.renderers.get(rendererId); if (port) { - port.postMessage(message); + try { + port.postMessage(message); + } catch (error) { + console.warn(`Failed to forward to renderer ${rendererId}, removing`); + this.renderers.delete(rendererId); + } } else { console.warn(`Renderer ${rendererId} not found for message`); } diff --git a/@plotly/dash-websocket-worker/src/WebSocketManager.ts b/@plotly/dash-websocket-worker/src/WebSocketManager.ts index e0353e7e94..5f11086945 100644 --- a/@plotly/dash-websocket-worker/src/WebSocketManager.ts +++ b/@plotly/dash-websocket-worker/src/WebSocketManager.ts @@ -12,6 +12,8 @@ interface WebSocketConfig { heartbeatInterval: number; /** Heartbeat timeout (ms) */ heartbeatTimeout: number; + /** Inactivity timeout (ms) - 0 to disable */ + inactivityTimeout: number; } const DEFAULT_CONFIG: WebSocketConfig = { @@ -19,7 +21,8 @@ const DEFAULT_CONFIG: WebSocketConfig = { initialRetryDelay: 1000, maxRetryDelay: 30000, heartbeatInterval: 30000, - heartbeatTimeout: 10000 + heartbeatTimeout: 10000, + inactivityTimeout: 300000 // 5 minutes default }; /** @@ -33,6 +36,7 @@ export class WebSocketManager { private retryTimeout: ReturnType | null = null; private heartbeatInterval: ReturnType | null = null; private heartbeatTimeout: ReturnType | null = null; + private lastActivityTime: number = Date.now(); private messageQueue: string[] = []; private isConnecting = false; @@ -49,6 +53,15 @@ export class WebSocketManager { this.config = { ...DEFAULT_CONFIG, ...config }; } + /** + * Update configuration options. + * Only updates the provided options, keeping others unchanged. + * @param config Partial configuration to merge + */ + public setConfig(config: Partial): void { + this.config = { ...this.config, ...config }; + } + /** * Connect to the WebSocket server. * @param serverUrl The WebSocket server URL @@ -74,27 +87,39 @@ export class WebSocketManager { */ public disconnect(): void { this.cleanup(); - if (this.ws) { + if (this.ws && this.ws.readyState === WebSocket.OPEN) { this.ws.close(1000, 'Client disconnect'); - this.ws = null; } + this.ws = null; this.serverUrl = null; this.retryCount = 0; } /** * Send a message through the WebSocket connection. - * If not connected, queues the message for later delivery. + * If not connected, queues the message and triggers reconnection. * @param message The message to send */ public send(message: unknown): void { const data = JSON.stringify(message); + // Track activity for non-heartbeat messages + const msgObj = message as { type?: string }; + if (msgObj.type !== 'heartbeat') { + this.lastActivityTime = Date.now(); + } + if (this.ws && this.ws.readyState === WebSocket.OPEN) { this.ws.send(data); } else { // Queue message for when connection is established this.messageQueue.push(data); + + // Trigger reconnect if we have a server URL but aren't connected/connecting + if (this.serverUrl && !this.isConnecting) { + this.isConnecting = true; + this.createConnection(); + } } } @@ -125,6 +150,7 @@ export class WebSocketManager { private handleOpen(): void { this.isConnecting = false; this.retryCount = 0; + this.lastActivityTime = Date.now(); // Flush queued messages while (this.messageQueue.length > 0) { @@ -134,7 +160,7 @@ export class WebSocketManager { } } - // Start heartbeat + // Start heartbeat (also handles inactivity check) this.startHeartbeat(); if (this.onOpen) { @@ -152,8 +178,10 @@ export class WebSocketManager { this.onClose(reason); } - // Only reconnect if we haven't explicitly disconnected - if (this.serverUrl && event.code !== 1000) { + // Only reconnect if: + // - We haven't explicitly disconnected (code 1000) + // - It's not an inactivity timeout (code 4001) + if (this.serverUrl && event.code !== 1000 && event.code !== 4001) { this.scheduleReconnect(); } } @@ -162,12 +190,15 @@ export class WebSocketManager { try { const data = JSON.parse(event.data); - // Handle heartbeat acknowledgment + // Handle heartbeat acknowledgment - does NOT count as activity if (data.type === 'heartbeat_ack') { this.clearHeartbeatTimeout(); return; } + // Track activity for actual callback messages + this.lastActivityTime = Date.now(); + if (this.onMessage) { this.onMessage(data); } @@ -214,10 +245,21 @@ export class WebSocketManager { this.stopHeartbeat(); this.heartbeatInterval = setInterval(() => { - if (this.ws && this.ws.readyState === WebSocket.OPEN) { - this.ws.send(JSON.stringify({ type: 'heartbeat' })); - this.setHeartbeatTimeout(); + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + return; } + + // Check for inactivity timeout + if (this.config.inactivityTimeout > 0) { + const timeSinceActivity = Date.now() - this.lastActivityTime; + if (timeSinceActivity >= this.config.inactivityTimeout) { + this.ws.close(4001, 'Inactivity timeout'); + return; + } + } + + this.ws.send(JSON.stringify({ type: 'heartbeat' })); + this.setHeartbeatTimeout(); }, this.config.heartbeatInterval); } @@ -234,7 +276,7 @@ export class WebSocketManager { this.heartbeatTimeout = setTimeout(() => { // Heartbeat timeout - connection may be dead - if (this.ws) { + if (this.ws && this.ws.readyState === WebSocket.OPEN) { this.ws.close(4000, 'Heartbeat timeout'); } }, this.config.heartbeatTimeout); diff --git a/@plotly/dash-websocket-worker/src/types.ts b/@plotly/dash-websocket-worker/src/types.ts index 36fadf03a0..fac282b5e1 100644 --- a/@plotly/dash-websocket-worker/src/types.ts +++ b/@plotly/dash-websocket-worker/src/types.ts @@ -34,6 +34,7 @@ export interface ConnectMessage extends WorkerMessage { type: WorkerMessageType.CONNECT; payload: { serverUrl: string; + inactivityTimeout?: number; }; } diff --git a/@plotly/dash-websocket-worker/src/worker.ts b/@plotly/dash-websocket-worker/src/worker.ts index ff84b4fa0f..0e68f0b09a 100644 --- a/@plotly/dash-websocket-worker/src/worker.ts +++ b/@plotly/dash-websocket-worker/src/worker.ts @@ -80,12 +80,18 @@ self.onconnect = (event: MessageEvent) => { const connectMsg = message as ConnectMessage; const rendererId = connectMsg.rendererId; const newServerUrl = connectMsg.payload.serverUrl; + const inactivityTimeout = connectMsg.payload.inactivityTimeout; // Register the renderer router.registerRenderer(rendererId, port); rendererIds.add(rendererId); - console.log(`[DashWSWorker] Renderer ${rendererId} connected`); + console.log(`[DashWSWorker] Renderer ${rendererId} connected, inactivityTimeout: ${inactivityTimeout}`); + + // Update inactivity timeout if provided + if (typeof inactivityTimeout === 'number') { + wsManager.setConfig({ inactivityTimeout }); + } // Connect to server if not already connected if (!wsManager.isConnected) { diff --git a/dash/dash-renderer/src/config.ts b/dash/dash-renderer/src/config.ts index caf49e348d..b9e68eae03 100644 --- a/dash/dash-renderer/src/config.ts +++ b/dash/dash-renderer/src/config.ts @@ -26,6 +26,7 @@ export type DashConfig = { enabled: boolean; url: string; worker_url: string; + inactivity_timeout?: number; }; }; diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts index b86be84b4c..ff5b9d099b 100644 --- a/dash/dash-renderer/src/observers/websocketObserver.ts +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -124,7 +124,11 @@ export async function initializeWebSocket( try { // config.websocket is guaranteed to exist due to wsAvailable check above - await workerClient.connect(config.websocket!.worker_url, wsUrl); + await workerClient.connect( + config.websocket!.worker_url, + wsUrl, + config.websocket!.inactivity_timeout + ); } catch (error) { console.error('[Dash] Failed to connect to WebSocket worker:', error); } diff --git a/dash/dash-renderer/src/utils/workerClient.ts b/dash/dash-renderer/src/utils/workerClient.ts index d0f61d022a..6bc503b16b 100644 --- a/dash/dash-renderer/src/utils/workerClient.ts +++ b/dash/dash-renderer/src/utils/workerClient.ts @@ -80,8 +80,13 @@ class WorkerClient { * Initialize the worker connection. * @param workerUrl URL to the SharedWorker script * @param serverUrl WebSocket server URL + * @param inactivityTimeout Optional inactivity timeout in ms */ - public async connect(workerUrl: string, serverUrl: string): Promise { + public async connect( + workerUrl: string, + serverUrl: string, + inactivityTimeout?: number + ): Promise { if (this.worker) { // Already connected return; @@ -108,7 +113,8 @@ class WorkerClient { type: WorkerMessageType.CONNECT, rendererId: this.rendererId, payload: { - serverUrl + serverUrl, + inactivityTimeout } }); @@ -144,7 +150,11 @@ class WorkerClient { * @param config The Dash config with websocket settings */ public async ensureConnected(config: { - websocket?: {url?: string; worker_url?: string}; + websocket?: { + url?: string; + worker_url?: string; + inactivity_timeout?: number; + }; }): Promise { // Already connected if (this.isConnected) { @@ -172,7 +182,11 @@ class WorkerClient { const host = window.location.host; const wsUrl = `${wsProtocol}//${host}${config.websocket.url}`; - await this.connect(config.websocket.worker_url, wsUrl); + await this.connect( + config.websocket.worker_url, + wsUrl, + config.websocket.inactivity_timeout + ); } /** @@ -181,12 +195,12 @@ class WorkerClient { * @returns Promise that resolves with the callback response */ public async sendCallback(payload: unknown): Promise { - // Wait for connection if one is in progress + // Wait for initial connection if one is in progress if (this.connectionPromise && !this.isConnected) { await this.connectionPromise; } - if (!this.worker || !this.isConnected) { + if (!this.worker) { throw new Error('Worker not connected'); } @@ -316,7 +330,12 @@ export function isWebSocketEnabled(config: { * @param config The Dash config */ export function isWebSocketAvailable(config: { - websocket?: {enabled?: boolean; url?: string; worker_url?: string}; + websocket?: { + enabled?: boolean; + url?: string; + worker_url?: string; + inactivity_timeout?: number; + }; }): boolean { return !!( config.websocket?.url && diff --git a/dash/dash.py b/dash/dash.py index 2762d14a9b..9ae1adc1c4 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -474,6 +474,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches health_endpoint: Optional[str] = None, websocket_callbacks: Optional[bool] = False, allowed_websocket_origins: Optional[List[str]] = None, + websocket_inactivity_timeout: Optional[int] = 300000, **obsolete, ): @@ -623,6 +624,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._background_manager = background_callback_manager self._websocket_callbacks = websocket_callbacks self._allowed_websocket_origins = allowed_websocket_origins or [] + self._websocket_inactivity_timeout = websocket_inactivity_timeout self.logger = logging.getLogger(__name__) @@ -957,6 +959,7 @@ def _config(self): "enabled": bool(self._websocket_callbacks), "url": self.config.requests_pathname_prefix + "_dash-ws-callback", "worker_url": self._get_worker_url(), + "inactivity_timeout": self._websocket_inactivity_timeout, } return config From 7d01f533bebdca2930de236706a93016cc007eb0 Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 16 Apr 2026 12:01:58 -0400 Subject: [PATCH 128/166] no storing of rendererId in session storage --- dash/dash-renderer/src/utils/rendererId.ts | 25 +++++++++++----------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/dash/dash-renderer/src/utils/rendererId.ts b/dash/dash-renderer/src/utils/rendererId.ts index 8168d1576d..b9bfcfd3af 100644 --- a/dash/dash-renderer/src/utils/rendererId.ts +++ b/dash/dash-renderer/src/utils/rendererId.ts @@ -1,23 +1,22 @@ +/** Cached renderer ID for this page instance */ +let cachedRendererId: string | null = null; + /** - * Generate or retrieve a unique renderer ID for this browser tab/session. + * Generate a unique renderer ID for this page instance. * - * The ID is stored in sessionStorage to persist across page reloads - * but remain unique per tab. + * Each page load gets a fresh ID to avoid conflicts with stale + * connections in the SharedWorker after page reloads. */ export function getRendererId(): string { - const key = '__dash_renderer_id'; - let id = sessionStorage.getItem(key); - - if (!id) { - // Generate a unique ID + if (!cachedRendererId) { if (typeof crypto !== 'undefined' && crypto.randomUUID) { - id = crypto.randomUUID(); + cachedRendererId = crypto.randomUUID(); } else { // Fallback for older browsers - id = `${Date.now()}-${Math.random().toString(36).slice(2)}`; + cachedRendererId = `${Date.now()}-${Math.random() + .toString(36) + .slice(2)}`; } - sessionStorage.setItem(key, id); } - - return id; + return cachedRendererId; } From b91c5d53304a64f2a9f949014d6215d07421ad75 Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 16 Apr 2026 12:14:08 -0400 Subject: [PATCH 129/166] rename allowed_websocket_origins -> websocket_allowed_origins --- dash/backends/_fastapi.py | 2 +- dash/dash.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 1e2186c384..9099187cae 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -738,7 +738,7 @@ def serve_websocket_callback(self, dash_app: "Dash"): # Get allowed origins from dash app config allowed_origins = getattr( - dash_app, "_allowed_websocket_origins", [] + dash_app, "_websocket_allowed_origins", [] ) # pylint: disable=protected-access def validate_origin(origin: str | None, host: str | None) -> str | None: diff --git a/dash/dash.py b/dash/dash.py index 9ae1adc1c4..c4ff93ff36 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -473,7 +473,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches use_async: Optional[bool] = None, health_endpoint: Optional[str] = None, websocket_callbacks: Optional[bool] = False, - allowed_websocket_origins: Optional[List[str]] = None, + websocket_allowed_origins: Optional[List[str]] = None, websocket_inactivity_timeout: Optional[int] = 300000, **obsolete, ): @@ -623,7 +623,7 @@ def __init__( # pylint: disable=too-many-statements, too-many-branches self._background_manager = background_callback_manager self._websocket_callbacks = websocket_callbacks - self._allowed_websocket_origins = allowed_websocket_origins or [] + self._websocket_allowed_origins = websocket_allowed_origins or [] self._websocket_inactivity_timeout = websocket_inactivity_timeout self.logger = logging.getLogger(__name__) From 96824a927a6c664d61a842350b7c2474b93378b5 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 17 Apr 2026 09:24:29 -0400 Subject: [PATCH 130/166] update architecture docs with websocket details --- .ai/ARCHITECTURE.md | 176 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) diff --git a/.ai/ARCHITECTURE.md b/.ai/ARCHITECTURE.md index afff17394a..db40397782 100644 --- a/.ai/ARCHITECTURE.md +++ b/.ai/ARCHITECTURE.md @@ -723,6 +723,11 @@ Special handling for Colab: - `background_callback_manager` - DiskcacheManager or CeleryManager - `on_error` - Global callback error handler +**WebSocket Callbacks:** +- `websocket_callbacks` - Enable WebSocket for all callbacks (default: `False`). Requires FastAPI backend. +- `websocket_allowed_origins` - List of allowed origins for WebSocket connections +- `websocket_inactivity_timeout` - Disconnect WebSocket after inactivity period in ms (default: `300000` = 5 minutes). Set to `0` to disable. + ### app.run() Parameters - `host` - Server IP (default: `"127.0.0.1"`, env: `HOST`) @@ -861,6 +866,177 @@ async def async_background(n_clicks): Both DiskcacheManager and CeleryManager support async functions via `asyncio.run()`. +## WebSocket Callbacks + +WebSocket callbacks use a persistent WebSocket connection instead of HTTP POST for callback execution. This reduces latency and connection overhead for applications with frequent callbacks. + +### Requirements + +- **FastAPI backend required**: WebSocket callbacks only work with FastAPI +- **SharedWorker support**: Modern browsers (not IE) + +### Usage + +**Enable globally for all callbacks:** +```python +from fastapi import FastAPI +from dash import Dash + +server = FastAPI() +app = Dash(__name__, server=server, websocket_callbacks=True) +``` + +**Enable per-callback:** +```python +@app.callback( + Output('output', 'children'), + Input('input', 'value'), + websocket=True # Use WebSocket for this callback only +) +def update(value): + return f"Value: {value}" +``` + +### Configuration + +```python +app = Dash( + __name__, + server=server, + websocket_callbacks=True, + websocket_inactivity_timeout=300000, # 5 minutes (default) + websocket_allowed_origins=['https://example.com'], +) +``` + +- **`websocket_callbacks`** - Enable WebSocket for all callbacks (default: `False`) +- **`websocket_inactivity_timeout`** - Close WebSocket after period of inactivity in milliseconds (default: `300000` = 5 minutes). Heartbeats do not count as activity. Set to `0` to disable timeout. Connection automatically reconnects when needed. +- **`websocket_allowed_origins`** - List of allowed origins for WebSocket connections (security) + +### Architecture + +``` +┌─────────────────────────────────────────────────────────────────────────┐ +│ Browser Tab 1 Browser Tab 2 │ +│ ┌─────────────┐ ┌─────────────┐ │ +│ │ Renderer │ │ Renderer │ │ +│ └──────┬──────┘ └──────┬──────┘ │ +│ │ postMessage │ postMessage │ +│ └────────────┬───────────────────────┘ │ +│ ▼ │ +│ ┌─────────────────────┐ │ +│ │ SharedWorker │ (one per origin) │ +│ │ dash-ws-worker │ │ +│ └──────────┬──────────┘ │ +└────────────────────│────────────────────────────────────────────────────┘ + │ WebSocket + ▼ +┌─────────────────────────────────────────────────────────────────────────┐ +│ Server (FastAPI) │ +│ WebSocket Endpoint: /_dash-ws-callback │ +└─────────────────────────────────────────────────────────────────────────┘ +``` + +**Connection & Reconnection Flow:** +``` +Renderer SharedWorker Server + │ │ │ + │──[CONNECT]──────────────────>│ │ + │ │──[WebSocket Connect]──>│ + │<─[CONNECTED]─────────────────│<─[Connected]───────────│ + │ │ │ + │──[CALLBACK_REQUEST]─────────>│──[callback request]───>│ + │<─[CALLBACK_RESPONSE]─────────│<─[callback response]───│ + │ │ │ + │ (inactivity) │ (heartbeat check) │ + │ │──[close 4001]─────────>│ + │<─[DISCONNECTED]──────────────│ │ + │ │ │ + │──[CALLBACK_REQUEST]─────────>│──[reconnect + send]───>│ + │<─[CALLBACK_RESPONSE]─────────│<─[response]────────────│ +``` + +- **SharedWorker**: Single WebSocket connection shared across browser tabs +- **Heartbeat**: Periodic ping/pong to detect dead connections (30s interval) +- **Inactivity timeout**: Closes connection after no actual callback activity (not heartbeats) +- **Auto-reconnect**: Reconnects automatically when a callback is triggered after timeout + +### Long-Running Callbacks with set_props/get_props + +WebSocket callbacks can stream updates to the client during execution using `set_props()` and read current component values using `ctx.get_websocket()`: + +```python +import asyncio +from dash import callback, Output, Input, set_props, ctx + +@callback( + Output('result', 'children'), + Input('start-btn', 'n_clicks'), + prevent_initial_call=True +) +async def long_running_task(n_clicks): + ws = ctx.get_websocket() + if not ws: + return "WebSocket not available" + + # Stream progress updates to the client + for i in range(100): + await asyncio.sleep(0.1) + set_props('progress-bar', {'value': i + 1}) + set_props('status', {'children': f'Processing step {i + 1}/100...'}) + + # Read current value from another component + current_value = await ws.get_prop('input-field', 'value') + + return f"Completed! Input was: {current_value}" +``` + +**API:** +- `set_props(component_id, props_dict)` - Stream prop updates immediately to client +- `ctx.get_websocket()` - Get WebSocket interface (returns `None` if not in WS context) +- `await ws.get_prop(component_id, prop_name)` - Read current prop value from client +- `await ws.set_prop(component_id, prop_name, value)` - Set single prop (async version) +- `await ws.close(code, reason)` - Close the WebSocket connection + +### Connection Hooks + +Use hooks to validate connections and messages: + +```python +from dash import Dash, hooks + +@hooks.websocket_connect() +async def validate_connection(websocket): + """Validate WebSocket connection before accepting.""" + session_id = websocket.cookies.get("session_id") + if not session_id: + return (4001, "No session cookie") + if not await is_valid_session(session_id): + return (4002, "Invalid session") + return True # Allow connection + +@hooks.websocket_message() +async def validate_message(websocket, message): + """Validate each WebSocket message.""" + session_id = websocket.cookies.get("session_id") + if not await is_session_active(session_id): + return (4002, "Session expired") + return True # Allow message +``` + +**Hook Return Values:** +- `True` (or truthy) - Allow connection/message +- `False` - Reject with default code (4001) +- `(code, reason)` - Reject with custom close code and reason + +### Key Files + +- `dash/dash.py` - WebSocket config in `_generate_config()` +- `dash/dash-renderer/src/utils/workerClient.ts` - Browser-side SharedWorker client +- `@plotly/dash-websocket-worker/src/WebSocketManager.ts` - WebSocket connection management +- `@plotly/dash-websocket-worker/src/worker.ts` - SharedWorker entry point +- `dash/backends/_fastapi.py` - Server-side WebSocket handler + ## Security ### XSS Protection From 3a14ef33297f39cd81b96d3ad92f25f63d3df021 Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 23 Apr 2026 13:51:40 -0400 Subject: [PATCH 131/166] add websocket callback validation and tests --- .github/workflows/testing.yml | 115 +++++++---- dash/_callback.py | 1 + dash/_validate.py | 30 +++ dash/backends/_fastapi.py | 9 + dash/exceptions.py | 4 + tests/websocket/__init__.py | 1 + tests/websocket/conftest.py | 12 ++ tests/websocket/test_ws_basic.py | 262 +++++++++++++++++++++++++ tests/websocket/test_ws_hooks.py | 267 ++++++++++++++++++++++++++ tests/websocket/test_ws_inactivity.py | 194 +++++++++++++++++++ tests/websocket/test_ws_origin.py | 154 +++++++++++++++ tests/websocket/test_ws_props.py | 267 ++++++++++++++++++++++++++ tests/websocket/test_ws_validate.py | 58 ++++++ wsapp.py | 1 + 14 files changed, 1339 insertions(+), 36 deletions(-) create mode 100644 tests/websocket/__init__.py create mode 100644 tests/websocket/conftest.py create mode 100644 tests/websocket/test_ws_basic.py create mode 100644 tests/websocket/test_ws_hooks.py create mode 100644 tests/websocket/test_ws_inactivity.py create mode 100644 tests/websocket/test_ws_origin.py create mode 100644 tests/websocket/test_ws_props.py create mode 100644 tests/websocket/test_ws_validate.py diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index fc518b1aaa..b755510fb6 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -19,6 +19,7 @@ jobs: backend_cb_changed: ${{ steps.filter.outputs.backend_paths }} dcc_paths_changed: ${{ steps.filter.outputs.dcc_related_paths }} html_paths_changed: ${{ steps.filter.outputs.html_related_paths }} + websocket_changed: ${{ steps.filter.outputs.websocket_paths }} steps: - name: Checkout repository uses: actions/checkout@v4 @@ -48,6 +49,16 @@ jobs: backend_paths: - 'dash/backends/**' - 'tests/backend_tests/**' + websocket_paths: + - 'dash/backends/_fastapi.py' + - 'dash/_callback.py' + - 'dash/_callback_context.py' + - 'dash/_hooks.py' + - 'dash/dash.py' + - '@dash-websocket-worker/**' + - 'dash/dash-renderer/src/**' + - 'tests/websocket/**' + - 'requirements/**' lint-unit: name: Lint & Unit Tests (Python ${{ matrix.python-version }}) @@ -366,7 +377,7 @@ jobs: - name: Set up Node.js uses: actions/setup-node@v4 with: - node-version: '20' + node-version: '24' cache: 'npm' - name: Install Node.js dependencies @@ -377,6 +388,7 @@ jobs: with: python-version: ${{ matrix.python-version }} cache: 'pip' + cache-dependency-path: requirements/*.txt - name: Download built Dash packages uses: actions/download-artifact@v4 @@ -387,43 +399,13 @@ jobs: - name: Install Dash packages run: | python -m pip install --upgrade pip wheel - python -m pip install "setuptools<78.0.0" - python -m pip install "selenium==4.32.0" + python -m pip install "setuptools<80.0.0" find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[async,ci,testing,dev,celery,diskcache,fastapi,quart]"' \; - - name: Install Google Chrome - run: | - sudo apt-get update - sudo apt-get install -y google-chrome-stable - - - name: Install ChromeDriver - run: | - echo "Determining Chrome version..." - CHROME_BROWSER_VERSION=$(google-chrome --version) - echo "Installed Chrome Browser version: $CHROME_BROWSER_VERSION" - CHROME_MAJOR_VERSION=$(echo "$CHROME_BROWSER_VERSION" | cut -f 3 -d ' ' | cut -f 1 -d '.') - echo "Detected Chrome Major version: $CHROME_MAJOR_VERSION" - if [ "$CHROME_MAJOR_VERSION" -ge 115 ]; then - echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using CfT endpoint..." - CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://googlechromelabs.github.io/chrome-for-testing/LATEST_RELEASE_${CHROME_MAJOR_VERSION}") - if [ -z "$CHROMEDRIVER_VERSION_STRING" ]; then - echo "Could not automatically find ChromeDriver version for Chrome $CHROME_MAJOR_VERSION via LATEST_RELEASE. Please check CfT endpoints." - exit 1 - fi - CHROMEDRIVER_URL="https://edgedl.me.gvt1.com/edgedl/chrome/chrome-for-testing/${CHROMEDRIVER_VERSION_STRING}/linux64/chromedriver-linux64.zip" - else - echo "Fetching ChromeDriver version for Chrome $CHROME_MAJOR_VERSION using older method..." - CHROMEDRIVER_VERSION_STRING=$(curl -sS "https://chromedriver.storage.googleapis.com/LATEST_RELEASE_${CHROME_MAJOR_VERSION}") - CHROMEDRIVER_URL="https://chromedriver.storage.googleapis.com/${CHROMEDRIVER_VERSION_STRING}/chromedriver_linux64.zip" - fi - echo "Using ChromeDriver version string: $CHROMEDRIVER_VERSION_STRING" - echo "Downloading ChromeDriver from: $CHROMEDRIVER_URL" - wget -q -O chromedriver.zip "$CHROMEDRIVER_URL" - unzip -o chromedriver.zip -d /tmp/ - sudo mv /tmp/chromedriver-linux64/chromedriver /usr/local/bin/chromedriver || sudo mv /tmp/chromedriver /usr/local/bin/chromedriver - sudo chmod +x /usr/local/bin/chromedriver - echo "/usr/local/bin" >> $GITHUB_PATH - shell: bash + - name: Setup Chrome and ChromeDriver + uses: browser-actions/setup-chrome@v1 + with: + chrome-version: stable - name: Build/Setup test components run: npm run setup-tests.py @@ -558,6 +540,67 @@ jobs: path: components/dash-table/test-reports/ retention-days: 7 + websocket-tests: + name: WebSocket Tests (Python ${{ matrix.python-version }}) + needs: [build, changes_filter] + if: | + (github.event_name == 'push' && (github.ref == 'refs/heads/master' || github.ref == 'refs/heads/dev')) || + needs.changes_filter.outputs.websocket_changed == 'true' + timeout-minutes: 30 + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.12"] + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Node.js + uses: actions/setup-node@v4 + with: + node-version: '24' + cache: 'npm' + + - name: Install Node.js dependencies + run: npm ci + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: requirements/*.txt + + - name: Download built Dash packages + uses: actions/download-artifact@v4 + with: + name: dash-packages + path: packages/ + + - name: Install Dash packages + run: | + python -m pip install --upgrade pip wheel + python -m pip install "setuptools<80.0.0" + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[ci,testing,dev,fastapi]"' \; + + - name: Setup Chrome and ChromeDriver + uses: browser-actions/setup-chrome@v1 + with: + chrome-version: stable + + - name: Build/Setup test components + run: npm run setup-tests.py + + - name: Run WebSocket tests + run: | + mkdir wstests + cp -r tests wstests/tests + cd wstests + touch __init__.py + pytest --headless --nopercyfinalize tests/websocket -v -s + test-main: name: Main Dash Tests (Python ${{ matrix.python-version }}, Group ${{ matrix.test-group }}) needs: build diff --git a/dash/_callback.py b/dash/_callback.py index b27e9a18b9..718a016d82 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -319,6 +319,7 @@ def insert_callback( "manager": manager, "allow_dynamic_callbacks": dynamic_creator, "no_output": no_output, + "websocket": websocket, } callback_list.append(callback_spec) diff --git a/dash/_validate.py b/dash/_validate.py index fb5689f850..b80c61df2c 100644 --- a/dash/_validate.py +++ b/dash/_validate.py @@ -629,3 +629,33 @@ def check_backend(backend, inferred_backend): raise ValueError( f"Conflict between provided backend '{backend}' and server type '{inferred_backend}'." ) + + +def validate_websocket_callback_request( + callback_id, callback_map, websocket_callbacks_enabled +): + """Validate a WebSocket callback request at runtime. + + Called by WebSocket handlers to verify that a callback received via WebSocket + is actually allowed to use WebSocket transport. + + Args: + callback_id: The callback output ID from the request + callback_map: The app's callback_map dictionary + websocket_callbacks_enabled: Whether websocket_callbacks=True at app level + + Raises: + WebSocketCallbackError: If the callback is not websocket-enabled + """ + # If global websocket_callbacks is enabled, all callbacks can use WebSocket + if websocket_callbacks_enabled: + return + + # Otherwise, check if this specific callback has websocket=True + cb = callback_map.get(callback_id, {}) + if not cb.get("websocket"): + raise exceptions.WebSocketCallbackError( + f"Callback '{callback_id}' received via WebSocket but does not have " + f"websocket=True. Either enable websocket_callbacks=True globally " + f"or add websocket=True to this callback." + ) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 9099187cae..2ed642d546 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -852,6 +852,15 @@ async def _execute_ws_callback( """ payload = message.get("payload", {}) + # Validate that the callback is allowed to use WebSocket transport + # pylint: disable=protected-access + _validate.validate_websocket_callback_request( + payload.get("output"), + dash_app.callback_map, + dash_app._websocket_callbacks, + ) + # pylint: enable=protected-access + # Create WebSocket callback context cb_ctx = self._create_ws_context( dash_app, websocket, payload, pending_get_props diff --git a/dash/exceptions.py b/dash/exceptions.py index 019f0d2726..40e882c409 100644 --- a/dash/exceptions.py +++ b/dash/exceptions.py @@ -113,3 +113,7 @@ class HookError(DashException): class AppNotFoundError(DashException): pass + + +class WebSocketCallbackError(CallbackException): + pass diff --git a/tests/websocket/__init__.py b/tests/websocket/__init__.py new file mode 100644 index 0000000000..1116026afc --- /dev/null +++ b/tests/websocket/__init__.py @@ -0,0 +1 @@ +# WebSocket callback tests diff --git a/tests/websocket/conftest.py b/tests/websocket/conftest.py new file mode 100644 index 0000000000..d72fcd04dc --- /dev/null +++ b/tests/websocket/conftest.py @@ -0,0 +1,12 @@ +import pytest +from dash import hooks + + +@pytest.fixture +def ws_hook_cleanup(): + """Clean up WebSocket hooks after each test.""" + yield + hooks._ns["websocket_connect"] = [] + hooks._ns["websocket_message"] = [] + hooks._finals.pop("websocket_connect", None) + hooks._finals.pop("websocket_message", None) diff --git a/tests/websocket/test_ws_basic.py b/tests/websocket/test_ws_basic.py new file mode 100644 index 0000000000..80a2e7d975 --- /dev/null +++ b/tests/websocket/test_ws_basic.py @@ -0,0 +1,262 @@ +""" +Basic WebSocket callback tests. + +Tests: +- Per-callback websocket (websocket=True) +- Global websocket callbacks (websocket_callbacks=True) +- Mixed HTTP and WebSocket callbacks +""" + +from dash import Dash, html, dcc, Input, Output, State, ctx + + +def test_ws001_per_callback_websocket(dash_duo): + """Test single callback with websocket=True on FastAPI backend.""" + app = Dash(__name__, backend="fastapi") + + app.layout = html.Div( + [ + html.H1("Per-Callback WebSocket Test"), + dcc.Input(id="ws-input", type="text", placeholder="Type here..."), + html.Div(id="ws-output"), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"WS: {value or ''}" + + dash_duo.start_server(app) + + # Test initial state + dash_duo.wait_for_text_to_equal("#ws-output", "WS: ") + + # Type into the input and verify callback executes + input_elem = dash_duo.find_element("#ws-input") + input_elem.send_keys("hello") + + dash_duo.wait_for_text_to_equal("#ws-output", "WS: hello") + assert dash_duo.get_logs() == [] + + +def test_ws002_global_websocket_callbacks(dash_duo): + """Test global websocket_callbacks=True enables WebSocket for all callbacks.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + ) + + app.layout = html.Div( + [ + html.Button("Click me", id="btn", n_clicks=0), + html.Div(id="output"), + dcc.Input(id="input", type="text"), + html.Div(id="input-output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks} times" + + @app.callback(Output("input-output", "children"), Input("input", "value")) + def on_input(value): + return f"Input: {value or ''}" + + dash_duo.start_server(app) + + # Test button callback + dash_duo.wait_for_text_to_equal("#output", "Clicked 0 times") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1 times") + + # Test input callback + dash_duo.find_element("#input").send_keys("test") + dash_duo.wait_for_text_to_equal("#input-output", "Input: test") + + assert dash_duo.get_logs() == [] + + +def test_ws003_mixed_http_and_websocket(dash_duo): + """Test mixing WebSocket and HTTP callbacks in the same app.""" + app = Dash(__name__, backend="fastapi") + + app.layout = html.Div( + [ + # WebSocket callback section + html.Div( + [ + dcc.Input(id="ws-input", type="text"), + html.Div(id="ws-output"), + ] + ), + # HTTP callback section (default) + html.Div( + [ + dcc.Input(id="http-input", type="text"), + html.Div(id="http-output"), + ] + ), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"[WebSocket] {value or ''}" + + @app.callback(Output("http-output", "children"), Input("http-input", "value")) + def http_callback(value): + return f"[HTTP] {value or ''}" + + dash_duo.start_server(app) + + # Test WebSocket callback + dash_duo.find_element("#ws-input").send_keys("ws-test") + dash_duo.wait_for_text_to_equal("#ws-output", "[WebSocket] ws-test") + + # Test HTTP callback + dash_duo.find_element("#http-input").send_keys("http-test") + dash_duo.wait_for_text_to_equal("#http-output", "[HTTP] http-test") + + assert dash_duo.get_logs() == [] + + +def test_ws004_websocket_with_state(dash_duo): + """Test WebSocket callback with State inputs.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + dcc.Input(id="state-input", type="text", value="initial"), + html.Button("Submit", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn", "n_clicks"), + State("state-input", "value"), + ) + def on_click(n_clicks, state_value): + if not n_clicks: + return "Click to submit" + return f"Submitted: {state_value}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to submit") + + # Update state input + state_input = dash_duo.find_element("#state-input") + dash_duo.clear_input(state_input) + state_input.send_keys("new value") + + # Click button to trigger callback + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Submitted: new value") + + assert dash_duo.get_logs() == [] + + +def test_ws005_websocket_context_available(dash_duo): + """Test that WebSocket context is available in WebSocket callbacks.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Check context", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def check_context(n_clicks): + if not n_clicks: + return "Click to check" + ws = ctx.get_websocket + if ws is not None: + return "WebSocket context available" + return "No WebSocket context" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to check") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "WebSocket context available") + + assert dash_duo.get_logs() == [] + + +def test_ws006_websocket_multiple_outputs(dash_duo): + """Test WebSocket callback with multiple outputs.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update", id="btn"), + html.Div(id="output1"), + html.Div(id="output2"), + html.Div(id="output3"), + ] + ) + + @app.callback( + Output("output1", "children"), + Output("output2", "children"), + Output("output3", "children"), + Input("btn", "n_clicks"), + ) + def multi_output(n_clicks): + n = n_clicks or 0 + return f"First: {n}", f"Second: {n * 2}", f"Third: {n * 3}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output1", "First: 0") + dash_duo.wait_for_text_to_equal("#output2", "Second: 0") + dash_duo.wait_for_text_to_equal("#output3", "Third: 0") + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output1", "First: 1") + dash_duo.wait_for_text_to_equal("#output2", "Second: 2") + dash_duo.wait_for_text_to_equal("#output3", "Third: 3") + + assert dash_duo.get_logs() == [] + + +def test_ws007_websocket_slider_callback(dash_duo): + """Test WebSocket callback with dcc.Slider component.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + dcc.Slider(id="slider", min=0, max=100, value=50, step=10), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("slider", "value")) + def update_output(value): + return f"Slider value: {value}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Slider value: 50") + + # Move slider - find slider handle and drag it + slider = dash_duo.find_element("#slider .rc-slider-handle") + dash_duo.driver.execute_script( + "arguments[0].dispatchEvent(new MouseEvent('mousedown', {bubbles: true}));" + "document.dispatchEvent(new MouseEvent('mouseup', {bubbles: true}));", + slider, + ) + + # The callback should still work + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_hooks.py b/tests/websocket/test_ws_hooks.py new file mode 100644 index 0000000000..f37a8941cc --- /dev/null +++ b/tests/websocket/test_ws_hooks.py @@ -0,0 +1,267 @@ +""" +WebSocket hooks tests. + +Tests: +- websocket_connect hook - accept/reject connections +- websocket_message hook - accept/reject messages +- Custom close codes and reasons +""" + +from dash import Dash, html, Input, Output, hooks + + +def test_ws010_connect_hook_accept(dash_duo, ws_hook_cleanup): + """Test websocket_connect hook that accepts all connections.""" + connection_count = {"value": 0} + + @hooks.websocket_connect() + def allow_all(websocket): + connection_count["value"] += 1 + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Hook should have been called at least once for connection + assert connection_count["value"] >= 1 + assert dash_duo.get_logs() == [] + + +def test_ws011_connect_hook_reject_false(dash_duo, ws_hook_cleanup): + """Test websocket_connect hook that rejects with False.""" + + @hooks.websocket_connect() + def reject_all(websocket): + return False + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Initial callback should still work via HTTP fallback + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + # Should still get updates via HTTP fallback + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + +def test_ws012_connect_hook_reject_tuple(dash_duo, ws_hook_cleanup): + """Test websocket_connect hook that rejects with custom code/reason.""" + + @hooks.websocket_connect() + def reject_with_reason(websocket): + return (4001, "Connection not allowed") + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Callbacks should still work via HTTP fallback + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + +def test_ws013_message_hook_accept(dash_duo, ws_hook_cleanup): + """Test websocket_message hook that accepts all messages.""" + message_count = {"value": 0} + + @hooks.websocket_message() + def allow_all_messages(websocket, message): + message_count["value"] += 1 + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Message hook should have been called + assert message_count["value"] >= 1 + assert dash_duo.get_logs() == [] + + +def test_ws014_message_hook_reject(dash_duo, ws_hook_cleanup): + """Test websocket_message hook that rejects specific messages.""" + reject_clicks = {"should_reject": False} + + @hooks.websocket_message() + def conditional_reject(websocket, message): + if reject_clicks["should_reject"]: + return (4010, "Message rejected") + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # First click should work + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws015_async_connect_hook(dash_duo, ws_hook_cleanup): + """Test async websocket_connect hook.""" + import asyncio + + @hooks.websocket_connect() + async def async_validate(websocket): + await asyncio.sleep(0.01) # Simulate async validation + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws016_async_message_hook(dash_duo, ws_hook_cleanup): + """Test async websocket_message hook.""" + import asyncio + + @hooks.websocket_message() + async def async_validate_message(websocket, message): + await asyncio.sleep(0.01) # Simulate async validation + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws017_multiple_connect_hooks(dash_duo, ws_hook_cleanup): + """Test multiple websocket_connect hooks with priorities.""" + hook_order = [] + + @hooks.websocket_connect(priority=1) + def first_hook(websocket): + hook_order.append("first") + return True + + @hooks.websocket_connect(priority=2) + def second_hook(websocket): + hook_order.append("second") + return True + + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Both hooks should have been called + assert "first" in hook_order + assert "second" in hook_order + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_inactivity.py b/tests/websocket/test_ws_inactivity.py new file mode 100644 index 0000000000..8dd95e1094 --- /dev/null +++ b/tests/websocket/test_ws_inactivity.py @@ -0,0 +1,194 @@ +""" +WebSocket inactivity timeout tests. + +Tests: +- Connection closes after inactivity period +- Activity resets the timer +- Heartbeats don't count as activity +- Auto-reconnect when callback fires after timeout +""" + +import time +from dash import Dash, html, Input, Output + + +def test_ws020_inactivity_timeout_closes(dash_duo): + """Test that WebSocket connection closes after inactivity timeout.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=3000, # 3 seconds for testing + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Trigger callback to establish connection + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Wait for inactivity timeout + time.sleep(4) + + # Click again - should auto-reconnect and work + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 2") + + +def test_ws021_activity_resets_timer(dash_duo): + """Test that callback activity resets the inactivity timer.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=4000, # 4 seconds + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + + # Click every 2 seconds - should keep connection alive + for i in range(1, 4): + time.sleep(2) + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", f"Clicked {i}") + + # All clicks should work without disconnection + assert dash_duo.get_logs() == [] + + +def test_ws022_quick_successive_callbacks(dash_duo): + """Test rapid successive callbacks work correctly.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=5000, + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("0", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return str(n_clicks or 0) + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "0") + + # Rapid clicks + for _ in range(5): + dash_duo.find_element("#btn").click() + time.sleep(0.1) + + dash_duo.wait_for_text_to_equal("#output", "5") + assert dash_duo.get_logs() == [] + + +def test_ws023_auto_reconnect_after_timeout(dash_duo): + """Test auto-reconnect when callback fires after inactivity timeout.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=2000, # 2 seconds + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Initial callback + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + # Wait for timeout to expire + time.sleep(3) + + # Click again - should auto-reconnect + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 2") + + # And keep working + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 3") + + assert dash_duo.get_logs() == [] + + +def test_ws024_long_callback_doesnt_timeout(dash_duo): + """Test that long-running callbacks don't cause timeout during execution.""" + import asyncio + + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=3000, # 3 seconds + ) + + app.layout = html.Div( + [ + html.Button("Start Long Task", id="btn"), + html.Div("ready", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + async def long_task(n_clicks): + if not n_clicks: + return "ready" + # Simulate long task (longer than inactivity timeout) + await asyncio.sleep(2) + return f"Completed task {n_clicks}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "ready") + + # Start long task + dash_duo.find_element("#btn").click() + + # Should complete despite being longer than half the timeout + dash_duo.wait_for_text_to_equal("#output", "Completed task 1", timeout=10) + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_origin.py b/tests/websocket/test_ws_origin.py new file mode 100644 index 0000000000..c6235613a5 --- /dev/null +++ b/tests/websocket/test_ws_origin.py @@ -0,0 +1,154 @@ +""" +WebSocket origin validation tests. + +Tests: +- Same-origin connections allowed by default +- Cross-origin rejected unless explicitly allowed +- websocket_allowed_origins configuration +""" + +from dash import Dash, html, Input, Output + + +def test_ws040_same_origin_allowed(dash_duo): + """Test that same-origin WebSocket connections work by default.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Same-origin request should work + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws041_websocket_allowed_origins_empty(dash_duo): + """Test with empty websocket_allowed_origins (only same-origin).""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_allowed_origins=[], # Only same-origin + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Same-origin should still work + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws042_websocket_allowed_origins_wildcard(dash_duo): + """Test with wildcard in websocket_allowed_origins.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_allowed_origins=["*"], # Allow all origins + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws043_websocket_allowed_origins_specific(dash_duo): + """Test with specific origins in websocket_allowed_origins.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + # Should work since we're running on localhost + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] + + +def test_ws044_origin_with_per_callback_websocket(dash_duo): + """Test origin validation with per-callback websocket=True.""" + app = Dash( + __name__, + backend="fastapi", + websocket_allowed_origins=["http://localhost:*", "http://127.0.0.1:*"], + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn"), + html.Div("initial", id="output"), + ] + ) + + @app.callback( + Output("output", "children"), Input("btn", "n_clicks"), websocket=True + ) + def on_click(n_clicks): + return f"Clicked {n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_props.py b/tests/websocket/test_ws_props.py new file mode 100644 index 0000000000..6b940792b3 --- /dev/null +++ b/tests/websocket/test_ws_props.py @@ -0,0 +1,267 @@ +""" +WebSocket set_props and get_props tests. + +Tests: +- set_props streaming during long-running callback +- get_prop reads current component value +- async set_prop method +""" + +import asyncio +from dash import Dash, html, Input, Output +from dash._callback_context import set_props +from dash.exceptions import PreventUpdate + + +def test_ws030_set_props_streaming(dash_duo): + """Test that set_props streams updates during callback execution.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Start", id="btn"), + html.Div("0%", id="progress"), + html.Div("waiting", id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def long_task(n): + if not n: + raise PreventUpdate + + for i in range(1, 6): + set_props("progress", {"children": f"{i * 20}%"}) + await asyncio.sleep(0.1) + + return "Done" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#progress", "0%") + dash_duo.wait_for_text_to_equal("#result", "waiting") + + dash_duo.find_element("#btn").click() + + # Should see progress updates and final result + dash_duo.wait_for_text_to_equal("#result", "Done", timeout=10) + # Final progress should be 100% + dash_duo.wait_for_text_to_equal("#progress", "100%") + + assert dash_duo.get_logs() == [] + + +def test_ws031_set_props_multiple_components(dash_duo): + """Test set_props updating multiple components during callback.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update All", id="btn"), + html.Div("A: initial", id="output-a"), + html.Div("B: initial", id="output-b"), + html.Div("C: initial", id="output-c"), + html.Div("result", id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_all(n): + if not n: + raise PreventUpdate + + set_props("output-a", {"children": f"A: updated {n}"}) + await asyncio.sleep(0.05) + set_props("output-b", {"children": f"B: updated {n}"}) + await asyncio.sleep(0.05) + set_props("output-c", {"children": f"C: updated {n}"}) + + return f"All updated {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output-a", "A: updated 1", timeout=10) + dash_duo.wait_for_text_to_equal("#output-b", "B: updated 1") + dash_duo.wait_for_text_to_equal("#output-c", "C: updated 1") + dash_duo.wait_for_text_to_equal("#result", "All updated 1") + + assert dash_duo.get_logs() == [] + + +def test_ws032_set_props_with_complex_values(dash_duo): + """Test set_props with various value types.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Test Values", id="btn"), + html.Div(id="text-output"), + html.Div(id="number-output"), + html.Div(id="list-output"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def test_values(n): + if not n: + raise PreventUpdate + + # String + set_props("text-output", {"children": "Hello World"}) + await asyncio.sleep(0.02) + + # Number as string + set_props("number-output", {"children": str(42)}) + await asyncio.sleep(0.02) + + # List of strings + set_props("list-output", {"children": ["Item 1", " - ", "Item 2"]}) + + return "Values set" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#text-output", "Hello World", timeout=10) + dash_duo.wait_for_text_to_equal("#number-output", "42") + dash_duo.wait_for_text_to_equal("#list-output", "Item 1 - Item 2") + dash_duo.wait_for_text_to_equal("#result", "Values set") + + assert dash_duo.get_logs() == [] + + +def test_ws033_set_props_sync_callback(dash_duo): + """Test set_props in synchronous callback with WebSocket.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Sync Update", id="btn"), + html.Div("before", id="side-effect"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + def sync_update(n): + if not n: + raise PreventUpdate + + # set_props should work in sync callback too + set_props("side-effect", {"children": f"Side effect {n}"}) + return f"Result {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#result", "Result 1", timeout=10) + dash_duo.wait_for_text_to_equal("#side-effect", "Side effect 1") + + assert dash_duo.get_logs() == [] + + +def test_ws034_get_prop_reads_value(dash_duo): + """Test that get_prop can read current component values.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Div("Source Value", id="source"), + html.Button("Read", id="btn"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def read_prop(n): + if not n: + raise PreventUpdate + + from dash import ctx + + ws = ctx.get_websocket + if ws: + value = await ws.get_prop("source", "children") + return f"Read: {value}" + return "No WebSocket" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#result", "Read: Source Value", timeout=10) + + assert dash_duo.get_logs() == [] + + +def test_ws035_websocket_set_prop_method(dash_duo): + """Test using ws.set_prop() method directly.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Set via WS", id="btn"), + html.Div("original", id="target"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def set_via_ws(n): + if not n: + raise PreventUpdate + + from dash import ctx + + ws = ctx.get_websocket + if ws: + await ws.set_prop("target", "children", f"Set via WebSocket {n}") + return "Set complete" + return "No WebSocket" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#target", "Set via WebSocket 1", timeout=10) + dash_duo.wait_for_text_to_equal("#result", "Set complete") + + assert dash_duo.get_logs() == [] + + +def test_ws036_set_props_dict_component_id(dash_duo): + """Test set_props with dict component ID (pattern matching).""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update", id="btn"), + html.Div("initial", id={"type": "output", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_with_dict_id(n): + if not n: + raise PreventUpdate + + set_props({"type": "output", "index": 0}, {"children": f"Updated {n}"}) + return f"Done {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + # Use attribute selector for the dict ID + dash_duo.wait_for_text_to_equal( + '[id=\'{"index":0,"type":"output"}\']', "Updated 1", timeout=10 + ) + dash_duo.wait_for_text_to_equal("#result", "Done 1") + + assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_validate.py b/tests/websocket/test_ws_validate.py new file mode 100644 index 0000000000..0a43ada553 --- /dev/null +++ b/tests/websocket/test_ws_validate.py @@ -0,0 +1,58 @@ +import pytest + +from dash.exceptions import WebSocketCallbackError +from dash._validate import validate_websocket_callback_request + + +class TestWebsocketCallbackRequestValidation: + """Tests for runtime WebSocket callback request validation.""" + + def test_global_enabled_allows_any_callback(self): + """When websocket_callbacks=True globally, any callback can use WebSocket.""" + callback_map = { + "out1.children": {"websocket": False}, + "out2.children": {}, # no websocket key + } + # Should not raise - global setting allows all + validate_websocket_callback_request("out1.children", callback_map, True) + validate_websocket_callback_request("out2.children", callback_map, True) + + def test_per_callback_websocket_enabled_passes(self): + """Callback with websocket=True should pass when global is False.""" + callback_map = { + "out1.children": {"websocket": True}, + } + # Should not raise + validate_websocket_callback_request("out1.children", callback_map, False) + + def test_per_callback_websocket_disabled_raises(self): + """Callback without websocket=True should raise when global is False.""" + callback_map = { + "out1.children": {"websocket": False}, + } + + with pytest.raises(WebSocketCallbackError) as exc_info: + validate_websocket_callback_request("out1.children", callback_map, False) + + assert "out1.children" in str(exc_info.value) + assert "websocket=True" in str(exc_info.value) + + def test_callback_without_websocket_key_raises(self): + """Callback without websocket key should raise when global is False.""" + callback_map = { + "out1.children": {}, # no websocket key + } + + with pytest.raises(WebSocketCallbackError) as exc_info: + validate_websocket_callback_request("out1.children", callback_map, False) + + assert "out1.children" in str(exc_info.value) + + def test_unknown_callback_raises(self): + """Unknown callback ID should raise when global is False.""" + callback_map = {} + + with pytest.raises(WebSocketCallbackError) as exc_info: + validate_websocket_callback_request("unknown.children", callback_map, False) + + assert "unknown.children" in str(exc_info.value) diff --git a/wsapp.py b/wsapp.py index 98b2db2f38..eda9c952ff 100644 --- a/wsapp.py +++ b/wsapp.py @@ -15,6 +15,7 @@ __name__, backend="fastapi", websocket_callbacks=True, + websocket_inactivity_timeout=10000, ) app.layout = html.Div([ From 2ba5557cd88b3ef1ada698c8fa61984e38592070 Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 23 Apr 2026 16:09:00 -0400 Subject: [PATCH 132/166] fixes --- dash/backends/_fastapi.py | 73 ++++++++++++++----- dash/dash-renderer/src/AppProvider.react.tsx | 8 +- .../src/observers/websocketObserver.ts | 31 +++++++- dash/testing/application_runners.py | 16 +++- tests/websocket/test_ws_basic.py | 14 +--- tests/websocket/test_ws_hooks.py | 36 ++++++--- 6 files changed, 131 insertions(+), 47 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 2ed642d546..08092fb2b6 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -89,16 +89,21 @@ class FastAPIWebsocketCallback(DashWebsocketCallback): """ def __init__( - self, websocket: WebSocket, pending_get_props: Dict[str, asyncio.Future] + self, + websocket: WebSocket, + pending_get_props: Dict[str, asyncio.Future], + renderer_id: str = "", ): """Initialize the WebSocket callback interface. Args: websocket: The WebSocket connection pending_get_props: Dict to track pending get_props requests + renderer_id: The renderer ID for routing messages back to the correct client """ self._websocket = websocket self._pending_get_props = pending_get_props + self._renderer_id = renderer_id async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: """Send immediate prop update to the client via WebSocket. @@ -111,6 +116,7 @@ async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: await self._websocket.send_json( { "type": "set_props", + "rendererId": self._renderer_id, "payload": {"componentId": component_id, "props": {prop_name: value}}, } ) @@ -135,6 +141,7 @@ async def get_prop(self, component_id: str, prop_name: str) -> Any: await self._websocket.send_json( { "type": "get_props_request", + "rendererId": self._renderer_id, "requestId": request_id, "payload": {"componentId": component_id, "properties": [prop_name]}, } @@ -445,11 +452,14 @@ def run(self, dash_app: Dash, host, port, debug, **kwargs): # pylint: disable=R is_threaded = threading.current_thread() != threading.main_thread() if is_threaded: - # Running in a thread (testing context) - use uvicorn.run directly - # This allows the testing framework to control the server lifecycle - if kwargs.get("reload"): - kwargs["reload"] = True - uvicorn.run(self.server, host=host, port=port, **kwargs) + # Running in a thread (testing context) - use uvicorn.Server + # This allows graceful shutdown via should_exit flag + kwargs.pop("reload", None) # Reload not supported in threaded mode + config = uvicorn.Config(self.server, host=host, port=port, **kwargs) + server = uvicorn.Server(config) + # Store server reference on the app for graceful shutdown + dash_app._uvicorn_server = server # pylint: disable=protected-access + server.run() else: # Running in main thread (normal context) - use subprocess file_path = frame.filename @@ -779,6 +789,27 @@ async def websocket_handler(websocket: WebSocket): # Track pending get_props requests pending_get_props: Dict[str, asyncio.Future] = {} + # Track running callback tasks + callback_tasks: Dict[str, asyncio.Task] = {} + + async def execute_callback_task( + req_message: dict, req_renderer_id: str, req_id: str + ): + """Execute callback and send response.""" + try: + response = await self._execute_ws_callback( + dash_app, websocket, req_message, pending_get_props + ) + await websocket.send_json( + { + "type": "callback_response", + "rendererId": req_renderer_id, + "requestId": req_id, + "payload": response, + } + ) + finally: + callback_tasks.pop(req_id, None) try: while True: @@ -799,17 +830,13 @@ async def websocket_handler(websocket: WebSocket): renderer_id = message.get("rendererId") if msg_type == "callback_request": - response = await self._execute_ws_callback( - dash_app, websocket, message, pending_get_props - ) - await websocket.send_json( - { - "type": "callback_response", - "rendererId": renderer_id, - "requestId": message.get("requestId"), - "payload": response, - } + # Run callback in background task to allow receiving + # get_props_response messages during execution + request_id = message.get("requestId") + task = asyncio.create_task( + execute_callback_task(message, renderer_id, request_id) ) + callback_tasks[request_id] = task elif msg_type == "get_props_response": # Handle response for pending get_props request @@ -825,10 +852,13 @@ async def websocket_handler(websocket: WebSocket): except WebSocketDisconnect: pass # Clean disconnect finally: - # Cancel any pending futures + # Cancel any pending futures and tasks for future in pending_get_props.values(): if not future.done(): future.cancel() + for task in callback_tasks.values(): + if not task.done(): + task.cancel() self.server.add_api_websocket_route(ws_path, websocket_handler) @@ -862,8 +892,9 @@ async def _execute_ws_callback( # pylint: enable=protected-access # Create WebSocket callback context + renderer_id = message.get("rendererId", "") cb_ctx = self._create_ws_context( - dash_app, websocket, payload, pending_get_props + dash_app, websocket, payload, pending_get_props, renderer_id ) try: @@ -894,6 +925,7 @@ def _create_ws_context( websocket: WebSocket, payload: dict, pending_get_props: Dict[str, asyncio.Future], + renderer_id: str = "", ): """Create callback context from WebSocket message. @@ -902,6 +934,7 @@ def _create_ws_context( websocket: The WebSocket connection payload: The callback payload pending_get_props: Dict to track pending get_props requests + renderer_id: The renderer ID for routing messages back to the correct client Returns: AttributeDict with callback context @@ -923,7 +956,9 @@ def _create_ws_context( g.updated_props = {} # Add WebSocket callback interface - g.dash_websocket = FastAPIWebsocketCallback(websocket, pending_get_props) + g.dash_websocket = FastAPIWebsocketCallback( + websocket, pending_get_props, renderer_id + ) return g diff --git a/dash/dash-renderer/src/AppProvider.react.tsx b/dash/dash-renderer/src/AppProvider.react.tsx index 2a6b95240c..f9d8b06f1a 100644 --- a/dash/dash-renderer/src/AppProvider.react.tsx +++ b/dash/dash-renderer/src/AppProvider.react.tsx @@ -22,10 +22,14 @@ const AppProvider = ({ }: any) => { const [{store}] = useState(() => new Store()); - // Initialize WebSocket connection if enabled + // Initialize WebSocket connection if enabled or if websocket config is available + // (for per-callback websocket=True) useEffect(() => { const config = getConfigFromDOM(); - if (config.websocket?.enabled) { + if ( + config.websocket?.enabled || + (config.websocket?.url && config.websocket?.worker_url) + ) { // Add fetch config for consistency const fullConfig = { ...config, diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts index ff5b9d099b..26201eab91 100644 --- a/dash/dash-renderer/src/observers/websocketObserver.ts +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -17,6 +17,25 @@ import { } from '../utils/workerClient'; import {DashConfig} from '../config'; +/** + * Parse a component ID that may be a stringified JSON object. + * This handles dict IDs like '{"index":0,"type":"output"}' that need + * to be parsed back to objects for getPath to work correctly. + */ +function parseComponentId( + componentId: string +): string | Record { + if (componentId.startsWith('{') && componentId.endsWith('}')) { + try { + return JSON.parse(componentId); + } catch { + // Not valid JSON, return as-is + return componentId; + } + } + return componentId; +} + /** * Initialize the WebSocket observer. * @@ -55,8 +74,9 @@ export async function initializeWebSocket( // Handle SET_PROPS messages workerClient.onSetProps = (payload: SetPropsPayload) => { const {componentId, props} = payload; + const parsedId = parseComponentId(componentId); const state = store.getState(); - const componentPath = getPath(state.paths, componentId); + const componentPath = getPath(state.paths, parsedId); if (!componentPath) { console.warn( @@ -75,7 +95,7 @@ export async function initializeWebSocket( ); // Notify observers - store.dispatch(notifyObservers({id: componentId, props}) as any); + store.dispatch(notifyObservers({id: parsedId, props}) as any); }; // Handle GET_PROPS_REQUEST messages @@ -84,8 +104,9 @@ export async function initializeWebSocket( payload: GetPropsRequestPayload ) => { const {componentId, properties} = payload; + const parsedId = parseComponentId(componentId); const state = store.getState(); - const componentPath = getPath(state.paths, componentId); + const componentPath = getPath(state.paths, parsedId); const result: Record = {}; @@ -100,6 +121,10 @@ export async function initializeWebSocket( result[propName] = componentProps[propName]; } } + } else { + console.warn( + `GET_PROPS_REQUEST: Component ${componentId} not found in layout` + ); } // Send the response diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py index 6e6cc8b810..f6aa8efe3a 100644 --- a/dash/testing/application_runners.py +++ b/dash/testing/application_runners.py @@ -147,6 +147,7 @@ class ThreadedRunner(BaseDashRunner): def __init__(self, keep_open=False, stop_timeout=3): super().__init__(keep_open=keep_open, stop_timeout=stop_timeout) self.thread = None + self._app = None # Store app reference for graceful shutdown def running_and_accessible(self, url): if self.thread.is_alive(): # type: ignore[reportOptionalMemberAccess] @@ -156,6 +157,7 @@ def running_and_accessible(self, url): # pylint: disable=arguments-differ def start(self, app, start_timeout=3, **kwargs): """Start the app server in threading flavor.""" + self._app = app # Store app reference for graceful shutdown def run(): app.scripts.config.serve_locally = True @@ -213,9 +215,17 @@ def run(): raise DashAppLoadingError("threaded server failed to start") def stop(self): - self.thread.kill() # type: ignore[reportOptionalMemberAccess] - self.thread.join() # type: ignore[reportOptionalMemberAccess] - wait.until_not(self.thread.is_alive, self.stop_timeout) # type: ignore[reportOptionalMemberAccess] + # For FastAPI apps with uvicorn, use graceful shutdown + if self._app and hasattr(self._app, "_uvicorn_server"): + server = self._app._uvicorn_server # pylint: disable=protected-access + server.should_exit = True + self.thread.join(timeout=self.stop_timeout) # type: ignore[reportOptionalMemberAccess] + else: + # Fall back to killing threads for Flask/other backends + self.thread.kill() # type: ignore[reportOptionalMemberAccess] + self.thread.join() # type: ignore[reportOptionalMemberAccess] + wait.until_not(self.thread.is_alive, self.stop_timeout) # type: ignore[reportOptionalMemberAccess] + self._app = None self.started = False diff --git a/tests/websocket/test_ws_basic.py b/tests/websocket/test_ws_basic.py index 80a2e7d975..935d633339 100644 --- a/tests/websocket/test_ws_basic.py +++ b/tests/websocket/test_ws_basic.py @@ -30,8 +30,8 @@ def ws_callback(value): dash_duo.start_server(app) - # Test initial state - dash_duo.wait_for_text_to_equal("#ws-output", "WS: ") + # Test initial state (trailing space is trimmed by HTML rendering) + dash_duo.wait_for_text_to_equal("#ws-output", "WS:") # Type into the input and verify callback executes input_elem = dash_duo.find_element("#ws-input") @@ -248,15 +248,7 @@ def update_output(value): dash_duo.start_server(app) + # Initial callback should work via WebSocket dash_duo.wait_for_text_to_equal("#output", "Slider value: 50") - # Move slider - find slider handle and drag it - slider = dash_duo.find_element("#slider .rc-slider-handle") - dash_duo.driver.execute_script( - "arguments[0].dispatchEvent(new MouseEvent('mousedown', {bubbles: true}));" - "document.dispatchEvent(new MouseEvent('mouseup', {bubbles: true}));", - slider, - ) - - # The callback should still work assert dash_duo.get_logs() == [] diff --git a/tests/websocket/test_ws_hooks.py b/tests/websocket/test_ws_hooks.py index f37a8941cc..db0a166efd 100644 --- a/tests/websocket/test_ws_hooks.py +++ b/tests/websocket/test_ws_hooks.py @@ -44,7 +44,11 @@ def on_click(n_clicks): def test_ws011_connect_hook_reject_false(dash_duo, ws_hook_cleanup): - """Test websocket_connect hook that rejects with False.""" + """Test websocket_connect hook that rejects with False. + + When WebSocket connection is rejected, callbacks won't work since + websocket_callbacks=True requires WebSocket transport. + """ @hooks.websocket_connect() def reject_all(websocket): @@ -65,15 +69,24 @@ def on_click(n_clicks): dash_duo.start_server(app) - # Initial callback should still work via HTTP fallback - dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + # WebSocket rejected - callbacks won't fire, output stays initial + import time + + time.sleep(1) # Give time for potential callback + assert dash_duo.find_element("#output").text == "initial" + dash_duo.find_element("#btn").click() - # Should still get updates via HTTP fallback - dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + time.sleep(1) + # Still initial since WebSocket was rejected + assert dash_duo.find_element("#output").text == "initial" def test_ws012_connect_hook_reject_tuple(dash_duo, ws_hook_cleanup): - """Test websocket_connect hook that rejects with custom code/reason.""" + """Test websocket_connect hook that rejects with custom code/reason. + + When WebSocket connection is rejected, callbacks won't work since + websocket_callbacks=True requires WebSocket transport. + """ @hooks.websocket_connect() def reject_with_reason(websocket): @@ -94,10 +107,15 @@ def on_click(n_clicks): dash_duo.start_server(app) - # Callbacks should still work via HTTP fallback - dash_duo.wait_for_text_to_equal("#output", "Clicked 0") + # WebSocket rejected - callbacks won't fire, output stays initial + import time + + time.sleep(1) + assert dash_duo.find_element("#output").text == "initial" + dash_duo.find_element("#btn").click() - dash_duo.wait_for_text_to_equal("#output", "Clicked 1") + time.sleep(1) + assert dash_duo.find_element("#output").text == "initial" def test_ws013_message_hook_accept(dash_duo, ws_hook_cleanup): From 104a71d616d49f9ad4403ac7228ea3925b741d9d Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 24 Apr 2026 14:08:27 -0400 Subject: [PATCH 133/166] add quart websocket callback implementation --- dash/backends/_fastapi.py | 121 +--------- dash/backends/_quart.py | 363 +++++++++++++++++++++++++++- dash/backends/base_server.py | 119 +++++++-- dash/dash.py | 6 +- dash/testing/application_runners.py | 10 +- tests/websocket/test_ws_quart.py | 228 +++++++++++++++++ 6 files changed, 717 insertions(+), 130 deletions(-) create mode 100644 tests/websocket/test_ws_quart.py diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 08092fb2b6..65a7a4c442 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -3,7 +3,6 @@ from contextvars import copy_context, ContextVar import asyncio import json -import uuid from typing import TYPE_CHECKING, Any, Callable, Dict import sys import mimetypes @@ -39,6 +38,7 @@ RequestAdapter, ResponseAdapter, DashWebsocketCallback, + create_ws_context, ) from ._utils import format_traceback_html @@ -83,10 +83,7 @@ def set_response(self, **kwargs): class FastAPIWebsocketCallback(DashWebsocketCallback): - """WebSocket callback implementation for FastAPI backend. - - Provides real-time bidirectional communication for callback execution. - """ + """WebSocket callback implementation for FastAPI backend.""" def __init__( self, @@ -94,78 +91,13 @@ def __init__( pending_get_props: Dict[str, asyncio.Future], renderer_id: str = "", ): - """Initialize the WebSocket callback interface. - - Args: - websocket: The WebSocket connection - pending_get_props: Dict to track pending get_props requests - renderer_id: The renderer ID for routing messages back to the correct client - """ + super().__init__(pending_get_props, renderer_id) self._websocket = websocket - self._pending_get_props = pending_get_props - self._renderer_id = renderer_id - - async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: - """Send immediate prop update to the client via WebSocket. - - Args: - component_id: The component ID (string or stringified dict) - prop_name: The property name to update - value: The new value to set - """ - await self._websocket.send_json( - { - "type": "set_props", - "rendererId": self._renderer_id, - "payload": {"componentId": component_id, "props": {prop_name: value}}, - } - ) - - async def get_prop(self, component_id: str, prop_name: str) -> Any: - """Request current prop value from the client. - - Args: - component_id: The component ID (string or stringified dict) - prop_name: The property name to retrieve - - Returns: - The current value of the property from the client's state - """ - request_id = str(uuid.uuid4()) - - # Create a future to wait for the response - future: asyncio.Future = asyncio.get_event_loop().create_future() - self._pending_get_props[request_id] = future - - # Send the request - await self._websocket.send_json( - { - "type": "get_props_request", - "rendererId": self._renderer_id, - "requestId": request_id, - "payload": {"componentId": component_id, "properties": [prop_name]}, - } - ) - # Wait for the response with timeout - try: - result = await asyncio.wait_for(future, timeout=30.0) - if result and prop_name in result: - return result[prop_name] - return None - except asyncio.TimeoutError as exc: - self._pending_get_props.pop(request_id, None) - raise TimeoutError( - f"Timeout waiting for get_prop response for {component_id}.{prop_name}" - ) from exc - - async def close(self, code: int = 1000, reason: str = "Connection closed") -> None: - """Close the WebSocket connection. + async def _send_json(self, data: dict) -> None: + await self._websocket.send_json(data) - Args: - code: WebSocket close code (default 1000 for normal closure) - reason: Human-readable reason for closing - """ + async def _close_websocket(self, code: int, reason: str) -> None: await self._websocket.close(code=code, reason=reason) @@ -894,7 +826,7 @@ async def _execute_ws_callback( # Create WebSocket callback context renderer_id = message.get("rendererId", "") cb_ctx = self._create_ws_context( - dash_app, websocket, payload, pending_get_props, renderer_id + websocket, payload, pending_get_props, renderer_id ) try: @@ -921,47 +853,18 @@ async def _execute_ws_callback( def _create_ws_context( self, - _dash_app: "Dash", # pylint: disable=unused-argument websocket: WebSocket, payload: dict, pending_get_props: Dict[str, asyncio.Future], renderer_id: str = "", ): - """Create callback context from WebSocket message. - - Args: - _dash_app: The Dash application instance (unused, kept for API consistency) - websocket: The WebSocket connection - payload: The callback payload - pending_get_props: Dict to track pending get_props requests - renderer_id: The renderer ID for routing messages back to the correct client - - Returns: - AttributeDict with callback context - """ - # pylint: disable=import-outside-toplevel - from dash._utils import AttributeDict, inputs_to_dict - - g = AttributeDict({}) - g.inputs_list = payload.get("inputs", []) - g.states_list = payload.get("state", []) - g.outputs_list = payload.get("outputs", []) - g.input_values = inputs_to_dict(g.inputs_list) - g.state_values = inputs_to_dict(g.states_list) - g.triggered_inputs = [ - {"prop_id": x, "value": g.input_values.get(x)} - for x in payload.get("changedPropIds", []) - ] - g.dash_response = FastAPIResponseAdapter() - g.updated_props = {} - - # Add WebSocket callback interface - g.dash_websocket = FastAPIWebsocketCallback( - websocket, pending_get_props, renderer_id + """Create callback context from WebSocket message.""" + return create_ws_context( + payload, + FastAPIResponseAdapter(), + FastAPIWebsocketCallback(websocket, pending_get_props, renderer_id), ) - return g - class FastAPIRequestAdapter(RequestAdapter): def __init__(self): diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index ddf31ff2f4..3e2c8b79e1 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -6,10 +6,14 @@ import pkgutil import time import sys +import asyncio +import json +import traceback +from urllib.parse import urlparse from logging.config import dictConfig from contextvars import copy_context -from typing import Any +from typing import Any, Dict, TYPE_CHECKING from importlib_metadata import version as _get_distribution_version @@ -24,6 +28,7 @@ g as quart_g, has_request_context, redirect, + websocket, ) except ImportError as _err: raise ImportError( @@ -33,10 +38,19 @@ from dash.exceptions import PreventUpdate, InvalidResourceError from dash.fingerprint import check_fingerprint from dash._utils import parse_version -from dash import _validate, Dash -from .base_server import BaseDashServer, RequestAdapter, ResponseAdapter +from dash import _validate +from .base_server import ( + BaseDashServer, + RequestAdapter, + ResponseAdapter, + DashWebsocketCallback, + create_ws_context, +) from ._utils import format_traceback_html +if TYPE_CHECKING: + from dash import Dash + class QuartResponseAdapter(ResponseAdapter): """ @@ -66,7 +80,28 @@ def set_response(self, **kwargs): return self._quart_response +class QuartWebsocketCallback(DashWebsocketCallback): + """WebSocket callback implementation for Quart backend.""" + + def __init__( + self, + ws, + pending_get_props: Dict[str, asyncio.Future], + renderer_id: str = "", + ): + super().__init__(pending_get_props, renderer_id) + self._websocket = ws + + async def _send_json(self, data: dict) -> None: + await self._websocket.send_json(data) + + async def _close_websocket(self, code: int, reason: str) -> None: + await self._websocket.close(code=code, reason=reason) + + class QuartDashServer(BaseDashServer[Quart]): + websocket_capability: bool = True + def __init__(self, server: Quart) -> None: super().__init__(server) self.server_type = "quart" @@ -74,6 +109,8 @@ def __init__(self, server: Quart) -> None: self.error_handling_mode = "ignore" self.request_adapter = QuartRequestAdapter self.response_adapter = QuartResponseAdapter + self._active_websockets: set = set() + self._ws_shutdown_event: asyncio.Event | None = None def __call__(self, *args: Any, **kwargs: Any): # type: ignore[name-defined] return self.server(*args, **kwargs) @@ -222,6 +259,11 @@ def has_request_context(self) -> bool: # pylint: disable=W0613 def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): + import signal # pylint: disable=import-outside-toplevel + import threading # pylint: disable=import-outside-toplevel + from hypercorn.config import Config # pylint: disable=import-outside-toplevel + from hypercorn.asyncio import serve # pylint: disable=import-outside-toplevel + self.config = {"debug": debug, **kwargs} if debug else kwargs # pylint: disable=protected-access if dash_app._dev_tools.silence_routes_logging: @@ -236,7 +278,51 @@ def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.An } ) - self.server.run(host=host, port=port, debug=debug, **kwargs) + # Check if we're running in a non-main thread (e.g., testing context) + is_main_thread = threading.current_thread() is threading.main_thread() + + config = Config() + config.bind = [f"{host}:{port}"] + config.use_reloader = False + if not is_main_thread: + config.accesslog = None + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Initialize shutdown event for WebSocket handlers + self._ws_shutdown_event = asyncio.Event() + + def signal_handler(): + """Handle shutdown signal by setting the WebSocket shutdown event.""" + if self._ws_shutdown_event is not None: + self._ws_shutdown_event.set() + + # Set up signal handlers in main thread + if is_main_thread: + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, signal_handler) + except (NotImplementedError, ValueError): + pass + + print(f" * Serving Quart app '{self.server.name}'") + print(f" * Debug mode: {debug}") + print( + " * Please use an ASGI server (e.g. Hypercorn) directly in production" + ) + print(f" * Running on http://{host}:{port} (CTRL + C to quit)") + + async def shutdown_trigger(): + if self._ws_shutdown_event is not None: + await self._ws_shutdown_event.wait() + + try: + loop.run_until_complete( + serve(self.server, config, shutdown_trigger=shutdown_trigger) + ) + finally: + loop.close() def make_response( self, @@ -385,6 +471,275 @@ def enable_compression(self) -> None: "To use the compress option, you need to install quart_compress." ) from error + async def _run_ws_hooks( + self, hooks, ws, *args, default_reason: str = "Rejected" + ) -> tuple | None: + """Run WebSocket hooks and return rejection tuple or None if all pass. + + Args: + hooks: List of hooks to run + ws: The WebSocket connection + *args: Additional arguments to pass to hooks + default_reason: Default reason if hook returns False + + Returns: + None if all hooks pass, or (code, reason) tuple for rejection + """ + for hook in hooks: + try: + result = hook(ws, *args) + if inspect.iscoroutine(result): + result = await result + if result is False: + return (4001, default_reason) + if isinstance(result, tuple) and len(result) == 2: + return result + except Exception: # pylint: disable=broad-exception-caught + return (4001, "Authentication error") + return None + + def _validate_ws_origin( + self, origin: str | None, host: str | None, allowed_origins: list + ) -> str | None: + """Validate WebSocket origin. Returns error message or None if valid.""" + if not origin: + return "Origin header required" + if origin in allowed_origins: + return None # Explicitly allowed + if not host: + return "Origin not allowed" + # Check same-origin + origin_host = urlparse(origin).netloc + if origin_host != host: + return "Origin not allowed" + return None + + async def _handle_ws_message( + self, + message: dict, + ws, + dash_app: "Dash", + pending_get_props: Dict[str, asyncio.Future], + callback_tasks: Dict[str, asyncio.Task], + ) -> tuple | None: + """Handle a single WebSocket message. Returns rejection tuple or None.""" + # Call websocket_message hooks + # pylint: disable=protected-access + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_message"), + ws, + message, + default_reason="Message rejected", + ) + if rejection: + return rejection + + msg_type = message.get("type") + + if msg_type == "callback_request": + await self._handle_callback_request( + message, ws, dash_app, pending_get_props, callback_tasks + ) + elif msg_type == "get_props_response": + self._handle_get_props_response(message, pending_get_props) + elif msg_type == "heartbeat": + await ws.send_json({"type": "heartbeat_ack"}) + + return None + + async def _handle_callback_request( + self, + message: dict, + ws, + dash_app: "Dash", + pending_get_props: Dict[str, asyncio.Future], + callback_tasks: Dict[str, asyncio.Task], + ): + """Handle a callback request message.""" + renderer_id = message.get("rendererId") + request_id = message.get("requestId") + + async def execute_and_respond(): + try: + response = await self._execute_ws_callback( + dash_app, ws, message, pending_get_props + ) + await ws.send_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": response, + } + ) + finally: + callback_tasks.pop(request_id, None) + + task = asyncio.create_task(execute_and_respond()) + callback_tasks[request_id] = task + + def _handle_get_props_response( + self, message: dict, pending_get_props: Dict[str, asyncio.Future] + ): + """Handle a get_props response message.""" + request_id = message.get("requestId") + if request_id in pending_get_props: + future = pending_get_props.pop(request_id) + if not future.done(): + future.set_result(message.get("payload")) + + @staticmethod + async def _cleanup_ws_tasks( + pending_get_props: Dict[str, asyncio.Future], + callback_tasks: Dict[str, asyncio.Task], + ): + """Cancel any pending futures and tasks on disconnect.""" + for future in pending_get_props.values(): + if not future.done(): + future.cancel() + # Cancel and await all callback tasks + for task in callback_tasks.values(): + if not task.done(): + task.cancel() + # Wait for all tasks to complete their cancellation + if callback_tasks: + await asyncio.gather(*callback_tasks.values(), return_exceptions=True) + + def serve_websocket_callback(self, dash_app: "Dash"): + """Set up the WebSocket endpoint for callback handling. + + Args: + dash_app: The Dash application instance + """ + ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback" + # pylint: disable=protected-access + allowed_origins = getattr(dash_app, "_websocket_allowed_origins", []) + + @self.server.websocket(ws_path) + async def websocket_handler(): # pylint: disable=too-many-branches + ws = websocket + + # Validate Origin header + error = self._validate_ws_origin( + ws.headers.get("origin"), ws.headers.get("host"), allowed_origins + ) + if error: + await ws.close(code=4003, reason=error) + return + + # Call websocket_connect hooks + # pylint: disable=protected-access + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_connect"), + ws, + default_reason="Connection rejected", + ) + if rejection: + await ws.close(code=rejection[0], reason=rejection[1]) + return + + await ws.accept() + + # Track this connection for graceful shutdown + ws_obj = ws._get_current_object() + self._active_websockets.add(ws_obj) + + pending_get_props: Dict[str, asyncio.Future] = {} + callback_tasks: Dict[str, asyncio.Task] = {} + + try: + shutdown_event = self._ws_shutdown_event + while shutdown_event is None or not shutdown_event.is_set(): + try: + # Use timeout to periodically check shutdown event + message = await asyncio.wait_for(ws.receive_json(), timeout=1.0) + except asyncio.TimeoutError: + # Re-check shutdown event (may have been set during run()) + shutdown_event = self._ws_shutdown_event + continue + rejection = await self._handle_ws_message( + message, ws, dash_app, pending_get_props, callback_tasks + ) + if rejection: + await ws.close(code=rejection[0], reason=rejection[1]) + return + except asyncio.CancelledError: + pass # Server is shutting down, exit gracefully + except Exception: # pylint: disable=broad-exception-caught + pass # Other exceptions treated as disconnect + finally: + self._active_websockets.discard(ws_obj) + await self._cleanup_ws_tasks(pending_get_props, callback_tasks) + + async def _execute_ws_callback( + self, + dash_app: "Dash", + ws, + message: dict, + pending_get_props: Dict[str, asyncio.Future], + ) -> dict: + """Execute callback from WebSocket message. + + Args: + dash_app: The Dash application instance + ws: The WebSocket connection + message: The callback request message + pending_get_props: Dict to track pending get_props requests + + Returns: + Response dict with status and data + """ + payload = message.get("payload", {}) + + # Validate that the callback is allowed to use WebSocket transport + # pylint: disable=protected-access + _validate.validate_websocket_callback_request( + payload.get("output"), + dash_app.callback_map, + dash_app._websocket_callbacks, + ) + # pylint: enable=protected-access + + # Create WebSocket callback context + renderer_id = message.get("rendererId", "") + cb_ctx = self._create_ws_context(ws, payload, pending_get_props, renderer_id) + + try: + # Reuse existing callback machinery + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, payload) + args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) + ctx = copy_context() + partial_func = dash_app._execute_callback( + func, args, cb_ctx.outputs_list, cb_ctx + ) + # pylint: enable=protected-access + response_data = ctx.run(partial_func) + if inspect.iscoroutine(response_data): + response_data = await response_data + + return {"status": "ok", "data": json.loads(response_data)} + + except PreventUpdate: + return {"status": "prevent_update"} + except Exception as e: # pylint: disable=broad-exception-caught + traceback.print_exc() + return {"status": "error", "message": str(e)} + + def _create_ws_context( + self, + ws, + payload: dict, + pending_get_props: Dict[str, asyncio.Future], + renderer_id: str = "", + ): + """Create callback context from WebSocket message.""" + return create_ws_context( + payload, + QuartResponseAdapter(), + QuartWebsocketCallback(ws, pending_get_props, renderer_id), + ) + class QuartRequestAdapter(RequestAdapter): def __init__(self) -> None: diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 283a3414e2..2c63089d77 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -4,6 +4,8 @@ for different web server backends (Flask, Quart, FastAPI, etc.) to integrate with Dash. """ from abc import ABC, abstractmethod +import asyncio +import uuid from typing import Any, Dict, Type, TypeVar, Generic, Protocol, TYPE_CHECKING @@ -385,42 +387,131 @@ def serve_websocket_callback(self, dash_app: "dash.Dash"): class DashWebsocketCallback(ABC): - """Abstract interface for WebSocket-based callback communication. + """Abstract base for WebSocket-based callback communication. Provides methods for real-time bidirectional communication between the server and renderer during callback execution. + + Subclasses must implement _send_json and _close_websocket for their + specific WebSocket implementation. """ - @abstractmethod - async def get_prop(self, component_id: str, prop_name: str) -> Any: - """Request current prop value from the client. + def __init__( + self, + pending_get_props: Dict[str, asyncio.Future], + renderer_id: str = "", + ): + """Initialize the WebSocket callback interface. Args: - component_id: The component ID (string or stringified dict for pattern matching) - prop_name: The property name to retrieve - - Returns: - The current value of the property from the client's state + pending_get_props: Dict to track pending get_props requests + renderer_id: The renderer ID for routing messages back to the correct client """ + self._pending_get_props = pending_get_props + self._renderer_id = renderer_id + + @abstractmethod + async def _send_json(self, data: dict) -> None: + """Send JSON data over the WebSocket. Must be implemented by subclasses.""" @abstractmethod + async def _close_websocket(self, code: int, reason: str) -> None: + """Close the WebSocket connection. Must be implemented by subclasses.""" + async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: """Send immediate prop update to the client via WebSocket. Args: - component_id: The component ID (string or stringified dict for pattern matching) + component_id: The component ID (string or stringified dict) prop_name: The property name to update value: The new value to set """ + await self._send_json( + { + "type": "set_props", + "rendererId": self._renderer_id, + "payload": {"componentId": component_id, "props": {prop_name: value}}, + } + ) + + async def get_prop(self, component_id: str, prop_name: str) -> Any: + """Request current prop value from the client. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to retrieve + + Returns: + The current value of the property from the client's state + """ + request_id = str(uuid.uuid4()) + + # Create a future to wait for the response + future: asyncio.Future = asyncio.get_event_loop().create_future() + self._pending_get_props[request_id] = future + + # Send the request + await self._send_json( + { + "type": "get_props_request", + "rendererId": self._renderer_id, + "requestId": request_id, + "payload": {"componentId": component_id, "properties": [prop_name]}, + } + ) + + # Wait for the response with timeout + try: + result = await asyncio.wait_for(future, timeout=30.0) + if result and prop_name in result: + return result[prop_name] + return None + except asyncio.TimeoutError as exc: + self._pending_get_props.pop(request_id, None) + raise TimeoutError( + f"Timeout waiting for get_prop response for {component_id}.{prop_name}" + ) from exc - @abstractmethod async def close(self, code: int = 1000, reason: str = "Connection closed") -> None: """Close the WebSocket connection. - Allows developers to forcibly disconnect a client, e.g., on suspicious - activity, session revocation, or policy violation. - Args: code: WebSocket close code (default 1000 for normal closure) reason: Human-readable reason for closing """ + await self._close_websocket(code, reason) + + +def create_ws_context( + payload: dict, + response_adapter: ResponseAdapter, + websocket_callback: DashWebsocketCallback, +): + """Create callback context from WebSocket message. + + Args: + payload: The callback payload + response_adapter: The response adapter instance for the backend + websocket_callback: The websocket callback instance for the backend + + Returns: + AttributeDict with callback context + """ + # pylint: disable=import-outside-toplevel + from dash._utils import AttributeDict, inputs_to_dict + + g = AttributeDict({}) + g.inputs_list = payload.get("inputs", []) + g.states_list = payload.get("state", []) + g.outputs_list = payload.get("outputs", []) + g.input_values = inputs_to_dict(g.inputs_list) + g.state_values = inputs_to_dict(g.states_list) + g.triggered_inputs = [ + {"prop_id": x, "value": g.input_values.get(x)} + for x in payload.get("changedPropIds", []) + ] + g.dash_response = response_adapter + g.updated_props = {} + g.dash_websocket = websocket_callback + + return g diff --git a/dash/dash.py b/dash/dash.py index ddae7896b4..ce02a7f9bd 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -1017,7 +1017,11 @@ def _get_worker_url(self) -> str: Returns: The fingerprinted URL for the worker script served via component suites. """ - relative_path = "dash-renderer/build/dash-ws-worker.js" + # Use dev worker when serving dev bundles (has source maps, visible in devtools) + if self._dev_tools.serve_dev_bundles: + relative_path = "dash-renderer/build/dash-ws-worker.dev.js" + else: + relative_path = "dash-renderer/build/dash-ws-worker.js" namespace = "dash" # Register the path so it can be served diff --git a/dash/testing/application_runners.py b/dash/testing/application_runners.py index f6aa8efe3a..51a938f72f 100644 --- a/dash/testing/application_runners.py +++ b/dash/testing/application_runners.py @@ -177,7 +177,10 @@ def run(): # FastAPI support if module.startswith("fastapi"): app.run(**options) - # Dash/Flask/Quart fallback + # Quart support (ASGI - runs its own async event loop) + elif module.startswith("quart"): + app.run(**options) + # Flask fallback (WSGI - needs threaded mode) else: app.run(threaded=True, **options) except SystemExit: @@ -249,7 +252,10 @@ def target(): # FastAPI support if module.startswith("fastapi"): app.run(**options) - # Dash/Flask/Quart fallback + # Quart support (ASGI - runs its own async event loop) + elif module.startswith("quart"): + app.run(**options) + # Flask fallback (WSGI - needs threaded mode) else: app.run(threaded=True, **options) except SystemExit: diff --git a/tests/websocket/test_ws_quart.py b/tests/websocket/test_ws_quart.py new file mode 100644 index 0000000000..30e33b329c --- /dev/null +++ b/tests/websocket/test_ws_quart.py @@ -0,0 +1,228 @@ +""" +Quart WebSocket callback tests. + +Tests the Quart backend websocket implementation which mirrors the FastAPI backend. +""" + +from dash import Dash, html, dcc, Input, Output, State, ctx + + +def test_wsq001_per_callback_websocket_quart(dash_duo): + """Test single callback with websocket=True on Quart backend.""" + app = Dash(__name__, backend="quart") + + app.layout = html.Div( + [ + html.H1("Per-Callback WebSocket Test (Quart)"), + dcc.Input(id="ws-input", type="text", placeholder="Type here..."), + html.Div(id="ws-output"), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"WS: {value or ''}" + + dash_duo.start_server(app) + + # Test initial state (trailing space is trimmed by HTML rendering) + dash_duo.wait_for_text_to_equal("#ws-output", "WS:") + + # Type into the input and verify callback executes + input_elem = dash_duo.find_element("#ws-input") + input_elem.send_keys("hello") + + dash_duo.wait_for_text_to_equal("#ws-output", "WS: hello") + assert dash_duo.get_logs() == [] + + +def test_wsq002_global_websocket_callbacks_quart(dash_duo): + """Test global websocket_callbacks=True enables WebSocket for all callbacks on Quart.""" + app = Dash( + __name__, + backend="quart", + websocket_callbacks=True, + ) + + app.layout = html.Div( + [ + html.Button("Click me", id="btn", n_clicks=0), + html.Div(id="output"), + dcc.Input(id="input", type="text"), + html.Div(id="input-output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return f"Clicked {n_clicks} times" + + @app.callback(Output("input-output", "children"), Input("input", "value")) + def on_input(value): + return f"Input: {value or ''}" + + dash_duo.start_server(app) + + # Test button callback + dash_duo.wait_for_text_to_equal("#output", "Clicked 0 times") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicked 1 times") + + # Test input callback + dash_duo.find_element("#input").send_keys("test") + dash_duo.wait_for_text_to_equal("#input-output", "Input: test") + + assert dash_duo.get_logs() == [] + + +def test_wsq003_mixed_http_and_websocket_quart(dash_duo): + """Test mixing WebSocket and HTTP callbacks in the same app on Quart.""" + app = Dash(__name__, backend="quart") + + app.layout = html.Div( + [ + # WebSocket callback section + html.Div( + [ + dcc.Input(id="ws-input", type="text"), + html.Div(id="ws-output"), + ] + ), + # HTTP callback section (default) + html.Div( + [ + dcc.Input(id="http-input", type="text"), + html.Div(id="http-output"), + ] + ), + ] + ) + + @app.callback( + Output("ws-output", "children"), Input("ws-input", "value"), websocket=True + ) + def ws_callback(value): + return f"[WebSocket] {value or ''}" + + @app.callback(Output("http-output", "children"), Input("http-input", "value")) + def http_callback(value): + return f"[HTTP] {value or ''}" + + dash_duo.start_server(app) + + # Test WebSocket callback + dash_duo.find_element("#ws-input").send_keys("ws-test") + dash_duo.wait_for_text_to_equal("#ws-output", "[WebSocket] ws-test") + + # Test HTTP callback + dash_duo.find_element("#http-input").send_keys("http-test") + dash_duo.wait_for_text_to_equal("#http-output", "[HTTP] http-test") + + assert dash_duo.get_logs() == [] + + +def test_wsq004_websocket_with_state_quart(dash_duo): + """Test WebSocket callback with State inputs on Quart.""" + app = Dash(__name__, backend="quart", websocket_callbacks=True) + + app.layout = html.Div( + [ + dcc.Input(id="state-input", type="text", value="initial"), + html.Button("Submit", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn", "n_clicks"), + State("state-input", "value"), + ) + def on_click(n_clicks, state_value): + if not n_clicks: + return "Click to submit" + return f"Submitted: {state_value}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to submit") + + # Update state input + state_input = dash_duo.find_element("#state-input") + dash_duo.clear_input(state_input) + state_input.send_keys("new value") + + # Click button to trigger callback + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Submitted: new value") + + assert dash_duo.get_logs() == [] + + +def test_wsq005_websocket_context_available_quart(dash_duo): + """Test that WebSocket context is available in WebSocket callbacks on Quart.""" + app = Dash(__name__, backend="quart", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Check context", id="btn"), + html.Div(id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def check_context(n_clicks): + if not n_clicks: + return "Click to check" + ws = ctx.get_websocket + if ws is not None: + return "WebSocket context available" + return "No WebSocket context" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "Click to check") + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "WebSocket context available") + + assert dash_duo.get_logs() == [] + + +def test_wsq006_websocket_multiple_outputs_quart(dash_duo): + """Test WebSocket callback with multiple outputs on Quart.""" + app = Dash(__name__, backend="quart", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Update", id="btn"), + html.Div(id="output1"), + html.Div(id="output2"), + html.Div(id="output3"), + ] + ) + + @app.callback( + Output("output1", "children"), + Output("output2", "children"), + Output("output3", "children"), + Input("btn", "n_clicks"), + ) + def multi_output(n_clicks): + n = n_clicks or 0 + return f"First: {n}", f"Second: {n * 2}", f"Third: {n * 3}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output1", "First: 0") + dash_duo.wait_for_text_to_equal("#output2", "Second: 0") + dash_duo.wait_for_text_to_equal("#output3", "Third: 0") + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output1", "First: 1") + dash_duo.wait_for_text_to_equal("#output2", "Second: 2") + dash_duo.wait_for_text_to_equal("#output3", "Third: 3") + + assert dash_duo.get_logs() == [] From 2070708edd4fb2f21d33914ae970b595d61c6a61 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 24 Apr 2026 14:34:10 -0400 Subject: [PATCH 134/166] fix tests --- .github/workflows/testing.yml | 4 +++- dash/backends/_quart.py | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index b755510fb6..4d9615e86b 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -51,6 +51,8 @@ jobs: - 'tests/backend_tests/**' websocket_paths: - 'dash/backends/_fastapi.py' + - 'dash/backends/_quart.py' + - 'dash/backends/base_server.py' - 'dash/_callback.py' - 'dash/_callback_context.py' - 'dash/_hooks.py' @@ -583,7 +585,7 @@ jobs: run: | python -m pip install --upgrade pip wheel python -m pip install "setuptools<80.0.0" - find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[ci,testing,dev,fastapi]"' \; + find packages -name dash-*.whl -print -exec sh -c 'pip install "{}[ci,testing,dev,fastapi,quart]"' \; - name: Setup Chrome and ChromeDriver uses: browser-actions/setup-chrome@v1 diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 3e2c8b79e1..b4d2ab8464 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -261,8 +261,12 @@ def has_request_context(self) -> bool: def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): import signal # pylint: disable=import-outside-toplevel import threading # pylint: disable=import-outside-toplevel - from hypercorn.config import Config # pylint: disable=import-outside-toplevel - from hypercorn.asyncio import serve # pylint: disable=import-outside-toplevel + + # pylint: disable=import-outside-toplevel,import-error + from hypercorn.config import Config + from hypercorn.asyncio import serve + + # pylint: enable=import-error self.config = {"debug": debug, **kwargs} if debug else kwargs # pylint: disable=protected-access From 750a585f875230642cf373188909ebae2fe4f98f Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 24 Apr 2026 15:28:15 -0400 Subject: [PATCH 135/166] remove dev bundle reference --- dash/dash.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index ce02a7f9bd..ddae7896b4 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -1017,11 +1017,7 @@ def _get_worker_url(self) -> str: Returns: The fingerprinted URL for the worker script served via component suites. """ - # Use dev worker when serving dev bundles (has source maps, visible in devtools) - if self._dev_tools.serve_dev_bundles: - relative_path = "dash-renderer/build/dash-ws-worker.dev.js" - else: - relative_path = "dash-renderer/build/dash-ws-worker.js" + relative_path = "dash-renderer/build/dash-ws-worker.js" namespace = "dash" # Register the path so it can be served From deda6700385455101ac65762d55a32b5522189d7 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 24 Apr 2026 16:49:20 -0400 Subject: [PATCH 136/166] version 4.2.0rc1 --- CHANGELOG.md | 5 +++++ dash/version.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index be8d80eaf3..d9b25edf2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,11 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#3740](https://github.com/plotly/dash/pull/3740) Fix cannot tab into dropdowns in Safari - [#2462](https://github.com/plotly/dash/issues/2462) Allow `MATCH` in `Input`/`State` when the callback's `Output` has no wildcards (fixed-id Output, no Output, or `ALL`-only wildcard Output). `ALLSMALLER` still requires a corresponding `MATCH` in an Output. +## [4.2.0rc1] - 2026-04-13 + +## Added +- [#3742](https://github.com/plotly/dash/pull/3742) Add websocket callbacks to fastapi and quart backends. + ## [4.1.0] - 2026-03-23 ## Added diff --git a/dash/version.py b/dash/version.py index 25b76de3c3..6287d2530f 100644 --- a/dash/version.py +++ b/dash/version.py @@ -1 +1 @@ -__version__ = "4.2.0rc0" +__version__ = "4.2.0rc1" From e4849eb61127465f767dd27bb504224c1d4c6646 Mon Sep 17 00:00:00 2001 From: CNFeffery Date: Sun, 26 Apr 2026 10:39:21 +0800 Subject: [PATCH 137/166] Fix websocket callback set_props() with Patch object problems --- dash/_callback_context.py | 11 +++++++---- .../src/observers/websocketObserver.ts | 14 +++++++++++--- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 4f296bde66..47fcc4cc64 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -9,12 +9,13 @@ from . import exceptions from ._get_app import get_app -from ._utils import AttributeDict, stringify_id +from ._patch import Patch +from ._utils import AttributeDict, stringify_id, to_json -context_value: contextvars.ContextVar[ - typing.Dict[str, typing.Any] -] = contextvars.ContextVar("callback_context") +context_value: contextvars.ContextVar[typing.Dict[str, typing.Any]] = ( + contextvars.ContextVar("callback_context") +) context_value.set({}) @@ -370,6 +371,8 @@ def set_props(component_id: typing.Union[str, dict], props: dict): async def _send_props(): for prop_name, value in props.items(): + if isinstance(value, Patch): + value = json.loads(to_json(value)) await ws.set_prop(_id, prop_name, value) # If we're in an async context, schedule the coroutine diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts index 26201eab91..cec3cbdc85 100644 --- a/dash/dash-renderer/src/observers/websocketObserver.ts +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -9,6 +9,7 @@ import {path} from 'ramda'; import {IStoreState} from '../store'; import {updateProps, notifyObservers} from '../actions'; +import {parsePatchProps} from '../actions/patch'; import {getPath} from '../actions/paths'; import { getWorkerClient, @@ -73,7 +74,7 @@ export async function initializeWebSocket( // Handle SET_PROPS messages workerClient.onSetProps = (payload: SetPropsPayload) => { - const {componentId, props} = payload; + const {componentId, props: rawProps} = payload; const parsedId = parseComponentId(componentId); const state = store.getState(); const componentPath = getPath(state.paths, parsedId); @@ -85,17 +86,24 @@ export async function initializeWebSocket( return; } + // Get old props for Patch processing + const oldProps = (path([...componentPath, 'props'], state.layout) || + {}) as Record; + + // Process props to handle Patch objects + const processedProps = parsePatchProps(rawProps, oldProps); + // Update the component props store.dispatch( updateProps({ - props, + props: processedProps, itempath: componentPath, renderType: 'websocket' }) as any ); // Notify observers - store.dispatch(notifyObservers({id: parsedId, props}) as any); + store.dispatch(notifyObservers({id: parsedId, props: processedProps}) as any); }; // Handle GET_PROPS_REQUEST messages From 70cbf48383c4a157b1087fe6dde18f9756860f4f Mon Sep 17 00:00:00 2001 From: CNFeffery Date: Sun, 26 Apr 2026 10:39:40 +0800 Subject: [PATCH 138/166] Add websocket callback set_props patch tests --- tests/websocket/test_ws_patch.py | 42 ++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 tests/websocket/test_ws_patch.py diff --git a/tests/websocket/test_ws_patch.py b/tests/websocket/test_ws_patch.py new file mode 100644 index 0000000000..83d3aef0a0 --- /dev/null +++ b/tests/websocket/test_ws_patch.py @@ -0,0 +1,42 @@ +""" +WebSocket set_props with Patch object test. + +Verifies that set_props works with Patch objects in websocket callbacks. +""" + +from dash import Dash, html, Input, Output, set_props, Patch +from dash.exceptions import PreventUpdate + + +def test_ws037_set_props_with_patch(dash_duo): + """Test set_props with Patch object in websocket callback.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Patch", id="btn"), + html.Div("initial", id="output"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), Input("btn", "n_clicks"), websocket=True + ) + def patch_append(n): + if not n: + raise PreventUpdate + + p = Patch() + p += f" + click {n}" + + set_props("output", {"children": p}) + return f"Appended {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output", "initial + click 1", timeout=10) + + assert dash_duo.get_logs() == [] \ No newline at end of file From 68a696a84bafa9c3c94feaef38af3fd69361f997 Mon Sep 17 00:00:00 2001 From: CNFeffery Date: Sun, 26 Apr 2026 10:47:27 +0800 Subject: [PATCH 139/166] Update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d9b25edf2f..8404a7f124 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#3723](https://github.com/plotly/dash/pull/3723) Fix misaligned `dcc.Slider` marks when some labels are empty strings - [#3740](https://github.com/plotly/dash/pull/3740) Fix cannot tab into dropdowns in Safari - [#2462](https://github.com/plotly/dash/issues/2462) Allow `MATCH` in `Input`/`State` when the callback's `Output` has no wildcards (fixed-id Output, no Output, or `ALL`-only wildcard Output). `ALLSMALLER` still requires a corresponding `MATCH` in an Output. +- [#3759](https://github.com/plotly/dash/pull/3759) Fix the issue where `Patch` objects cannot be updated via `set_props()` in WebSocket callbacks. Fix [#3742](https://github.com/plotly/dash/issues/3742) ## [4.2.0rc1] - 2026-04-13 From 65f3dc2ea70a40d5c0206668a6e37d9fb38a0be6 Mon Sep 17 00:00:00 2001 From: CNFeffery Date: Sun, 26 Apr 2026 10:49:31 +0800 Subject: [PATCH 140/166] Format code --- dash/_callback_context.py | 6 +++--- dash/dash-renderer/src/observers/websocketObserver.ts | 4 +++- tests/websocket/test_ws_patch.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 47fcc4cc64..871afb90da 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -13,9 +13,9 @@ from ._utils import AttributeDict, stringify_id, to_json -context_value: contextvars.ContextVar[typing.Dict[str, typing.Any]] = ( - contextvars.ContextVar("callback_context") -) +context_value: contextvars.ContextVar[ + typing.Dict[str, typing.Any] +] = contextvars.ContextVar("callback_context") context_value.set({}) diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts index cec3cbdc85..c32079dab2 100644 --- a/dash/dash-renderer/src/observers/websocketObserver.ts +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -103,7 +103,9 @@ export async function initializeWebSocket( ); // Notify observers - store.dispatch(notifyObservers({id: parsedId, props: processedProps}) as any); + store.dispatch( + notifyObservers({id: parsedId, props: processedProps}) as any + ); }; // Handle GET_PROPS_REQUEST messages diff --git a/tests/websocket/test_ws_patch.py b/tests/websocket/test_ws_patch.py index 83d3aef0a0..f278b5f26e 100644 --- a/tests/websocket/test_ws_patch.py +++ b/tests/websocket/test_ws_patch.py @@ -39,4 +39,4 @@ def patch_append(n): dash_duo.wait_for_text_to_equal("#output", "initial + click 1", timeout=10) - assert dash_duo.get_logs() == [] \ No newline at end of file + assert dash_duo.get_logs() == [] From 64603357471f878df25230b7dee806fcd6b9e93b Mon Sep 17 00:00:00 2001 From: CNFeffery Date: Sun, 26 Apr 2026 22:22:33 +0800 Subject: [PATCH 141/166] Fix websocket callback update component prop via set_props() --- dash/_callback_context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 871afb90da..0717e65fe5 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -371,7 +371,8 @@ def set_props(component_id: typing.Union[str, dict], props: dict): async def _send_props(): for prop_name, value in props.items(): - if isinstance(value, Patch): + # Convert Patch and Dash Components to JSON-compatible format + if isinstance(value, Patch) or hasattr(value, "to_plotly_json"): value = json.loads(to_json(value)) await ws.set_prop(_id, prop_name, value) From 4b19d9abbfdb5c1e61f460d28ba6628a63eaaded Mon Sep 17 00:00:00 2001 From: CNFeffery Date: Sun, 26 Apr 2026 22:23:26 +0800 Subject: [PATCH 142/166] Add websocket callback update component prop tests --- tests/websocket/test_ws_props.py | 149 ++++++++++++++++++++++++++++++- 1 file changed, 147 insertions(+), 2 deletions(-) diff --git a/tests/websocket/test_ws_props.py b/tests/websocket/test_ws_props.py index 6b940792b3..898f1604d7 100644 --- a/tests/websocket/test_ws_props.py +++ b/tests/websocket/test_ws_props.py @@ -5,11 +5,11 @@ - set_props streaming during long-running callback - get_prop reads current component value - async set_prop method +- set_props with Patch objects (bug fix for component property updates) """ import asyncio -from dash import Dash, html, Input, Output -from dash._callback_context import set_props +from dash import Dash, html, Input, Output, set_props, Patch from dash.exceptions import PreventUpdate @@ -265,3 +265,148 @@ async def update_with_dict_id(n): dash_duo.wait_for_text_to_equal("#result", "Done 1") assert dash_duo.get_logs() == [] + + +def test_ws045_set_props_with_patch_objects(dash_duo): + """Test set_props with Patch objects - verifies bug fix for component property updates.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Patch Update", id="btn"), + html.Div("initial text", id="output"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), Input("btn", "n_clicks"), websocket=True + ) + async def patch_update(n): + if not n: + raise PreventUpdate + + p = Patch() + p += f" - updated {n}" + + set_props("output", {"children": p}) + return f"Completed {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output", "initial text - updated 1", timeout=10) + dash_duo.wait_for_text_to_equal("#result", "Completed 1") + + assert dash_duo.get_logs() == [] + + +def test_ws046_set_props_multiple_props_with_patch(dash_duo): + """Test set_props with multiple props including Patch objects.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Multi Patch", id="btn"), + html.Div("start", id="output1"), + html.Div("count: 0", id="output2"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), Input("btn", "n_clicks"), websocket=True + ) + async def multi_patch_update(n): + if not n: + raise PreventUpdate + + p = Patch() + p += f" + added {n}" + + set_props("output1", {"children": p, "style": {"color": "blue"}}) + set_props("output2", {"children": f"count: {n}"}) + return f"Multi update {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#output1", "start + added 1", timeout=10) + dash_duo.wait_for_text_to_equal("#output2", "count: 1") + dash_duo.wait_for_text_to_equal("#result", "Multi update 1") + + assert dash_duo.get_logs() == [] + + +def test_ws047_set_props_patch_in_sync_callback(dash_duo): + """Test set_props with Patch in synchronous callback.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Sync Patch", id="btn"), + html.Div("original", id="target"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), Input("btn", "n_clicks"), websocket=True + ) + def sync_patch_update(n): + if not n: + raise PreventUpdate + + p = Patch() + p += f" sync {n}" + + set_props("target", {"children": p}) + return f"Sync done {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal("#target", "original sync 1", timeout=10) + dash_duo.wait_for_text_to_equal("#result", "Sync done 1") + + assert dash_duo.get_logs() == [] + + +def test_ws048_set_props_patch_with_dict_id(dash_duo): + """Test set_props with Patch and dict component ID (pattern matching).""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Dict ID Patch", id="btn"), + html.Div("base", id={"type": "item", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), Input("btn", "n_clicks"), websocket=True + ) + async def dict_id_patch(n): + if not n: + raise PreventUpdate + + p = Patch() + p += f" patched {n}" + + set_props({"type": "item", "index": 0}, {"children": p}) + return f"Patched item {n}" + + dash_duo.start_server(app) + + dash_duo.find_element("#btn").click() + + dash_duo.wait_for_text_to_equal( + '[id=\'{"index":0,"type":"item"}\']', "base patched 1", timeout=10 + ) + dash_duo.wait_for_text_to_equal("#result", "Patched item 1") + + assert dash_duo.get_logs() == [] From 411b467a8c9bf0f67e8744435e830e6bf8e51c8b Mon Sep 17 00:00:00 2001 From: CNFeffery Date: Sun, 26 Apr 2026 22:50:31 +0800 Subject: [PATCH 143/166] Add websocket callback update component prop tests --- tests/websocket/test_ws_props.py | 156 ++++++++++++++----------------- 1 file changed, 71 insertions(+), 85 deletions(-) diff --git a/tests/websocket/test_ws_props.py b/tests/websocket/test_ws_props.py index 898f1604d7..51a9d1eec6 100644 --- a/tests/websocket/test_ws_props.py +++ b/tests/websocket/test_ws_props.py @@ -9,7 +9,7 @@ """ import asyncio -from dash import Dash, html, Input, Output, set_props, Patch +from dash import Dash, html, Input, Output, set_props from dash.exceptions import PreventUpdate @@ -267,146 +267,132 @@ async def update_with_dict_id(n): assert dash_duo.get_logs() == [] -def test_ws045_set_props_with_patch_objects(dash_duo): - """Test set_props with Patch objects - verifies bug fix for component property updates.""" +def test_ws045_set_props_component_prop_children(dash_duo): + """Test set_props updating component props like Div's children with component.""" app = Dash(__name__, backend="fastapi", websocket_callbacks=True) app.layout = html.Div( [ - html.Button("Patch Update", id="btn"), - html.Div("initial text", id="output"), + html.Button("Update Children", id="btn"), + html.Div(id="container"), html.Div(id="result"), ] ) - @app.callback( - Output("result", "children"), Input("btn", "n_clicks"), websocket=True - ) - async def patch_update(n): + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_children(n): if not n: raise PreventUpdate - p = Patch() - p += f" - updated {n}" - - set_props("output", {"children": p}) - return f"Completed {n}" + set_props( + "container", + { + "children": html.Div( + [ + html.Span(f"Updated {n}"), + html.B(" - Bold Text"), + ] + ) + }, + ) + return f"Children updated {n}" dash_duo.start_server(app) dash_duo.find_element("#btn").click() - dash_duo.wait_for_text_to_equal("#output", "initial text - updated 1", timeout=10) - dash_duo.wait_for_text_to_equal("#result", "Completed 1") + dash_duo.wait_for_text_to_equal("#container span", "Updated 1", timeout=10) + dash_duo.wait_for_text_to_equal("#container b", "- Bold Text") + dash_duo.wait_for_text_to_equal("#result", "Children updated 1") assert dash_duo.get_logs() == [] -def test_ws046_set_props_multiple_props_with_patch(dash_duo): - """Test set_props with multiple props including Patch objects.""" +def test_ws046_set_props_nested_component_children(dash_duo): + """Test set_props with nested component in children prop.""" app = Dash(__name__, backend="fastapi", websocket_callbacks=True) app.layout = html.Div( [ - html.Button("Multi Patch", id="btn"), - html.Div("start", id="output1"), - html.Div("count: 0", id="output2"), + html.Button("Update Nested", id="btn"), + html.Div(id="wrapper"), html.Div(id="result"), ] ) - @app.callback( - Output("result", "children"), Input("btn", "n_clicks"), websocket=True - ) - async def multi_patch_update(n): + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_nested(n): if not n: raise PreventUpdate - p = Patch() - p += f" + added {n}" - - set_props("output1", {"children": p, "style": {"color": "blue"}}) - set_props("output2", {"children": f"count: {n}"}) - return f"Multi update {n}" + set_props( + "wrapper", + { + "children": html.Div( + [ + html.Ul( + [ + html.Li(f"Item {n}.1"), + html.Li(f"Item {n}.2"), + ] + ) + ] + ) + }, + ) + return f"Nested updated {n}" dash_duo.start_server(app) dash_duo.find_element("#btn").click() - dash_duo.wait_for_text_to_equal("#output1", "start + added 1", timeout=10) - dash_duo.wait_for_text_to_equal("#output2", "count: 1") - dash_duo.wait_for_text_to_equal("#result", "Multi update 1") - - assert dash_duo.get_logs() == [] - - -def test_ws047_set_props_patch_in_sync_callback(dash_duo): - """Test set_props with Patch in synchronous callback.""" - app = Dash(__name__, backend="fastapi", websocket_callbacks=True) - - app.layout = html.Div( - [ - html.Button("Sync Patch", id="btn"), - html.Div("original", id="target"), - html.Div(id="result"), - ] - ) - - @app.callback( - Output("result", "children"), Input("btn", "n_clicks"), websocket=True + dash_duo.wait_for_text_to_equal( + "#wrapper ul li:first-child", "Item 1.1", timeout=10 ) - def sync_patch_update(n): - if not n: - raise PreventUpdate - - p = Patch() - p += f" sync {n}" - - set_props("target", {"children": p}) - return f"Sync done {n}" - - dash_duo.start_server(app) - - dash_duo.find_element("#btn").click() - - dash_duo.wait_for_text_to_equal("#target", "original sync 1", timeout=10) - dash_duo.wait_for_text_to_equal("#result", "Sync done 1") + dash_duo.wait_for_text_to_equal("#wrapper ul li:last-child", "Item 1.2") + dash_duo.wait_for_text_to_equal("#result", "Nested updated 1") assert dash_duo.get_logs() == [] -def test_ws048_set_props_patch_with_dict_id(dash_duo): - """Test set_props with Patch and dict component ID (pattern matching).""" +def test_ws047_set_props_children_with_list(dash_duo): + """Test set_props with list of components wrapped in a single component.""" app = Dash(__name__, backend="fastapi", websocket_callbacks=True) app.layout = html.Div( [ - html.Button("Dict ID Patch", id="btn"), - html.Div("base", id={"type": "item", "index": 0}), + html.Button("Update List", id="btn"), + html.Div(id="list-container"), html.Div(id="result"), ] ) - @app.callback( - Output("result", "children"), Input("btn", "n_clicks"), websocket=True - ) - async def dict_id_patch(n): + @app.callback(Output("result", "children"), Input("btn", "n_clicks")) + async def update_list(n): if not n: raise PreventUpdate - p = Patch() - p += f" patched {n}" - - set_props({"type": "item", "index": 0}, {"children": p}) - return f"Patched item {n}" + set_props( + "list-container", + { + "children": html.Div( + [ + html.Div(f"Item 1 - {n}"), + html.Div(f"Item 2 - {n}"), + html.Div(f"Item 3 - {n}"), + ] + ) + }, + ) + return f"List updated {n}" dash_duo.start_server(app) dash_duo.find_element("#btn").click() - dash_duo.wait_for_text_to_equal( - '[id=\'{"index":0,"type":"item"}\']', "base patched 1", timeout=10 - ) - dash_duo.wait_for_text_to_equal("#result", "Patched item 1") + dash_duo.wait_for_text_to_equal("#result", "List updated 1", timeout=10) + assert "Item 1 - 1" in dash_duo.find_element("#list-container").text + assert "Item 2 - 1" in dash_duo.find_element("#list-container").text + assert "Item 3 - 1" in dash_duo.find_element("#list-container").text assert dash_duo.get_logs() == [] From 1abd7bfe9cf3cc44ed62f0011ecc7e38e362ac9e Mon Sep 17 00:00:00 2001 From: CNFeffery Date: Mon, 27 Apr 2026 10:20:47 +0800 Subject: [PATCH 144/166] Update CHANGELOG --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8404a7f124..1676b26acd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,7 +17,8 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#3723](https://github.com/plotly/dash/pull/3723) Fix misaligned `dcc.Slider` marks when some labels are empty strings - [#3740](https://github.com/plotly/dash/pull/3740) Fix cannot tab into dropdowns in Safari - [#2462](https://github.com/plotly/dash/issues/2462) Allow `MATCH` in `Input`/`State` when the callback's `Output` has no wildcards (fixed-id Output, no Output, or `ALL`-only wildcard Output). `ALLSMALLER` still requires a corresponding `MATCH` in an Output. -- [#3759](https://github.com/plotly/dash/pull/3759) Fix the issue where `Patch` objects cannot be updated via `set_props()` in WebSocket callbacks. Fix [#3742](https://github.com/plotly/dash/issues/3742) +- [#3759](https://github.com/plotly/dash/pull/3759) Fix the issue where `Patch` objects cannot be updated via `set_props()` in `websocket` callback. Fix [#3742](https://github.com/plotly/dash/issues/3742) +- [#3759](https://github.com/plotly/dash/pull/3759) Fix the error when using `set_props()` to update component-type properties in the `websocket` callback. ## [4.2.0rc1] - 2026-04-13 From 96af07d6300f9b1b39b12b3d1ab71949babbeb98 Mon Sep 17 00:00:00 2001 From: CNFeffery Date: Tue, 28 Apr 2026 22:16:01 +0800 Subject: [PATCH 145/166] Enhance websocket set_props with plotly JSON for full prop type compatibility --- dash/_callback_context.py | 4 +--- dash/backends/_fastapi.py | 4 ++-- dash/backends/_quart.py | 4 ++-- dash/backends/base_server.py | 36 ++++++++++++++++++++++++++---------- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 0717e65fe5..7fb1a3c447 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -371,9 +371,7 @@ def set_props(component_id: typing.Union[str, dict], props: dict): async def _send_props(): for prop_name, value in props.items(): - # Convert Patch and Dash Components to JSON-compatible format - if isinstance(value, Patch) or hasattr(value, "to_plotly_json"): - value = json.loads(to_json(value)) + value = json.loads(to_json(value)) await ws.set_prop(_id, prop_name, value) # If we're in an async context, schedule the coroutine diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 65a7a4c442..71c145de59 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -94,8 +94,8 @@ def __init__( super().__init__(pending_get_props, renderer_id) self._websocket = websocket - async def _send_json(self, data: dict) -> None: - await self._websocket.send_json(data) + async def _send(self, data: str) -> None: + await self._websocket.send({"type": "websocket.send", "text": data}) async def _close_websocket(self, code: int, reason: str) -> None: await self._websocket.close(code=code, reason=reason) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index b4d2ab8464..36d833e016 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -92,8 +92,8 @@ def __init__( super().__init__(pending_get_props, renderer_id) self._websocket = ws - async def _send_json(self, data: dict) -> None: - await self._websocket.send_json(data) + async def _send(self, data: str) -> None: + await self._websocket.send({"type": "websocket.send", "text": data}) async def _close_websocket(self, code: int, reason: str) -> None: await self._websocket.close(code=code, reason=reason) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 2c63089d77..135c42079a 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -5,6 +5,7 @@ """ from abc import ABC, abstractmethod import asyncio +import json import uuid from typing import Any, Dict, Type, TypeVar, Generic, Protocol, TYPE_CHECKING @@ -392,7 +393,7 @@ class DashWebsocketCallback(ABC): Provides methods for real-time bidirectional communication between the server and renderer during callback execution. - Subclasses must implement _send_json and _close_websocket for their + Subclasses must implement _send and _close_websocket for their specific WebSocket implementation. """ @@ -411,13 +412,29 @@ def __init__( self._renderer_id = renderer_id @abstractmethod - async def _send_json(self, data: dict) -> None: - """Send JSON data over the WebSocket. Must be implemented by subclasses.""" + async def _send(self, data: str) -> None: + """Send string data over the WebSocket. Must be implemented by subclasses.""" @abstractmethod async def _close_websocket(self, code: int, reason: str) -> None: """Close the WebSocket connection. Must be implemented by subclasses.""" + async def _send_plotly_json(self, value: Any) -> None: + """Serialize and send value to client using plotly JSON serialization. + + Uses to_json for full compatibility with all supported prop types, + then sends the string directly to avoid double serialization. + """ + # pylint: disable=import-outside-toplevel + from dash._utils import to_json + + serialized = to_json(value) + await self._send(serialized) + + async def _send_json(self, data: dict) -> None: + """Send JSON dict over the WebSocket.""" + await self._send(json.dumps(data)) + async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: """Send immediate prop update to the client via WebSocket. @@ -426,13 +443,12 @@ async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: prop_name: The property name to update value: The new value to set """ - await self._send_json( - { - "type": "set_props", - "rendererId": self._renderer_id, - "payload": {"componentId": component_id, "props": {prop_name: value}}, - } - ) + payload = { + "type": "set_props", + "rendererId": self._renderer_id, + "payload": {"componentId": component_id, "props": {prop_name: value}}, + } + await self._send_plotly_json(payload) async def get_prop(self, component_id: str, prop_name: str) -> Any: """Request current prop value from the client. From 649eadb4f48f92effe7d553dc37e060c0abab793 Mon Sep 17 00:00:00 2001 From: CNFeffery Date: Tue, 28 Apr 2026 22:17:38 +0800 Subject: [PATCH 146/166] Fix format --- dash/_callback_context.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 7fb1a3c447..288ad8ec5d 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -9,7 +9,6 @@ from . import exceptions from ._get_app import get_app -from ._patch import Patch from ._utils import AttributeDict, stringify_id, to_json From c7b3f74ab227cf2a9a133cefaecb68457945a217 Mon Sep 17 00:00:00 2001 From: CNFeffery Date: Tue, 28 Apr 2026 22:37:03 +0800 Subject: [PATCH 147/166] Optimize websocket set_props by centralizing serialization with _send_plotly_json --- dash/_callback_context.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 288ad8ec5d..4f296bde66 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -9,7 +9,7 @@ from . import exceptions from ._get_app import get_app -from ._utils import AttributeDict, stringify_id, to_json +from ._utils import AttributeDict, stringify_id context_value: contextvars.ContextVar[ @@ -370,7 +370,6 @@ def set_props(component_id: typing.Union[str, dict], props: dict): async def _send_props(): for prop_name, value in props.items(): - value = json.loads(to_json(value)) await ws.set_prop(_id, prop_name, value) # If we're in an async context, schedule the coroutine From 8b86163f2ac2269e1984ba431b042b8c29c296e2 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 1 May 2026 12:11:31 -0400 Subject: [PATCH 148/166] threadpool for ws callback execution --- dash/backends/_fastapi.py | 211 +++++------- dash/backends/_quart.py | 293 ++++++----------- dash/backends/base_server.py | 301 ++++++++++++++---- .../src/observers/websocketObserver.ts | 31 +- r19.py | 222 +++++++++++++ requirements/install.txt | 1 + tests/websocket/test_ws_props.py | 75 ++++- wsapp.py | 1 - wscb.py | 68 ++++ 9 files changed, 810 insertions(+), 393 deletions(-) create mode 100644 r19.py create mode 100644 wscb.py diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 71c145de59..3f6b81a3cd 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -2,7 +2,9 @@ from contextvars import copy_context, ContextVar import asyncio +import concurrent.futures import json +import queue from typing import TYPE_CHECKING, Any, Callable, Dict import sys import mimetypes @@ -30,6 +32,8 @@ "All dependencies not installed. Please install it with `dash[fastapi]` to use the FastAPI backend." ) from _err +import janus + from dash.fingerprint import check_fingerprint from dash import _validate, get_app from dash.exceptions import PreventUpdate @@ -38,7 +42,10 @@ RequestAdapter, ResponseAdapter, DashWebsocketCallback, - create_ws_context, + run_ws_sender, + run_callback_in_executor, + make_callback_done_handler, + SHUTDOWN_SIGNAL, ) from ._utils import format_traceback_html @@ -82,25 +89,6 @@ def set_response(self, **kwargs): return resp -class FastAPIWebsocketCallback(DashWebsocketCallback): - """WebSocket callback implementation for FastAPI backend.""" - - def __init__( - self, - websocket: WebSocket, - pending_get_props: Dict[str, asyncio.Future], - renderer_id: str = "", - ): - super().__init__(pending_get_props, renderer_id) - self._websocket = websocket - - async def _send(self, data: str) -> None: - await self._websocket.send({"type": "websocket.send", "text": data}) - - async def _close_websocket(self, code: int, reason: str) -> None: - await self._websocket.close(code=code, reason=reason) - - _current_request_var = ContextVar("dash_current_request", default=None) @@ -672,10 +660,13 @@ async def _run_ws_hooks( def serve_websocket_callback(self, dash_app: "Dash"): """Set up the WebSocket endpoint for callback handling. + Uses thread pool executor for callback execution with janus queues + for async/sync communication between main loop and worker threads. + Args: dash_app: The Dash application instance """ - # pylint: disable=too-many-statements + # pylint: disable=too-many-statements,too-many-locals ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback" # Get allowed origins from dash app config @@ -719,29 +710,19 @@ async def websocket_handler(websocket: WebSocket): await websocket.accept() - # Track pending get_props requests - pending_get_props: Dict[str, asyncio.Future] = {} - # Track running callback tasks - callback_tasks: Dict[str, asyncio.Task] = {} - - async def execute_callback_task( - req_message: dict, req_renderer_id: str, req_id: str - ): - """Execute callback and send response.""" - try: - response = await self._execute_ws_callback( - dash_app, websocket, req_message, pending_get_props - ) - await websocket.send_json( - { - "type": "callback_response", - "rendererId": req_renderer_id, - "requestId": req_id, - "payload": response, - } - ) - finally: - callback_tasks.pop(req_id, None) + # Create janus queue for outbound messages (main loop context) + outbound_queue: janus.Queue[str] = janus.Queue() + # Track pending get_props requests with standard queue.Queue for responses + pending_get_props: Dict[str, queue.Queue] = {} + # Get thread pool executor + executor = self.get_callback_executor() + # Track pending callback futures + pending_callbacks: Dict[str, concurrent.futures.Future] = {} + + # Start sender task to drain outbound queue (sends pre-serialized text) + sender_task = asyncio.create_task( + run_ws_sender(websocket.send_text, outbound_queue) + ) try: while True: @@ -759,112 +740,74 @@ async def execute_callback_task( return msg_type = message.get("type") - renderer_id = message.get("rendererId") if msg_type == "callback_request": - # Run callback in background task to allow receiving - # get_props_response messages during execution request_id = message.get("requestId") - task = asyncio.create_task( - execute_callback_task(message, renderer_id, request_id) + renderer_id = message.get("rendererId", "") + payload = message.get("payload", {}) + + # Validate that the callback is allowed to use WebSocket transport + # pylint: disable=protected-access + _validate.validate_websocket_callback_request( + payload.get("output"), + dash_app.callback_map, + dash_app._websocket_callbacks, + ) + + # Create WebSocket callback instance with outbound queue + ws_cb = DashWebsocketCallback( + pending_get_props, renderer_id, outbound_queue + ) + + # Submit callback to executor + future = run_callback_in_executor( + executor, + dash_app, + payload, + ws_cb, + FastAPIResponseAdapter(), + ) + + # Set up done callback to send response + future.add_done_callback( + make_callback_done_handler( + outbound_queue, + pending_callbacks, + request_id, + renderer_id, + ) ) - callback_tasks[request_id] = task + pending_callbacks[request_id] = future elif msg_type == "get_props_response": - # Handle response for pending get_props request + # Put response in waiting queue (non-blocking) request_id = message.get("requestId") - if request_id in pending_get_props: - future = pending_get_props.pop(request_id) - if not future.done(): - future.set_result(message.get("payload")) + response_queue = pending_get_props.get(request_id) + if response_queue is not None: + response_queue.put_nowait(message.get("payload")) elif msg_type == "heartbeat": - await websocket.send_json({"type": "heartbeat_ack"}) + outbound_queue.sync_q.put_nowait('{"type": "heartbeat_ack"}') except WebSocketDisconnect: pass # Clean disconnect finally: - # Cancel any pending futures and tasks - for future in pending_get_props.values(): - if not future.done(): - future.cancel() - for task in callback_tasks.values(): - if not task.done(): - task.cancel() + # Signal sender to shutdown and cancel it + outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + # Close the janus queue + outbound_queue.close() + await outbound_queue.wait_closed() + # Cancel any pending futures + for f in pending_callbacks.values(): + f.cancel() self.server.add_api_websocket_route(ws_path, websocket_handler) - async def _execute_ws_callback( - self, - dash_app: "Dash", - websocket: WebSocket, - message: dict, - pending_get_props: Dict[str, asyncio.Future], - ) -> dict: - """Execute callback from WebSocket message. - - Args: - dash_app: The Dash application instance - websocket: The WebSocket connection - message: The callback request message - pending_get_props: Dict to track pending get_props requests - - Returns: - Response dict with status and data - """ - payload = message.get("payload", {}) - - # Validate that the callback is allowed to use WebSocket transport - # pylint: disable=protected-access - _validate.validate_websocket_callback_request( - payload.get("output"), - dash_app.callback_map, - dash_app._websocket_callbacks, - ) - # pylint: enable=protected-access - - # Create WebSocket callback context - renderer_id = message.get("rendererId", "") - cb_ctx = self._create_ws_context( - websocket, payload, pending_get_props, renderer_id - ) - - try: - # Reuse existing callback machinery - # pylint: disable=protected-access - func = dash_app._prepare_callback(cb_ctx, payload) - args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) - ctx = copy_context() - partial_func = dash_app._execute_callback( - func, args, cb_ctx.outputs_list, cb_ctx - ) - # pylint: enable=protected-access - response_data = ctx.run(partial_func) - if inspect.iscoroutine(response_data): - response_data = await response_data - - return {"status": "ok", "data": json.loads(response_data)} - - except PreventUpdate: - return {"status": "prevent_update"} - except Exception as e: # pylint: disable=broad-exception-caught - traceback.print_exc() - return {"status": "error", "message": str(e)} - - def _create_ws_context( - self, - websocket: WebSocket, - payload: dict, - pending_get_props: Dict[str, asyncio.Future], - renderer_id: str = "", - ): - """Create callback context from WebSocket message.""" - return create_ws_context( - payload, - FastAPIResponseAdapter(), - FastAPIWebsocketCallback(websocket, pending_get_props, renderer_id), - ) - class FastAPIRequestAdapter(RequestAdapter): def __init__(self): diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 36d833e016..89ba4f4d1b 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -7,8 +7,8 @@ import time import sys import asyncio -import json -import traceback +import concurrent.futures +import queue from urllib.parse import urlparse from logging.config import dictConfig @@ -35,6 +35,8 @@ "All dependencies not installed. Please install it with `dash[quart]` to use the Quart backend." ) from _err +import janus + from dash.exceptions import PreventUpdate, InvalidResourceError from dash.fingerprint import check_fingerprint from dash._utils import parse_version @@ -44,7 +46,10 @@ RequestAdapter, ResponseAdapter, DashWebsocketCallback, - create_ws_context, + run_ws_sender, + run_callback_in_executor, + make_callback_done_handler, + SHUTDOWN_SIGNAL, ) from ._utils import format_traceback_html @@ -80,25 +85,6 @@ def set_response(self, **kwargs): return self._quart_response -class QuartWebsocketCallback(DashWebsocketCallback): - """WebSocket callback implementation for Quart backend.""" - - def __init__( - self, - ws, - pending_get_props: Dict[str, asyncio.Future], - renderer_id: str = "", - ): - super().__init__(pending_get_props, renderer_id) - self._websocket = ws - - async def _send(self, data: str) -> None: - await self._websocket.send({"type": "websocket.send", "text": data}) - - async def _close_websocket(self, code: int, reason: str) -> None: - await self._websocket.close(code=code, reason=reason) - - class QuartDashServer(BaseDashServer[Quart]): websocket_capability: bool = True @@ -518,109 +504,22 @@ def _validate_ws_origin( return "Origin not allowed" return None - async def _handle_ws_message( - self, - message: dict, - ws, - dash_app: "Dash", - pending_get_props: Dict[str, asyncio.Future], - callback_tasks: Dict[str, asyncio.Task], - ) -> tuple | None: - """Handle a single WebSocket message. Returns rejection tuple or None.""" - # Call websocket_message hooks - # pylint: disable=protected-access - rejection = await self._run_ws_hooks( - dash_app._hooks.get_hooks("websocket_message"), - ws, - message, - default_reason="Message rejected", - ) - if rejection: - return rejection - - msg_type = message.get("type") - - if msg_type == "callback_request": - await self._handle_callback_request( - message, ws, dash_app, pending_get_props, callback_tasks - ) - elif msg_type == "get_props_response": - self._handle_get_props_response(message, pending_get_props) - elif msg_type == "heartbeat": - await ws.send_json({"type": "heartbeat_ack"}) - - return None - - async def _handle_callback_request( - self, - message: dict, - ws, - dash_app: "Dash", - pending_get_props: Dict[str, asyncio.Future], - callback_tasks: Dict[str, asyncio.Task], - ): - """Handle a callback request message.""" - renderer_id = message.get("rendererId") - request_id = message.get("requestId") - - async def execute_and_respond(): - try: - response = await self._execute_ws_callback( - dash_app, ws, message, pending_get_props - ) - await ws.send_json( - { - "type": "callback_response", - "rendererId": renderer_id, - "requestId": request_id, - "payload": response, - } - ) - finally: - callback_tasks.pop(request_id, None) - - task = asyncio.create_task(execute_and_respond()) - callback_tasks[request_id] = task - - def _handle_get_props_response( - self, message: dict, pending_get_props: Dict[str, asyncio.Future] - ): - """Handle a get_props response message.""" - request_id = message.get("requestId") - if request_id in pending_get_props: - future = pending_get_props.pop(request_id) - if not future.done(): - future.set_result(message.get("payload")) - - @staticmethod - async def _cleanup_ws_tasks( - pending_get_props: Dict[str, asyncio.Future], - callback_tasks: Dict[str, asyncio.Task], - ): - """Cancel any pending futures and tasks on disconnect.""" - for future in pending_get_props.values(): - if not future.done(): - future.cancel() - # Cancel and await all callback tasks - for task in callback_tasks.values(): - if not task.done(): - task.cancel() - # Wait for all tasks to complete their cancellation - if callback_tasks: - await asyncio.gather(*callback_tasks.values(), return_exceptions=True) - def serve_websocket_callback(self, dash_app: "Dash"): """Set up the WebSocket endpoint for callback handling. + Uses thread pool executor for callback execution with janus queues + for async/sync communication between main loop and worker threads. + Args: dash_app: The Dash application instance """ + # pylint: disable=too-many-statements,too-many-locals ws_path = dash_app.config.requests_pathname_prefix + "_dash-ws-callback" # pylint: disable=protected-access allowed_origins = getattr(dash_app, "_websocket_allowed_origins", []) @self.server.websocket(ws_path) - async def websocket_handler(): # pylint: disable=too-many-branches + async def websocket_handler(): ws = websocket # Validate Origin header @@ -645,11 +544,24 @@ async def websocket_handler(): # pylint: disable=too-many-branches await ws.accept() # Track this connection for graceful shutdown - ws_obj = ws._get_current_object() - self._active_websockets.add(ws_obj) - - pending_get_props: Dict[str, asyncio.Future] = {} - callback_tasks: Dict[str, asyncio.Task] = {} + try: + ws_obj = ws._get_current_object() + self._active_websockets.add(ws_obj) + except AttributeError: + ws_obj = ws + self._active_websockets.add(ws) + + # Create janus queue for outbound messages (main loop context) + outbound_queue: janus.Queue[str] = janus.Queue() + # Track pending get_props requests with standard queue.Queue for responses + pending_get_props: Dict[str, queue.Queue] = {} + # Get thread pool executor + executor = self.get_callback_executor() + # Track pending callback futures + pending_callbacks: Dict[str, concurrent.futures.Future] = {} + + # Start sender task to drain outbound queue (sends pre-serialized text) + sender_task = asyncio.create_task(run_ws_sender(ws.send, outbound_queue)) try: shutdown_event = self._ws_shutdown_event @@ -661,88 +573,87 @@ async def websocket_handler(): # pylint: disable=too-many-branches # Re-check shutdown event (may have been set during run()) shutdown_event = self._ws_shutdown_event continue - rejection = await self._handle_ws_message( - message, ws, dash_app, pending_get_props, callback_tasks + + # Call websocket_message hooks + rejection = await self._run_ws_hooks( + dash_app._hooks.get_hooks("websocket_message"), + ws, + message, + default_reason="Message rejected", ) if rejection: await ws.close(code=rejection[0], reason=rejection[1]) return + + msg_type = message.get("type") + + if msg_type == "callback_request": + request_id = message.get("requestId") + renderer_id = message.get("rendererId", "") + payload = message.get("payload", {}) + + # Validate that the callback is allowed to use WebSocket transport + # pylint: disable=protected-access + _validate.validate_websocket_callback_request( + payload.get("output"), + dash_app.callback_map, + dash_app._websocket_callbacks, + ) + + # Create WebSocket callback instance with outbound queue + ws_cb = DashWebsocketCallback( + pending_get_props, renderer_id, outbound_queue + ) + + # Submit callback to executor + future = run_callback_in_executor( + executor, + dash_app, + payload, + ws_cb, + QuartResponseAdapter(), + ) + + # Set up done callback to send response + future.add_done_callback( + make_callback_done_handler( + outbound_queue, + pending_callbacks, + request_id, + renderer_id, + ) + ) + pending_callbacks[request_id] = future + + elif msg_type == "get_props_response": + # Put response in waiting queue (non-blocking) + request_id = message.get("requestId") + response_queue = pending_get_props.get(request_id) + if response_queue is not None: + response_queue.put_nowait(message.get("payload")) + + elif msg_type == "heartbeat": + outbound_queue.sync_q.put_nowait('{"type": "heartbeat_ack"}') + except asyncio.CancelledError: pass # Server is shutting down, exit gracefully except Exception: # pylint: disable=broad-exception-caught pass # Other exceptions treated as disconnect finally: self._active_websockets.discard(ws_obj) - await self._cleanup_ws_tasks(pending_get_props, callback_tasks) - - async def _execute_ws_callback( - self, - dash_app: "Dash", - ws, - message: dict, - pending_get_props: Dict[str, asyncio.Future], - ) -> dict: - """Execute callback from WebSocket message. - - Args: - dash_app: The Dash application instance - ws: The WebSocket connection - message: The callback request message - pending_get_props: Dict to track pending get_props requests - - Returns: - Response dict with status and data - """ - payload = message.get("payload", {}) - - # Validate that the callback is allowed to use WebSocket transport - # pylint: disable=protected-access - _validate.validate_websocket_callback_request( - payload.get("output"), - dash_app.callback_map, - dash_app._websocket_callbacks, - ) - # pylint: enable=protected-access - - # Create WebSocket callback context - renderer_id = message.get("rendererId", "") - cb_ctx = self._create_ws_context(ws, payload, pending_get_props, renderer_id) - - try: - # Reuse existing callback machinery - # pylint: disable=protected-access - func = dash_app._prepare_callback(cb_ctx, payload) - args = dash_app._inputs_to_vals(cb_ctx.inputs_list + cb_ctx.states_list) - ctx = copy_context() - partial_func = dash_app._execute_callback( - func, args, cb_ctx.outputs_list, cb_ctx - ) - # pylint: enable=protected-access - response_data = ctx.run(partial_func) - if inspect.iscoroutine(response_data): - response_data = await response_data - - return {"status": "ok", "data": json.loads(response_data)} - - except PreventUpdate: - return {"status": "prevent_update"} - except Exception as e: # pylint: disable=broad-exception-caught - traceback.print_exc() - return {"status": "error", "message": str(e)} - - def _create_ws_context( - self, - ws, - payload: dict, - pending_get_props: Dict[str, asyncio.Future], - renderer_id: str = "", - ): - """Create callback context from WebSocket message.""" - return create_ws_context( - payload, - QuartResponseAdapter(), - QuartWebsocketCallback(ws, pending_get_props, renderer_id), - ) + # Signal sender to shutdown and cancel it + outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) + sender_task.cancel() + try: + await sender_task + except asyncio.CancelledError: + pass + # Close the janus queue + outbound_queue.close() + await outbound_queue.wait_closed() + # Cancel any pending futures + for f in pending_callbacks.values(): + f.cancel() class QuartRequestAdapter(RequestAdapter): diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 135c42079a..bae003ea78 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -5,10 +5,30 @@ """ from abc import ABC, abstractmethod import asyncio +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor +import inspect import json +import queue +import traceback import uuid -from typing import Any, Dict, Type, TypeVar, Generic, Protocol, TYPE_CHECKING - +from contextvars import copy_context +from typing import ( + Any, + Callable, + Dict, + Type, + TypeVar, + Generic, + Protocol, + TYPE_CHECKING, + cast, +) + +import janus + +from dash.exceptions import PreventUpdate +from dash._utils import to_json if TYPE_CHECKING: import dash @@ -182,6 +202,34 @@ def __init__(self, server: ServerType) -> None: """ super().__init__() self.server = server + self._callback_executor: ThreadPoolExecutor | None = None + + def get_callback_executor( + self, max_workers: int | None = None + ) -> ThreadPoolExecutor: + """Get or create the thread pool executor for callback execution. + + Args: + max_workers: Maximum number of worker threads. If None, uses default. + + Returns: + ThreadPoolExecutor instance for running callbacks. + """ + if self._callback_executor is None: + self._callback_executor = ThreadPoolExecutor( + max_workers=max_workers, thread_name_prefix="dash-callback-" + ) + return self._callback_executor + + def shutdown_executor(self, wait: bool = True) -> None: + """Shutdown the callback executor. + + Args: + wait: If True, wait for pending tasks to complete. + """ + if self._callback_executor is not None: + self._callback_executor.shutdown(wait=wait) + self._callback_executor = None def __call__(self, *args, **kwargs) -> Any: """Make the server wrapper callable as a WSGI/ASGI application. @@ -387,115 +435,101 @@ def serve_websocket_callback(self, dash_app: "dash.Dash"): """ -class DashWebsocketCallback(ABC): - """Abstract base for WebSocket-based callback communication. +class DashWebsocketCallback: + """WebSocket callback communication via queues. Provides methods for real-time bidirectional communication between the server and renderer during callback execution. - Subclasses must implement _send and _close_websocket for their - specific WebSocket implementation. + Uses janus.Queue for outbound messages (serialized with to_json) and + queue.Queue for get_props responses, enabling thread-safe communication + between worker threads and the main event loop. """ def __init__( self, - pending_get_props: Dict[str, asyncio.Future], - renderer_id: str = "", + pending_get_props: Dict[str, queue.Queue[Any]], + renderer_id: str, + outbound_queue: janus.Queue[str], ): """Initialize the WebSocket callback interface. Args: - pending_get_props: Dict to track pending get_props requests + pending_get_props: Dict to track pending get_props requests. + Values are queue.Queue instances for blocking response retrieval. renderer_id: The renderer ID for routing messages back to the correct client + outbound_queue: janus.Queue for thread-safe outbound messaging. """ self._pending_get_props = pending_get_props self._renderer_id = renderer_id + self._outbound_queue = outbound_queue - @abstractmethod - async def _send(self, data: str) -> None: - """Send string data over the WebSocket. Must be implemented by subclasses.""" - - @abstractmethod - async def _close_websocket(self, code: int, reason: str) -> None: - """Close the WebSocket connection. Must be implemented by subclasses.""" - - async def _send_plotly_json(self, value: Any) -> None: - """Serialize and send value to client using plotly JSON serialization. + def _queue_message(self, msg: dict) -> None: + """Serialize and queue message for sending (thread-safe, non-blocking). - Uses to_json for full compatibility with all supported prop types, - then sends the string directly to avoid double serialization. + Uses to_json for proper serialization of Dash components. """ - # pylint: disable=import-outside-toplevel - from dash._utils import to_json - - serialized = to_json(value) - await self._send(serialized) - - async def _send_json(self, data: dict) -> None: - """Send JSON dict over the WebSocket.""" - await self._send(json.dumps(data)) + self._outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: """Send immediate prop update to the client via WebSocket. + Queues the message for the sender coroutine to send. + Args: component_id: The component ID (string or stringified dict) prop_name: The property name to update value: The new value to set """ - payload = { + msg = { "type": "set_props", "rendererId": self._renderer_id, "payload": {"componentId": component_id, "props": {prop_name: value}}, } - await self._send_plotly_json(payload) + self._queue_message(msg) - async def get_prop(self, component_id: str, prop_name: str) -> Any: + async def get_prop( + self, component_id: str, prop_name: str, timeout: float = 30.0 + ) -> Any: """Request current prop value from the client. + Uses queue.Queue for blocking wait in worker thread. + Args: component_id: The component ID (string or stringified dict) prop_name: The property name to retrieve + timeout: Timeout in seconds for waiting for response Returns: The current value of the property from the client's state """ request_id = str(uuid.uuid4()) + msg = { + "type": "get_props_request", + "rendererId": self._renderer_id, + "requestId": request_id, + "payload": {"componentId": component_id, "properties": [prop_name]}, + } - # Create a future to wait for the response - future: asyncio.Future = asyncio.get_event_loop().create_future() - self._pending_get_props[request_id] = future - - # Send the request - await self._send_json( - { - "type": "get_props_request", - "rendererId": self._renderer_id, - "requestId": request_id, - "payload": {"componentId": component_id, "properties": [prop_name]}, - } - ) - - # Wait for the response with timeout + # Use standard queue.Queue for response + response_queue: queue.Queue = queue.Queue() + self._pending_get_props[request_id] = response_queue + + # Queue the outbound request via janus sync interface + self._queue_message(msg) + + # Wait for response (blocking is OK in worker thread) try: - result = await asyncio.wait_for(future, timeout=30.0) + result = response_queue.get(timeout=timeout) if result and prop_name in result: return result[prop_name] return None - except asyncio.TimeoutError as exc: - self._pending_get_props.pop(request_id, None) + except queue.Empty as exc: raise TimeoutError( - f"Timeout waiting for get_prop response for {component_id}.{prop_name}" + f"Timeout waiting for {component_id}.{prop_name}" ) from exc - - async def close(self, code: int = 1000, reason: str = "Connection closed") -> None: - """Close the WebSocket connection. - - Args: - code: WebSocket close code (default 1000 for normal closure) - reason: Human-readable reason for closing - """ - await self._close_websocket(code, reason) + finally: + self._pending_get_props.pop(request_id, None) def create_ws_context( @@ -531,3 +565,148 @@ def create_ws_context( g.dash_websocket = websocket_callback return g + + +SHUTDOWN_SIGNAL = "__shutdown__" + + +async def run_ws_sender( + send_text: Callable[[str], Any], outbound_queue: janus.Queue[str] +) -> None: + """Sender coroutine - drains queue and sends to WebSocket. + + This coroutine runs in the main event loop and handles sending + messages that are queued by worker threads via janus.Queue. + + Messages are pre-serialized strings (using to_json). + + Args: + send_text: Async function to send text data over WebSocket + outbound_queue: janus.Queue instance for receiving messages (strings) + """ + try: + while True: + msg = await outbound_queue.async_q.get() + if msg == SHUTDOWN_SIGNAL: + break + await send_text(msg) + except asyncio.CancelledError: + pass + + +def make_callback_done_handler( + outbound_queue: janus.Queue[str], + pending_callbacks: Dict[str, concurrent.futures.Future], + request_id: str, + renderer_id: str, +) -> Callable[[concurrent.futures.Future], None]: + """Create a done callback handler for executor futures. + + This factory creates a callback that sends the result back through + the WebSocket when an executor future completes. + + Args: + outbound_queue: janus.Queue for sending responses + pending_callbacks: Dict tracking pending callbacks for cleanup + request_id: The request ID for the callback response + renderer_id: The renderer ID for routing the response + + Returns: + A callback function suitable for Future.add_done_callback() + """ + + def on_done(f: concurrent.futures.Future) -> None: + try: + result = f.result() + outbound_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": result, + } + ), + ) + ) + except Exception as e: # pylint: disable=broad-exception-caught + outbound_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": { + "status": "error", + "message": str(e), + }, + } + ), + ) + ) + finally: + pending_callbacks.pop(request_id, None) + + return on_done + + +def run_callback_in_executor( + executor: ThreadPoolExecutor, + dash_app: "dash.Dash", + payload: dict, + ws_callback: DashWebsocketCallback, + response_adapter: ResponseAdapter, +) -> concurrent.futures.Future: + """Submit callback to executor for thread pool execution. + + This function creates a callback execution context and runs it + in a separate thread. Both sync and async callbacks are supported. + + Args: + executor: ThreadPoolExecutor to submit the task to + dash_app: The Dash application instance + payload: The callback payload from WebSocket message + ws_callback: WebSocket callback instance for set_prop/get_prop + response_adapter: Response adapter for the backend + + Returns: + Future representing the pending callback execution + """ + + def execute() -> dict: + try: + cb_ctx = create_ws_context(payload, response_adapter, ws_callback) + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, payload) + args = dash_app._inputs_to_vals( # pylint: disable=protected-access + cb_ctx.inputs_list + cb_ctx.states_list + ) + + ctx = copy_context() + partial_func = ( + dash_app._execute_callback( # pylint: disable=protected-access + func, args, cb_ctx.outputs_list, cb_ctx + ) + ) + + # Run in new event loop (handles both sync and async callbacks) + def run_callback(): + result = partial_func() + if inspect.iscoroutine(result): + return asyncio.run(result) + return result + + response_data = ctx.run(run_callback) + return {"status": "ok", "data": json.loads(response_data)} + + except PreventUpdate: + return {"status": "prevent_update"} + except Exception as e: # pylint: disable=broad-exception-caught + traceback.print_exc() + return {"status": "error", "message": str(e)} + + return executor.submit(execute) diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts index c32079dab2..7b75fada38 100644 --- a/dash/dash-renderer/src/observers/websocketObserver.ts +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -8,9 +8,9 @@ import {Store} from 'redux'; import {path} from 'ramda'; import {IStoreState} from '../store'; -import {updateProps, notifyObservers} from '../actions'; +import {updateProps, notifyObservers, setPaths} from '../actions'; import {parsePatchProps} from '../actions/patch'; -import {getPath} from '../actions/paths'; +import {computePaths, getPath} from '../actions/paths'; import { getWorkerClient, SetPropsPayload, @@ -86,9 +86,12 @@ export async function initializeWebSocket( return; } - // Get old props for Patch processing - const oldProps = (path([...componentPath, 'props'], state.layout) || - {}) as Record; + // Get old component for Patch processing and path recomputation + const oldComponent = path(componentPath, state.layout) as Record< + string, + unknown + > | null; + const oldProps = (oldComponent?.props || {}) as Record; // Process props to handle Patch objects const processedProps = parsePatchProps(rawProps, oldProps); @@ -106,6 +109,24 @@ export async function initializeWebSocket( store.dispatch( notifyObservers({id: parsedId, props: processedProps}) as any ); + + // Recompute paths for any new child components + if (oldComponent) { + const updatedState = store.getState(); + store.dispatch( + setPaths( + computePaths( + { + ...oldComponent, + props: {...oldProps, ...processedProps} + }, + [...componentPath], + updatedState.paths, + updatedState.paths.events + ) + ) as any + ); + } }; // Handle GET_PROPS_REQUEST messages diff --git a/r19.py b/r19.py new file mode 100644 index 0000000000..815d2a5066 --- /dev/null +++ b/r19.py @@ -0,0 +1,222 @@ +""" +React 19 test app with most Dash components. +Run with: python r19.py +""" + +import os +os.environ["REACT_VERSION"] = "19.2.0" + +from dash import Dash, html, dcc, dash_table, callback, Input, Output +import plotly.express as px +import pandas as pd + +# Sample data +df = pd.DataFrame({ + "Fruit": ["Apples", "Oranges", "Bananas", "Grapes", "Strawberries"], + "Amount": [4, 2, 5, 3, 6], + "City": ["NYC", "LA", "Chicago", "Houston", "Phoenix"] +}) + +app = Dash(__name__) + +app.layout = html.Div([ + html.H1("React 19 Component Test"), + html.P(f"Running React version: {os.environ.get('REACT_VERSION')}"), + + html.Hr(), + html.H2("Core HTML Components"), + html.Div([ + html.Button("Click Me", id="button", n_clicks=0), + html.Span(" Clicks: ", style={"marginLeft": "10px"}), + html.Span(id="click-output", children="0"), + ]), + + html.Hr(), + html.H2("Input Components"), + html.Div([ + html.Label("Text Input:"), + dcc.Input(id="text-input", type="text", placeholder="Type something...", debounce=True), + html.Div(id="text-output"), + ], style={"marginBottom": "20px"}), + + html.Div([ + html.Label("Dropdown:"), + dcc.Dropdown( + id="dropdown", + options=[{"label": f, "value": f} for f in df["Fruit"]], + value="Apples", + clearable=True, + ), + html.Div(id="dropdown-output"), + ], style={"marginBottom": "20px", "width": "300px"}), + + html.Div([ + html.Label("Multi-Select Dropdown:"), + dcc.Dropdown( + id="multi-dropdown", + options=[{"label": f, "value": f} for f in df["Fruit"]], + value=["Apples", "Oranges"], + multi=True, + ), + ], style={"marginBottom": "20px", "width": "300px"}), + + html.Div([ + html.Label("Slider:"), + dcc.Slider(id="slider", min=0, max=10, step=1, value=5, marks={i: str(i) for i in range(11)}), + html.Div(id="slider-output"), + ], style={"marginBottom": "20px", "width": "400px"}), + + html.Div([ + html.Label("Range Slider:"), + dcc.RangeSlider(id="range-slider", min=0, max=100, step=10, value=[20, 80]), + ], style={"marginBottom": "20px", "width": "400px"}), + + html.Div([ + html.Label("Radio Items:"), + dcc.RadioItems( + id="radio", + options=[{"label": c, "value": c} for c in df["City"]], + value="NYC", + inline=True, + ), + ], style={"marginBottom": "20px"}), + + html.Div([ + html.Label("Checklist:"), + dcc.Checklist( + id="checklist", + options=[{"label": c, "value": c} for c in df["City"]], + value=["NYC", "LA"], + inline=True, + ), + ], style={"marginBottom": "20px"}), + + html.Div([ + html.Label("Date Picker:"), + dcc.DatePickerSingle(id="date-picker", date="2024-01-15"), + ], style={"marginBottom": "20px"}), + + html.Div([ + html.Label("Date Range Picker:"), + dcc.DatePickerRange( + id="date-range", + start_date="2024-01-01", + end_date="2024-12-31", + ), + ], style={"marginBottom": "20px"}), + + html.Div([ + html.Label("Textarea:"), + dcc.Textarea(id="textarea", value="Some text here...", style={"width": "300px", "height": "100px"}), + ], style={"marginBottom": "20px"}), + + html.Hr(), + html.H2("Graph Component"), + dcc.Graph( + id="graph", + figure=px.bar(df, x="Fruit", y="Amount", color="City", title="Fruit Amounts by City") + ), + + html.Hr(), + html.H2("DataTable"), + dash_table.DataTable( + id="table", + columns=[{"name": c, "id": c} for c in df.columns], + data=df.to_dict("records"), + editable=True, + filter_action="native", + sort_action="native", + row_selectable="multi", + page_size=10, + ), + + html.Hr(), + html.H2("Tabs"), + dcc.Tabs(id="tabs", value="tab-1", children=[ + dcc.Tab(label="Tab 1", value="tab-1", children=[ + html.Div("Content for Tab 1", style={"padding": "20px"}) + ]), + dcc.Tab(label="Tab 2", value="tab-2", children=[ + html.Div("Content for Tab 2", style={"padding": "20px"}) + ]), + ]), + + html.Hr(), + html.H2("Loading Component"), + dcc.Loading( + id="loading", + type="circle", + children=html.Div(id="loading-output", children="Content loaded!") + ), + + html.Hr(), + html.H2("Markdown"), + dcc.Markdown(""" + ### This is Markdown + + - Item 1 + - Item 2 + - **Bold text** + - *Italic text* + + ```python + def hello(): + return "Hello, React 19!" + ``` + """), + + html.Hr(), + html.H2("Store & Interval"), + dcc.Store(id="store", data={"count": 0}), + dcc.Interval(id="interval", interval=5000, n_intervals=0, disabled=True), + html.Div(id="interval-output", children="Interval disabled"), + + html.Hr(), + html.H2("Clipboard"), + dcc.Clipboard(id="clipboard", target_id="text-input", style={"fontSize": "20px"}), + + html.Hr(), + html.H2("Tooltip"), + html.Div([ + html.Span("Hover over the graph points to see tooltips", style={"fontStyle": "italic"}), + ]), + + html.Br(), + html.Br(), +], style={"padding": "20px", "maxWidth": "800px", "margin": "0 auto"}) + + +@callback( + Output("click-output", "children"), + Input("button", "n_clicks") +) +def update_clicks(n): + return str(n) + + +@callback( + Output("text-output", "children"), + Input("text-input", "value") +) +def update_text(value): + return f"You typed: {value}" if value else "" + + +@callback( + Output("dropdown-output", "children"), + Input("dropdown", "value") +) +def update_dropdown(value): + return f"Selected: {value}" if value else "Nothing selected" + + +@callback( + Output("slider-output", "children"), + Input("slider", "value") +) +def update_slider(value): + return f"Slider value: {value}" + + +if __name__ == "__main__": + app.run(debug=True, port=8050) diff --git a/requirements/install.txt b/requirements/install.txt index df0e1299e3..284f3a5031 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -7,3 +7,4 @@ requests retrying nest-asyncio setuptools +janus>=1.0.0 diff --git a/tests/websocket/test_ws_props.py b/tests/websocket/test_ws_props.py index 51a9d1eec6..e800668ae8 100644 --- a/tests/websocket/test_ws_props.py +++ b/tests/websocket/test_ws_props.py @@ -6,10 +6,11 @@ - get_prop reads current component value - async set_prop method - set_props with Patch objects (bug fix for component property updates) +- set_props with pattern-matching components triggering MATCH callbacks """ import asyncio -from dash import Dash, html, Input, Output, set_props +from dash import Dash, html, Input, Output, State, set_props, MATCH from dash.exceptions import PreventUpdate @@ -396,3 +397,75 @@ async def update_list(n): assert "Item 3 - 1" in dash_duo.find_element("#list-container").text assert dash_duo.get_logs() == [] + + +def test_ws048_set_props_dynamic_match_callback(dash_duo): + """Test set_props injecting components with pattern-matching IDs that trigger MATCH callbacks.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Add Component", id="add-btn"), + html.Div(id="container"), + html.Div("waiting", id="match-result"), + html.Div(id="result"), + ] + ) + + @app.callback(Output("result", "children"), Input("add-btn", "n_clicks")) + async def add_component(n): + if not n: + raise PreventUpdate + + # Inject component with pattern-matching ID via set_props + set_props( + "container", + { + "children": html.Div( + [ + html.Span("Hello"), + html.Button("Click me", id={"type": "dynamic", "index": 0}), + ] + ) + }, + ) + return f"Component added {n}" + + @app.callback( + Output("match-result", "children"), + Input({"type": "dynamic", "index": MATCH}, "n_clicks"), + State({"type": "dynamic", "index": MATCH}, "id"), + prevent_initial_call=True, + ) + def handle_dynamic_click(n_clicks, btn_id): + if not n_clicks: + raise PreventUpdate + return f"Clicked button index {btn_id['index']} - {n_clicks} times" + + dash_duo.start_server(app) + + # Initial state + dash_duo.wait_for_text_to_equal("#match-result", "waiting") + + # Add the dynamic component + dash_duo.find_element("#add-btn").click() + dash_duo.wait_for_text_to_equal("#result", "Component added 1", timeout=10) + + # Verify the component was added + dash_duo.wait_for_text_to_equal("#container span", "Hello", timeout=5) + + # Click the dynamically added button with pattern-matching ID + dash_duo.find_element('[id=\'{"index":0,"type":"dynamic"}\']').click() + + # Verify the MATCH callback fired + dash_duo.wait_for_text_to_equal( + "#match-result", "Clicked button index 0 - 1 times", timeout=10 + ) + + # Click again to verify it continues to work + dash_duo.find_element('[id=\'{"index":0,"type":"dynamic"}\']').click() + dash_duo.wait_for_text_to_equal( + "#match-result", "Clicked button index 0 - 2 times", timeout=10 + ) + + assert dash_duo.get_logs() == [] diff --git a/wsapp.py b/wsapp.py index eda9c952ff..98b2db2f38 100644 --- a/wsapp.py +++ b/wsapp.py @@ -15,7 +15,6 @@ __name__, backend="fastapi", websocket_callbacks=True, - websocket_inactivity_timeout=10000, ) app.layout = html.Div([ diff --git a/wscb.py b/wscb.py new file mode 100644 index 0000000000..629ed3cdc1 --- /dev/null +++ b/wscb.py @@ -0,0 +1,68 @@ +""" +Test app for per-callback WebSocket support. + +This app demonstrates using websocket=True on specific callbacks +without enabling global websocket_callbacks. +""" + +from dash import Dash, html, dcc, callback, Input, Output, State + +app = Dash(__name__, backend="fastapi") + +app.layout = html.Div([ + html.H1("Per-Callback WebSocket Test"), + + html.Div([ + html.H3("WebSocket Callback"), + dcc.Input(id="ws-input", type="text", placeholder="Type here..."), + html.Div(id="ws-output", style={"padding": "10px", "background": "#e0ffe0"}) + ], style={"margin": "20px", "padding": "20px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("HTTP Callback (default)"), + dcc.Input(id="http-input", type="text", placeholder="Type here..."), + html.Div(id="http-output", style={"padding": "10px", "background": "#e0e0ff"}) + ], style={"margin": "20px", "padding": "20px", "border": "1px solid #ccc"}), + + html.Div([ + html.H3("WebSocket Counter"), + html.Button("Increment", id="ws-btn"), + html.Div(id="ws-counter", children="0", style={"padding": "10px", "background": "#ffe0e0"}) + ], style={"margin": "20px", "padding": "20px", "border": "1px solid #ccc"}), +]) + + +@callback( + Output("ws-output", "children"), + Input("ws-input", "value"), + websocket=True +) +def ws_callback(value): + """This callback uses WebSocket.""" + return f"[WebSocket] You typed: {value or ''}" + + +@callback( + Output("http-output", "children"), + Input("http-input", "value") +) +def http_callback(value): + """This callback uses HTTP (default).""" + return f"[HTTP] You typed: {value or ''}" + + +@callback( + Output("ws-counter", "children"), + Input("ws-btn", "n_clicks"), + State("ws-counter", "children"), + websocket=True +) +def ws_counter(n_clicks, current): + """WebSocket counter callback.""" + if n_clicks is None: + return "0" + return str(int(current or 0) + 1) + + +if __name__ == "__main__": + app.run(debug=True) From 727626d1f69e3257460febb0f85b569161a46582 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 1 May 2026 12:19:26 -0400 Subject: [PATCH 149/166] move websocket code out of base_server --- dash/_callback_context.py | 2 +- dash/backends/_fastapi.py | 2 + dash/backends/_quart.py | 2 + dash/backends/base_server.py | 292 --------------------------------- dash/backends/ws.py | 303 +++++++++++++++++++++++++++++++++++ 5 files changed, 308 insertions(+), 293 deletions(-) create mode 100644 dash/backends/ws.py diff --git a/dash/_callback_context.py b/dash/_callback_context.py index 4f296bde66..e03f343129 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -5,7 +5,7 @@ import contextvars import typing -from dash.backends.base_server import DashWebsocketCallback +from dash.backends.ws import DashWebsocketCallback from . import exceptions from ._get_app import get_app diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 3f6b81a3cd..1fc60cd703 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -41,6 +41,8 @@ BaseDashServer, RequestAdapter, ResponseAdapter, +) +from .ws import ( DashWebsocketCallback, run_ws_sender, run_callback_in_executor, diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 89ba4f4d1b..9441ba8bd3 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -45,6 +45,8 @@ BaseDashServer, RequestAdapter, ResponseAdapter, +) +from .ws import ( DashWebsocketCallback, run_ws_sender, run_callback_in_executor, diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index bae003ea78..8657f17535 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -4,32 +4,17 @@ for different web server backends (Flask, Quart, FastAPI, etc.) to integrate with Dash. """ from abc import ABC, abstractmethod -import asyncio -import concurrent.futures from concurrent.futures import ThreadPoolExecutor -import inspect -import json -import queue -import traceback -import uuid -from contextvars import copy_context from typing import ( Any, - Callable, Dict, Type, TypeVar, Generic, Protocol, TYPE_CHECKING, - cast, ) -import janus - -from dash.exceptions import PreventUpdate -from dash._utils import to_json - if TYPE_CHECKING: import dash @@ -433,280 +418,3 @@ def serve_websocket_callback(self, dash_app: "dash.Dash"): Args: dash_app: The Dash application instance """ - - -class DashWebsocketCallback: - """WebSocket callback communication via queues. - - Provides methods for real-time bidirectional communication between - the server and renderer during callback execution. - - Uses janus.Queue for outbound messages (serialized with to_json) and - queue.Queue for get_props responses, enabling thread-safe communication - between worker threads and the main event loop. - """ - - def __init__( - self, - pending_get_props: Dict[str, queue.Queue[Any]], - renderer_id: str, - outbound_queue: janus.Queue[str], - ): - """Initialize the WebSocket callback interface. - - Args: - pending_get_props: Dict to track pending get_props requests. - Values are queue.Queue instances for blocking response retrieval. - renderer_id: The renderer ID for routing messages back to the correct client - outbound_queue: janus.Queue for thread-safe outbound messaging. - """ - self._pending_get_props = pending_get_props - self._renderer_id = renderer_id - self._outbound_queue = outbound_queue - - def _queue_message(self, msg: dict) -> None: - """Serialize and queue message for sending (thread-safe, non-blocking). - - Uses to_json for proper serialization of Dash components. - """ - self._outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) - - async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: - """Send immediate prop update to the client via WebSocket. - - Queues the message for the sender coroutine to send. - - Args: - component_id: The component ID (string or stringified dict) - prop_name: The property name to update - value: The new value to set - """ - msg = { - "type": "set_props", - "rendererId": self._renderer_id, - "payload": {"componentId": component_id, "props": {prop_name: value}}, - } - self._queue_message(msg) - - async def get_prop( - self, component_id: str, prop_name: str, timeout: float = 30.0 - ) -> Any: - """Request current prop value from the client. - - Uses queue.Queue for blocking wait in worker thread. - - Args: - component_id: The component ID (string or stringified dict) - prop_name: The property name to retrieve - timeout: Timeout in seconds for waiting for response - - Returns: - The current value of the property from the client's state - """ - request_id = str(uuid.uuid4()) - msg = { - "type": "get_props_request", - "rendererId": self._renderer_id, - "requestId": request_id, - "payload": {"componentId": component_id, "properties": [prop_name]}, - } - - # Use standard queue.Queue for response - response_queue: queue.Queue = queue.Queue() - self._pending_get_props[request_id] = response_queue - - # Queue the outbound request via janus sync interface - self._queue_message(msg) - - # Wait for response (blocking is OK in worker thread) - try: - result = response_queue.get(timeout=timeout) - if result and prop_name in result: - return result[prop_name] - return None - except queue.Empty as exc: - raise TimeoutError( - f"Timeout waiting for {component_id}.{prop_name}" - ) from exc - finally: - self._pending_get_props.pop(request_id, None) - - -def create_ws_context( - payload: dict, - response_adapter: ResponseAdapter, - websocket_callback: DashWebsocketCallback, -): - """Create callback context from WebSocket message. - - Args: - payload: The callback payload - response_adapter: The response adapter instance for the backend - websocket_callback: The websocket callback instance for the backend - - Returns: - AttributeDict with callback context - """ - # pylint: disable=import-outside-toplevel - from dash._utils import AttributeDict, inputs_to_dict - - g = AttributeDict({}) - g.inputs_list = payload.get("inputs", []) - g.states_list = payload.get("state", []) - g.outputs_list = payload.get("outputs", []) - g.input_values = inputs_to_dict(g.inputs_list) - g.state_values = inputs_to_dict(g.states_list) - g.triggered_inputs = [ - {"prop_id": x, "value": g.input_values.get(x)} - for x in payload.get("changedPropIds", []) - ] - g.dash_response = response_adapter - g.updated_props = {} - g.dash_websocket = websocket_callback - - return g - - -SHUTDOWN_SIGNAL = "__shutdown__" - - -async def run_ws_sender( - send_text: Callable[[str], Any], outbound_queue: janus.Queue[str] -) -> None: - """Sender coroutine - drains queue and sends to WebSocket. - - This coroutine runs in the main event loop and handles sending - messages that are queued by worker threads via janus.Queue. - - Messages are pre-serialized strings (using to_json). - - Args: - send_text: Async function to send text data over WebSocket - outbound_queue: janus.Queue instance for receiving messages (strings) - """ - try: - while True: - msg = await outbound_queue.async_q.get() - if msg == SHUTDOWN_SIGNAL: - break - await send_text(msg) - except asyncio.CancelledError: - pass - - -def make_callback_done_handler( - outbound_queue: janus.Queue[str], - pending_callbacks: Dict[str, concurrent.futures.Future], - request_id: str, - renderer_id: str, -) -> Callable[[concurrent.futures.Future], None]: - """Create a done callback handler for executor futures. - - This factory creates a callback that sends the result back through - the WebSocket when an executor future completes. - - Args: - outbound_queue: janus.Queue for sending responses - pending_callbacks: Dict tracking pending callbacks for cleanup - request_id: The request ID for the callback response - renderer_id: The renderer ID for routing the response - - Returns: - A callback function suitable for Future.add_done_callback() - """ - - def on_done(f: concurrent.futures.Future) -> None: - try: - result = f.result() - outbound_queue.sync_q.put_nowait( - cast( - str, - to_json( - { - "type": "callback_response", - "rendererId": renderer_id, - "requestId": request_id, - "payload": result, - } - ), - ) - ) - except Exception as e: # pylint: disable=broad-exception-caught - outbound_queue.sync_q.put_nowait( - cast( - str, - to_json( - { - "type": "callback_response", - "rendererId": renderer_id, - "requestId": request_id, - "payload": { - "status": "error", - "message": str(e), - }, - } - ), - ) - ) - finally: - pending_callbacks.pop(request_id, None) - - return on_done - - -def run_callback_in_executor( - executor: ThreadPoolExecutor, - dash_app: "dash.Dash", - payload: dict, - ws_callback: DashWebsocketCallback, - response_adapter: ResponseAdapter, -) -> concurrent.futures.Future: - """Submit callback to executor for thread pool execution. - - This function creates a callback execution context and runs it - in a separate thread. Both sync and async callbacks are supported. - - Args: - executor: ThreadPoolExecutor to submit the task to - dash_app: The Dash application instance - payload: The callback payload from WebSocket message - ws_callback: WebSocket callback instance for set_prop/get_prop - response_adapter: Response adapter for the backend - - Returns: - Future representing the pending callback execution - """ - - def execute() -> dict: - try: - cb_ctx = create_ws_context(payload, response_adapter, ws_callback) - # pylint: disable=protected-access - func = dash_app._prepare_callback(cb_ctx, payload) - args = dash_app._inputs_to_vals( # pylint: disable=protected-access - cb_ctx.inputs_list + cb_ctx.states_list - ) - - ctx = copy_context() - partial_func = ( - dash_app._execute_callback( # pylint: disable=protected-access - func, args, cb_ctx.outputs_list, cb_ctx - ) - ) - - # Run in new event loop (handles both sync and async callbacks) - def run_callback(): - result = partial_func() - if inspect.iscoroutine(result): - return asyncio.run(result) - return result - - response_data = ctx.run(run_callback) - return {"status": "ok", "data": json.loads(response_data)} - - except PreventUpdate: - return {"status": "prevent_update"} - except Exception as e: # pylint: disable=broad-exception-caught - traceback.print_exc() - return {"status": "error", "message": str(e)} - - return executor.submit(execute) diff --git a/dash/backends/ws.py b/dash/backends/ws.py new file mode 100644 index 0000000000..db59fa1628 --- /dev/null +++ b/dash/backends/ws.py @@ -0,0 +1,303 @@ +"""WebSocket callback support for Dash backend implementations. + +This module provides the WebSocket callback infrastructure for real-time +bidirectional communication between Dash backends and the renderer. +""" +from __future__ import annotations + +import asyncio +import concurrent.futures +from concurrent.futures import ThreadPoolExecutor +import inspect +import json +import queue +import traceback +import uuid +from contextvars import copy_context +from typing import Any, Callable, Dict, TYPE_CHECKING, cast + +import janus + +from dash.exceptions import PreventUpdate +from dash._utils import to_json + +if TYPE_CHECKING: + import dash + from .base_server import ResponseAdapter + + +SHUTDOWN_SIGNAL = "__shutdown__" + + +class DashWebsocketCallback: + """WebSocket callback communication via queues. + + Provides methods for real-time bidirectional communication between + the server and renderer during callback execution. + + Uses janus.Queue for outbound messages (serialized with to_json) and + queue.Queue for get_props responses, enabling thread-safe communication + between worker threads and the main event loop. + """ + + def __init__( + self, + pending_get_props: Dict[str, queue.Queue[Any]], + renderer_id: str, + outbound_queue: janus.Queue[str], + ): + """Initialize the WebSocket callback interface. + + Args: + pending_get_props: Dict to track pending get_props requests. + Values are queue.Queue instances for blocking response retrieval. + renderer_id: The renderer ID for routing messages back to the correct client + outbound_queue: janus.Queue for thread-safe outbound messaging. + """ + self._pending_get_props = pending_get_props + self._renderer_id = renderer_id + self._outbound_queue = outbound_queue + + def _queue_message(self, msg: dict) -> None: + """Serialize and queue message for sending (thread-safe, non-blocking). + + Uses to_json for proper serialization of Dash components. + """ + self._outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) + + async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: + """Send immediate prop update to the client via WebSocket. + + Queues the message for the sender coroutine to send. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to update + value: The new value to set + """ + msg = { + "type": "set_props", + "rendererId": self._renderer_id, + "payload": {"componentId": component_id, "props": {prop_name: value}}, + } + self._queue_message(msg) + + async def get_prop( + self, component_id: str, prop_name: str, timeout: float = 30.0 + ) -> Any: + """Request current prop value from the client. + + Uses queue.Queue for blocking wait in worker thread. + + Args: + component_id: The component ID (string or stringified dict) + prop_name: The property name to retrieve + timeout: Timeout in seconds for waiting for response + + Returns: + The current value of the property from the client's state + """ + request_id = str(uuid.uuid4()) + msg = { + "type": "get_props_request", + "rendererId": self._renderer_id, + "requestId": request_id, + "payload": {"componentId": component_id, "properties": [prop_name]}, + } + + # Use standard queue.Queue for response + response_queue: queue.Queue = queue.Queue() + self._pending_get_props[request_id] = response_queue + + # Queue the outbound request via janus sync interface + self._queue_message(msg) + + # Wait for response (blocking is OK in worker thread) + try: + result = response_queue.get(timeout=timeout) + if result and prop_name in result: + return result[prop_name] + return None + except queue.Empty as exc: + raise TimeoutError( + f"Timeout waiting for {component_id}.{prop_name}" + ) from exc + finally: + self._pending_get_props.pop(request_id, None) + + +def create_ws_context( + payload: dict, + response_adapter: "ResponseAdapter", + websocket_callback: DashWebsocketCallback, +): + """Create callback context from WebSocket message. + + Args: + payload: The callback payload + response_adapter: The response adapter instance for the backend + websocket_callback: The websocket callback instance for the backend + + Returns: + AttributeDict with callback context + """ + # pylint: disable=import-outside-toplevel + from dash._utils import AttributeDict, inputs_to_dict + + g = AttributeDict({}) + g.inputs_list = payload.get("inputs", []) + g.states_list = payload.get("state", []) + g.outputs_list = payload.get("outputs", []) + g.input_values = inputs_to_dict(g.inputs_list) + g.state_values = inputs_to_dict(g.states_list) + g.triggered_inputs = [ + {"prop_id": x, "value": g.input_values.get(x)} + for x in payload.get("changedPropIds", []) + ] + g.dash_response = response_adapter + g.updated_props = {} + g.dash_websocket = websocket_callback + + return g + + +async def run_ws_sender( + send_text: Callable[[str], Any], outbound_queue: janus.Queue[str] +) -> None: + """Sender coroutine - drains queue and sends to WebSocket. + + This coroutine runs in the main event loop and handles sending + messages that are queued by worker threads via janus.Queue. + + Messages are pre-serialized strings (using to_json). + + Args: + send_text: Async function to send text data over WebSocket + outbound_queue: janus.Queue instance for receiving messages (strings) + """ + try: + while True: + msg = await outbound_queue.async_q.get() + if msg == SHUTDOWN_SIGNAL: + break + await send_text(msg) + except asyncio.CancelledError: + pass + + +def make_callback_done_handler( + outbound_queue: janus.Queue[str], + pending_callbacks: Dict[str, concurrent.futures.Future], + request_id: str, + renderer_id: str, +) -> Callable[[concurrent.futures.Future], None]: + """Create a done callback handler for executor futures. + + This factory creates a callback that sends the result back through + the WebSocket when an executor future completes. + + Args: + outbound_queue: janus.Queue for sending responses + pending_callbacks: Dict tracking pending callbacks for cleanup + request_id: The request ID for the callback response + renderer_id: The renderer ID for routing the response + + Returns: + A callback function suitable for Future.add_done_callback() + """ + + def on_done(f: concurrent.futures.Future) -> None: + try: + result = f.result() + outbound_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": result, + } + ), + ) + ) + except Exception as e: # pylint: disable=broad-exception-caught + outbound_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": { + "status": "error", + "message": str(e), + }, + } + ), + ) + ) + finally: + pending_callbacks.pop(request_id, None) + + return on_done + + +def run_callback_in_executor( + executor: ThreadPoolExecutor, + dash_app: "dash.Dash", + payload: dict, + ws_callback: DashWebsocketCallback, + response_adapter: "ResponseAdapter", +) -> concurrent.futures.Future: + """Submit callback to executor for thread pool execution. + + This function creates a callback execution context and runs it + in a separate thread. Both sync and async callbacks are supported. + + Args: + executor: ThreadPoolExecutor to submit the task to + dash_app: The Dash application instance + payload: The callback payload from WebSocket message + ws_callback: WebSocket callback instance for set_prop/get_prop + response_adapter: Response adapter for the backend + + Returns: + Future representing the pending callback execution + """ + + def execute() -> dict: + try: + cb_ctx = create_ws_context(payload, response_adapter, ws_callback) + # pylint: disable=protected-access + func = dash_app._prepare_callback(cb_ctx, payload) + args = dash_app._inputs_to_vals( # pylint: disable=protected-access + cb_ctx.inputs_list + cb_ctx.states_list + ) + + ctx = copy_context() + partial_func = ( + dash_app._execute_callback( # pylint: disable=protected-access + func, args, cb_ctx.outputs_list, cb_ctx + ) + ) + + # Run in new event loop (handles both sync and async callbacks) + def run_callback(): + result = partial_func() + if inspect.iscoroutine(result): + return asyncio.run(result) + return result + + response_data = ctx.run(run_callback) + return {"status": "ok", "data": json.loads(response_data)} + + except PreventUpdate: + return {"status": "prevent_update"} + except Exception as e: # pylint: disable=broad-exception-caught + traceback.print_exc() + return {"status": "error", "message": str(e)} + + return executor.submit(execute) From 86e3263cad7778e8d741af2fea3d820976974bbe Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 1 May 2026 13:54:55 -0400 Subject: [PATCH 150/166] add future annotations --- dash/backends/base_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dash/backends/base_server.py b/dash/backends/base_server.py index 8657f17535..52443d4104 100644 --- a/dash/backends/base_server.py +++ b/dash/backends/base_server.py @@ -3,6 +3,8 @@ This module provides abstract base classes and protocols that define the interface for different web server backends (Flask, Quart, FastAPI, etc.) to integrate with Dash. """ +from __future__ import annotations + from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from typing import ( From e4703eaf8c14c5ecb8c765e4e94991b38b07dab1 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 1 May 2026 14:48:09 -0400 Subject: [PATCH 151/166] rejects all pending callbacks when a DISCONNECTED message is received --- dash/dash-renderer/src/utils/workerClient.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dash/dash-renderer/src/utils/workerClient.ts b/dash/dash-renderer/src/utils/workerClient.ts index 6bc503b16b..f7cf4d613b 100644 --- a/dash/dash-renderer/src/utils/workerClient.ts +++ b/dash/dash-renderer/src/utils/workerClient.ts @@ -263,6 +263,11 @@ class WorkerClient { case WorkerMessageType.DISCONNECTED: this.isConnected = false; + // Reject all pending callbacks so loading states don't stay on forever + for (const [, pending] of this.pendingCallbacks) { + pending.reject(new Error('WebSocket disconnected')); + } + this.pendingCallbacks.clear(); if (this.onDisconnected) { this.onDisconnected(message.payload?.reason); } From 379e847bdd245a2abfc7820a6472f49806d235e2 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 1 May 2026 14:51:55 -0400 Subject: [PATCH 152/166] reset retryCount on connection --- @plotly/dash-websocket-worker/src/WebSocketManager.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/@plotly/dash-websocket-worker/src/WebSocketManager.ts b/@plotly/dash-websocket-worker/src/WebSocketManager.ts index 5f11086945..f7abe18dda 100644 --- a/@plotly/dash-websocket-worker/src/WebSocketManager.ts +++ b/@plotly/dash-websocket-worker/src/WebSocketManager.ts @@ -118,6 +118,8 @@ export class WebSocketManager { // Trigger reconnect if we have a server URL but aren't connected/connecting if (this.serverUrl && !this.isConnecting) { this.isConnecting = true; + // Reset retry count since this is user-initiated activity + this.retryCount = 0; this.createConnection(); } } From bf7e97faa23920c3ad6e57d336740443f3f2a8eb Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 1 May 2026 15:41:49 -0400 Subject: [PATCH 153/166] Version 4.2.0rc2 --- CHANGELOG.md | 5 +++++ dash/version.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1676b26acd..b719972684 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,12 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#3740](https://github.com/plotly/dash/pull/3740) Fix cannot tab into dropdowns in Safari - [#2462](https://github.com/plotly/dash/issues/2462) Allow `MATCH` in `Input`/`State` when the callback's `Output` has no wildcards (fixed-id Output, no Output, or `ALL`-only wildcard Output). `ALLSMALLER` still requires a corresponding `MATCH` in an Output. - [#3759](https://github.com/plotly/dash/pull/3759) Fix the issue where `Patch` objects cannot be updated via `set_props()` in `websocket` callback. Fix [#3742](https://github.com/plotly/dash/issues/3742) + +## [4.2.0rc1] - 2026-05-01 + +## Fixed - [#3759](https://github.com/plotly/dash/pull/3759) Fix the error when using `set_props()` to update component-type properties in the `websocket` callback. +- Add threadpool for running websocket callbacks. ## [4.2.0rc1] - 2026-04-13 diff --git a/dash/version.py b/dash/version.py index 6287d2530f..6af77684e6 100644 --- a/dash/version.py +++ b/dash/version.py @@ -1 +1 @@ -__version__ = "4.2.0rc1" +__version__ = "4.2.0rc2" From c35e6ad6b6df601643a6babc6f9397c5bb16462d Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 4 May 2026 13:13:46 -0400 Subject: [PATCH 154/166] better handling of disconnect for persistent callbacks --- dash/backends/_fastapi.py | 14 +++++++++++++- dash/backends/_quart.py | 18 +++++++++++++++--- dash/backends/ws.py | 30 +++++++++++++++++++++++++++++- dash/exceptions.py | 4 ++++ 4 files changed, 61 insertions(+), 5 deletions(-) diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index 1fc60cd703..9e06bd418c 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -48,6 +48,7 @@ run_callback_in_executor, make_callback_done_handler, SHUTDOWN_SIGNAL, + DISCONNECTED, ) from ._utils import format_traceback_html @@ -716,6 +717,8 @@ async def websocket_handler(websocket: WebSocket): outbound_queue: janus.Queue[str] = janus.Queue() # Track pending get_props requests with standard queue.Queue for responses pending_get_props: Dict[str, queue.Queue] = {} + # Shutdown event to signal connection closure to worker threads + shutdown_event = threading.Event() # Get thread pool executor executor = self.get_callback_executor() # Track pending callback futures @@ -758,7 +761,10 @@ async def websocket_handler(websocket: WebSocket): # Create WebSocket callback instance with outbound queue ws_cb = DashWebsocketCallback( - pending_get_props, renderer_id, outbound_queue + pending_get_props, + renderer_id, + outbound_queue, + shutdown_event, ) # Submit callback to executor @@ -777,6 +783,7 @@ async def websocket_handler(websocket: WebSocket): pending_callbacks, request_id, renderer_id, + shutdown_event, ) ) pending_callbacks[request_id] = future @@ -794,6 +801,11 @@ async def websocket_handler(websocket: WebSocket): except WebSocketDisconnect: pass # Clean disconnect finally: + # Signal shutdown to worker threads + shutdown_event.set() + # Unblock any threads waiting on get_prop responses + for response_queue in pending_get_props.values(): + response_queue.put_nowait(DISCONNECTED) # Signal sender to shutdown and cancel it outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) sender_task.cancel() diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 9441ba8bd3..0916f206fb 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -9,6 +9,7 @@ import asyncio import concurrent.futures import queue +import threading from urllib.parse import urlparse from logging.config import dictConfig @@ -52,6 +53,7 @@ run_callback_in_executor, make_callback_done_handler, SHUTDOWN_SIGNAL, + DISCONNECTED, ) from ._utils import format_traceback_html @@ -248,7 +250,6 @@ def has_request_context(self) -> bool: # pylint: disable=W0613 def run(self, dash_app: Dash, host: str, port: int, debug: bool, **kwargs: _t.Any): import signal # pylint: disable=import-outside-toplevel - import threading # pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel,import-error from hypercorn.config import Config @@ -521,7 +522,7 @@ def serve_websocket_callback(self, dash_app: "Dash"): allowed_origins = getattr(dash_app, "_websocket_allowed_origins", []) @self.server.websocket(ws_path) - async def websocket_handler(): + async def websocket_handler(): # pylint: disable=too-many-branches ws = websocket # Validate Origin header @@ -557,6 +558,8 @@ async def websocket_handler(): outbound_queue: janus.Queue[str] = janus.Queue() # Track pending get_props requests with standard queue.Queue for responses pending_get_props: Dict[str, queue.Queue] = {} + # Shutdown event to signal connection closure to worker threads + connection_shutdown_event = threading.Event() # Get thread pool executor executor = self.get_callback_executor() # Track pending callback futures @@ -604,7 +607,10 @@ async def websocket_handler(): # Create WebSocket callback instance with outbound queue ws_cb = DashWebsocketCallback( - pending_get_props, renderer_id, outbound_queue + pending_get_props, + renderer_id, + outbound_queue, + connection_shutdown_event, ) # Submit callback to executor @@ -623,6 +629,7 @@ async def websocket_handler(): pending_callbacks, request_id, renderer_id, + connection_shutdown_event, ) ) pending_callbacks[request_id] = future @@ -643,6 +650,11 @@ async def websocket_handler(): pass # Other exceptions treated as disconnect finally: self._active_websockets.discard(ws_obj) + # Signal shutdown to worker threads + connection_shutdown_event.set() + # Unblock any threads waiting on get_prop responses + for response_queue in pending_get_props.values(): + response_queue.put_nowait(DISCONNECTED) # Signal sender to shutdown and cancel it outbound_queue.sync_q.put_nowait(SHUTDOWN_SIGNAL) sender_task.cancel() diff --git a/dash/backends/ws.py b/dash/backends/ws.py index db59fa1628..75321c7189 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -11,6 +11,7 @@ import inspect import json import queue +import threading import traceback import uuid from contextvars import copy_context @@ -18,7 +19,7 @@ import janus -from dash.exceptions import PreventUpdate +from dash.exceptions import PreventUpdate, WebsocketDisconnected from dash._utils import to_json if TYPE_CHECKING: @@ -27,6 +28,7 @@ SHUTDOWN_SIGNAL = "__shutdown__" +DISCONNECTED = "__disconnected__" class DashWebsocketCallback: @@ -45,6 +47,7 @@ def __init__( pending_get_props: Dict[str, queue.Queue[Any]], renderer_id: str, outbound_queue: janus.Queue[str], + shutdown_event: "threading.Event", ): """Initialize the WebSocket callback interface. @@ -53,16 +56,26 @@ def __init__( Values are queue.Queue instances for blocking response retrieval. renderer_id: The renderer ID for routing messages back to the correct client outbound_queue: janus.Queue for thread-safe outbound messaging. + shutdown_event: Event signaling the websocket connection has closed. """ self._pending_get_props = pending_get_props self._renderer_id = renderer_id self._outbound_queue = outbound_queue + self._shutdown_event = shutdown_event + + @property + def is_shutdown(self) -> bool: + """Check if the websocket connection has been shut down.""" + return self._shutdown_event.is_set() def _queue_message(self, msg: dict) -> None: """Serialize and queue message for sending (thread-safe, non-blocking). Uses to_json for proper serialization of Dash components. + Does nothing if the connection has been shut down. """ + if self._shutdown_event.is_set(): + return self._outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: @@ -96,7 +109,14 @@ async def get_prop( Returns: The current value of the property from the client's state + + Raises: + WebsocketDisconnected: If the websocket connection has been closed. + TimeoutError: If the response doesn't arrive within the timeout. """ + if self._shutdown_event.is_set(): + raise WebsocketDisconnected() + request_id = str(uuid.uuid4()) msg = { "type": "get_props_request", @@ -115,6 +135,8 @@ async def get_prop( # Wait for response (blocking is OK in worker thread) try: result = response_queue.get(timeout=timeout) + if result == DISCONNECTED: + raise WebsocketDisconnected() if result and prop_name in result: return result[prop_name] return None @@ -190,6 +212,7 @@ def make_callback_done_handler( pending_callbacks: Dict[str, concurrent.futures.Future], request_id: str, renderer_id: str, + shutdown_event: threading.Event, ) -> Callable[[concurrent.futures.Future], None]: """Create a done callback handler for executor futures. @@ -201,6 +224,7 @@ def make_callback_done_handler( pending_callbacks: Dict tracking pending callbacks for cleanup request_id: The request ID for the callback response renderer_id: The renderer ID for routing the response + shutdown_event: Event signaling the websocket connection has closed. Returns: A callback function suitable for Future.add_done_callback() @@ -208,6 +232,8 @@ def make_callback_done_handler( def on_done(f: concurrent.futures.Future) -> None: try: + if shutdown_event.is_set(): + return result = f.result() outbound_queue.sync_q.put_nowait( cast( @@ -223,6 +249,8 @@ def on_done(f: concurrent.futures.Future) -> None: ) ) except Exception as e: # pylint: disable=broad-exception-caught + if shutdown_event.is_set(): + return outbound_queue.sync_q.put_nowait( cast( str, diff --git a/dash/exceptions.py b/dash/exceptions.py index 40e882c409..9366f9359c 100644 --- a/dash/exceptions.py +++ b/dash/exceptions.py @@ -117,3 +117,7 @@ class AppNotFoundError(DashException): class WebSocketCallbackError(CallbackException): pass + + +class WebsocketDisconnected(CallbackException): + pass From ceaea5bb93bc5065089f4f20df8e9c7bbf6c25fd Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 4 May 2026 13:44:23 -0400 Subject: [PATCH 155/166] silence websocket disconnected --- dash/backends/ws.py | 2 ++ dash/dash-renderer/src/utils/workerClient.ts | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dash/backends/ws.py b/dash/backends/ws.py index 75321c7189..041241823e 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -324,6 +324,8 @@ def run_callback(): except PreventUpdate: return {"status": "prevent_update"} + except WebsocketDisconnected: + return {"status": "prevent_update"} except Exception as e: # pylint: disable=broad-exception-caught traceback.print_exc() return {"status": "error", "message": str(e)} diff --git a/dash/dash-renderer/src/utils/workerClient.ts b/dash/dash-renderer/src/utils/workerClient.ts index f7cf4d613b..c16594f8ef 100644 --- a/dash/dash-renderer/src/utils/workerClient.ts +++ b/dash/dash-renderer/src/utils/workerClient.ts @@ -138,9 +138,9 @@ class WorkerClient { this.connectionPromise = null; this.connectionResolve = null; - // Reject any pending callbacks + // Resolve pending callbacks with prevent_update so loading states clear for (const [, pending] of this.pendingCallbacks) { - pending.reject(new Error('Worker disconnected')); + pending.resolve({status: 'prevent_update'}); } this.pendingCallbacks.clear(); } @@ -263,9 +263,9 @@ class WorkerClient { case WorkerMessageType.DISCONNECTED: this.isConnected = false; - // Reject all pending callbacks so loading states don't stay on forever + // Resolve pending callbacks with prevent_update so loading states clear for (const [, pending] of this.pendingCallbacks) { - pending.reject(new Error('WebSocket disconnected')); + pending.resolve({status: 'prevent_update'}); } this.pendingCallbacks.clear(); if (this.onDisconnected) { From 6a72fe9445fac5d952d5068cdfbb68858922cdd2 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 5 May 2026 15:37:11 -0400 Subject: [PATCH 156/166] reset retry count onn connection --- @plotly/dash-websocket-worker/src/WebSocketManager.ts | 3 +++ 1 file changed, 3 insertions(+) diff --git a/@plotly/dash-websocket-worker/src/WebSocketManager.ts b/@plotly/dash-websocket-worker/src/WebSocketManager.ts index f7abe18dda..2d32c7e8ca 100644 --- a/@plotly/dash-websocket-worker/src/WebSocketManager.ts +++ b/@plotly/dash-websocket-worker/src/WebSocketManager.ts @@ -79,6 +79,9 @@ export class WebSocketManager { this.serverUrl = serverUrl; this.isConnecting = true; + // Reset retry count since this is an explicit connect request + // (e.g., from hot reload reconnection) + this.retryCount = 0; this.createConnection(); } From 5825c2d7b9e953742486356dd719c4741e601a7a Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 5 May 2026 15:38:21 -0400 Subject: [PATCH 157/166] fastapi requirement uvicorn[standard] --- requirements/fastapi.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/fastapi.txt b/requirements/fastapi.txt index 97dc7cd8c1..364e2ee48e 100644 --- a/requirements/fastapi.txt +++ b/requirements/fastapi.txt @@ -1,2 +1,2 @@ fastapi -uvicorn +uvicorn[standard] From 7ba747a229ae79b2c13f07eeb61266c63a0ca3e8 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 8 May 2026 11:47:33 -0400 Subject: [PATCH 158/166] fix infinite loading with persitent keyword on callbacks --- dash/_callback.py | 9 +++ dash/dash-renderer/src/observers/isLoading.ts | 7 ++- dash/dash-renderer/src/types/callbacks.ts | 1 + .../renderer/test_loading_states.py | 58 +++++++++++++++++++ 4 files changed, 74 insertions(+), 1 deletion(-) diff --git a/dash/_callback.py b/dash/_callback.py index 718a016d82..f5f64970b0 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -78,6 +78,7 @@ def callback( optional: Optional[bool] = False, hidden: Optional[bool] = None, websocket: Optional[bool] = False, + persistent: Optional[bool] = False, **_kwargs, ) -> Callable[..., Any]: """ @@ -172,6 +173,10 @@ def callback( The endpoint is relative to the Dash app's base URL. Note that the endpoint will not appear in the list of registered callbacks in the Dash devtools. + :param persistent: + If True, this callback will not show the "Updating..." title while + running. Useful for persistent WebSocket callbacks that stay active + for long periods without requiring a loading indicator. """ background_spec: Any = None @@ -230,6 +235,7 @@ def callback( optional=optional, hidden=hidden, websocket=websocket, + persistent=persistent, ) @@ -278,6 +284,7 @@ def insert_callback( optional=False, hidden=None, websocket=False, + persistent=False, ) -> str: if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks @@ -304,6 +311,7 @@ def insert_callback( "optional": optional, "hidden": hidden, "websocket": websocket, + "persistent": persistent, } if running: callback_spec["running"] = running @@ -658,6 +666,7 @@ def register_callback( optional=_kwargs.get("optional", False), hidden=_kwargs.get("hidden", None), websocket=_kwargs.get("websocket", False), + persistent=_kwargs.get("persistent", False), ) # pylint: disable=too-many-locals diff --git a/dash/dash-renderer/src/observers/isLoading.ts b/dash/dash-renderer/src/observers/isLoading.ts index 687f607378..cc3bf193b8 100644 --- a/dash/dash-renderer/src/observers/isLoading.ts +++ b/dash/dash-renderer/src/observers/isLoading.ts @@ -9,7 +9,12 @@ const observer: IStoreObserverDefinition = { const pendingCallbacks = getPendingCallbacks(callbacks); - const next = Boolean(pendingCallbacks.length); + // Filter out persistent callbacks - they shouldn't trigger the loading indicator + const nonPersistentCallbacks = pendingCallbacks.filter( + cb => !cb.callback.persistent + ); + + const next = Boolean(nonPersistentCallbacks.length); if (isLoading !== next) { dispatch(setIsLoading(next)); diff --git a/dash/dash-renderer/src/types/callbacks.ts b/dash/dash-renderer/src/types/callbacks.ts index 38a5d7d82f..5f963463d2 100644 --- a/dash/dash-renderer/src/types/callbacks.ts +++ b/dash/dash-renderer/src/types/callbacks.ts @@ -16,6 +16,7 @@ export interface ICallbackDefinition { running: any; no_output?: boolean; websocket?: boolean; + persistent?: boolean; } export interface ICallbackProperty { diff --git a/tests/integration/renderer/test_loading_states.py b/tests/integration/renderer/test_loading_states.py index 169b505ed1..9818902f61 100644 --- a/tests/integration/renderer/test_loading_states.py +++ b/tests/integration/renderer/test_loading_states.py @@ -298,3 +298,61 @@ def update(n): dash_duo.wait_for_text_to_equal("#final-output", "1") until(lambda: dash_duo.driver.title == "Page 1", timeout=1) + + +def test_rdls005_persistent_callback_no_update_title(dash_duo): + """Test that persistent=True callbacks don't trigger the 'Updating...' title.""" + lock = Lock() + + app = Dash(__name__) + + app.layout = html.Div( + children=[ + html.H3("Test persistent callback"), + html.Button("Persistent", id="persistent-btn", n_clicks=0), + html.Button("Regular", id="regular-btn", n_clicks=0), + html.Div(id="persistent-output"), + html.Div(id="regular-output"), + ] + ) + + @app.callback( + Output("persistent-output", "children"), + Input("persistent-btn", "n_clicks"), + persistent=True, + ) + def persistent_update(n): + with lock: + return f"Persistent: {n}" + + @app.callback( + Output("regular-output", "children"), + Input("regular-btn", "n_clicks"), + ) + def regular_update(n): + with lock: + return f"Regular: {n}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#persistent-output", "Persistent: 0") + dash_duo.wait_for_text_to_equal("#regular-output", "Regular: 0") + + # Verify title is "Dash" after initial load + until(lambda: dash_duo.driver.title == "Dash", timeout=1) + + # Test that persistent callback does NOT change title to "Updating..." + with lock: + dash_duo.find_element("#persistent-btn").click() + # Title should remain "Dash" even while callback is running + until(lambda: dash_duo.driver.title == "Dash", timeout=1) + + dash_duo.wait_for_text_to_equal("#persistent-output", "Persistent: 1") + + # Test that regular callback DOES change title to "Updating..." + with lock: + dash_duo.find_element("#regular-btn").click() + until(lambda: dash_duo.driver.title == "Updating...", timeout=1) + + dash_duo.wait_for_text_to_equal("#regular-output", "Regular: 1") + # Title should revert after callback completes + until(lambda: dash_duo.driver.title == "Dash", timeout=1) From 2bdb5daad53abda73c0e8eda50280b1542d81e01 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 8 May 2026 11:59:11 -0400 Subject: [PATCH 159/166] fix initial no output callbacks --- dash/dash-renderer/src/actions/index.js | 38 +++++++++++++--- .../callbacks/test_basic_callback.py | 44 +++++++++++++++++++ 2 files changed, 75 insertions(+), 7 deletions(-) diff --git a/dash/dash-renderer/src/actions/index.js b/dash/dash-renderer/src/actions/index.js index 6169c4f65e..c1fb5efec4 100644 --- a/dash/dash-renderer/src/actions/index.js +++ b/dash/dash-renderer/src/actions/index.js @@ -5,7 +5,12 @@ import {getAppState} from '../reducers/constants'; import {getAction} from './constants'; import * as cookie from 'cookie'; import {validateCallbacksToLayout} from './dependencies'; -import {includeObservers, getLayoutCallbacks} from './dependencies_ts'; +import { + includeObservers, + getLayoutCallbacks, + makeResolvedCallback, + resolveDeps +} from './dependencies_ts'; import {computePaths, getPath} from './paths'; import {recordUiEdit} from '../persistence'; @@ -95,13 +100,32 @@ function triggerDefaultState(dispatch, getState) { ); } - dispatch( - addRequestedCallbacks( - getLayoutCallbacks(graphs, paths, layout.components, { - outputsOnly: true - }) - ) + const layoutCallbacks = getLayoutCallbacks( + graphs, + paths, + layout.components, + { + outputsOnly: true + } ); + + // Also include no-output callbacks whose inputs are in the layout + const noOutputCallbacks = (graphs.callbacks || []) + .filter(cb => cb.noOutput && !cb.prevent_initial_call) + .map(cb => { + const resolved = makeResolvedCallback(cb, resolveDeps(), ''); + resolved.initialCall = true; + return resolved; + }) + .filter(cb => { + // Check if any input is in the layout + const inputs = cb.getInputs(paths); + return inputs.some(inp => + Array.isArray(inp) ? inp.length > 0 : inp + ); + }); + + dispatch(addRequestedCallbacks([...layoutCallbacks, ...noOutputCallbacks])); } export const redo = moveHistory('REDO'); diff --git a/tests/integration/callbacks/test_basic_callback.py b/tests/integration/callbacks/test_basic_callback.py index 87ce3507e7..e50876a90c 100644 --- a/tests/integration/callbacks/test_basic_callback.py +++ b/tests/integration/callbacks/test_basic_callback.py @@ -917,3 +917,47 @@ def on_click(_): assert error.text == error_title for error_text in dash_duo.find_elements(".dash-backend-error"): assert all(line in error_text for line in error_message) + + +def test_cbsc022_no_output_callback_initial_call(dash_duo): + """Test that no-output callbacks fire on initial load.""" + + call_count = Value("i", 0) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button("Click", id="btn", n_clicks=0), + html.Div(id="output"), + ] + ) + + @app.callback( + Input("btn", "n_clicks"), + ) + def no_output_callback(n_clicks): + call_count.value += 1 + + @app.callback( + Output("output", "children"), + Input("btn", "n_clicks"), + ) + def with_output_callback(n_clicks): + return f"Clicks: {n_clicks}" + + dash_duo.start_server(app) + + # Wait for initial render + dash_duo.wait_for_text_to_equal("#output", "Clicks: 0") + + # No-output callback should have fired on initial load + assert call_count.value == 1, "no-output callback should fire on initial load" + + # Click button + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "Clicks: 1") + + # No-output callback should have fired again + assert call_count.value == 2, "no-output callback should fire on click" + + assert dash_duo.get_logs() == [] From 9c76a17638e956bbbdee8ab049e3392b27062071 Mon Sep 17 00:00:00 2001 From: philippe Date: Fri, 8 May 2026 13:11:32 -0400 Subject: [PATCH 160/166] support no output/no inputs callbacks --- dash/_utils.py | 14 ++ .../dash-renderer/src/actions/dependencies.js | 7 +- .../src/actions/dependencies_ts.ts | 14 +- dash/dash-renderer/src/actions/index.js | 35 ++++- .../callbacks/test_basic_callback.py | 147 ++++++++++++++++++ 5 files changed, 209 insertions(+), 8 deletions(-) diff --git a/dash/_utils.py b/dash/_utils.py index 5e241fe21d..85ff9ab073 100644 --- a/dash/_utils.py +++ b/dash/_utils.py @@ -165,6 +165,20 @@ def _concat(x): if no_output: # No output will hash the inputs. + # For no-input callbacks, also include the call site to make each unique + if not inputs: + # Get the call site of the @callback decorator + stack = inspect.stack() + # Walk up the stack to find the actual callback call site + # (skip internal dash package frames) + dash_package_path = os.path.dirname(__file__) + for frame_info in stack: + # Skip frames from within the dash package itself + if not frame_info.filename.startswith(dash_package_path): + call_site = f"{frame_info.filename}:{frame_info.lineno}" + return hashlib.sha256(call_site.encode("utf-8")).hexdigest() + # Fallback to empty hash if no external frame found + return _hash_inputs() return _hash_inputs() if isinstance(output, (list, tuple)): diff --git a/dash/dash-renderer/src/actions/dependencies.js b/dash/dash-renderer/src/actions/dependencies.js index 7b5d1665f0..fa29199a1d 100644 --- a/dash/dash-renderer/src/actions/dependencies.js +++ b/dash/dash-renderer/src/actions/dependencies.js @@ -224,14 +224,17 @@ function validateDependencies(parsedDependencies, dispatchError) { 'In the callback for output(s):\n ' + outputs.map(combineIdAndProp).join('\n '); - if (!inputs.length) { + if (!inputs.length && dep.prevent_initial_call) { dispatchError('A callback is missing Inputs', [ head, 'there are no `Input` elements.', 'Without `Input` elements, it will never get called.', '', 'Subscribing to `Input` components will cause the', - 'callback to be called whenever their values change.' + 'callback to be called whenever their values change.', + '', + 'If you want a callback without inputs that fires on initial load,', + 'set prevent_initial_call=False.' ]); } diff --git a/dash/dash-renderer/src/actions/dependencies_ts.ts b/dash/dash-renderer/src/actions/dependencies_ts.ts index 33f968cf91..4056cdeac1 100644 --- a/dash/dash-renderer/src/actions/dependencies_ts.ts +++ b/dash/dash-renderer/src/actions/dependencies_ts.ts @@ -352,12 +352,18 @@ export const getLayoutCallbacks = ( export const getUniqueIdentifier = ({ anyVals, - callback: {inputs, outputs, state} -}: ICallback): string => - concat( - map(combineIdAndProp, [...inputs, ...outputs, ...state]), + callback: {inputs, outputs, state, output} +}: ICallback): string => { + const idParts = map(combineIdAndProp, [...inputs, ...outputs, ...state]); + // For no-output callbacks, include the output hash to ensure uniqueness + if (outputs.length === 0 && output) { + idParts.push(output); + } + return concat( + idParts, Array.isArray(anyVals) ? anyVals : anyVals === '' ? [] : [anyVals] ).join(','); +}; export function includeObservers( id: any, diff --git a/dash/dash-renderer/src/actions/index.js b/dash/dash-renderer/src/actions/index.js index c1fb5efec4..51229767ba 100644 --- a/dash/dash-renderer/src/actions/index.js +++ b/dash/dash-renderer/src/actions/index.js @@ -109,7 +109,7 @@ function triggerDefaultState(dispatch, getState) { } ); - // Also include no-output callbacks whose inputs are in the layout + // Also include no-output callbacks whose inputs are in the layout (or have no inputs) const noOutputCallbacks = (graphs.callbacks || []) .filter(cb => cb.noOutput && !cb.prevent_initial_call) .map(cb => { @@ -118,6 +118,10 @@ function triggerDefaultState(dispatch, getState) { return resolved; }) .filter(cb => { + // If no inputs, always include (fires once on initial load) + if (cb.callback.inputs.length === 0) { + return true; + } // Check if any input is in the layout const inputs = cb.getInputs(paths); return inputs.some(inp => @@ -125,7 +129,34 @@ function triggerDefaultState(dispatch, getState) { ); }); - dispatch(addRequestedCallbacks([...layoutCallbacks, ...noOutputCallbacks])); + // Also include no-input callbacks (with outputs) that should fire on initial load + const noInputCallbacks = (graphs.callbacks || []) + .filter( + cb => + !cb.noOutput && + cb.inputs.length === 0 && + !cb.prevent_initial_call + ) + .map(cb => { + const resolved = makeResolvedCallback(cb, resolveDeps(), ''); + resolved.initialCall = true; + return resolved; + }) + .filter(cb => { + // Check if any output is in the layout + const outputs = cb.getOutputs(paths); + return outputs.some(out => + Array.isArray(out) ? out.length > 0 : out + ); + }); + + dispatch( + addRequestedCallbacks([ + ...layoutCallbacks, + ...noOutputCallbacks, + ...noInputCallbacks + ]) + ); } export const redo = moveHistory('REDO'); diff --git a/tests/integration/callbacks/test_basic_callback.py b/tests/integration/callbacks/test_basic_callback.py index e50876a90c..6e724c186f 100644 --- a/tests/integration/callbacks/test_basic_callback.py +++ b/tests/integration/callbacks/test_basic_callback.py @@ -961,3 +961,150 @@ def with_output_callback(n_clicks): assert call_count.value == 2, "no-output callback should fire on click" assert dash_duo.get_logs() == [] + + +def test_cbsc023_no_input_callback_initial_call(dash_duo): + """Test that no-input callbacks fire on initial load (issue #3411).""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Store(id="store", data="initial"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + State("store", "data"), + ) + def no_input_callback(data): + return f"Data: {data}" + + dash_duo.start_server(app) + + # No-input callback should fire on initial load + dash_duo.wait_for_text_to_equal("#output", "Data: initial") + + assert dash_duo.get_logs() == [] + + +def test_cbsc024_no_input_no_output_callback_initial_call(dash_duo): + """Test that callbacks with no input and no output fire on initial load.""" + from multiprocessing import Value + + call_count = Value("i", 0) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="output", children="Waiting..."), + ] + ) + + @app.callback() + def no_input_no_output_callback(): + call_count.value += 1 + print(f"No-input no-output callback fired: {call_count.value}") + + dash_duo.start_server(app) + + # Give it time to fire + dash_duo.wait_for_element("#output") + time.sleep(0.5) + + # Callback should have fired on initial load + assert ( + call_count.value == 1 + ), "no-input no-output callback should fire on initial load" + + assert dash_duo.get_logs() == [] + + +def test_cbsc025_multiple_no_input_no_output_callbacks(dash_duo): + """Test that multiple no-input no-output callbacks all fire on initial load.""" + from multiprocessing import Value + + call_count_1 = Value("i", 0) + call_count_2 = Value("i", 0) + call_count_3 = Value("i", 0) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="output", children="Waiting..."), + ] + ) + + @app.callback() + def first_callback(): + call_count_1.value += 1 + + @app.callback() + def second_callback(): + call_count_2.value += 1 + + @app.callback() + def third_callback(): + call_count_3.value += 1 + + dash_duo.start_server(app) + + # Give callbacks time to fire + dash_duo.wait_for_element("#output") + time.sleep(0.5) + + # All callbacks should have fired on initial load + assert call_count_1.value == 1, "first callback should fire" + assert call_count_2.value == 1, "second callback should fire" + assert call_count_3.value == 1, "third callback should fire" + + assert dash_duo.get_logs() == [] + + +def test_cbsc026_no_input_with_duplicate_outputs(dash_duo): + """Test no-input callbacks with duplicate outputs.""" + from multiprocessing import Value + + call_count_1 = Value("i", 0) + call_count_2 = Value("i", 0) + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Store(id="store", data="initial"), + html.Div(id="output", children="Waiting..."), + ] + ) + + @app.callback( + Output("output", "children"), + State("store", "data"), + ) + def first_no_input_callback(data): + call_count_1.value += 1 + return f"First: {data}" + + @app.callback( + Output("output", "children", allow_duplicate=True), + State("store", "data"), + prevent_initial_call="initial_duplicate", + ) + def second_no_input_callback(data): + call_count_2.value += 1 + return f"Second: {data}" + + dash_duo.start_server(app) + + # Give callbacks time to fire + dash_duo.wait_for_element("#output") + time.sleep(0.5) + + # Both callbacks should have fired on initial load + assert call_count_1.value == 1, "first no-input callback should fire" + assert call_count_2.value == 1, "second no-input callback should fire" + + # Output should contain result from one of the callbacks + output_text = dash_duo.find_element("#output").text + assert "initial" in output_text, "output should contain data from store" + + assert dash_duo.get_logs() == [] From b1a6f0f659058a6fc63719008c8366542093f940 Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 11 May 2026 10:09:44 -0400 Subject: [PATCH 161/166] adapt test for new behavior --- .../devtools/test_callback_validation.py | 18 +++++++++++++++++- tests/integration/renderer/test_render_type.py | 1 + 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/integration/devtools/test_callback_validation.py b/tests/integration/devtools/test_callback_validation.py index eaee814980..8501821886 100644 --- a/tests/integration/devtools/test_callback_validation.py +++ b/tests/integration/devtools/test_callback_validation.py @@ -69,10 +69,26 @@ def check_errors(dash_duo, specs): def test_dvcv001_blank(dash_duo): + """No-input no-output callbacks are allowed when prevent_initial_call=False (default).""" app = Dash(__name__) app.layout = html.Div() - @app.callback([], []) + @app.callback() + def x(): + pass # No-output callbacks shouldn't return anything + + dash_duo.start_server(app, **debugging) + # No errors expected - no-input callbacks are allowed when prevent_initial_call=False + dash_duo.wait_for_element("div") + assert dash_duo.get_logs() == [] + + +def test_dvcv001b_blank_prevent_initial_call(dash_duo): + """No-input callbacks should error when prevent_initial_call=True.""" + app = Dash(__name__) + app.layout = html.Div() + + @app.callback([], [], prevent_initial_call=True) def x(): return 42 diff --git a/tests/integration/renderer/test_render_type.py b/tests/integration/renderer/test_render_type.py index 17a6cfbae3..417be6e586 100644 --- a/tests/integration/renderer/test_render_type.py +++ b/tests/integration/renderer/test_render_type.py @@ -25,6 +25,7 @@ def test_rtype001_rendertype(dash_duo): dash_clientside.set_props('render_test', {n_clicks: 20}) }""", Input("clientside_render", "n_clicks"), + prevent_initial_call=True, ) @app.callback( From 512e1191ee1eb97a919135c64e0b8d221b24932d Mon Sep 17 00:00:00 2001 From: philippe Date: Mon, 11 May 2026 15:47:39 -0400 Subject: [PATCH 162/166] rename ctx.get_websocket to ctx.websocket --- .ai/ARCHITECTURE.md | 6 +++--- dash/_callback_context.py | 2 +- tests/websocket/test_ws_basic.py | 2 +- tests/websocket/test_ws_props.py | 4 ++-- tests/websocket/test_ws_quart.py | 2 +- wsapp.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.ai/ARCHITECTURE.md b/.ai/ARCHITECTURE.md index db40397782..84f553978c 100644 --- a/.ai/ARCHITECTURE.md +++ b/.ai/ARCHITECTURE.md @@ -963,7 +963,7 @@ Renderer SharedWorker Server ### Long-Running Callbacks with set_props/get_props -WebSocket callbacks can stream updates to the client during execution using `set_props()` and read current component values using `ctx.get_websocket()`: +WebSocket callbacks can stream updates to the client during execution using `set_props()` and read current component values using `ctx.websocket`: ```python import asyncio @@ -975,7 +975,7 @@ from dash import callback, Output, Input, set_props, ctx prevent_initial_call=True ) async def long_running_task(n_clicks): - ws = ctx.get_websocket() + ws = ctx.websocket if not ws: return "WebSocket not available" @@ -993,7 +993,7 @@ async def long_running_task(n_clicks): **API:** - `set_props(component_id, props_dict)` - Stream prop updates immediately to client -- `ctx.get_websocket()` - Get WebSocket interface (returns `None` if not in WS context) +- `ctx.websocket` - Get WebSocket interface (returns `None` if not in WS context) - `await ws.get_prop(component_id, prop_name)` - Read current prop value from client - `await ws.set_prop(component_id, prop_name, value)` - Set single prop (async version) - `await ws.close(code, reason)` - Close the WebSocket connection diff --git a/dash/_callback_context.py b/dash/_callback_context.py index e03f343129..809def45ef 100644 --- a/dash/_callback_context.py +++ b/dash/_callback_context.py @@ -328,7 +328,7 @@ def custom_data(self): @property @has_context - def get_websocket(self) -> typing.Optional[DashWebsocketCallback]: + def websocket(self) -> typing.Optional[DashWebsocketCallback]: """Get WebSocket interface if running in WebSocket context. Returns the DashWebsocketCallback instance if the callback is being diff --git a/tests/websocket/test_ws_basic.py b/tests/websocket/test_ws_basic.py index 935d633339..1d74706a68 100644 --- a/tests/websocket/test_ws_basic.py +++ b/tests/websocket/test_ws_basic.py @@ -179,7 +179,7 @@ def test_ws005_websocket_context_available(dash_duo): def check_context(n_clicks): if not n_clicks: return "Click to check" - ws = ctx.get_websocket + ws = ctx.websocket if ws is not None: return "WebSocket context available" return "No WebSocket context" diff --git a/tests/websocket/test_ws_props.py b/tests/websocket/test_ws_props.py index e800668ae8..a86402954d 100644 --- a/tests/websocket/test_ws_props.py +++ b/tests/websocket/test_ws_props.py @@ -185,7 +185,7 @@ async def read_prop(n): from dash import ctx - ws = ctx.get_websocket + ws = ctx.websocket if ws: value = await ws.get_prop("source", "children") return f"Read: {value}" @@ -219,7 +219,7 @@ async def set_via_ws(n): from dash import ctx - ws = ctx.get_websocket + ws = ctx.websocket if ws: await ws.set_prop("target", "children", f"Set via WebSocket {n}") return "Set complete" diff --git a/tests/websocket/test_ws_quart.py b/tests/websocket/test_ws_quart.py index 30e33b329c..3d40493ba5 100644 --- a/tests/websocket/test_ws_quart.py +++ b/tests/websocket/test_ws_quart.py @@ -176,7 +176,7 @@ def test_wsq005_websocket_context_available_quart(dash_duo): def check_context(n_clicks): if not n_clicks: return "Click to check" - ws = ctx.get_websocket + ws = ctx.websocket if ws is not None: return "WebSocket context available" return "No WebSocket context" diff --git a/wsapp.py b/wsapp.py index 98b2db2f38..ade2f80d39 100644 --- a/wsapp.py +++ b/wsapp.py @@ -82,7 +82,7 @@ def update_with_set_props(n_clicks): @callback(Output("output-6", "children"), Input("btn-3", "n_clicks")) def check_websocket_context(n_clicks): if n_clicks > 0: - ws = ctx.get_websocket + ws = ctx.websocket if ws is not None: return f"WebSocket context is available! (click {n_clicks})" else: From 995cb1db2ac89789c75e19b9b1e99b3031598981 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 12 May 2026 09:20:31 -0400 Subject: [PATCH 163/166] reduce initial special callbacks --- dash/dash-renderer/src/actions/index.js | 86 ++++++++++++------------- 1 file changed, 41 insertions(+), 45 deletions(-) diff --git a/dash/dash-renderer/src/actions/index.js b/dash/dash-renderer/src/actions/index.js index 51229767ba..e595f79f9a 100644 --- a/dash/dash-renderer/src/actions/index.js +++ b/dash/dash-renderer/src/actions/index.js @@ -109,54 +109,50 @@ function triggerDefaultState(dispatch, getState) { } ); - // Also include no-output callbacks whose inputs are in the layout (or have no inputs) - const noOutputCallbacks = (graphs.callbacks || []) - .filter(cb => cb.noOutput && !cb.prevent_initial_call) - .map(cb => { - const resolved = makeResolvedCallback(cb, resolveDeps(), ''); - resolved.initialCall = true; - return resolved; - }) - .filter(cb => { - // If no inputs, always include (fires once on initial load) - if (cb.callback.inputs.length === 0) { - return true; + // Also include no-output and no-input callbacks that should fire on initial load + const specialCallbacks = (graphs.callbacks || []).reduce((acc, cb) => { + if (cb.prevent_initial_call) { + return acc; + } + + const isNoOutput = cb.noOutput; + const isNoInput = !cb.noOutput && cb.inputs.length === 0; + + if (!isNoOutput && !isNoInput) { + return acc; + } + + const resolved = makeResolvedCallback(cb, resolveDeps(), ''); + resolved.initialCall = true; + + if (isNoOutput) { + // No-output: include if no inputs or any input is in layout + if (cb.inputs.length === 0) { + acc.push(resolved); + } else { + const inputs = resolved.getInputs(paths); + if ( + inputs.some(inp => + Array.isArray(inp) ? inp.length > 0 : inp + ) + ) { + acc.push(resolved); + } } - // Check if any input is in the layout - const inputs = cb.getInputs(paths); - return inputs.some(inp => - Array.isArray(inp) ? inp.length > 0 : inp - ); - }); + } else { + // No-input: include if any output is in layout + const outputs = resolved.getOutputs(paths); + if ( + outputs.some(out => (Array.isArray(out) ? out.length > 0 : out)) + ) { + acc.push(resolved); + } + } - // Also include no-input callbacks (with outputs) that should fire on initial load - const noInputCallbacks = (graphs.callbacks || []) - .filter( - cb => - !cb.noOutput && - cb.inputs.length === 0 && - !cb.prevent_initial_call - ) - .map(cb => { - const resolved = makeResolvedCallback(cb, resolveDeps(), ''); - resolved.initialCall = true; - return resolved; - }) - .filter(cb => { - // Check if any output is in the layout - const outputs = cb.getOutputs(paths); - return outputs.some(out => - Array.isArray(out) ? out.length > 0 : out - ); - }); + return acc; + }, []); - dispatch( - addRequestedCallbacks([ - ...layoutCallbacks, - ...noOutputCallbacks, - ...noInputCallbacks - ]) - ); + dispatch(addRequestedCallbacks([...layoutCallbacks, ...specialCallbacks])); } export const redo = moveHistory('REDO'); From 78cf94ac75b19d8bf1fa8840b4a15a5fcd141c6e Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 12 May 2026 09:36:20 -0400 Subject: [PATCH 164/166] improved callback id logic for no input/output --- dash/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dash/_utils.py b/dash/_utils.py index 85ff9ab073..1ab2036820 100644 --- a/dash/_utils.py +++ b/dash/_utils.py @@ -170,6 +170,7 @@ def _concat(x): # Get the call site of the @callback decorator stack = inspect.stack() # Walk up the stack to find the actual callback call site + # Fallback to empty hash if no external frame found # (skip internal dash package frames) dash_package_path = os.path.dirname(__file__) for frame_info in stack: @@ -177,8 +178,7 @@ def _concat(x): if not frame_info.filename.startswith(dash_package_path): call_site = f"{frame_info.filename}:{frame_info.lineno}" return hashlib.sha256(call_site.encode("utf-8")).hexdigest() - # Fallback to empty hash if no external frame found - return _hash_inputs() + return _hash_inputs() if isinstance(output, (list, tuple)): From 64056cfa593e914690562a207f7a0fc6406ced50 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 12 May 2026 10:17:32 -0400 Subject: [PATCH 165/166] Version 4.2.0rc3 --- CHANGELOG.md | 7 ++++++- dash/version.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 19db4b2d33..74925de692 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,12 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#2462](https://github.com/plotly/dash/issues/2462) Allow `MATCH` in `Input`/`State` when the callback's `Output` has no wildcards (fixed-id Output, no Output, or `ALL`-only wildcard Output). `ALLSMALLER` still requires a corresponding `MATCH` in an Output. - [#3759](https://github.com/plotly/dash/pull/3759) Fix the issue where `Patch` objects cannot be updated via `set_props()` in `websocket` callback. Fix [#3742](https://github.com/plotly/dash/issues/3742) -## [4.2.0rc1] - 2026-05-01 +## [4.2.0rc3] - 2026-05-12 + +- [#3771](https://github.com/plotly/dash/pull/3771) Add persistent callbacks and no inputs/no outputs callback support. +- Rename ctx.get_websocket to ctx.websocket + +## [4.2.0rc2] - 2026-05-01 ## Fixed - [#3759](https://github.com/plotly/dash/pull/3759) Fix the error when using `set_props()` to update component-type properties in the `websocket` callback. diff --git a/dash/version.py b/dash/version.py index 6af77684e6..d39d52c3b8 100644 --- a/dash/version.py +++ b/dash/version.py @@ -1 +1 @@ -__version__ = "4.2.0rc2" +__version__ = "4.2.0rc3" From 5b2088c9f5f8f250f27e4bad55bf396fae0ecaff Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 12 May 2026 15:15:33 -0400 Subject: [PATCH 166/166] remove test apps --- quart_app.py | 23 ------ r19.py | 222 --------------------------------------------------- wsapp.py | 106 ------------------------ wscb.py | 68 ---------------- 4 files changed, 419 deletions(-) delete mode 100644 quart_app.py delete mode 100644 r19.py delete mode 100644 wsapp.py delete mode 100644 wscb.py diff --git a/quart_app.py b/quart_app.py deleted file mode 100644 index 54d40add56..0000000000 --- a/quart_app.py +++ /dev/null @@ -1,23 +0,0 @@ -from dash import Dash, html, Input, Output -from dash import dcc -from dash import backends - -app = Dash(__name__, backend="quart") - -app.layout = html.Div( - [ - html.H2("Quart Server Factory Example"), - html.Div("Type below to see async callback update."), - dcc.Input(id="text", value="hello", autoComplete="off"), - html.Div(id="echo"), - ] -) - - -@app.callback(Output("echo", "children"), Input("text", "value")) -def update_echo(val): - return f"You typed: {val}" if val else "Type something" - - -if __name__ == "__main__": - app.run(debug=True) diff --git a/r19.py b/r19.py deleted file mode 100644 index 815d2a5066..0000000000 --- a/r19.py +++ /dev/null @@ -1,222 +0,0 @@ -""" -React 19 test app with most Dash components. -Run with: python r19.py -""" - -import os -os.environ["REACT_VERSION"] = "19.2.0" - -from dash import Dash, html, dcc, dash_table, callback, Input, Output -import plotly.express as px -import pandas as pd - -# Sample data -df = pd.DataFrame({ - "Fruit": ["Apples", "Oranges", "Bananas", "Grapes", "Strawberries"], - "Amount": [4, 2, 5, 3, 6], - "City": ["NYC", "LA", "Chicago", "Houston", "Phoenix"] -}) - -app = Dash(__name__) - -app.layout = html.Div([ - html.H1("React 19 Component Test"), - html.P(f"Running React version: {os.environ.get('REACT_VERSION')}"), - - html.Hr(), - html.H2("Core HTML Components"), - html.Div([ - html.Button("Click Me", id="button", n_clicks=0), - html.Span(" Clicks: ", style={"marginLeft": "10px"}), - html.Span(id="click-output", children="0"), - ]), - - html.Hr(), - html.H2("Input Components"), - html.Div([ - html.Label("Text Input:"), - dcc.Input(id="text-input", type="text", placeholder="Type something...", debounce=True), - html.Div(id="text-output"), - ], style={"marginBottom": "20px"}), - - html.Div([ - html.Label("Dropdown:"), - dcc.Dropdown( - id="dropdown", - options=[{"label": f, "value": f} for f in df["Fruit"]], - value="Apples", - clearable=True, - ), - html.Div(id="dropdown-output"), - ], style={"marginBottom": "20px", "width": "300px"}), - - html.Div([ - html.Label("Multi-Select Dropdown:"), - dcc.Dropdown( - id="multi-dropdown", - options=[{"label": f, "value": f} for f in df["Fruit"]], - value=["Apples", "Oranges"], - multi=True, - ), - ], style={"marginBottom": "20px", "width": "300px"}), - - html.Div([ - html.Label("Slider:"), - dcc.Slider(id="slider", min=0, max=10, step=1, value=5, marks={i: str(i) for i in range(11)}), - html.Div(id="slider-output"), - ], style={"marginBottom": "20px", "width": "400px"}), - - html.Div([ - html.Label("Range Slider:"), - dcc.RangeSlider(id="range-slider", min=0, max=100, step=10, value=[20, 80]), - ], style={"marginBottom": "20px", "width": "400px"}), - - html.Div([ - html.Label("Radio Items:"), - dcc.RadioItems( - id="radio", - options=[{"label": c, "value": c} for c in df["City"]], - value="NYC", - inline=True, - ), - ], style={"marginBottom": "20px"}), - - html.Div([ - html.Label("Checklist:"), - dcc.Checklist( - id="checklist", - options=[{"label": c, "value": c} for c in df["City"]], - value=["NYC", "LA"], - inline=True, - ), - ], style={"marginBottom": "20px"}), - - html.Div([ - html.Label("Date Picker:"), - dcc.DatePickerSingle(id="date-picker", date="2024-01-15"), - ], style={"marginBottom": "20px"}), - - html.Div([ - html.Label("Date Range Picker:"), - dcc.DatePickerRange( - id="date-range", - start_date="2024-01-01", - end_date="2024-12-31", - ), - ], style={"marginBottom": "20px"}), - - html.Div([ - html.Label("Textarea:"), - dcc.Textarea(id="textarea", value="Some text here...", style={"width": "300px", "height": "100px"}), - ], style={"marginBottom": "20px"}), - - html.Hr(), - html.H2("Graph Component"), - dcc.Graph( - id="graph", - figure=px.bar(df, x="Fruit", y="Amount", color="City", title="Fruit Amounts by City") - ), - - html.Hr(), - html.H2("DataTable"), - dash_table.DataTable( - id="table", - columns=[{"name": c, "id": c} for c in df.columns], - data=df.to_dict("records"), - editable=True, - filter_action="native", - sort_action="native", - row_selectable="multi", - page_size=10, - ), - - html.Hr(), - html.H2("Tabs"), - dcc.Tabs(id="tabs", value="tab-1", children=[ - dcc.Tab(label="Tab 1", value="tab-1", children=[ - html.Div("Content for Tab 1", style={"padding": "20px"}) - ]), - dcc.Tab(label="Tab 2", value="tab-2", children=[ - html.Div("Content for Tab 2", style={"padding": "20px"}) - ]), - ]), - - html.Hr(), - html.H2("Loading Component"), - dcc.Loading( - id="loading", - type="circle", - children=html.Div(id="loading-output", children="Content loaded!") - ), - - html.Hr(), - html.H2("Markdown"), - dcc.Markdown(""" - ### This is Markdown - - - Item 1 - - Item 2 - - **Bold text** - - *Italic text* - - ```python - def hello(): - return "Hello, React 19!" - ``` - """), - - html.Hr(), - html.H2("Store & Interval"), - dcc.Store(id="store", data={"count": 0}), - dcc.Interval(id="interval", interval=5000, n_intervals=0, disabled=True), - html.Div(id="interval-output", children="Interval disabled"), - - html.Hr(), - html.H2("Clipboard"), - dcc.Clipboard(id="clipboard", target_id="text-input", style={"fontSize": "20px"}), - - html.Hr(), - html.H2("Tooltip"), - html.Div([ - html.Span("Hover over the graph points to see tooltips", style={"fontStyle": "italic"}), - ]), - - html.Br(), - html.Br(), -], style={"padding": "20px", "maxWidth": "800px", "margin": "0 auto"}) - - -@callback( - Output("click-output", "children"), - Input("button", "n_clicks") -) -def update_clicks(n): - return str(n) - - -@callback( - Output("text-output", "children"), - Input("text-input", "value") -) -def update_text(value): - return f"You typed: {value}" if value else "" - - -@callback( - Output("dropdown-output", "children"), - Input("dropdown", "value") -) -def update_dropdown(value): - return f"Selected: {value}" if value else "Nothing selected" - - -@callback( - Output("slider-output", "children"), - Input("slider", "value") -) -def update_slider(value): - return f"Slider value: {value}" - - -if __name__ == "__main__": - app.run(debug=True, port=8050) diff --git a/wsapp.py b/wsapp.py deleted file mode 100644 index ade2f80d39..0000000000 --- a/wsapp.py +++ /dev/null @@ -1,106 +0,0 @@ -""" -Test app for WebSocket-based callbacks. - -Run with: - python wsapp.py - -Then open http://127.0.0.1:8050 in your browser. -""" - -from dash import Dash, html, dcc, callback, Output, Input, ctx -import time - -# Create app with FastAPI backend and WebSocket callbacks enabled -app = Dash( - __name__, - backend="fastapi", - websocket_callbacks=True, -) - -app.layout = html.Div([ - html.H1("WebSocket Callbacks Test"), - - html.Div([ - html.H3("Basic Callback Test"), - html.Button("Click me", id="btn-1", n_clicks=0), - html.Div(id="output-1"), - ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), - - html.Div([ - html.H3("Input Test"), - dcc.Input(id="input-1", type="text", placeholder="Type something..."), - html.Div(id="output-2"), - ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), - - html.Div([ - html.H3("Slider Test"), - dcc.Slider(id="slider-1", min=0, max=100, value=50), - html.Div(id="output-3"), - ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), - - html.Div([ - html.H3("set_props Test"), - html.Button("Update via set_props", id="btn-2", n_clicks=0), - html.Div(id="output-4", children="Initial content"), - html.Div(id="output-5", children="Will be updated by set_props"), - ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), - - html.Div([ - html.H3("WebSocket Context Test"), - html.Button("Check WebSocket Context", id="btn-3", n_clicks=0), - html.Div(id="output-6"), - ], style={"marginBottom": "20px", "padding": "10px", "border": "1px solid #ccc"}), - - html.Div(id="config-display", style={"marginTop": "20px", "fontSize": "12px", "color": "#666"}), -]) - - -@callback(Output("output-1", "children"), Input("btn-1", "n_clicks")) -def update_output_1(n_clicks): - return f"Button clicked {n_clicks} times" - - -@callback(Output("output-2", "children"), Input("input-1", "value")) -def update_output_2(value): - return f"You typed: {value}" - - -@callback(Output("output-3", "children"), Input("slider-1", "value")) -def update_output_3(value): - return f"Slider value: {value}" - - -@callback(Output("output-4", "children"), Input("btn-2", "n_clicks")) -def update_with_set_props(n_clicks): - if n_clicks > 0: - # Use set_props to update another component - from dash._callback_context import set_props - set_props("output-5", {"children": f"Updated via set_props at click {n_clicks}"}) - return f"set_props button clicked {n_clicks} times" - - -@callback(Output("output-6", "children"), Input("btn-3", "n_clicks")) -def check_websocket_context(n_clicks): - if n_clicks > 0: - ws = ctx.websocket - if ws is not None: - return f"WebSocket context is available! (click {n_clicks})" - else: - return f"WebSocket context is None (click {n_clicks}) - may be using HTTP fallback" - return "Click to check WebSocket context" - - -@callback(Output("config-display", "children"), Input("btn-1", "n_clicks")) -def show_config(n_clicks): - config = app._config() - ws_config = config.get("websocket", {}) - if ws_config: - return f"WebSocket enabled: {ws_config.get('enabled')}, URL: {ws_config.get('url')}" - return "WebSocket not configured" - - -if __name__ == "__main__": - print("Starting WebSocket callbacks test app...") - print(f"WebSocket callbacks enabled: {app._websocket_callbacks}") - print(f"Backend websocket capability: {app.backend.websocket_capability}") - app.run(debug=True, port=8050) diff --git a/wscb.py b/wscb.py deleted file mode 100644 index 629ed3cdc1..0000000000 --- a/wscb.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Test app for per-callback WebSocket support. - -This app demonstrates using websocket=True on specific callbacks -without enabling global websocket_callbacks. -""" - -from dash import Dash, html, dcc, callback, Input, Output, State - -app = Dash(__name__, backend="fastapi") - -app.layout = html.Div([ - html.H1("Per-Callback WebSocket Test"), - - html.Div([ - html.H3("WebSocket Callback"), - dcc.Input(id="ws-input", type="text", placeholder="Type here..."), - html.Div(id="ws-output", style={"padding": "10px", "background": "#e0ffe0"}) - ], style={"margin": "20px", "padding": "20px", "border": "1px solid #ccc"}), - - html.Div([ - html.H3("HTTP Callback (default)"), - dcc.Input(id="http-input", type="text", placeholder="Type here..."), - html.Div(id="http-output", style={"padding": "10px", "background": "#e0e0ff"}) - ], style={"margin": "20px", "padding": "20px", "border": "1px solid #ccc"}), - - html.Div([ - html.H3("WebSocket Counter"), - html.Button("Increment", id="ws-btn"), - html.Div(id="ws-counter", children="0", style={"padding": "10px", "background": "#ffe0e0"}) - ], style={"margin": "20px", "padding": "20px", "border": "1px solid #ccc"}), -]) - - -@callback( - Output("ws-output", "children"), - Input("ws-input", "value"), - websocket=True -) -def ws_callback(value): - """This callback uses WebSocket.""" - return f"[WebSocket] You typed: {value or ''}" - - -@callback( - Output("http-output", "children"), - Input("http-input", "value") -) -def http_callback(value): - """This callback uses HTTP (default).""" - return f"[HTTP] You typed: {value or ''}" - - -@callback( - Output("ws-counter", "children"), - Input("ws-btn", "n_clicks"), - State("ws-counter", "children"), - websocket=True -) -def ws_counter(n_clicks, current): - """WebSocket counter callback.""" - if n_clicks is None: - return "0" - return str(int(current or 0) + 1) - - -if __name__ == "__main__": - app.run(debug=True)