Skip to content

Commit fbefbc9

Browse files
authored
Merge pull request #4 from chgiesse/quart-factory
Quart factory
2 parents b14f6d2 + 1824e11 commit fbefbc9

File tree

2 files changed

+257
-12
lines changed

2 files changed

+257
-12
lines changed
Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,50 @@
11
from abc import ABC, abstractmethod
2+
from typing import Any
23

34

45
class BaseServerFactory(ABC):
5-
def __call__(self, server, *args, **kwargs):
6+
def __call__(self, server, *args, **kwargs) -> Any:
67
# Default: WSGI
78
return server(*args, **kwargs)
89

910
@abstractmethod
10-
def create_app(self, name="__main__", config=None):
11+
def create_app(self, name: str = "__main__", config=None) -> Any: # pragma: no cover - interface
1112
pass
1213

1314
@abstractmethod
1415
def register_assets_blueprint(
15-
self, app, blueprint_name, assets_url_path, assets_folder
16-
):
16+
self, app, blueprint_name: str, assets_url_path: str, assets_folder: str
17+
) -> None: # pragma: no cover - interface
1718
pass
1819

1920
@abstractmethod
20-
def register_error_handlers(self, app):
21+
def register_error_handlers(self, app) -> None: # pragma: no cover - interface
2122
pass
2223

2324
@abstractmethod
24-
def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None):
25+
def add_url_rule(self, app, rule: str, view_func, endpoint=None, methods=None) -> None: # pragma: no cover - interface
2526
pass
2627

2728
@abstractmethod
28-
def before_request(self, app, func):
29+
def before_request(self, app, func) -> None: # pragma: no cover - interface
2930
pass
3031

3132
@abstractmethod
32-
def after_request(self, app, func):
33+
def after_request(self, app, func) -> None: # pragma: no cover - interface
3334
pass
3435

3536
@abstractmethod
36-
def run(self, app, host, port, debug, **kwargs):
37+
def run(self, app, host: str, port: int, debug: bool, **kwargs) -> None: # pragma: no cover - interface
3738
pass
3839

3940
@abstractmethod
40-
def make_response(self, data, mimetype=None, content_type=None):
41+
def make_response(self, data, mimetype=None, content_type=None) -> Any: # pragma: no cover - interface
4142
pass
4243

4344
@abstractmethod
44-
def jsonify(self, obj):
45+
def jsonify(self, obj) -> Any: # pragma: no cover - interface
4546
pass
4647

