10
10
import inspect
11
11
import socket
12
12
import sys
13
- import threading
14
13
import traceback
15
14
import warnings
16
15
from asyncio import AbstractEventLoop , AbstractEventLoopPolicy
50
49
PytestPluginManager ,
51
50
)
52
51
52
+ _seen_markers : set [int ] = set ()
53
+
54
+
55
+ def _warn_scope_deprecation_once (marker_id : int ) -> None :
56
+ """Issues deprecation warning exactly once per marker ID."""
57
+ if marker_id not in _seen_markers :
58
+ _seen_markers .add (marker_id )
59
+ warnings .warn (PytestDeprecationWarning (_MARKER_SCOPE_KWARG_DEPRECATION_WARNING ))
60
+
61
+
53
62
if sys .version_info >= (3 , 10 ):
54
63
from typing import ParamSpec
55
64
else :
63
72
_ScopeName = Literal ["session" , "package" , "module" , "class" , "function" ]
64
73
_R = TypeVar ("_R" , bound = Union [Awaitable [Any ], AsyncIterator [Any ]])
65
74
_P = ParamSpec ("_P" )
75
+ T = TypeVar ("T" )
66
76
FixtureFunction = Callable [_P , _R ]
77
+ CoroutineFunction = Callable [_P , Awaitable [T ]]
67
78
68
79
69
80
class PytestAsyncioError (Exception ):
@@ -292,7 +303,7 @@ def _asyncgen_fixture_wrapper(
292
303
gen_obj = fixture_function (* args , ** kwargs )
293
304
294
305
async def setup ():
295
- res = await gen_obj .__anext__ () # type: ignore[union-attr]
306
+ res = await gen_obj .__anext__ ()
296
307
return res
297
308
298
309
context = contextvars .copy_context ()
@@ -305,7 +316,7 @@ def finalizer() -> None:
305
316
306
317
async def async_finalizer () -> None :
307
318
try :
308
- await gen_obj .__anext__ () # type: ignore[union-attr]
319
+ await gen_obj .__anext__ ()
309
320
except StopAsyncIteration :
310
321
pass
311
322
else :
@@ -334,8 +345,7 @@ def _wrap_async_fixture(
334
345
runner : Runner ,
335
346
request : FixtureRequest ,
336
347
) -> Callable [AsyncFixtureParams , AsyncFixtureReturnType ]:
337
-
338
- @functools .wraps (fixture_function ) # type: ignore[arg-type]
348
+ @functools .wraps (fixture_function )
339
349
def _async_fixture_wrapper (
340
350
* args : AsyncFixtureParams .args ,
341
351
** kwargs : AsyncFixtureParams .kwargs ,
@@ -448,7 +458,7 @@ def _can_substitute(item: Function) -> bool:
448
458
return inspect .iscoroutinefunction (func )
449
459
450
460
def runtest (self ) -> None :
451
- synchronized_obj = wrap_in_sync ( self .obj )
461
+ synchronized_obj = get_async_test_wrapper ( self , self .obj )
452
462
with MonkeyPatch .context () as c :
453
463
c .setattr (self , "obj" , synchronized_obj )
454
464
super ().runtest ()
@@ -490,7 +500,7 @@ def _can_substitute(item: Function) -> bool:
490
500
)
491
501
492
502
def runtest (self ) -> None :
493
- synchronized_obj = wrap_in_sync ( self .obj )
503
+ synchronized_obj = get_async_test_wrapper ( self , self .obj )
494
504
with MonkeyPatch .context () as c :
495
505
c .setattr (self , "obj" , synchronized_obj )
496
506
super ().runtest ()
@@ -512,7 +522,10 @@ def _can_substitute(item: Function) -> bool:
512
522
)
513
523
514
524
def runtest (self ) -> None :
515
- synchronized_obj = wrap_in_sync (self .obj .hypothesis .inner_test )
525
+ synchronized_obj = get_async_test_wrapper (
526
+ self ,
527
+ self .obj .hypothesis .inner_test ,
528
+ )
516
529
with MonkeyPatch .context () as c :
517
530
c .setattr (self .obj .hypothesis , "inner_test" , synchronized_obj )
518
531
super ().runtest ()
@@ -603,10 +616,60 @@ def _set_event_loop(loop: AbstractEventLoop | None) -> None:
603
616
asyncio .set_event_loop (loop )
604
617
605
618
606
- def _reinstate_event_loop_on_main_thread () -> None :
607
- if threading .current_thread () is threading .main_thread ():
619
+ _session_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
620
+ contextvars .ContextVar (
621
+ "_session_loop" ,
622
+ default = None ,
623
+ )
624
+ )
625
+ _package_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
626
+ contextvars .ContextVar (
627
+ "_package_loop" ,
628
+ default = None ,
629
+ )
630
+ )
631
+ _module_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
632
+ contextvars .ContextVar (
633
+ "_module_loop" ,
634
+ default = None ,
635
+ )
636
+ )
637
+ _class_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
638
+ contextvars .ContextVar (
639
+ "_class_loop" ,
640
+ default = None ,
641
+ )
642
+ )
643
+ _function_loop : contextvars .ContextVar [asyncio .AbstractEventLoop | None ] = (
644
+ contextvars .ContextVar (
645
+ "_function_loop" ,
646
+ default = None ,
647
+ )
648
+ )
649
+
650
+ _SCOPE_TO_CONTEXTVAR = {
651
+ "session" : _session_loop ,
652
+ "package" : _package_loop ,
653
+ "module" : _module_loop ,
654
+ "class" : _class_loop ,
655
+ "function" : _function_loop ,
656
+ }
657
+
658
+
659
+ def _get_or_restore_event_loop (loop_scope : _ScopeName ) -> asyncio .AbstractEventLoop :
660
+ """
661
+ Get or restore the appropriate event loop for the given scope.
662
+
663
+ If we have a shared loop for this scope, restore and return it.
664
+ Otherwise, get the current event loop or create a new one.
665
+ """
666
+ shared_loop = _SCOPE_TO_CONTEXTVAR [loop_scope ].get ()
667
+ if shared_loop is not None :
608
668
policy = _get_event_loop_policy ()
609
- policy .set_event_loop (policy .new_event_loop ())
669
+ policy .set_event_loop (shared_loop )
670
+ return shared_loop
671
+ else :
672
+ return _get_event_loop_no_warn ()
610
673
611
674
612
675
@pytest .hookimpl (tryfirst = True , hookwrapper = True )
@@ -659,9 +722,22 @@ def pytest_pyfunc_call(pyfuncitem: Function) -> object | None:
659
722
return None
660
723
661
724
662
- def wrap_in_sync (
663
- func : Callable [..., Awaitable [Any ]],
664
- ):
725
+ def get_async_test_wrapper (
726
+ item : Function ,
727
+ func : CoroutineFunction [_P , T ],
728
+ ) -> Callable [_P , None ]:
729
+ """Returns a synchronous wrapper for the specified async test function."""
730
+ marker = item .get_closest_marker ("asyncio" )
731
+ assert marker is not None
732
+ default_loop_scope = _get_default_test_loop_scope (item .config )
733
+ loop_scope = _get_marked_loop_scope (marker , default_loop_scope )
734
+ return _wrap_in_sync (func , loop_scope )
735
+
736
+
737
+ def _wrap_in_sync (
738
+ func : CoroutineFunction [_P , T ],
739
+ loop_scope : _ScopeName ,
740
+ ) -> Callable [_P , None ]:
665
741
"""
666
742
Return a sync wrapper around an async function executing it in the
667
743
current event loop.
@@ -670,12 +746,7 @@ def wrap_in_sync(
670
746
@functools .wraps (func )
671
747
def inner (* args , ** kwargs ):
672
748
coro = func (* args , ** kwargs )
673
- try :
674
- _loop = _get_event_loop_no_warn ()
675
- except RuntimeError :
676
- # Handle situation where asyncio.set_event_loop(None) removes shared loops.
677
- _reinstate_event_loop_on_main_thread ()
678
- _loop = _get_event_loop_no_warn ()
749
+ _loop = _get_or_restore_event_loop (loop_scope )
679
750
task = asyncio .ensure_future (coro , loop = _loop )
680
751
try :
681
752
_loop .run_until_complete (task )
@@ -758,7 +829,7 @@ def _get_marked_loop_scope(
758
829
if "scope" in asyncio_marker .kwargs :
759
830
if "loop_scope" in asyncio_marker .kwargs :
760
831
raise pytest .UsageError (_DUPLICATE_LOOP_SCOPE_DEFINITION_ERROR )
761
- warnings . warn ( PytestDeprecationWarning ( _MARKER_SCOPE_KWARG_DEPRECATION_WARNING ))
832
+ _warn_scope_deprecation_once ( id ( asyncio_marker ))
762
833
scope = asyncio_marker .kwargs .get ("loop_scope" ) or asyncio_marker .kwargs .get (
763
834
"scope"
764
835
)
@@ -768,7 +839,7 @@ def _get_marked_loop_scope(
768
839
return scope
769
840
770
841
771
- def _get_default_test_loop_scope (config : Config ) -> _ScopeName :
842
+ def _get_default_test_loop_scope (config : Config ) -> Any :
772
843
return config .getini ("asyncio_default_test_loop_scope" )
773
844
774
845
@@ -796,6 +867,8 @@ def _scoped_runner(
796
867
debug_mode = _get_asyncio_debug (request .config )
797
868
with _temporary_event_loop_policy (new_loop_policy ):
798
869
runner = Runner (debug = debug_mode ).__enter__ ()
870
+ shared_loop = runner .get_loop ()
871
+ _SCOPE_TO_CONTEXTVAR [scope ].set (shared_loop )
799
872
try :
800
873
yield runner
801
874
except Exception as e :
0 commit comments