|
| 1 | +"""Discord module to publish mentor request queues. Version 2.""" |
| 2 | + |
| 3 | +import asyncio |
| 4 | +import datetime |
| 5 | +import enum |
| 6 | +import logging |
| 7 | +import random |
| 8 | +import re |
| 9 | +import statistics |
| 10 | +import time |
| 11 | +from typing import Sequence |
| 12 | + |
| 13 | +import discord |
| 14 | +import prometheus_client # type: ignore |
| 15 | +from discord.ext import commands |
| 16 | +from discord.ext import tasks |
| 17 | +from exercism_lib import exercism |
| 18 | + |
| 19 | +from cogs import base_cog |
| 20 | + |
| 21 | +logger = logging.getLogger(__name__) |
| 22 | + |
| 23 | +PROM_EXERCISM_REQUESTS = prometheus_client.Counter( |
| 24 | + "mentor_request_exercism_rpc", "Number of API calls to Exercism", ["track"] |
| 25 | +) |
| 26 | +PROM_DISCORD_REQUESTS = prometheus_client.Counter( |
| 27 | + "mentor_request_discord_rpc", "Number of API calls to Discord", ["write"] |
| 28 | +) |
| 29 | +PROM_TASK_QUEUE = prometheus_client.Gauge("mentor_requests_task_queue", "size of the task queue") |
| 30 | + |
| 31 | + |
| 32 | +class TaskType(enum.IntEnum): |
| 33 | + """Types of tasks to execute.""" |
| 34 | + TASK_QUERY_EXERCISM = enum.auto() |
| 35 | + TASK_QUERY_DISCORD = enum.auto() |
| 36 | + TASK_DISCORD_ADD = enum.auto() |
| 37 | + TASK_DISCORD_DEL = enum.auto() |
| 38 | + |
| 39 | + |
| 40 | +class RequestNotifierV2(base_cog.BaseCog): |
| 41 | + """Update Discord with Mentor Requests.""" |
| 42 | + |
| 43 | + qualified_name = "Request Notifier v2" |
| 44 | + |
| 45 | + def __init__( |
| 46 | + self, |
| 47 | + bot: commands.Bot, |
| 48 | + channel_id: int, |
| 49 | + tracks: Sequence[str] | None = None, |
| 50 | + **kwargs, |
| 51 | + ) -> None: |
| 52 | + super().__init__(bot=bot, **kwargs) |
| 53 | + self.exercism = exercism.AsyncExercism() |
| 54 | + self.channel_id = channel_id |
| 55 | + |
| 56 | + self.threads: dict[str, discord.Thread] = {} |
| 57 | + self.requests: dict[str, dict[str, str]] = {} |
| 58 | + self.messages: dict[str, dict[str, int]] = {} |
| 59 | + self.next_task_time = 0 |
| 60 | + |
| 61 | + self.lock = asyncio.Lock() |
| 62 | + self.queue: asyncio.PriorityQueue[ |
| 63 | + tuple[int, TaskType, str, str | tuple[str, ...] | None] |
| 64 | + ] = asyncio.PriorityQueue() |
| 65 | + |
| 66 | + if tracks: |
| 67 | + self.tracks = list(tracks) |
| 68 | + else: |
| 69 | + self.tracks = exercism.Exercism().all_tracks() |
| 70 | + self.tracks.sort() |
| 71 | + # Default to 10 minute polling. |
| 72 | + self.request_interval = {track: 600 for track in self.tracks} |
| 73 | + self.request_timestamps: dict[str, list[int]] = {track: [] for track in self.tracks} |
| 74 | + |
| 75 | + self.task_manager.start() # pylint: disable=E1101 |
| 76 | + |
| 77 | + async def get_thread(self, track: str) -> discord.Thread: |
| 78 | + """Return the request thread for a specific track.""" |
| 79 | + thread = self.threads.get(track) |
| 80 | + if not thread: |
| 81 | + raise LookupError(f"Failed to find track {track} in threads") |
| 82 | + |
| 83 | + # Refresh the thread object. This is helpful to update the is_archived bit. |
| 84 | + async with asyncio.timeout(10): |
| 85 | + got = await self.bot.fetch_channel(thread.id) |
| 86 | + assert isinstance(got, discord.Thread), f"Expected a Thread. {got=}" |
| 87 | + self.threads[track] = got |
| 88 | + return got |
| 89 | + |
| 90 | + @tasks.loop(seconds=5) |
| 91 | + async def task_manager(self): |
| 92 | + """Task loop.""" |
| 93 | + if self.lock.locked(): |
| 94 | + return |
| 95 | + async with self.lock: |
| 96 | + try: |
| 97 | + PROM_TASK_QUEUE.set(self.queue.qsize()) |
| 98 | + now = int(time.time()) |
| 99 | + # If the queue is empty or the next task is not yet due, return. |
| 100 | + if now < self.next_task_time or self.queue.empty(): |
| 101 | + return |
| 102 | + task = self.queue.get_nowait() |
| 103 | + task_time, task_type, track, *details = task |
| 104 | + # If the next task is not yet due, queue it and return. |
| 105 | + if now < task_time: |
| 106 | + self.queue.put_nowait(task) |
| 107 | + self.next_task_time = task_time |
| 108 | + return |
| 109 | + # Handle a task. |
| 110 | + if task_type == TaskType.TASK_QUERY_EXERCISM: |
| 111 | + try: |
| 112 | + await self.fetch_track_requests(track) |
| 113 | + finally: |
| 114 | + self.queue_query_exercism(track) |
| 115 | + elif task_type == TaskType.TASK_QUERY_DISCORD: |
| 116 | + try: |
| 117 | + await self.fetch_discord_thread(track) |
| 118 | + finally: |
| 119 | + self.queue_query_discord(track) |
| 120 | + elif task_type == TaskType.TASK_DISCORD_ADD: |
| 121 | + await self.update_discord_add(track, *details) |
| 122 | + elif task_type == TaskType.TASK_DISCORD_DEL: |
| 123 | + await self.update_discord_del(track, *details) |
| 124 | + else: |
| 125 | + logger.exception("Unknown task type, %d", task_type) |
| 126 | + |
| 127 | + except Exception: # pylint: disable=broad-exception-caught |
| 128 | + logger.exception("Unhandled exception in task manager loop.") |
| 129 | + |
| 130 | + @task_manager.before_loop |
| 131 | + async def before_task_manager(self): |
| 132 | + """Before starting the task manager, wait for ready and load Discord messages.""" |
| 133 | + logger.debug("Start before_task_manager()") |
| 134 | + await self.bot.wait_until_ready() |
| 135 | + await self.load_data() |
| 136 | + self.populate_task_queue() |
| 137 | + logger.debug("End before_task_manager()") |
| 138 | + |
| 139 | + def exercism_poll_interval(self, track: str) -> int: |
| 140 | + """Return the poll interval between getting requests for a track.""" |
| 141 | + interval = self.request_interval[track] |
| 142 | + times = self.request_timestamps[track] |
| 143 | + if len(times) < 2: |
| 144 | + self.request_interval[track] = min(int(1.5 * interval), 60 * 60 * 60) |
| 145 | + return interval |
| 146 | + times.sort() |
| 147 | + intervals = [a - b for a, b in zip(times[1:], times)] |
| 148 | + # [5 min ... avg ... 1 hour] |
| 149 | + return min(max(int(statistics.mean(intervals)), 60 * 5), 60 * 60 * 60) |
| 150 | + |
| 151 | + async def fetch_track_requests(self, track: str) -> None: |
| 152 | + """Fetch the requests for a given track. Queue tasks to update the Discord thread.""" |
| 153 | + logger.debug("Start fetch_track_requests(%s)", track) |
| 154 | + if track not in self.messages: |
| 155 | + return |
| 156 | + PROM_EXERCISM_REQUESTS.labels(track).inc() |
| 157 | + # fetch requests |
| 158 | + # update DB |
| 159 | + # update request interval data |
| 160 | + # compare to Discord thread; queue tasks to add/remove. |
| 161 | + async with asyncio.timeout(15): |
| 162 | + requests = await self.get_requests(track) |
| 163 | + logger.debug("Found %d requests for %s.", len(requests), track) |
| 164 | + |
| 165 | + add_requests = set(requests) - set(self.messages[track]) |
| 166 | + del_requests = set(self.messages[track]) - set(requests) |
| 167 | + self.requests[track] = { |
| 168 | + request_id: message |
| 169 | + for request_id, (timestamp, message) in requests.items() |
| 170 | + } |
| 171 | + self.request_timestamps[track].extend( |
| 172 | + timestamp |
| 173 | + for request_id, (timestamp, message) in requests.items() |
| 174 | + if request_id in add_requests |
| 175 | + ) |
| 176 | + self.request_timestamps[track] = sorted(self.request_timestamps[track], reverse=True)[:10] |
| 177 | + |
| 178 | + for request_id in add_requests: |
| 179 | + logger.debug("Queue TASK_DISCORD_ADD %s %s for now", track, request_id) |
| 180 | + self.queue.put_nowait((0, TaskType.TASK_DISCORD_ADD, track, request_id)) |
| 181 | + |
| 182 | + if del_requests: |
| 183 | + to_del = tuple(list(del_requests)[:10]) |
| 184 | + logger.debug("Queue TASK_DISCORD_DEL %s %s for now", track, to_del) |
| 185 | + self.queue.put_nowait((0, TaskType.TASK_DISCORD_DEL, track, to_del)) |
| 186 | + |
| 187 | + def queue_query_exercism(self, track: str) -> None: |
| 188 | + """Queue a task to query Exercism for a track.""" |
| 189 | + interval = self.exercism_poll_interval(track) |
| 190 | + task_time = int(time.time()) + interval |
| 191 | + logger.debug("Queue TASK_QUERY_EXERCISM %s in %d seconds", track, interval) |
| 192 | + self.queue.put_nowait((task_time, TaskType.TASK_QUERY_EXERCISM, track, None)) |
| 193 | + |
| 194 | + async def fetch_discord_thread(self, track: str) -> None: |
| 195 | + """Fetch a track thread from Discord to update the local cache.""" |
| 196 | + logger.debug("Start fetch_discord_thread(%s)", track) |
| 197 | + request_url_re = re.compile(r"\bhttps://exercism.org/mentoring/requests/(\w+)\b") |
| 198 | + PROM_DISCORD_REQUESTS.labels(False).inc() |
| 199 | + thread = await self.get_thread(track) |
| 200 | + messages = {} |
| 201 | + await self.unarchive(thread) |
| 202 | + async for message in thread.history(): |
| 203 | + if message.author != thread.owner: |
| 204 | + continue |
| 205 | + if message == thread.starter_message: |
| 206 | + continue |
| 207 | + match = request_url_re.search(message.content) |
| 208 | + if match is None: |
| 209 | + continue |
| 210 | + messages[str(match.group(1))] = message.id |
| 211 | + self.messages[track] = messages |
| 212 | + |
| 213 | + def queue_query_discord(self, track: str) -> None: |
| 214 | + """Queue a task to query a Discord request thread.""" |
| 215 | + interval = 60 # one minute |
| 216 | + task_time = int(time.time()) + interval |
| 217 | + logger.debug("Queue TASK_QUERY_DISCORD %s in %d seconds", track, interval) |
| 218 | + self.queue.put_nowait((task_time, TaskType.TASK_QUERY_DISCORD, track, None)) |
| 219 | + |
| 220 | + async def update_discord_add(self, track: str, request_id: str) -> None: |
| 221 | + """Add a request message to Discord.""" |
| 222 | + logger.debug("Start update_discord_add(%s, %s)", track, request_id) |
| 223 | + PROM_DISCORD_REQUESTS.labels(True).inc() |
| 224 | + thread = await self.get_thread(track) |
| 225 | + description = self.requests[track][request_id] |
| 226 | + async with asyncio.timeout(10): |
| 227 | + message = await thread.send(description, suppress_embeds=True) |
| 228 | + self.messages[track][request_id] = message.id |
| 229 | + |
| 230 | + async def update_discord_del(self, track: str, message_ids: tuple[int, ...]) -> None: |
| 231 | + """Remove a request message from Discord.""" |
| 232 | + PROM_DISCORD_REQUESTS.labels(True).inc() |
| 233 | + logger.debug("Start update_discord_del(%s, %s)", track, message_ids) |
| 234 | + thread = await self.get_thread(track) |
| 235 | + await self.unarchive(self.threads[track]) |
| 236 | + |
| 237 | + async with asyncio.timeout(15): |
| 238 | + try: |
| 239 | + await thread.delete_messages( |
| 240 | + discord.Object(message_id) for message_id in message_ids |
| 241 | + ) |
| 242 | + except discord.errors.NotFound: |
| 243 | + pass |
| 244 | + request_ids = [ |
| 245 | + request_id |
| 246 | + for request_id, message_id in self.messages[track].items() |
| 247 | + if message_id in message_ids |
| 248 | + ] |
| 249 | + for request_id in request_ids: |
| 250 | + del self.messages[track][request_id] |
| 251 | + |
| 252 | + def populate_task_queue(self): |
| 253 | + """Populate the initial task queue.""" |
| 254 | + tracks = self.tracks.copy() |
| 255 | + random.shuffle(tracks) |
| 256 | + # Spread the initial requests over 5 minutes |
| 257 | + for track, offset in zip(tracks, range(0, 5 * 60, int(5 * 60 / len(tracks)))): |
| 258 | + task_time = int(time.time()) + offset |
| 259 | + self.queue.put_nowait((task_time, TaskType.TASK_QUERY_DISCORD, track)) |
| 260 | + self.queue.put_nowait((task_time + 1, TaskType.TASK_QUERY_EXERCISM, track)) |
| 261 | + |
| 262 | + async def unarchive(self, thread: discord.Thread) -> None: |
| 263 | + """Ensure a thread is not archived.""" |
| 264 | + if not thread.archived: |
| 265 | + return |
| 266 | + async with asyncio.timeout(10): |
| 267 | + message = await thread.send("Sending a message to unarchive this thread.") |
| 268 | + async with asyncio.timeout(10): |
| 269 | + await message.delete() |
| 270 | + |
| 271 | + async def load_data(self) -> None: |
| 272 | + """Load Exercism data.""" |
| 273 | + channel = self.bot.get_channel(self.channel_id) |
| 274 | + assert isinstance(channel, discord.TextChannel), f"{channel} is not a TextChannel." |
| 275 | + |
| 276 | + self.threads = {} |
| 277 | + async for message in channel.history(): |
| 278 | + if not message.thread: |
| 279 | + continue |
| 280 | + thread = await message.fetch_thread() |
| 281 | + self.threads[thread.name.lower()] = thread |
| 282 | + |
| 283 | + for track in self.tracks: |
| 284 | + if track in self.threads: |
| 285 | + continue |
| 286 | + thread = await channel.create_thread( |
| 287 | + name=track.title(), |
| 288 | + type=discord.ChannelType.public_thread, |
| 289 | + ) |
| 290 | + self.threads[track] = thread |
| 291 | + await asyncio.sleep(2) |
| 292 | + |
| 293 | + async def get_requests(self, track_slug: str) -> dict[str, tuple[int, str]]: |
| 294 | + """Return formatted mentor requests.""" |
| 295 | + requests = {} |
| 296 | + for req in await self.exercism.mentor_requests(track_slug): |
| 297 | + # uuid = req["uuid"] |
| 298 | + track_title = req["track"]["title"] |
| 299 | + exercise_title = req["exercise"]["title"] |
| 300 | + student_handle = req["student"]["handle"] |
| 301 | + status = req["status"] |
| 302 | + url = req["url"] |
| 303 | + |
| 304 | + msg = f"{track_title.title()}: {url} => {exercise_title} " |
| 305 | + if status: |
| 306 | + msg += f"({student_handle}, {status})" |
| 307 | + else: |
| 308 | + msg += f"({student_handle})" |
| 309 | + |
| 310 | + timestamp = int(datetime.datetime.fromisoformat(req["updated_at"]).timestamp()) |
| 311 | + requests[req["uuid"]] = (timestamp, msg) |
| 312 | + return requests |
0 commit comments