Skip to content

Commit 65b8b61

Browse files
Lumabotsplun1331JustaSqu1dDA-344pre-commit-ci[bot]
authored
feat(loop): add optional overlap support to allow concurrent loop executions (#2771)
Co-authored-by: plun1331 <[email protected]> Co-authored-by: JustaSqu1d <[email protected]> Co-authored-by: DA344 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 72b6e05 commit 65b8b61

File tree

3 files changed

+94
-4
lines changed

3 files changed

+94
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ These changes are available on the `master` branch, but have not yet been releas
117117
([#2714](https://github.com/Pycord-Development/pycord/pull/2714))
118118
- Added the ability to pass a `datetime.time` object to `format_dt`.
119119
([#2747](https://github.com/Pycord-Development/pycord/pull/2747))
120+
- Added the ability to pass an `overlap` parameter to the `loop` decorator and `Loop`
121+
class, allowing concurrent iterations if enabled.
122+
([#2765](https://github.com/Pycord-Development/pycord/pull/2765))
120123
- Added various missing channel parameters and allow `default_reaction_emoji` to be
121124
`None`. ([#2772](https://github.com/Pycord-Development/pycord/pull/2772))
122125
- Added support for type hinting slash command options with `typing.Annotated`.

discord/ext/tasks/__init__.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from __future__ import annotations
2727

2828
import asyncio
29+
import contextvars
2930
import datetime
3031
import inspect
3132
import sys
@@ -46,6 +47,9 @@
4647
LF = TypeVar("LF", bound=_func)
4748
FT = TypeVar("FT", bound=_func)
4849
ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]])
50+
_current_loop_ctx: contextvars.ContextVar[int] = contextvars.ContextVar(
51+
"_current_loop_ctx", default=None
52+
)
4953

5054

5155
class SleepHandle:
@@ -59,10 +63,14 @@ def __init__(
5963
relative_delta = discord.utils.compute_timedelta(dt)
6064
self.handle = loop.call_later(relative_delta, future.set_result, True)
6165

66+
def _set_result_safe(self):
67+
if not self.future.done():
68+
self.future.set_result(True)
69+
6270
def recalculate(self, dt: datetime.datetime) -> None:
6371
self.handle.cancel()
6472
relative_delta = discord.utils.compute_timedelta(dt)
65-
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
73+
self.handle = self.loop.call_later(relative_delta, self._set_result_safe)
6674

6775
def wait(self) -> asyncio.Future[Any]:
6876
return self.future
@@ -91,10 +99,12 @@ def __init__(
9199
count: int | None,
92100
reconnect: bool,
93101
loop: asyncio.AbstractEventLoop,
102+
overlap: bool | int,
94103
) -> None:
95104
self.coro: LF = coro
96105
self.reconnect: bool = reconnect
97106
self.loop: asyncio.AbstractEventLoop = loop
107+
self.overlap: bool | int = overlap
98108
self.count: int | None = count
99109
self._current_loop = 0
100110
self._handle: SleepHandle = MISSING
@@ -115,6 +125,7 @@ def __init__(
115125
self._is_being_cancelled = False
116126
self._has_failed = False
117127
self._stop_next_iteration = False
128+
self._tasks: set[asyncio.Task[Any]] = set()
118129

119130
if self.count is not None and self.count <= 0:
120131
raise ValueError("count must be greater than 0 or None.")
@@ -128,6 +139,29 @@ def __init__(
128139
raise TypeError(
129140
f"Expected coroutine function, not {type(self.coro).__name__!r}."
130141
)
142+
if isinstance(overlap, bool):
143+
if overlap:
144+
self._run_with_semaphore = self._run_direct
145+
elif isinstance(overlap, int):
146+
if overlap <= 1:
147+
raise ValueError("overlap as an integer must be greater than 1.")
148+
self._semaphore = asyncio.Semaphore(overlap)
149+
self._run_with_semaphore = self._semaphore_runner_factory()
150+
else:
151+
raise TypeError("overlap must be a bool or a positive integer.")
152+
153+
async def _run_direct(self, *args: Any, **kwargs: Any) -> None:
154+
"""Run the coroutine directly."""
155+
await self.coro(*args, **kwargs)
156+
157+
def _semaphore_runner_factory(self) -> Callable[..., Awaitable[None]]:
158+
"""Return a function that runs the coroutine with a semaphore."""
159+
160+
async def runner(*args: Any, **kwargs: Any) -> None:
161+
async with self._semaphore:
162+
await self.coro(*args, **kwargs)
163+
164+
return runner
131165

132166
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
133167
coro = getattr(self, f"_{name}")
@@ -166,7 +200,18 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
166200
self._last_iteration = self._next_iteration
167201
self._next_iteration = self._get_next_sleep_time()
168202
try:
169-
await self.coro(*args, **kwargs)
203+
token = _current_loop_ctx.set(self._current_loop)
204+
if not self.overlap:
205+
await self.coro(*args, **kwargs)
206+
else:
207+
task = asyncio.create_task(
208+
self._run_with_semaphore(*args, **kwargs),
209+
name=f"pycord-loop-{self.coro.__name__}-{self._current_loop}",
210+
)
211+
task.add_done_callback(self._tasks.discard)
212+
self._tasks.add(task)
213+
214+
_current_loop_ctx.reset(token)
170215
self._last_iteration_failed = False
171216
backoff = ExponentialBackoff()
172217
except self._valid_exception:
@@ -192,6 +237,9 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
192237

193238
except asyncio.CancelledError:
194239
self._is_being_cancelled = True
240+
for task in self._tasks:
241+
task.cancel()
242+
await asyncio.gather(*self._tasks, return_exceptions=True)
195243
raise
196244
except Exception as exc:
197245
self._has_failed = True
@@ -218,6 +266,7 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]:
218266
count=self.count,
219267
reconnect=self.reconnect,
220268
loop=self.loop,
269+
overlap=self.overlap,
221270
)
222271
copy._injected = obj
223272
copy._before_loop = self._before_loop
@@ -269,7 +318,11 @@ def time(self) -> list[datetime.time] | None:
269318
@property
270319
def current_loop(self) -> int:
271320
"""The current iteration of the loop."""
272-
return self._current_loop
321+
return (
322+
_current_loop_ctx.get()
323+
if _current_loop_ctx.get() is not None
324+
else self._current_loop
325+
)
273326

274327
@property
275328
def next_iteration(self) -> datetime.datetime | None:
@@ -738,6 +791,7 @@ def loop(
738791
count: int | None = None,
739792
reconnect: bool = True,
740793
loop: asyncio.AbstractEventLoop = MISSING,
794+
overlap: bool | int = False,
741795
) -> Callable[[LF], Loop[LF]]:
742796
"""A decorator that schedules a task in the background for you with
743797
optional reconnect logic. The decorator returns a :class:`Loop`.
@@ -773,6 +827,11 @@ def loop(
773827
loop: :class:`asyncio.AbstractEventLoop`
774828
The loop to use to register the task, if not given
775829
defaults to :func:`asyncio.get_event_loop`.
830+
overlap: Union[:class:`bool`, :class:`int`]
831+
Controls whether overlapping executions of the task loop are allowed.
832+
Set to False (default) to run iterations one at a time, True for unlimited overlap, or an int to cap the number of concurrent runs.
833+
834+
.. versionadded:: 2.7
776835
777836
Raises
778837
------
@@ -793,6 +852,7 @@ def decorator(func: LF) -> Loop[LF]:
793852
time=time,
794853
reconnect=reconnect,
795854
loop=loop,
855+
overlap=overlap,
796856
)
797857

798858
return decorator

examples/background_task.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import random
13
from datetime import time, timezone
24

35
import discord
@@ -10,7 +12,6 @@ def __init__(self, *args, **kwargs):
1012

1113
# An attribute we can access from our task
1214
self.counter = 0
13-
1415
# Start the tasks to run in the background
1516
self.my_background_task.start()
1617
self.time_task.start()
@@ -37,6 +38,32 @@ async def time_task(self):
3738
async def before_my_task(self):
3839
await self.wait_until_ready() # Wait until the bot logs in
3940

41+
# Schedule every 10s; each run takes between 5 to 20s. With overlap=2, at most 2 runs
42+
# execute concurrently so we don't build an ever-growing backlog.
43+
@tasks.loop(seconds=10, overlap=2)
44+
async def fetch_status_task(self):
45+
"""
46+
Practical overlap use-case:
47+
48+
Poll an external service and post a short summary. Each poll may take
49+
between 5 to 20s due to network latency or rate limits, but we want fresh data
50+
every 10s. Allowing a small amount of overlap avoids drifting schedules
51+
without opening the floodgates to unlimited concurrency.
52+
"""
53+
print(f"[status] start run #{self.fetch_status_task.current_loop}")
54+
55+
# Simulate slow I/O (e.g., HTTP requests, DB queries, file I/O)
56+
await asyncio.sleep(random.randint(5, 20))
57+
58+
channel = self.get_channel(1234567) # Replace with your channel ID
59+
msg = f"[status] run #{self.fetch_status_task.current_loop} complete"
60+
if channel:
61+
await channel.send(msg)
62+
else:
63+
print(msg)
64+
65+
print(f"[status] end run #{self.fetch_status_task.current_loop}")
66+
4067

4168
client = MyClient()
4269
client.run("TOKEN")

0 commit comments

Comments
 (0)