4748
@abstractmethod
48-
def get_request_adapter(self):
49+
def get_request_adapter(self) -> Any: # pragma: no cover - interface
4950
pass
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
from .base_factory import BaseServerFactory
2+
from quart import Quart, Request, Response, jsonify, request
3+
from dash.exceptions import PreventUpdate, InvalidResourceError
4+
from dash.server_factories import set_request_adapter
5+
from dash.fingerprint import check_fingerprint
6+
from dash import _validate
7+
from contextvars import copy_context
8+
import inspect
9+
import pkgutil
10+
import mimetypes
11+
import sys
12+
import time
13+
14+
15+
class QuartAPIServerFactory(BaseServerFactory):
16+
"""Quart implementation of the Dash server factory.
17+
18+
All Quart/async specific imports are at the top-level (per user request) so
19+
Quart must be installed when this module is imported.
20+
"""
21+
22+
def __init__(self) -> None:
23+
self.config = {}
24+
super().__init__()
25+
26+
def __call__(self, server, *args, **kwargs):
27+
return super().__call__(server, *args, **kwargs)
28+
29+
def create_app(self, name="__main__", config=None):
30+
app = Quart(name)
31+
if config:
32+
for key, value in config.items():
33+
app.config[key] = value
34+
return app
35+
36+
def register_assets_blueprint(
37+
self, app, blueprint_name, assets_url_path, assets_folder
38+
):
39+
from quart import Blueprint
40+
41+
bp = Blueprint(
42+
blueprint_name,
43+
__name__,
44+
static_folder=assets_folder,
45+
static_url_path=assets_url_path,
46+
)
47+
app.register_blueprint(bp)
48+
49+
def register_prune_error_handler(self, app, secret, get_traceback_func):
50+
@app.errorhandler(Exception)
51+
async def _wrap_errors(_error_request, error):
52+
tb = get_traceback_func(secret, error)
53+
return tb, 500
54+
55+
def register_timing_hooks(self, app, _first_run): # parity with Flask factory
56+
from quart import g
57+
58+
@app.before_request
59+
async def _before_request(): # pragma: no cover - timing infra
60+
g.timing_information = {"__dash_server": {"dur": time.time(), "desc": None}}
61+
62+
@app.after_request
63+
async def _after_request(response): # pragma: no cover - timing infra
64+
timing_information = getattr(g, "timing_information", None)
65+
if timing_information is None:
66+
return response
67+
dash_total = timing_information.get("__dash_server", None)
68+
if dash_total is not None:
69+
dash_total["dur"] = round((time.time() - dash_total["dur"]) * 1000)
70+
for name, info in timing_information.items():
71+
value = name
72+
if info.get("desc") is not None:
73+
value += f';desc="{info["desc"]}"'
74+
if info.get("dur") is not None:
75+
value += f";dur={info['dur']}"
76+
# Quart/Werkzeug headers expose 'add' (not 'append')
77+
if hasattr(response.headers, "add"):
78+
response.headers.add("Server-Timing", value)
79+
else: # fallback just in case
80+
response.headers["Server-Timing"] = value
81+
return response
82+
83+
def register_error_handlers(self, app):
84+
@app.errorhandler(PreventUpdate)
85+
async def _prevent_update(_):
86+
return "", 204
87+
88+
@app.errorhandler(InvalidResourceError)
89+
async def _invalid_resource(err):
90+
return err.args[0], 404
91+
92+
def _html_response_wrapper(self, view_func):
93+
async def wrapped(*args, **kwargs):
94+
html_val = view_func() if callable(view_func) else view_func
95+
if inspect.iscoroutine(html_val): # handle async function returning html
96+
html_val = await html_val
97+
html = str(html_val)
98+
return Response(html, content_type="text/html")
99+
100+
return wrapped
101+
102+
def add_url_rule(self, app, rule, view_func, endpoint=None, methods=None):
103+
app.add_url_rule(
104+
rule, view_func=view_func, endpoint=endpoint, methods=methods or ["GET"]
105+
)
106+
107+
def setup_index(self, dash_app):
108+
async def index():
109+
adapter = QuartRequestAdapter()
110+
set_request_adapter(adapter)
111+
adapter.set_request(request)
112+
return Response(dash_app.index(), content_type="text/html")
113+
114+
dash_app._add_url("", index, methods=["GET"])
115+
116+
def setup_catchall(self, dash_app):
117+
async def catchall(path): # noqa: ARG001 - path is unused but kept for route signature
118+
adapter = QuartRequestAdapter()
119+
set_request_adapter(adapter)
120+
adapter.set_request(request)
121+
return Response(dash_app.index(), content_type="text/html")
122+
123+
dash_app._add_url("<path:path>", catchall, methods=["GET"])
124+
125+
def before_request(self, app, func):
126+
app.before_request(func)
127+
128+
def after_request(self, app, func):
129+
@app.after_request
130+
async def _after(response):
131+
if func is not None:
132+
result = func()
133+
if inspect.iscoroutine(result): # Allow async hooks
134+
await result
135+
return response
136+
137+
def run(self, app, host, port, debug, **kwargs):
138+
self.config = {'debug': debug, **kwargs} if debug else kwargs
139+
app.run(host=host, port=port, debug=debug, **kwargs)
140+
141+
def make_response(self, data, mimetype=None, content_type=None):
142+
return Response(data, mimetype=mimetype, content_type=content_type)
143+
144+
def jsonify(self, obj):
145+
return jsonify(obj)
146+
147+
def get_request_adapter(self):
148+
return QuartRequestAdapter
149+
150+
def serve_component_suites(self, dash_app, package_name, fingerprinted_path, req): # noqa: ARG002 unused req preserved for interface parity
151+
path_in_pkg, has_fingerprint = check_fingerprint(fingerprinted_path)
152+
_validate.validate_js_path(dash_app.registered_paths, package_name, path_in_pkg)
153+
extension = "." + path_in_pkg.split(".")[-1]
154+
mimetype = mimetypes.types_map.get(extension, "application/octet-stream")
155+
package = sys.modules[package_name]
156+
dash_app.logger.debug(
157+
"serving -- package: %s[%s] resource: %s => location: %s",
158+
package_name,
159+
getattr(package, "__version__", "unknown"),
160+
path_in_pkg,
161+
package.__path__,
162+
)
163+
data = pkgutil.get_data(package_name, path_in_pkg)
164+
headers = {}
165+
if has_fingerprint:
166+
headers["Cache-Control"] = "public, max-age=31536000"
167+
168+
return Response(data, content_type=mimetype, headers=headers)
169+
170+
def setup_component_suites(self, dash_app):
171+
async def serve(package_name, fingerprinted_path):
172+
return self.serve_component_suites(
173+
dash_app, package_name, fingerprinted_path, request
174+
)
175+
176+
dash_app._add_url(
177+
"_dash-component-suites/<string:package_name>/<path:fingerprinted_path>",
178+
serve,
179+
)
180+
181+
def dispatch(self, app, dash_app, use_async=True): # Quart always async
182+
async def _dispatch():
183+
adapter = QuartRequestAdapter()
184+
set_request_adapter(adapter)
185+
adapter.set_request(request)
186+
body = await request.get_json()
187+
g = dash_app._initialize_context(body, adapter)
188+
func = dash_app._prepare_callback(g, body)
189+
args = dash_app._inputs_to_vals(g.inputs_list + g.states_list)
190+
ctx = copy_context()
191+
partial_func = dash_app._execute_callback(func, args, g.outputs_list, g)
192+
response_data = ctx.run(partial_func)
193+
if inspect.iscoroutine(response_data): # if user callback is async
194+
response_data = await response_data
195+
return Response(response_data, content_type="application/json")
196+
197+
return _dispatch
198+
199+
def _serve_default_favicon(self):
200+
return Response(
201+
pkgutil.get_data("dash", "favicon.ico"), content_type="image/x-icon"
202+
)
203+
204+
205+
class QuartRequestAdapter:
206+
def __init__(self) -> None:
207+
self._request = None
208+
209+
def set_request(self, request: Request) -> None:
210+
self._request = request
211+
212+
# Accessors (instance-based)
213+
def get_root(self):
214+
return self._request.root_url
215+
216+
def get_args(self):
217+
return self._request.args
218+
219+
async def get_json(self):
220+
return await self._request.get_json()
221+
222+
def is_json(self):
223+
return self._request.is_json
224+
225+
def get_cookies(self):
226+
return self._request.cookies
227+
228+
def get_headers(self):
229+
return self._request.headers
230+
231+
def get_full_path(self):
232+
return self._request.full_path
233+
234+
def get_url(self):
235+
return str(self._request.url)
236+
237+
def get_remote_addr(self):
238+
return self._request.remote_addr
239+
240+
def get_origin(self):
241+
return self._request.headers.get("origin")
242+
243+
def get_path(self):
244+
return self._request.path

0 commit comments

Comments
 (0)