diff --git a/cogs/mentor_requests.py b/cogs/mentor_requests.py old mode 100755 new mode 100644 index 78883fa..25b51f6 --- a/cogs/mentor_requests.py +++ b/cogs/mentor_requests.py @@ -1,8 +1,13 @@ -"""Discord module to publish mentor request queues.""" +"""Discord module to publish mentor request queues. Version 2.""" + import asyncio +import datetime +import enum import logging +import random import re -import sqlite3 +import statistics +import time from typing import Sequence import discord @@ -15,273 +20,266 @@ logger = logging.getLogger(__name__) -QUERY = { - "add_request": "INSERT INTO requests VALUES (:request_id, :track_slug, :message_id)", - "del_request": "DELETE FROM requests WHERE request_id = :request_id", - "get_requests": "SELECT request_id, track_slug, message_id FROM requests", - "get_theads": "SELECT track_slug, message_id FROM track_threads", - "add_thead": "INSERT INTO track_threads VALUES (:track_slug, :message_id)", -} -PROM_TRACK_COUNT = prometheus_client.Gauge("mentor_requests_tracks", "Number of tracks") -ACTIVE_REQUESTS = prometheus_client.Gauge("mentor_requests", "Number of requests") -REQUEST_QUEUED = prometheus_client.Counter( - "mentor_request_queued", "Number of requests queued", ["track"], -) -PROM_UPDATE_HIST = prometheus_client.Histogram("mentor_requests_update", "Update requests") -PROM_UPDATE_TRACK_HIST = prometheus_client.Histogram( - "mentor_requests_update_track", "Update one track", ["track"], +PROM_EXERCISM_REQUESTS = prometheus_client.Counter( + "mentor_request_exercism_rpc", "Number of API calls to Exercism", ["track"] ) -PROM_LAST_UPDATE = prometheus_client.Gauge( - "mentor_requests_last_update", "Timestamp of last update", +PROM_DISCORD_REQUESTS = prometheus_client.Counter( + "mentor_request_discord_rpc", "Number of API calls to Discord", ["write"] ) +PROM_TASK_QUEUE = prometheus_client.Gauge("mentor_requests_task_queue", "size of the task queue") + + +class TaskType(enum.IntEnum): + """Types of tasks to execute.""" + TASK_QUERY_EXERCISM = enum.auto() + TASK_QUERY_DISCORD = enum.auto() + TASK_DISCORD_ADD = enum.auto() + TASK_DISCORD_DEL = enum.auto() -class RequestNotifier(base_cog.BaseCog): +class RequestNotifierV2(base_cog.BaseCog): """Update Discord with Mentor Requests.""" - qualified_name = "Request Notifier" + qualified_name = "Request Notifier v2" def __init__( self, bot: commands.Bot, channel_id: int, - sqlite_db: str, tracks: Sequence[str] | None = None, **kwargs, ) -> None: super().__init__(bot=bot, **kwargs) - self.conn = sqlite3.Connection(sqlite_db, isolation_level=None) self.exercism = exercism.AsyncExercism() self.channel_id = channel_id + self.threads: dict[str, discord.Thread] = {} - self.requests: dict[str, tuple[str, discord.Message]] = {} + self.requests: dict[str, dict[str, str]] = {} + self.messages: dict[str, dict[str, int]] = {} + self.next_task_time = 0 + self.lock = asyncio.Lock() + self.queue: asyncio.PriorityQueue[ + tuple[int, TaskType, str, str | tuple[str, ...] | None] + ] = asyncio.PriorityQueue() if tracks: self.tracks = list(tracks) else: self.tracks = exercism.Exercism().all_tracks() self.tracks.sort() - PROM_TRACK_COUNT.set(len(self.tracks)) - PROM_LAST_UPDATE.set_to_current_time() + # Default to 10 minute polling. + self.request_interval = {track: 600 for track in self.tracks} + self.request_timestamps: dict[str, list[int]] = {track: [] for track in self.tracks} - self.synced_tracks: set[str] = set() - self.task_update_mentor_requests.start() # pylint: disable=E1101 - self.task_delete_old_messages.start() # pylint: disable=E1101 - - async def unarchive(self, thread: discord.Thread) -> None: - """Ensure a thread is not archived.""" - if not thread.archived: - return - async with asyncio.timeout(10): - message = await thread.send("Sending a message to unarchive this thread.") - async with asyncio.timeout(10): - await message.delete() + self.task_manager.start() # pylint: disable=E1101 - async def update_track_requests(self, track: str) -> dict[str, str]: - """Update mentor requests for one track, returning current requests.""" - logger.debug("Updating mentor requests for track %s", track) + async def get_thread(self, track: str) -> discord.Thread: + """Return the request thread for a specific track.""" thread = self.threads.get(track) if not thread: - logger.warning("Failed to find track %s in threads", track) - return {} - - # Refresh the thread object. - # This is helpful to update the is_archived bit. - try: - async with asyncio.timeout(10): - got = await self.bot.fetch_channel(thread.id) - except asyncio.TimeoutError: - logger.warning("fetch_channel timed out for track %s (%s).", track, thread.id) - return {} - except (discord.NotFound, discord.DiscordServerError) as e: - logger.error("fetch_channel failed %s (%s): %s", track, thread.id, e) - return {} - assert isinstance(got, discord.Thread), f"Expected a Thread. {got=}" - thread = got - self.threads[track] = thread - - try: - async with asyncio.timeout(15): - requests = await self.get_requests(track) - except Exception as e: # pylint: disable=W0718 - logger.error("get_requests(%s) failed: %s", track, e) - return {} - logger.debug("Found %d requests for %s.", len(requests), track) + raise LookupError(f"Failed to find track {track} in threads") - REQUEST_QUEUED.labels(track).inc(len(set(requests) - set(self.requests))) - for request_id, description in requests.items(): - if request_id in self.requests: - logger.debug("Request %s-%s is already being tracked.", track, request_id) - continue - logger.debug("Adding request %s in %s.", request_id, track) - self.usage_stats[track] += 1 - async with self.lock: - async with asyncio.timeout(10): - message = await thread.send(description, suppress_embeds=True) - self.requests[request_id] = (track, message) - data = { - "request_id": request_id, - "track_slug": track, - "message_id": message.id, - } - self.conn.execute(QUERY["add_request"], data) - # Update the gauge which shows the last success timestamp. - PROM_LAST_UPDATE.set_to_current_time() - return requests - - @PROM_UPDATE_HIST.time() - async def update_mentor_requests(self) -> None: - """Update threads with new/expires requests.""" - logger.debug("Start update_mentor_requests()") + # Refresh the thread object. This is helpful to update the is_archived bit. + async with asyncio.timeout(10): + got = await self.bot.fetch_channel(thread.id) + assert isinstance(got, discord.Thread), f"Expected a Thread. {got=}" + self.threads[track] = got + return got - drop: list[tuple[str, str, discord.Message]] = [] - synced_tracks: set[str] = set() - for track in self.tracks: + @tasks.loop(seconds=5) + async def task_manager(self): + """Task loop.""" + if self.lock.locked(): + return + async with self.lock: try: - async with asyncio.timeout(30): - with PROM_UPDATE_TRACK_HIST.labels(track).time(): - requests = await self.update_track_requests(track) - synced_tracks.add(track) - except asyncio.TimeoutError: - logger.warning("update_track_requests timed out for track %s.", track) - else: - expired = [ - (request_id, track, message) - for request_id, (request_track, message) in self.requests.items() - if request_track == track and request_id not in requests - ] - drop.extend(expired) - expired_fmt = "; ".join( - f"{request_id}-{message.id}" - for request_id, track, message in expired - ) - logger.debug("Expired requests for %s: %s", track, expired_fmt) - await asyncio.sleep(1) - self.synced_tracks = synced_tracks - - if len(drop) > 25: - logger.info("Found %d requests to drop. Truncating to 25.", len(drop)) - drop = drop[:25] - - if drop: - drops = "; ".join( - f"{track}-{request_id}-{message.id}" - for request_id, track, message in drop - ) - logger.debug("Dropping requests no longer in the queue: %s", drops) - - for request_id, track, message in drop: - async with self.lock: - self.conn.execute(QUERY["del_request"], {"request_id": request_id}) - del self.requests[request_id] - await asyncio.sleep(0.1) - - for request_id, track, message in drop: - assert track in self.threads, f"Could not find {track=} in threads." - await self.unarchive(self.threads[track]) - async with self.lock: - try: - async with asyncio.timeout(10): - await message.delete() - except discord.errors.NotFound: - logger.info("Message not found; dropping from DB. %s", message.jump_url) - await asyncio.sleep(0.1) - ACTIVE_REQUESTS.set(len(self.requests)) - logger.debug("End update_mentor_requests()") - - @tasks.loop(minutes=15) - async def task_update_mentor_requests(self): - """Task loop to update mentor requests.""" - try: - async with asyncio.timeout(60 * 14): - await self.update_mentor_requests() - except asyncio.TimeoutError: - logger.warning("update_mentor_requests timed out after 10 minutes.") - except Exception: # pylint: disable=broad-exception-caught - logger.exception("Unhandled exception using update_mentor_requests.") - - @task_update_mentor_requests.before_loop - async def before_update_mentor_requests(self): - """Before starting the update_mentor_requests task, wait for ready and load data.""" - await self.bot.wait_until_ready() - await self.load_data() + PROM_TASK_QUEUE.set(self.queue.qsize()) + now = int(time.time()) + # If the queue is empty or the next task is not yet due, return. + if now < self.next_task_time or self.queue.empty(): + return + task = self.queue.get_nowait() + task_time, task_type, track, *details = task + # If the next task is not yet due, queue it and return. + if now < task_time: + self.queue.put_nowait(task) + self.next_task_time = task_time + return + # Handle a task. + if task_type == TaskType.TASK_QUERY_EXERCISM: + try: + await self.fetch_track_requests(track) + finally: + self.queue_query_exercism(track) + elif task_type == TaskType.TASK_QUERY_DISCORD: + try: + await self.fetch_discord_thread(track) + finally: + self.queue_query_discord(track) + elif task_type == TaskType.TASK_DISCORD_ADD: + await self.update_discord_add(track, *details) + elif task_type == TaskType.TASK_DISCORD_DEL: + await self.update_discord_del(track, *details) + else: + logger.exception("Unknown task type, %d", task_type) - @tasks.loop(hours=1) - async def task_delete_old_messages(self): - """Task to periodically run delete_old_messages.""" - await self.delete_old_messages() + except Exception: # pylint: disable=broad-exception-caught + logger.exception("Unhandled exception in task manager loop.") - @task_delete_old_messages.before_loop - async def before_delete_old_messages(self): - """Before starting the update_mentor_requests task.""" + @task_manager.before_loop + async def before_task_manager(self): + """Before starting the task manager, wait for ready and load Discord messages.""" + logger.debug("Start before_task_manager()") await self.bot.wait_until_ready() + await self.load_data() + self.populate_task_queue() + logger.debug("End before_task_manager()") + + def exercism_poll_interval(self, track: str) -> int: + """Return the poll interval between getting requests for a track.""" + interval = self.request_interval[track] + times = self.request_timestamps[track] + if len(times) < 2: + self.request_interval[track] = min(int(1.5 * interval), 60 * 60 * 60) + return interval + times.sort() + intervals = [a - b for a, b in zip(times[1:], times)] + # [5 min ... avg ... 1 hour] + return min(max(int(statistics.mean(intervals)), 60 * 5), 60 * 60 * 60) + + async def fetch_track_requests(self, track: str) -> None: + """Fetch the requests for a given track. Queue tasks to update the Discord thread.""" + logger.debug("Start fetch_track_requests(%s)", track) + if track not in self.messages: + return + PROM_EXERCISM_REQUESTS.labels(track).inc() + # fetch requests + # update DB + # update request interval data + # compare to Discord thread; queue tasks to add/remove. + async with asyncio.timeout(15): + requests = await self.get_requests(track) + logger.debug("Found %d requests for %s.", len(requests), track) - @commands.is_owner() - @commands.dm_only() - @commands.command() - async def requests_stats(self, ctx: commands.Context) -> None: - """Command to dump stats.""" - msg = f"{len(self.requests)=}, {len(self.tracks)=}" - await ctx.reply(msg) - - async def delete_old_messages(self) -> None: - """Delete old request messages which do not have a corresponding request cached.""" - logger.debug("Start delete_old_messages()") + add_requests = set(requests) - set(self.messages[track]) + del_requests = set(self.messages[track]) - set(requests) + self.requests[track] = { + request_id: message + for request_id, (timestamp, message) in requests.items() + } + self.request_timestamps[track].extend( + timestamp + for request_id, (timestamp, message) in requests.items() + if request_id in add_requests + ) + self.request_timestamps[track] = sorted(self.request_timestamps[track], reverse=True)[:10] + + for request_id in add_requests: + logger.debug("Queue TASK_DISCORD_ADD %s %s for now", track, request_id) + self.queue.put_nowait((0, TaskType.TASK_DISCORD_ADD, track, request_id)) + + if del_requests: + to_del = tuple(list(del_requests)[:10]) + logger.debug("Queue TASK_DISCORD_DEL %s %s for now", track, to_del) + self.queue.put_nowait((0, TaskType.TASK_DISCORD_DEL, track, to_del)) + + def queue_query_exercism(self, track: str) -> None: + """Queue a task to query Exercism for a track.""" + interval = self.exercism_poll_interval(track) + task_time = int(time.time()) + interval + logger.debug("Queue TASK_QUERY_EXERCISM %s in %d seconds", track, interval) + self.queue.put_nowait((task_time, TaskType.TASK_QUERY_EXERCISM, track, None)) + + async def fetch_discord_thread(self, track: str) -> None: + """Fetch a track thread from Discord to update the local cache.""" + logger.debug("Start fetch_discord_thread(%s)", track) request_url_re = re.compile(r"\bhttps://exercism.org/mentoring/requests/(\w+)\b") - request_ids = set(self.requests.keys()) - for track_slug in self.synced_tracks: - thread = self.threads.get(track_slug) - if not thread: - logger.warning("delete_old_messages does not have a thread for %s.", track_slug) + PROM_DISCORD_REQUESTS.labels(False).inc() + thread = await self.get_thread(track) + messages = {} + await self.unarchive(thread) + async for message in thread.history(): + if message.author != thread.owner: continue - logger.debug("Deleting stale messages for track %s", track_slug) - await self.unarchive(thread) - async with self.lock: - async for message in thread.history(): - if message.author != thread.owner: - continue - if message == thread.starter_message: - continue - match = request_url_re.search(message.content) - if match is None: - continue - request_id = match.group(1) - if request_id not in request_ids or self.requests[request_id][1] != message: - logger.debug( - "Untracked request found! Deleting. %s %s", - track_slug, - request_id, - ) - await message.delete() - self.conn.execute(QUERY["del_request"], {"request_id": request_id}) - await asyncio.sleep(0.5) - await asyncio.sleep(1) - logger.debug("End delete_old_messages()") - - @commands.is_owner() - @commands.dm_only() - @commands.command() - async def requests_delete_old_messages(self, ctx: commands.Context) -> None: - """Command to trigger delete_old_messages.""" - _ = ctx - await self.delete_old_messages() + if message == thread.starter_message: + continue + match = request_url_re.search(message.content) + if match is None: + continue + messages[str(match.group(1))] = message.id + self.messages[track] = messages + + def queue_query_discord(self, track: str) -> None: + """Queue a task to query a Discord request thread.""" + interval = 60 # one minute + task_time = int(time.time()) + interval + logger.debug("Queue TASK_QUERY_DISCORD %s in %d seconds", track, interval) + self.queue.put_nowait((task_time, TaskType.TASK_QUERY_DISCORD, track, None)) + + async def update_discord_add(self, track: str, request_id: str) -> None: + """Add a request message to Discord.""" + logger.debug("Start update_discord_add(%s, %s)", track, request_id) + PROM_DISCORD_REQUESTS.labels(True).inc() + thread = await self.get_thread(track) + description = self.requests[track][request_id] + async with asyncio.timeout(10): + message = await thread.send(description, suppress_embeds=True) + self.messages[track][request_id] = message.id + + async def update_discord_del(self, track: str, message_ids: tuple[int, ...]) -> None: + """Remove a request message from Discord.""" + PROM_DISCORD_REQUESTS.labels(True).inc() + logger.debug("Start update_discord_del(%s, %s)", track, message_ids) + thread = await self.get_thread(track) + await self.unarchive(self.threads[track]) + + async with asyncio.timeout(15): + try: + await thread.delete_messages( + discord.Object(message_id) for message_id in message_ids + ) + except discord.errors.NotFound: + pass + request_ids = [ + request_id + for request_id, message_id in self.messages[track].items() + if message_id in message_ids + ] + for request_id in request_ids: + del self.messages[track][request_id] + + def populate_task_queue(self): + """Populate the initial task queue.""" + tracks = self.tracks.copy() + random.shuffle(tracks) + # Spread the initial requests over 5 minutes + for track, offset in zip(tracks, range(0, 5 * 60, int(5 * 60 / len(tracks)))): + task_time = int(time.time()) + offset + self.queue.put_nowait((task_time, TaskType.TASK_QUERY_DISCORD, track)) + self.queue.put_nowait((task_time + 1, TaskType.TASK_QUERY_EXERCISM, track)) + + async def unarchive(self, thread: discord.Thread) -> None: + """Ensure a thread is not archived.""" + if not thread.archived: + return + async with asyncio.timeout(10): + message = await thread.send("Sending a message to unarchive this thread.") + async with asyncio.timeout(10): + await message.delete() async def load_data(self) -> None: """Load Exercism data.""" - logger.debug("Starting load_data()") + channel = self.bot.get_channel(self.channel_id) + assert isinstance(channel, discord.TextChannel), f"{channel} is not a TextChannel." - cur = self.conn.execute(QUERY["get_theads"]) self.threads = {} - for track_slug, message_id in cur.fetchall(): - thread = await self.bot.fetch_channel(message_id) - if thread is None: - raise RuntimeError(f"Unable to find thread {message_id} for track {track_slug}") - assert isinstance(thread, discord.Thread), f"{thread=} is not a Thread." - self.threads[track_slug] = thread + async for message in channel.history(): + if not message.thread: + continue + thread = await message.fetch_thread() + self.threads[thread.name.lower()] = thread - channel = self.bot.get_channel(self.channel_id) - assert isinstance(channel, discord.TextChannel), f"{channel} is not a TextChannel." for track in self.tracks: if track in self.threads: continue @@ -289,51 +287,10 @@ async def load_data(self) -> None: name=track.title(), type=discord.ChannelType.public_thread, ) - self.conn.execute( - QUERY["add_thead"], - {"track_slug": track, "message_id": thread.id}, - ) self.threads[track] = thread - await asyncio.sleep(5) - - cur = self.conn.execute(QUERY["get_requests"]) - self.requests = {} - db_requests = list(cur.fetchall()) - track_slugs = {track_slug for _, track_slug, _ in db_requests} - for track_slug in track_slugs: - messages = {} - try: - async with asyncio.timeout(8): - async for message in self.threads[track_slug].history(limit=200): - messages[message.id] = message - except asyncio.TimeoutError: - logger.warning("load_data thread history(%s): TimeoutError!", track_slug) - logger.debug("Loaded %d messages from %s thread.", len(messages), track_slug) - - for request_id, request_track_slug, message_id in db_requests: - if request_track_slug != track_slug: - continue - request_message = messages.get(int(message_id)) - if request_message is None: - logger.warning( - "load_data Could not find message %s in %s; DELETE from DB.", - message_id, - track_slug, - ) - self.conn.execute(QUERY["del_request"], {"request_id": request_id}) - else: - self.requests[request_id] = (track_slug, request_message) - logger.debug("End load_data().") - - @commands.is_owner() - @commands.dm_only() - @commands.command() - async def requests_reload(self, ctx: commands.Context) -> None: - """Command to reload data.""" - _ = ctx # unused - await self.load_data() + await asyncio.sleep(2) - async def get_requests(self, track_slug: str) -> dict[str, str]: + async def get_requests(self, track_slug: str) -> dict[str, tuple[int, str]]: """Return formatted mentor requests.""" requests = {} for req in await self.exercism.mentor_requests(track_slug): @@ -350,5 +307,6 @@ async def get_requests(self, track_slug: str) -> dict[str, str]: else: msg += f"({student_handle})" - requests[req["uuid"]] = msg + timestamp = int(datetime.datetime.fromisoformat(req["updated_at"]).timestamp()) + requests[req["uuid"]] = (timestamp, msg) return requests diff --git a/cogs/mentor_requests_v2.md b/cogs/mentor_requests_v2.md new file mode 100644 index 0000000..c970012 --- /dev/null +++ b/cogs/mentor_requests_v2.md @@ -0,0 +1,38 @@ +# Mentor Requests Cog + +## Overview + +This cog polls Exercism's API for requests in the queue. +The list of requests is synced to a per-track thread in the mentor requests Discord channel. + +## Periodic Tasks + +* Poll Exercism for mentor requests. + * Each track can be handled as a separate task vs polling all tracks every time we get an updated. + This is helpful since some tracks are significantly more active than others. + We want to poll at most once every 5 minutes. + Less active tracks can be polled as little as once per hour. + * Store a copy of this in the DB and in memory. + * Expose Exercism request rates to Prometheus. + * Maintain the timestamps of the past N requests per-track to set the per-track interval. + Spitballing, maybe use avg - 1 * stddev, clamped to 5-60 minutes. +* Poll Discord to get all messages in the channel/threads. + * Since we control the messages, they shouldn't drift out of sync too often. + * Reading messages from Discord should be relatively light weight. + * Spread the reads. One track per minute. + * Store the results in the DB and in memory. +* On any state change, queue the change (add message, remove message). + +## Tasks + +* Fetch track requests from Exercism. +* Fetch Discord messages for a track. +* Send a Discord message. +* Delete a Discord message. + +## Worker + +* Use a loop task that runs every 5 seconds. +* Use an async-safe lock so only one task runs at a time. If the lock is held, return. +* Store the timestamp for the next queued task. If the timestamp is in the future, return. +* If there is any issues executing a task, leave it for the next loop.