2626from __future__ import annotations
2727
2828import asyncio
29+ import contextvars
2930import datetime
3031import inspect
3132import sys
4647LF = TypeVar ("LF" , bound = _func )
4748FT = TypeVar ("FT" , bound = _func )
4849ET = TypeVar ("ET" , bound = Callable [[Any , BaseException ], Awaitable [Any ]])
50+ _current_loop_ctx : contextvars .ContextVar [int ] = contextvars .ContextVar (
51+ "_current_loop_ctx" , default = None
52+ )
4953
5054
5155class SleepHandle :
@@ -59,10 +63,14 @@ def __init__(
5963 relative_delta = discord .utils .compute_timedelta (dt )
6064 self .handle = loop .call_later (relative_delta , future .set_result , True )
6165
66+ def _set_result_safe (self ):
67+ if not self .future .done ():
68+ self .future .set_result (True )
69+
6270 def recalculate (self , dt : datetime .datetime ) -> None :
6371 self .handle .cancel ()
6472 relative_delta = discord .utils .compute_timedelta (dt )
65- self .handle = self .loop .call_later (relative_delta , self .future . set_result , True )
73+ self .handle = self .loop .call_later (relative_delta , self ._set_result_safe )
6674
6775 def wait (self ) -> asyncio .Future [Any ]:
6876 return self .future
@@ -91,10 +99,12 @@ def __init__(
9199 count : int | None ,
92100 reconnect : bool ,
93101 loop : asyncio .AbstractEventLoop ,
102+ overlap : bool | int ,
94103 ) -> None :
95104 self .coro : LF = coro
96105 self .reconnect : bool = reconnect
97106 self .loop : asyncio .AbstractEventLoop = loop
107+ self .overlap : bool | int = overlap
98108 self .count : int | None = count
99109 self ._current_loop = 0
100110 self ._handle : SleepHandle = MISSING
@@ -115,6 +125,7 @@ def __init__(
115125 self ._is_being_cancelled = False
116126 self ._has_failed = False
117127 self ._stop_next_iteration = False
128+ self ._tasks : set [asyncio .Task [Any ]] = set ()
118129
119130 if self .count is not None and self .count <= 0 :
120131 raise ValueError ("count must be greater than 0 or None." )
@@ -128,6 +139,29 @@ def __init__(
128139 raise TypeError (
129140 f"Expected coroutine function, not { type (self .coro ).__name__ !r} ."
130141 )
142+ if isinstance (overlap , bool ):
143+ if overlap :
144+ self ._run_with_semaphore = self ._run_direct
145+ elif isinstance (overlap , int ):
146+ if overlap <= 1 :
147+ raise ValueError ("overlap as an integer must be greater than 1." )
148+ self ._semaphore = asyncio .Semaphore (overlap )
149+ self ._run_with_semaphore = self ._semaphore_runner_factory ()
150+ else :
151+ raise TypeError ("overlap must be a bool or a positive integer." )
152+
153+ async def _run_direct (self , * args : Any , ** kwargs : Any ) -> None :
154+ """Run the coroutine directly."""
155+ await self .coro (* args , ** kwargs )
156+
157+ def _semaphore_runner_factory (self ) -> Callable [..., Awaitable [None ]]:
158+ """Return a function that runs the coroutine with a semaphore."""
159+
160+ async def runner (* args : Any , ** kwargs : Any ) -> None :
161+ async with self ._semaphore :
162+ await self .coro (* args , ** kwargs )
163+
164+ return runner
131165
132166 async def _call_loop_function (self , name : str , * args : Any , ** kwargs : Any ) -> None :
133167 coro = getattr (self , f"_{ name } " )
@@ -166,7 +200,18 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
166200 self ._last_iteration = self ._next_iteration
167201 self ._next_iteration = self ._get_next_sleep_time ()
168202 try :
169- await self .coro (* args , ** kwargs )
203+ token = _current_loop_ctx .set (self ._current_loop )
204+ if not self .overlap :
205+ await self .coro (* args , ** kwargs )
206+ else :
207+ task = asyncio .create_task (
208+ self ._run_with_semaphore (* args , ** kwargs ),
209+ name = f"pycord-loop-{ self .coro .__name__ } -{ self ._current_loop } " ,
210+ )
211+ task .add_done_callback (self ._tasks .discard )
212+ self ._tasks .add (task )
213+
214+ _current_loop_ctx .reset (token )
170215 self ._last_iteration_failed = False
171216 backoff = ExponentialBackoff ()
172217 except self ._valid_exception :
@@ -192,6 +237,9 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
192237
193238 except asyncio .CancelledError :
194239 self ._is_being_cancelled = True
240+ for task in self ._tasks :
241+ task .cancel ()
242+ await asyncio .gather (* self ._tasks , return_exceptions = True )
195243 raise
196244 except Exception as exc :
197245 self ._has_failed = True
@@ -218,6 +266,7 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]:
218266 count = self .count ,
219267 reconnect = self .reconnect ,
220268 loop = self .loop ,
269+ overlap = self .overlap ,
221270 )
222271 copy ._injected = obj
223272 copy ._before_loop = self ._before_loop
@@ -269,7 +318,11 @@ def time(self) -> list[datetime.time] | None:
269318 @property
270319 def current_loop (self ) -> int :
271320 """The current iteration of the loop."""
272- return self ._current_loop
321+ return (
322+ _current_loop_ctx .get ()
323+ if _current_loop_ctx .get () is not None
324+ else self ._current_loop
325+ )
273326
274327 @property
275328 def next_iteration (self ) -> datetime .datetime | None :
@@ -738,6 +791,7 @@ def loop(
738791 count : int | None = None ,
739792 reconnect : bool = True ,
740793 loop : asyncio .AbstractEventLoop = MISSING ,
794+ overlap : bool | int = False ,
741795) -> Callable [[LF ], Loop [LF ]]:
742796 """A decorator that schedules a task in the background for you with
743797 optional reconnect logic. The decorator returns a :class:`Loop`.
@@ -773,6 +827,11 @@ def loop(
773827 loop: :class:`asyncio.AbstractEventLoop`
774828 The loop to use to register the task, if not given
775829 defaults to :func:`asyncio.get_event_loop`.
830+ overlap: Union[:class:`bool`, :class:`int`]
831+ Controls whether overlapping executions of the task loop are allowed.
832+ Set to False (default) to run iterations one at a time, True for unlimited overlap, or an int to cap the number of concurrent runs.
833+
834+ .. versionadded:: 2.7
776835
777836 Raises
778837 ------
@@ -793,6 +852,7 @@ def decorator(func: LF) -> Loop[LF]:
793852 time = time ,
794853 reconnect = reconnect ,
795854 loop = loop ,
855+ overlap = overlap ,
796856 )
797857
798858 return decorator
0 commit comments