Skip to content

Commit 74231b6

Browse files
committed
Have AsyncValueFn leverage generic WrapAsync as base class
1 parent d281879 commit 74231b6

File tree

2 files changed

+83
-107
lines changed

2 files changed

+83
-107
lines changed

shiny/_utils.py

Lines changed: 80 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Awaitable,
2222
Callable,
2323
Generator,
24+
Generic,
2425
Iterable,
2526
Optional,
2627
TypeVar,
@@ -279,58 +280,85 @@ async def fn_async(*args: P.args, **kwargs: P.kwargs) -> R:
279280
return fn_async
280281

281282

282-
# # TODO-barret-future; Q: Keep code?
283-
# class WrapAsync(Generic[P, R]):
284-
# """
285-
# Make a function asynchronous.
286-
287-
# Parameters
288-
# ----------
289-
# fn
290-
# Function to make asynchronous.
291-
292-
# Returns
293-
# -------
294-
# :
295-
# Asynchronous function (within the `WrapAsync` instance)
296-
# """
297-
298-
# def __init__(self, fn: Callable[P, R] | Callable[P, Awaitable[R]]):
299-
# if isinstance(fn, WrapAsync):
300-
# fn = cast(WrapAsync[P, R], fn)
301-
# return fn
302-
# self._is_async = is_async_callable(fn)
303-
# self._fn = wrap_async(fn)
304-
305-
# async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
306-
# """
307-
# Call the asynchronous function.
308-
# """
309-
# return await self._fn(*args, **kwargs)
310-
311-
# @property
312-
# def is_async(self) -> bool:
313-
# """
314-
# Was the original function asynchronous?
315-
316-
# Returns
317-
# -------
318-
# :
319-
# Whether the original function is asynchronous.
320-
# """
321-
# return self._is_async
322-
323-
# @property
324-
# def fn(self) -> Callable[P, R] | Callable[P, Awaitable[R]]:
325-
# """
326-
# Retrieve the original function
327-
328-
# Returns
329-
# -------
330-
# :
331-
# Original function supplied to the `WrapAsync` constructor.
332-
# """
333-
# return self._fn
283+
class WrapAsync(Generic[P, R]):
284+
"""
285+
Make a function asynchronous.
286+
287+
Parameters
288+
----------
289+
fn
290+
Function to make asynchronous.
291+
292+
Returns
293+
-------
294+
:
295+
Asynchronous function (within the `WrapAsync` instance)
296+
"""
297+
298+
_fn: Callable[P, Awaitable[R]]
299+
_is_async: bool
300+
_orig_fn: Callable[P, R] | Callable[P, Awaitable[R]]
301+
302+
def __init__(
303+
self,
304+
fn: Callable[P, R] | Callable[P, Awaitable[R]],
305+
):
306+
if isinstance(fn, WrapAsync):
307+
wa = cast(WrapAsync[P, R], fn)
308+
self._fn = wa._fn
309+
self._is_async = wa._is_async
310+
self._orig_fn = wa._orig_fn
311+
else:
312+
self._is_async = is_async_callable(fn)
313+
self._fn = wrap_async(fn)
314+
self._orig_fn = fn
315+
316+
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
317+
"""
318+
Call the asynchronous function.
319+
"""
320+
return await self._fn(*args, **kwargs)
321+
322+
def is_async(self) -> bool:
323+
"""
324+
Was the original function asynchronous?
325+
326+
Returns
327+
-------
328+
:
329+
Whether the original function is asynchronous.
330+
"""
331+
return self._is_async
332+
333+
def get_async_fn(self) -> Callable[P, Awaitable[R]]:
334+
"""
335+
Return the async value function.
336+
337+
Returns
338+
-------
339+
:
340+
Async wrapped value function supplied to the `AsyncValueFn` constructor.
341+
"""
342+
return self._fn
343+
344+
def get_sync_fn(self) -> Callable[P, R]:
345+
"""
346+
Retrieve the original, synchronous value function function.
347+
348+
If the original function was asynchronous, a runtime error will be thrown.
349+
350+
Returns
351+
-------
352+
:
353+
Original, synchronous function supplied to the `AsyncValueFn` constructor.
354+
"""
355+
if self._is_async:
356+
raise RuntimeError(
357+
"The original function was asynchronous. Use `async_fn` instead."
358+
)
359+
360+
sync_fn = cast(Callable[P, R], self._orig_fn)
361+
return sync_fn
334362

335363

336364
# This function should generally be used in this code base instead of

shiny/render/renderer/_renderer.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,13 @@
99
Optional,
1010
TypeVar,
1111
Union,
12-
cast,
1312
)
1413

1514
from htmltools import MetadataNode, Tag, TagList
1615

1716
from ..._docstring import add_example
1817
from ..._typing_extensions import Self
19-
from ..._utils import is_async_callable, wrap_async
18+
from ..._utils import WrapAsync
2019
from ...types import Jsonifiable
2120

2221
if TYPE_CHECKING:
@@ -336,10 +335,7 @@ def _auto_register(self) -> None:
336335
self._auto_registered = True
337336

338337

339-
# Not inheriting from `WrapAsync[[], IT]` as python 3.8 needs typing extensions that
340-
# doesn't support `[]` for a ParamSpec definition. :-( Would be minimal/clean if we
341-
# could do `class AsyncValueFn(WrapAsync[[], IT]):`
342-
class AsyncValueFn(Generic[IT]):
338+
class AsyncValueFn(WrapAsync[[], IT | None]):
343339
"""
344340
App-supplied output value function which returns type `IT`.
345341
asynchronous.
@@ -355,52 +351,4 @@ def __init__(
355351
raise TypeError(
356352
"Must not call `AsyncValueFn.__init__` with an object of class `AsyncValueFn`"
357353
)
358-
self._is_async = is_async_callable(fn)
359-
self._fn = wrap_async(fn)
360-
self._orig_fn = fn
361-
362-
async def __call__(self) -> IT | None:
363-
"""
364-
Call the asynchronous function.
365-
"""
366-
return await self._fn()
367-
368-
def is_async(self) -> bool:
369-
"""
370-
Was the original function asynchronous?
371-
372-
Returns
373-
-------
374-
:
375-
Whether the original function is asynchronous.
376-
"""
377-
return self._is_async
378-
379-
def get_async_fn(self) -> Callable[[], Awaitable[IT | None]]:
380-
"""
381-
Return the async value function.
382-
383-
Returns
384-
-------
385-
:
386-
Async wrapped value function supplied to the `AsyncValueFn` constructor.
387-
"""
388-
return self._fn
389-
390-
def get_sync_fn(self) -> Callable[[], IT | None]:
391-
"""
392-
Retrieve the original, synchronous value function function.
393-
394-
If the original function was asynchronous, a runtime error will be thrown.
395-
396-
Returns
397-
-------
398-
:
399-
Original, synchronous function supplied to the `AsyncValueFn` constructor.
400-
"""
401-
if self._is_async:
402-
raise RuntimeError(
403-
"The original function was asynchronous. Use `async_fn` instead."
404-
)
405-
sync_fn = cast(Callable[[], IT], self._orig_fn)
406-
return sync_fn
354+
super().__init__(fn)

0 commit comments

Comments
 (0)